├── NOTICE ├── img └── high-level-diagram.jpg ├── src └── synthesizrr │ ├── __init__.py │ ├── driver.py │ ├── corpus.py │ ├── data.py │ └── main.py ├── CODE_OF_CONDUCT.md ├── .gitignore ├── .github └── workflows │ ├── linting.yml │ ├── tests.yml │ └── release.yml ├── pyproject.toml ├── CONTRIBUTING.md ├── README.md ├── requirements.txt └── LICENSE /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /img/high-level-diagram.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/synthesizrr/HEAD/img/high-level-diagram.jpg -------------------------------------------------------------------------------- /src/synthesizrr/__init__.py: -------------------------------------------------------------------------------- 1 | ## Import in dependency order: 2 | import synthesizrr.data 3 | import synthesizrr.common 4 | import synthesizrr.generation 5 | import synthesizrr.metrics 6 | import synthesizrr.driver -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | *# 3 | *.pyc 4 | *.DS_Store 5 | 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | *.egg-info/ 10 | /.coverage 11 | /.coverage.* 12 | /.cache 13 | /.pytest_cache 14 | /.mypy_cache 15 | .idea/* 16 | **/.ipynb_checkpoints/* 17 | /build 18 | /doc/_apidoc/ 19 | *.swp 20 | bandit_report_code.txt 21 | bandit_report.json 22 | -------------------------------------------------------------------------------- /.github/workflows/linting.yml: -------------------------------------------------------------------------------- 1 | name: Linting 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | ruff-formatter: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v4 10 | 11 | - uses: astral-sh/ruff-action@v3 12 | with: 13 | version: '0.9.2' 14 | args: format --check 15 | src: './src' -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | 9 | jobs: 10 | pytest: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - name: Check out repo 15 | uses: actions/checkout@v3 16 | 17 | - name: Set up Python 18 | uses: actions/setup-python@v4 19 | with: 20 | python-version: "3.11" 21 | 22 | - name: Install test requirements 23 | # Install package dependencies: 24 | run: | 25 | python -m pip install --upgrade pip 26 | python -m pip install --upgrade uv 27 | python -m uv pip install --upgrade pytest 28 | python -m uv pip install -e . 29 | 30 | - name: Run tests 31 | run: | 32 | pytest tests/ 33 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: release 2 | 3 | on: 4 | release: 5 | types: 6 | - published # Published through GitHub UI: https://github.com/amazon-science/synthesizrr/releases/new 7 | 8 | jobs: 9 | pypi: 10 | runs-on: ubuntu-latest 11 | if: > 12 | ${{ github.event.workflow_run.conclusion == 'success' && 13 | github.event.workflow_run.head_branch == 'main' }} 14 | steps: 15 | - name: Checkout Repository 16 | uses: actions/checkout@v3 17 | with: 18 | fetch-depth: 0 # Required for Git versioning 19 | 20 | - name: Set up Python 21 | uses: actions/setup-python@v4 22 | with: 23 | python-version: "3.11" 24 | 25 | - name: Install Hatch & Twine 26 | run: | 27 | python -m pip install --upgrade pip 28 | python -m pip install uv 29 | python -m uv pip install hatch twine 30 | 31 | - name: Verify Version from Git Tag 32 | run: hatch version 33 | 34 | - name: Build Package 35 | run: hatch build 36 | 37 | - name: Publish to Test PyPI 38 | env: 39 | TWINE_USERNAME: "__token__" 40 | TWINE_PASSWORD: ${{ secrets.TEST_PYPI_API_TOKEN }} 41 | run: twine upload --repository testpypi dist/* 42 | 43 | - name: Publish to PyPI 44 | env: 45 | TWINE_USERNAME: "__token__" 46 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 47 | run: twine upload --repository pypi dist/* -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling", "hatch-vcs"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "synthesizrr" 7 | dynamic = ["version"] 8 | authors = [ 9 | { name = "Abhishek Divekar", email = "adivekar@utexas.edu" } 10 | ] 11 | description = "Synthesizing realistic and diverse text-datasets from augmented LLMs." 12 | readme = "README.md" 13 | requires-python = ">=3.11" 14 | classifiers = [ 15 | "Programming Language :: Python :: 3", 16 | "Operating System :: OS Independent", 17 | ] 18 | license-files = ["LICENSE"] 19 | dependencies = [ 20 | "fmcore[all]", 21 | ] 22 | 23 | [tool.hatch.version] 24 | source = "vcs" 25 | 26 | [tool.pytest.ini_options] 27 | pythonpath = ["src"] 28 | 29 | [tool.ruff] 30 | line-length = 110 31 | fix = true 32 | force-exclude = true 33 | extend-exclude = [ 34 | "__init__.py", 35 | ] 36 | 37 | [tool.ruff.lint] 38 | fixable = [ 39 | "I", # Add all rules under isort linter: https://docs.astral.sh/ruff/rules/#isort-i 40 | "W", # Add all rules under whitespace: https://docs.astral.sh/ruff/rules/#warning-w 41 | "E401", # multiple-imports-on-one-line: https://docs.astral.sh/ruff/rules/multiple-imports-on-one-line/ 42 | "E713", # not-in-test: https://docs.astral.sh/ruff/rules/not-in-test/ 43 | "E721", # type-comparison: https://docs.astral.sh/ruff/rules/type-comparison/ 44 | "E722", # bare-except: https://docs.astral.sh/ruff/rules/bare-except/ 45 | "F401", # unused-import: https://docs.astral.sh/ruff/rules/unused-import/ 46 | "F541", # f-string-missing-placeholders: https://docs.astral.sh/ruff/rules/f-string-missing-placeholders/ 47 | "F811", # redefined-while-unused: https://docs.astral.sh/ruff/rules/redefined-while-unused/ 48 | "F841", # unused-variable: https://docs.astral.sh/ruff/rules/unused-variable/ 49 | ] 50 | ignore = [ 51 | ## Ignored because it makes the code too verbose: 52 | "E731", # lambda-assignment: https://docs.astral.sh/ruff/rules/lambda-assignment/ 53 | "E741", # ambiguous-variable-name: https://docs.astral.sh/ruff/rules/ambiguous-variable-name/ 54 | 55 | ## Ignored because of bad interaction with `from typing import *` 56 | "F405", # undefined-local-with-import-star-usage: https://docs.astral.sh/ruff/rules/undefined-local-with-import-star-usage/ 57 | "F403", # undefined-local-with-import-star: https://docs.astral.sh/ruff/rules/undefined-local-with-import-star/ 58 | ] -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SynthesizRR: Generating Diverse Datasets with Retrieval Augmentation 2 | 3 | This repository contains the implementation of the paper "SynthesizRR: Generating Diverse Datasets with Retrieval Augmentation" (https://arxiv.org/abs/2405.10040) 4 | 5 | 6 | Our proposed approach is below. Refer Algorithm 1 in the paper for details: https://arxiv.org/abs/2405.10040 7 | 8 | ![SynthesizRR High Level Diagram](img/high-level-diagram.jpg) 9 | 10 | ## Installing dependencies 11 | 12 | We recommend installing required dependencies in a new Conda environment using the commands below. 13 | 14 | These commands were tested to work on `Deep Learning AMI GPU PyTorch 1.13.1 (Amazon Linux 2) 20230221` from AWS. 15 | 16 | Install dependencies: 17 | ```commandline 18 | conda create -n synthesizrr python=3.11.8 --yes 19 | conda activate synthesizrr 20 | pip install uv ## For super-fast installation 21 | 22 | uv pip install -r requirements.txt 23 | 24 | uv pip install "spacy==3.7.4" "spacy-transformers==1.3.5" 25 | uv pip install "setuptools==69.5.1" 26 | 27 | python -m spacy download en_core_web_lg 28 | python -c "import nltk; nltk.download('punkt');" 29 | ``` 30 | 31 | ## Code structure 32 | 33 | `synthesizrr/base/` contains utility functions and classes. 34 | 35 | `synthesizrr/expts/` contains code to reproduce the experiments. 36 | 37 | ## Running the code 38 | 1. Setup `DATA_DIR`: 39 | - Download the datasets into a local folder `DATA_DIR`. 40 | - Inside `synthesizrr/expt/data.py`, set the variable `DATA_DIR` (marked TODO) to the above folder. 41 | 42 | 2. Setup `CORPUS_DIR`: 43 | - Download the corpora into a folder `CORPUS_DIR`. 44 | - We recommend using S3 for this since the corpora are large. 45 | - Inside `synthesizrr/expt/corpus.py`, set the variable `CORPUS_DIR` (marked TODO) to the above folder. 46 | 47 | 3. Setup `RESULTS_DIR`: 48 | - Inside `synthesizrr/expt/common.py`, set the variable `RESULTS_DIR` (marked with TODO) to a different folder. Intermediate datasets and metrics will be saved here. 49 | - We recommend using S3 for this since the file-paths are long. 50 | 51 | 4. Start a Ray cluster: 52 | - On the Ray head node, run: `ray start --head` 53 | - On the Ray worker nodes, run `ray start --address=':6379'` 54 | - At the top of the files `data.py`, `corpus.py`, `main.py`, add the following to connect to the Ray cluster: 55 | ```commandline 56 | import synthesizrr 57 | import ray 58 | from ray.util.dask import ray_dask_get, enable_dask_on_ray, disable_dask_on_ray 59 | from pprint import pprint 60 | pprint(ray.init( 61 | address='ray://:10001', ## MODIFY THIS 62 | ignore_reinit_error=True, 63 | _temp_dir=str('/tmp/ray/'), 64 | runtime_env={"py_modules": [ 65 | synthesizrr, 66 | ]}, 67 | )) 68 | enable_dask_on_ray() 69 | pprint(ray.cluster_resources()) ## Shows you number of cpus and gpus to make sure it is setup properly. 70 | ``` 71 | 72 | 5. After modifying the code to set `DATA_DIR`, `CORPUS_DIR` and `RESULTS_DIR`, and starting the Ray cluster, run the following: 73 | - First, run `cd synthesizrr/expts/ && python3 data.py` to create the datasets. (You will need to download certain datasets to `DATA_DIR` folder beforehand). 74 | - Next, run `cd synthesizrr/expts/ && python3 corpus.py` to create the corpora (**warning**, this step needs a lot of compute! Make sure you setup the Ray cluster and use a big machine with at least a few hundred GB of RAM as the head node). 75 | - Finally, run the file `cd synthesizrr/expts/ && python3 main.py` to reproduce the experiments. 76 | 77 | 78 | ## Security 79 | 80 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. 81 | 82 | ## License 83 | 84 | This project is licensed under the Apache-2.0 License. 85 | 86 | ## Citing 87 | 88 | If you use or refer to this code in another publication, please cite it using the Bibtex below: 89 | 90 | ```bibtex 91 | @misc{divekar2024synthesizrr, 92 | title={SynthesizRR: Generating Diverse Datasets with Retrieval Augmentation}, 93 | author={Abhishek Divekar and Greg Durrett}, 94 | year={2024}, 95 | eprint={2405.10040}, 96 | archivePrefix={arXiv} 97 | } 98 | ``` 99 | 100 | ## Acknowledgements 101 | The compute infrastructure used for these experiments was financially supported by the Amazon Central Machine Learning department. 102 | 103 | The following people contributed to the design or implemented smaller components in this codebase: 104 | - [Gaurav Manchanda](https://in.linkedin.com/in/gauravmanchanda) 105 | - [Siba Rajendran](https://www.linkedin.com/in/siba-rajendran-920135156/) 106 | - [Vijit Malik](https://scholar.google.com/citations?user=noW8sb8AAAAJ&hl=en) 107 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | accelerate @ git+https://github.com/huggingface/accelerate@b7fa2fa956f40e0b6f650d5eb1764680bf3fd8f7 3 | aim==3.20.1 4 | aim-ui==3.20.1 5 | aimrecords==0.0.7 6 | aimrocks==0.4.0 7 | aiobotocore==2.13.0 8 | aiofiles==23.2.1 9 | aiohttp==3.10.2 10 | aiohttp-cors==0.7.0 11 | aioitertools==0.11.0 12 | aiorwlock==1.4.0 13 | aiosignal==1.3.1 14 | alembic==1.13.1 15 | altair==5.3.0 16 | anyio==4.4.0 17 | argon2-cffi==23.1.0 18 | argon2-cffi-bindings==21.2.0 19 | arrow==1.3.0 20 | art==6.2 21 | asttokens==2.4.1 22 | async-lru==2.0.4 23 | attrs==23.2.0 24 | babel==2.15.0 25 | base58==2.0.1 26 | beautifulsoup4==4.12.3 27 | bitsandbytes==0.43.1 28 | bleach==6.1.0 29 | blessed==1.20.0 30 | blinker==1.8.2 31 | bokeh==3.4.1 32 | boto3==1.34.106 33 | botocore==1.34.106 34 | brotli==1.1.0 35 | cachetools==5.3.3 36 | certifi==2024.7.4 37 | cffi==1.16.0 38 | charset-normalizer==3.3.2 39 | click==8.1.7 40 | cloudpickle==3.0.0 41 | cmake==3.29.3 42 | colorcet==3.1.0 43 | colorful==0.5.6 44 | comm==0.2.2 45 | contourpy==1.2.1 46 | cramjam==2.8.3 47 | cryptography==42.0.8 48 | cycler==0.12.1 49 | dask==2024.2.0 50 | datasets==2.2.1 51 | debugpy==1.8.1 52 | decorator==5.1.1 53 | deepspeed==0.14.2 54 | defusedxml==0.7.1 55 | dill==0.3.8 56 | distlib==0.3.8 57 | distributed==2024.2.0 58 | docker-pycreds==0.4.0 59 | einops==0.8.0 60 | et-xmlfile==1.1.0 61 | evaluate==0.4.2 62 | executing==2.0.1 63 | faiss-cpu==1.8.0 64 | fastapi==0.109.1 65 | fastjsonschema==2.19.1 66 | fastparquet==2024.5.0 67 | filelock==3.14.0 68 | fonttools==4.53.0 69 | fqdn==1.5.1 70 | frozenlist==1.4.1 71 | fsspec==2024.6.0 72 | gitdb==4.0.11 73 | gitpython==3.1.43 74 | google-api-core==2.19.0 75 | google-auth==2.29.0 76 | googleapis-common-protos==1.63.1 77 | gpustat==1.1.1 78 | greenlet==3.0.3 79 | grpcio==1.64.1 80 | h11==0.14.0 81 | hjson==3.1.0 82 | holoviews==1.18.3 83 | httpcore==1.0.5 84 | httptools==0.6.1 85 | httpx==0.27.0 86 | huggingface-hub @ git+https://github.com/huggingface/huggingface_hub@919ce7d0ca281574a26cfa73cf242def95ac0119 87 | hvplot==0.10.0 88 | idna==3.7 89 | imageio==2.34.1 90 | importlib-metadata==7.1.0 91 | ipykernel==6.29.4 92 | ipython==8.25.0 93 | ipywidgets==8.1.3 94 | isoduration==20.11.0 95 | jedi==0.19.1 96 | jinja2==3.1.4 97 | jmespath==1.0.1 98 | joblib==1.4.2 99 | json5==0.9.25 100 | jsonpointer==2.4 101 | jsonschema==4.22.0 102 | jsonschema-specifications==2023.12.1 103 | jupyter==1.0.0 104 | jupyter-client==8.6.2 105 | jupyter-console==6.6.3 106 | jupyter-core==5.7.2 107 | jupyter-events==0.10.0 108 | jupyter-lsp==2.2.5 109 | jupyter-server==2.14.1 110 | jupyter-server-terminals==0.5.3 111 | jupyterlab==4.2.1 112 | jupyterlab-pygments==0.3.0 113 | jupyterlab-server==2.27.2 114 | jupyterlab-widgets==3.0.11 115 | kiwisolver==1.4.5 116 | linkify-it-py==2.0.3 117 | lit==18.1.6 118 | locket==1.0.0 119 | lz4==4.3.3 120 | mako==1.3.5 121 | markdown==3.6 122 | markdown-it-py==3.0.0 123 | markupsafe==2.1.5 124 | matplotlib==3.9.0 125 | matplotlib-inline==0.1.7 126 | mauve-text==0.3.0 127 | mdit-py-plugins==0.4.1 128 | mdurl==0.1.2 129 | mistune==3.0.2 130 | mpmath==1.3.0 131 | msgpack==1.0.8 132 | multidict==6.0.5 133 | multiprocess==0.70.16 134 | nbclient==0.10.0 135 | nbconvert==7.16.4 136 | nbformat==5.10.4 137 | nest-asyncio==1.6.0 138 | networkx==3.3 139 | ninja==1.11.1.1 140 | nltk==3.8.2 141 | notebook==7.2.0 142 | notebook-shim==0.2.4 143 | numpy==1.26.4 144 | nvidia-cublas-cu11==11.10.3.66 145 | nvidia-cuda-cupti-cu11==11.7.101 146 | nvidia-cuda-nvrtc-cu11==11.7.99 147 | nvidia-cuda-runtime-cu11==11.7.99 148 | nvidia-cudnn-cu11==8.5.0.96 149 | nvidia-cufft-cu11==10.9.0.58 150 | nvidia-curand-cu11==10.2.10.91 151 | nvidia-cusolver-cu11==11.4.0.1 152 | nvidia-cusparse-cu11==11.7.4.91 153 | nvidia-ml-py==12.535.161 154 | nvidia-nccl-cu11==2.14.3 155 | nvidia-nvtx-cu11==11.7.91 156 | nvitop==1.3.2 157 | opencensus==0.11.4 158 | opencensus-context==0.1.3 159 | openpyxl==3.1.3 160 | orjson==3.10.3 161 | overrides==7.7.0 162 | packaging==24.0 163 | pandas==1.5.3 164 | pandocfilters==1.5.1 165 | panel==1.4.4 166 | param==2.1.0 167 | parso==0.8.4 168 | partd==1.4.2 169 | patsy==0.5.6 170 | pexpect==4.9.0 171 | pillow==10.3.0 172 | pip==24.0 173 | platformdirs==4.2.2 174 | plotly==5.22.0 175 | plotly-express==0.4.1 176 | prometheus-client==0.20.0 177 | prompt-toolkit==3.0.46 178 | proto-plus==1.23.0 179 | protobuf==4.25.3 180 | psutil==5.9.8 181 | ptyprocess==0.7.0 182 | pure-eval==0.2.2 183 | py-cpuinfo==9.0.0 184 | py-spy==0.3.14 185 | pyarrow==16.1.0 186 | pyarrow-hotfix==0.6 187 | pyasn1==0.6.0 188 | pyasn1-modules==0.4.0 189 | pycparser==2.22 190 | pydantic==1.10.15 191 | pydeck==0.9.1 192 | pygments==2.18.0 193 | pynvml==11.5.0 194 | pyparsing==3.1.2 195 | python-dateutil==2.9.0.post0 196 | python-dotenv==1.0.1 197 | python-json-logger==2.0.7 198 | pytz==2024.1 199 | pyviz-comms==3.0.2 200 | pyyaml==6.0.1 201 | pyzmq==26.0.3 202 | qtconsole==5.5.2 203 | qtpy==2.4.1 204 | ray==2.9.2 205 | referencing==0.35.1 206 | regex==2024.5.15 207 | requests==2.32.3 208 | responses==0.18.0 209 | restrictedpython==7.1 210 | rfc3339-validator==0.1.4 211 | rfc3986-validator==0.1.1 212 | rich==13.7.1 213 | rpds-py==0.18.1 214 | rsa==4.9 215 | s3fs==2024.6.0 216 | s3transfer==0.10.1 217 | safetensors==0.4.3 218 | scikit-learn==1.5.0 219 | scipy==1.13.1 220 | seaborn==0.13.2 221 | send2trash==1.8.3 222 | sentence-transformers==3.0.0 223 | sentencepiece==0.2.0 224 | sentry-sdk==2.8.0 225 | setproctitle==1.3.3 226 | setuptools==70.0.0 227 | six==1.16.0 228 | smart-open==7.0.4 229 | smmap==5.0.1 230 | sniffio==1.3.1 231 | sortedcontainers==2.4.0 232 | soupsieve==2.5 233 | sqlalchemy==2.0.30 234 | stack-data==0.6.3 235 | starlette==0.36.2 236 | statsmodels==0.14.2 237 | streamlit==1.37.0 238 | sympy==1.12.1 239 | tabulate==0.9.0 240 | tblib==3.0.0 241 | tenacity==8.3.0 242 | tensorboard==2.16.2 243 | tensorboard-data-server==0.7.2 244 | tensorboardx==2.6.2.2 245 | termcolor==2.4.0 246 | terminado==0.18.1 247 | threadpoolctl==3.5.0 248 | tiktoken==0.7.0 249 | tinycss2==1.3.0 250 | tokenizers==0.19.1 251 | toml==0.10.2 252 | toolz==0.12.1 253 | torch==2.0.1 254 | tornado==6.4.1 255 | tqdm==4.66.4 256 | traitlets==5.14.3 257 | transformers==4.41.1 258 | triton==2.0.0 259 | types-python-dateutil==2.9.0.20240316 260 | typing-extensions==4.12.1 261 | uc-micro-py==1.0.3 262 | uri-template==1.3.0 263 | urllib3==2.2.2 264 | uv==0.2.6 265 | uvicorn==0.30.1 266 | uvloop==0.19.0 267 | virtualenv==20.26.2 268 | wandb==0.17.0 269 | watchdog==4.0.1 270 | watchfiles==0.22.0 271 | wcwidth==0.2.13 272 | webcolors==1.13 273 | webencodings==0.5.1 274 | websocket-client==1.8.0 275 | websockets==12.0 276 | werkzeug==3.0.3 277 | wheel==0.43.0 278 | widgetsnbextension==4.0.11 279 | wrapt==1.16.0 280 | xlrd==2.0.1 281 | xlsxwriter==3.2.0 282 | xxhash==3.4.1 283 | xyzservices==2024.4.0 284 | yarl==1.9.4 285 | zict==3.0.0 286 | zipp==3.19.2 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | -------------------------------------------------------------------------------- /src/synthesizrr/driver.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | import ray 4 | from fmcore.framework import * 5 | from fmcore.framework.dl.torch import * 6 | from fmcore.util import * 7 | 8 | from synthesizrr.common import ( 9 | DEFAULT_TEMPERATURE, 10 | DEFAULT_TOP_P, 11 | LABEL_OVERALL, 12 | DatasetFilterParams, 13 | Experiment, 14 | MetricName, 15 | Student, 16 | ) 17 | from synthesizrr.generation import ( 18 | Corpus, 19 | CreateFewGenDatasets, 20 | CreateSeedSet, 21 | CreateSynthesizRRDatasets, 22 | DatasetName, 23 | EmbedCorpus, 24 | FewGen, 25 | ModelName, 26 | RetrieveFromSeedSet, 27 | Retriever, 28 | SynthesizRR, 29 | ) 30 | from synthesizrr.metrics import ( 31 | FewGenTextGenMetrics, 32 | GoldDatasetMetrics, 33 | SynthesizRRTextGenMetrics, 34 | ) 35 | 36 | 37 | def get_wf(expt: Experiment, metrics: bool) -> Chain: 38 | return { 39 | ## Gold: 40 | (Experiment.Gold, True): Chain.of( 41 | GoldDatasetMetrics, 42 | ), 43 | ## Gold workflow without metrics does not exist. 44 | ## SynthesizRR: 45 | (Experiment.SynthesizRR, False): Chain.of( 46 | EmbedCorpus, 47 | CreateSeedSet, 48 | RetrieveFromSeedSet, 49 | CreateSynthesizRRDatasets, 50 | SynthesizRR, 51 | # SynthesizRRTextGenMetrics, 52 | ), 53 | (Experiment.SynthesizRR, True): Chain.of( 54 | EmbedCorpus, 55 | CreateSeedSet, 56 | RetrieveFromSeedSet, 57 | CreateSynthesizRRDatasets, 58 | SynthesizRR, 59 | SynthesizRRTextGenMetrics, 60 | ), 61 | ## FewGen: 62 | (Experiment.FewGen, False): Chain.of( 63 | CreateSeedSet, 64 | CreateFewGenDatasets, 65 | FewGen, 66 | # FewGenTextGenMetrics, 67 | ), 68 | (Experiment.FewGen, True): Chain.of( 69 | CreateSeedSet, 70 | CreateFewGenDatasets, 71 | FewGen, 72 | FewGenTextGenMetrics, 73 | ), 74 | }[(expt, metrics)] 75 | 76 | 77 | @safe_validate_arguments 78 | def run_chain( 79 | *, 80 | expt: Experiment, 81 | results_dir: FileMetadata, 82 | notifier: Optional[Notifier], 83 | tracker: Optional[Tracker], 84 | background: bool, 85 | step_wait: confloat(ge=0.0) = 30, ## To avoid AWS creds error when running many in parallel 86 | pause: confloat(ge=0.0) = 3, 87 | dataset_name: DatasetName, 88 | model_name: Optional[ModelName] = None, 89 | num_samples_per_label: Optional[conint(ge=10)] = None, 90 | seed_type: Optional[Literal["generated", "train_set"]] = None, 91 | seed_size: Optional[conint(ge=1)] = None, 92 | seed_set_stratify_on_ground_truth: bool = True, 93 | seed_generation_params: Optional[Dict] = None, 94 | top_p: confloat(ge=0.0, le=1.0) = DEFAULT_TOP_P, 95 | temperature: confloat(ge=0.0, le=1e6) = DEFAULT_TEMPERATURE, 96 | icl_and_prompt_template: Optional[Dict[str, str]] = None, 97 | label_verbalizer: Optional[Dict[str, str]] = None, 98 | num_shots_list: Optional[List[conint(ge=0)]] = None, 99 | metrics_overall_num_samples_per_label: conint(ge=10), 100 | metrics_other_label_num_samples_per_label: Optional[conint(ge=10)] = None, 101 | corpus: Optional[Corpus] = None, 102 | retriever: Optional[Retriever] = None, 103 | icl_type: Literal["retrieved", "curated", "seed"] = "retrieved", 104 | retrieval_top_k: conint(ge=1) = 500, 105 | retr_icl_top_ks: Tuple[conint(ge=1), ...] = (1, 2), 106 | retr_icl_distance_range: Tuple[float, float] = (0.5, 0.9), 107 | retr_icl_token_range: Optional[Tuple[conint(ge=1), conint(ge=1)]] = None, 108 | synthesizrr_top_k_range: Optional[range] = None, 109 | synthesizrr_distance_range: Tuple[float, float] = (0.4, 0.9), ## (0.4, 0.9) 110 | llm_batch_size: Optional[conint(ge=1)] = None, 111 | llm_submission_batch_size: conint(ge=1) = 24, 112 | llm_tracking_batch_size: Optional[conint(ge=1)] = None, 113 | llm_num_concurrent_preds: conint(ge=1) = 6, 114 | llm_num_models: Optional[conint(ge=1)] = None, 115 | llm_resources_per_model: Optional[ 116 | Dict[Literal["cpu", "gpu"], Union[confloat(ge=0.0, lt=1.0), conint(ge=0)]] 117 | ] = None, 118 | llm_load_balancing_strategy: LoadBalancingStrategy = LoadBalancingStrategy.LEAST_USED, 119 | llm_evaluation_timeout: confloat(ge=0.0, allow_inf_nan=True) = math.inf, 120 | text_gens_parser: Optional[Callable] = None, 121 | text_gens_parser_type: Literal["default", "rejection"] = "rejection", 122 | filter_params: DatasetFilterParams = DatasetFilterParams(filter_type="none"), 123 | metrics_calc_overall: bool = True, 124 | metrics_calc_labels: bool = False, 125 | metrics_max_parallel: conint(ge=1) = 2, 126 | metrics_override_row_count: bool = False, 127 | dataset_cartography_student: Student = Student.DistilBERT, 128 | dataset_cartography_text: Literal["context", "generations"] = "generations", 129 | metrics_to_evaluate: Optional[Tuple[MetricName, ...]] = ( 130 | MetricName.RowCount, 131 | MetricName.TextLength, 132 | MetricName.EntityCount, 133 | MetricName.SelfBLEU, 134 | MetricName.Mauve, 135 | MetricName.StudentDistilBERT_AttrPromptTable13, 136 | MetricName.StudentDeBERTaV3Large, 137 | # MetricName.StudentHPOTinyBert, 138 | # MetricName.StudentHPOMiniLM, 139 | # MetricName.StudentHPODistilBERT, 140 | # MetricName.StudentHPOBERT, 141 | # MetricName.StudentHPODeBERTaV3Base, 142 | # MetricName.StudentHPODeBERTaV3Large, 143 | MetricName.LabelPreservation, 144 | MetricName.StudentDatasetCartography, 145 | MetricName.SaveFilteredDataset, 146 | ), 147 | metrics_label_distribution: Literal["balanced", "train_set"] = "train_set", 148 | dry_run: bool = False, 149 | verbosity: conint(ge=0) = 2, 150 | cart_frac: Optional[confloat(gt=0.0)] = None, # 0.83, 151 | ): 152 | if cart_frac is not None: 153 | filter_params = dict( 154 | filter_type="cartography", 155 | cartography_apply="label", 156 | cartography_confidence_frac=("top", cart_frac), 157 | ) 158 | metrics_overall_num_samples_per_label: int = int(cart_frac * metrics_overall_num_samples_per_label) 159 | 160 | if expt is Experiment.SynthesizRR: 161 | assert corpus is not None 162 | assert retriever is not None 163 | 164 | if expt is not Experiment.Gold: 165 | assert model_name is not None 166 | assert num_samples_per_label is not None 167 | assert seed_type is not None 168 | 169 | if llm_num_models is None: 170 | ray_num_gpus: int = int(ray.cluster_resources()["GPU"]) 171 | llm_num_gpus: int = model_name.llm_resources_per_model().get("gpu", 0) 172 | assert llm_num_gpus >= 0 173 | if llm_num_gpus > 0: 174 | ## LLaMa etc. 175 | llm_num_models: Optional[int] = ray_num_gpus // llm_num_gpus 176 | else: 177 | ## Claude, ChatGPT, etc. 178 | llm_num_models: Optional[int] = None 179 | llm_batch_size: int = get_default( 180 | llm_batch_size, 181 | model_name.llm_batch_size(dataset_name=dataset_name, expt=expt), 182 | ) 183 | llm_resources_per_model: Dict[ 184 | Literal["cpu", "gpu"], Union[confloat(ge=0.0, lt=1.0), conint(ge=0)] 185 | ] = get_default( 186 | llm_resources_per_model, 187 | model_name.llm_resources_per_model(), 188 | ) 189 | metrics_num_samples_per_label: Dict[str, int] = { 190 | LABEL_OVERALL: metrics_overall_num_samples_per_label, 191 | } 192 | 193 | label_verbalizer: Dict[str, str] = get_default(label_verbalizer, dataset_name.label_verbalizer()) 194 | if metrics_other_label_num_samples_per_label is not None: 195 | for label_text in label_verbalizer.keys(): 196 | metrics_num_samples_per_label[label_text] = metrics_other_label_num_samples_per_label 197 | 198 | if expt is Experiment.Gold: 199 | text_gens_parser: Optional[Callable] = None 200 | elif text_gens_parser is None: 201 | if text_gens_parser_type == "default": 202 | text_gens_parser: Callable = dataset_name.text_gens_parser() 203 | elif text_gens_parser_type == "rejection": 204 | text_gens_parser: Callable = dataset_name.text_gens_parser_rejection(expt=expt) 205 | else: 206 | raise not_impl("text_gens_parser_type", text_gens_parser_type) 207 | 208 | exn_input = dict( 209 | results_dir=results_dir, 210 | dataset_name=dataset_name, 211 | label_verbalizer=label_verbalizer, 212 | seed_type=seed_type, 213 | seed_size=get_default(seed_size, dataset_name.seed_size()), 214 | seed_set_stratify_on_ground_truth=seed_set_stratify_on_ground_truth, 215 | seed_generation_params=seed_generation_params, 216 | model_name=model_name, 217 | num_shots_list=get_default( 218 | num_shots_list, 219 | { 220 | Experiment.Gold: [None], 221 | Experiment.SynthesizRR: [0, 3], 222 | Experiment.FewGen: [0, 32], 223 | }[expt], 224 | ), 225 | num_samples_per_label=num_samples_per_label, 226 | top_p=top_p, 227 | temperature=temperature, 228 | **get_default(icl_and_prompt_template, {}), 229 | **( 230 | { 231 | Experiment.Gold: lambda: dict( 232 | dataset_cartography_student=dataset_cartography_student, 233 | ), 234 | Experiment.FewGen: lambda: dict( 235 | fewgen_max_tokens=dataset_name.max_num_tokens(), ## Max number of output tokens 236 | dataset_cartography_student=dataset_cartography_student, 237 | text_gens_parser=text_gens_parser, 238 | filter_params=filter_params, 239 | ), 240 | Experiment.SynthesizRR: lambda: dict( 241 | synthesizrr_max_tokens=dataset_name.max_num_tokens(), ## Max number of output tokens 242 | corpus=corpus, 243 | corpus_raw_text_dir=corpus.raw_text_dir(), 244 | retriever=retriever, 245 | icl_type=icl_type, 246 | retrieval_top_k=retrieval_top_k, 247 | retr_icl_top_ks=retr_icl_top_ks, 248 | retr_icl_distance_range=retr_icl_distance_range, 249 | retr_icl_token_range=get_default(retr_icl_token_range, corpus.context_token_range()), 250 | synthesizrr_top_k_range=get_default( 251 | synthesizrr_top_k_range, corpus.synthesizrr_top_k_range() 252 | ), 253 | synthesizrr_distance_range=synthesizrr_distance_range, 254 | dataset_cartography_text=dataset_cartography_text, 255 | dataset_cartography_student=dataset_cartography_student, 256 | text_gens_parser=text_gens_parser, 257 | filter_params=filter_params, 258 | ), 259 | }[expt]() 260 | ), 261 | llm_batch_size=llm_batch_size, 262 | llm_submission_batch_size=llm_submission_batch_size, 263 | llm_tracking_batch_size=llm_tracking_batch_size, 264 | llm_resources_per_model=llm_resources_per_model, 265 | llm_num_models=llm_num_models, 266 | llm_num_concurrent_preds=llm_num_concurrent_preds, 267 | llm_load_balancing_strategy=llm_load_balancing_strategy, 268 | llm_evaluation_timeout=llm_evaluation_timeout, 269 | metrics_to_evaluate=get_default(metrics_to_evaluate, []), 270 | metrics_num_samples_per_label=metrics_num_samples_per_label, 271 | metrics_label_distribution=metrics_label_distribution, 272 | metrics_calc_overall=metrics_calc_overall, 273 | metrics_calc_labels=metrics_calc_labels, 274 | metrics_max_parallel=metrics_max_parallel, 275 | metrics_override_row_count=metrics_override_row_count, 276 | ) 277 | 278 | wf: Chain = get_wf(expt, metrics=metrics_to_evaluate is not None) 279 | print(set(wf.all_step_inputs(required_only=True)) - set(exn_input.keys())) 280 | if dry_run: 281 | return exn_input 282 | else: 283 | exn = wf.run( 284 | **exn_input, 285 | notifier=notifier, 286 | tracker=tracker, 287 | verbosity=verbosity, 288 | background=background, 289 | step_wait=step_wait, 290 | ) 291 | time.sleep(pause) 292 | return exn 293 | -------------------------------------------------------------------------------- /src/synthesizrr/corpus.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import io 3 | import math 4 | from typing import * 5 | 6 | import orjson 7 | import pandas as pd 8 | from bs4 import BeautifulSoup as BS 9 | from fmcore.constants import DataLayout 10 | from fmcore.data import FileMetadata, Reader, to_sdf 11 | from fmcore.util import ( 12 | ProgressBar, 13 | StringUtil, 14 | Timer, 15 | accumulate, 16 | as_list, 17 | as_tuple, 18 | get_default, 19 | is_null, 20 | run_concurrent, 21 | run_parallel_ray, 22 | wait, 23 | whitespace_normalize, 24 | ) 25 | 26 | CORPUS_DIR: str = "" # TODO: fill this out! Recommended to use S3. 27 | 28 | 29 | def create_amazon_products(): 30 | corpus_dir: FileMetadata = FileMetadata.of( 31 | f"{CORPUS_DIR}/data/amazon-reviews/2018/meta/" 32 | ) 33 | corpus_dir.mkdir() 34 | if len(corpus_dir.list()) == 0: 35 | raise SystemError( 36 | f'Expected Amazon Products metadata to be in folder "{corpus_dir.path}". ' 37 | f"Please download the data file All_Amazon_Meta.json.gz and unzip to this directory " 38 | f"(to get this data you need to submit the form at https://nijianmo.github.io/amazon/index.html#complete-data)" 39 | ) 40 | 41 | source: str = corpus_dir.file_in_dir("All_Amazon_Meta.json") 42 | all_dfs: List[pd.DataFrame] = [] 43 | buf = [] 44 | df_pbar = ProgressBar.of(unit="file") 45 | row_pbar = ProgressBar.of(unit="rows", miniters=10_000) 46 | with io.open(source, "rb") as inp: 47 | for line in inp: 48 | buf.append(orjson.loads(line)) 49 | row_pbar.update(1) 50 | if len(buf) == 100_000: 51 | all_dfs.append(pd.DataFrame(buf)) 52 | df_pbar.update(1) 53 | buf = [] 54 | row_pbar.success() 55 | all_dfs.append(pd.DataFrame(buf)) 56 | df_pbar.update(1) 57 | # futs.append(run_concurrent( 58 | # write_df, 59 | # df=all_dfs[-1], 60 | # n=df_pbar.pbar.n, 61 | # )) 62 | buf = [] 63 | gc.collect() 64 | 65 | # fpaths = accumulate(futs, progress=dict(desc='Writing')) 66 | 67 | def _convert_row(row): 68 | if not is_null(row["category"]): 69 | row["category"] = as_tuple(row["category"]) 70 | if not is_null(row["description"]): 71 | row["description"] = as_tuple(row["description"]) 72 | if not is_null(row["also_buy"]): 73 | row["also_buy"] = as_tuple(row["also_buy"]) 74 | if not is_null(row["image"]): 75 | row["image"] = as_tuple(row["image"]) 76 | if not is_null(row["feature"]): 77 | row["feature"] = as_tuple(row["feature"]) 78 | if not is_null(row["also_view"]): 79 | row["also_view"] = as_tuple(row["also_view"]) 80 | if not is_null(row["rank"]): 81 | row["rank"] = as_tuple(row["rank"]) 82 | if is_null(row["details"]): 83 | row["details"] = {} 84 | return row 85 | 86 | def _convert_df(df_part): 87 | return ( 88 | to_sdf(df_part) 89 | .to_layout(DataLayout.LIST_OF_DICT) 90 | .apply(_convert_row, axis=1) 91 | .pandas() 92 | ) 93 | 94 | corpus_split_dir: FileMetadata = corpus_dir.subdir_in_dir( 95 | "split", return_metadata=True 96 | ) 97 | futs = [] 98 | for df_part_i, df_part in enumerate(all_dfs): 99 | df_part = _convert_df(df_part) 100 | dest: str = corpus_split_dir.file_in_dir( 101 | f"amazon-reviews-2018-meta-part-{StringUtil.pad_zeros(df_part_i)}.parquet" 102 | ) 103 | futs.append( 104 | run_concurrent( 105 | df_part.to_parquet, 106 | dest, 107 | ) 108 | ) 109 | print(df_part_i) 110 | accumulate(futs, progress=dict(desc="Writing", unit="file")) 111 | 112 | with Timer(): 113 | prods = Reader.of( 114 | "parquet", 115 | data_schema={ 116 | "asin": "object", 117 | # "also_buy": 'object', 118 | # "also_view": 'object', 119 | "title": "object", 120 | "description": "object", 121 | "brand": "object", 122 | "category": "object", 123 | "date": "object", 124 | # "details": 'object', 125 | "feature": "object", 126 | "fit": "object", 127 | # "image": 'object', 128 | "main_cat": "object", 129 | "price": "object", 130 | # "rank": 'object', 131 | # "similar_item": 'object', 132 | # "tech1": 'object', 133 | # "tech2": 'object', 134 | }, 135 | ).read( 136 | corpus_split_dir, 137 | read_as=DataLayout.PANDAS, 138 | ) 139 | prods = prods.drop_duplicates("asin").persist(wait=True) 140 | 141 | corpus_raw_text_dir: FileMetadata = corpus_dir.subdir_in_dir( 142 | "raw-text", return_metadata=True 143 | ) 144 | with Timer(): 145 | 146 | def create_product_text(row): 147 | product_text: str = "" 148 | for col in ["title", "description"]: 149 | for text in as_list(row[col]): 150 | product_text += f"{get_default(text, '')}
" 151 | product_text: str = BS(product_text).get_text(separator="\n") 152 | for i in range(1, 10): 153 | product_text: str = product_text.replace(f"{i}.", f"{i}. ") 154 | product_text: str = whitespace_normalize(product_text) 155 | return product_text 156 | 157 | def set_product_text(_prods_part): 158 | _prods_part["product_text"] = _prods_part.apply(create_product_text, axis=1) 159 | return _prods_part 160 | 161 | prods_list = [] 162 | for prods_part_i, prods_part in ProgressBar.iter( 163 | enumerate(prods.stream(stream_as=DataLayout.PANDAS, batch_size=10_000)), 164 | total=math.ceil(len(prods) / 10_000), 165 | ): 166 | prods_list.append(run_parallel_ray(set_product_text, prods_part)) 167 | prods_list: List[pd.DataFrame] = accumulate(prods_list, progress=True) 168 | 169 | futs = [] 170 | for prods_part_i, prods_part in ProgressBar.iter( 171 | enumerate(prods_list), 172 | total=len(prods_list), 173 | ): 174 | prods_part = prods_part.reset_index(drop=True) 175 | prods_part["asin"] = prods_part["asin"].astype(str) 176 | fpath: str = corpus_raw_text_dir.file_in_dir( 177 | f"amazon-products-2018-raw-text-part-{StringUtil.pad_zeros(prods_part_i)}.parquet" 178 | ) 179 | futs.append(run_concurrent(prods_part.to_parquet, fpath)) 180 | wait(futs, progress=True) 181 | print( 182 | f'Done creating Amazon Products corpus, final data is at: "{corpus_raw_text_dir.path}"' 183 | ) 184 | 185 | def amazon_products_count_num_tokens(df_path): 186 | df_part = Reader.of( 187 | "parquet", 188 | data_schema={ 189 | "asin": "index", 190 | "title": "text", 191 | "description": "text", 192 | }, 193 | ).read(df_path, raw=True) 194 | ser_part = ( 195 | df_part["title"].fillna("").astype(str) 196 | + " " 197 | + df_part["description"].apply(lambda x: " ".join(x)).astype(str) 198 | ) 199 | from transformers import AutoTokenizer 200 | 201 | tokenizer = AutoTokenizer.from_pretrained("TheBloke/Llama-2-13B-fp16") 202 | sum_is: int = sum( 203 | [ 204 | len(input_ids) 205 | for input_ids in tokenizer( 206 | ser_part.tolist(), add_special_tokens=False 207 | )["input_ids"] 208 | ] 209 | ) 210 | return sum_is 211 | 212 | counts: List[int] = accumulate( 213 | [ 214 | run_parallel_ray( 215 | amazon_products_count_num_tokens, 216 | df_path=df_path, 217 | ) 218 | for df_path in FileMetadata.of( 219 | corpus_raw_text_dir.path, 220 | file_glob="*.parquet", 221 | ).list() 222 | ], 223 | progress=True, 224 | ) 225 | print( 226 | f"Amazon Products corpus has {round(sum(counts) / 1e9, 2)} billion tokens" 227 | ) 228 | 229 | 230 | REALNEWS_REGIONAL_NEWS_DOMAINS: List[str] = [ 231 | "mid-day.com", 232 | "financialexpress.com", 233 | "thenationonlineng.net", 234 | "livemint.com", 235 | "hindustantimes.com", 236 | "vanguardngr.com", 237 | "capitalfm.co.ke", 238 | "straitstimes.com", 239 | "indianexpress.com", 240 | "nation.com.pk", 241 | "jamaica-gleaner.com", 242 | "trend.az", 243 | "stabroeknews.com", 244 | "dawn.com", 245 | "emirates247.com", 246 | "mangalorean.com", 247 | "vccircle.com", 248 | "thisdaylive.com", 249 | "gulfnews.com", 250 | "tribune.com.pk", 251 | "arabnews.com", 252 | "pakobserver.net", 253 | "nation.co.ke", 254 | "eurasiareview.com", 255 | "thedailystar.net", 256 | "deccanchronicle.com", 257 | "jewishpress.com", 258 | "app.com.pk", 259 | "err.ee", 260 | "lankabusinessonline.com", 261 | "koreatimes.co.kr", 262 | "newera.com.na", 263 | "ticotimes.net", 264 | "codewit.com", 265 | "sunnewsonline.com", 266 | "afaqs.com", 267 | "ameinfo.com", 268 | "malaysiakini.com", 269 | "ynetnews.com", 270 | "palestinechronicle.com", 271 | "zmescience.com", 272 | "cyprus-mail.com", 273 | "colombiareports.com", 274 | "arabtimesonline.com", 275 | "bollywoodhungama.com", 276 | "pattayamail.com", 277 | "insightcrime.org", 278 | "medianewsline.com", 279 | "dailytimes.com.pk", 280 | "chinadigitaltimes.net", 281 | "saudigazette.com.sa", 282 | "newsday.co.zw", 283 | "sunstar.com.ph", 284 | "nehandaradio.com", 285 | "freemalaysiatoday.com", 286 | "onlanka.com", 287 | "thezimbabwemail.com", 288 | "theeastafrican.co.ke", 289 | "thecitizen.co.tz", 290 | "lusakatimes.com", 291 | "orissadiary.com", 292 | "aljazeera.com", 293 | "tehrantimes.com", 294 | "theborneopost.com", 295 | "morungexpress.com", 296 | "monitor.co.ug", 297 | "countercurrents.org", 298 | "businessworld.in", 299 | "governancenow.com", 300 | "itweb.co.za", 301 | "972mag.com", 302 | "memeburn.com", 303 | "themediaonline.co.za", 304 | "koimoi.com", 305 | "caribbean360.com", 306 | "yalibnan.com", 307 | "milligazette.com", 308 | "thefrontierpost.com", 309 | "kuwaittimes.net", 310 | "somalilandpress.com", 311 | "thestkittsnevisobserver.com", 312 | "news24.com", 313 | "livinginperu.com", 314 | "journal.com.ph", 315 | "bworldonline.com", 316 | "venezuelanalysis.com", 317 | "businessdayonline.com", 318 | "macaudailytimes.com.mo", 319 | "ghanabusinessnews.com", 320 | "trinidadexpress.com", 321 | "pmnewsnigeria.com", 322 | "lankanewspapers.com", 323 | "asiasentinel.com", 324 | "maravipost.com", 325 | "dayafterindia.com", 326 | "defense-update.com", 327 | "antiguaobserver.com", 328 | "newsbytes.ph", 329 | "truthdive.com", 330 | "thehimalayantimes.com", 331 | "standardmedia.co.ke", 332 | "groundviews.org", 333 | "japantoday.com", 334 | "kbc.co.ke", 335 | "mindanews.com", 336 | "thejakartaglobe.com", 337 | "actionforex.com", 338 | "modernghana.com", 339 | "newstodaynet.com", 340 | "centralchronicle.com", 341 | "dalje.com", 342 | "escambray.cu", 343 | "middle-east-online.com", 344 | "theminaretonline.com", 345 | "pakistankakhudahafiz.com", 346 | "meed.com", 347 | "tribwekchron.com", 348 | "thenews.com.pk", 349 | "iafrica.com", 350 | "philstar.com", 351 | "praguepost.com", 352 | "yonhapnews.co.kr", 353 | "china.org.cn", 354 | "rtn.asia", 355 | "nationalturk.com", 356 | "thebraziltimes.com", 357 | "businessdailyafrica.com", 358 | "hku.hk", 359 | "intifada-palestine.com", 360 | "realbollywood.com", 361 | "pak1stanfirst.com", 362 | "mutiny.in", 363 | "mareeg.com", 364 | "paltelegraph.com", 365 | "pakwatan.com", 366 | "mybroadband.co.za", 367 | "african-bulletin.com", 368 | "thedailynewsegypt.com", 369 | "7days.ae", 370 | "dailyforex.com", 371 | "melodika.net", 372 | ] 373 | 374 | REALNEWS_INDIAN_NEWS_DOMAINS: List[str] = [ 375 | "mid-day.com", 376 | "financialexpress.com", 377 | "livemint.com", 378 | "hindustantimes.com", 379 | "indianexpress.com", 380 | "mangalorean.com", 381 | "vccircle.com", 382 | "deccanchronicle.com", 383 | "afaqs.com", 384 | "bollywoodhungama.com", 385 | "medianewsline.com", 386 | "orissadiary.com", 387 | "morungexpress.com", 388 | "countercurrents.org", 389 | "businessworld.in", 390 | "governancenow.com", 391 | "koimoi.com", 392 | "milligazette.com", 393 | "dayafterindia.com", 394 | "truthdive.com", 395 | "newstodaynet.com", 396 | "centralchronicle.com", 397 | "dalje.com", 398 | "rtn.asia", 399 | "realbollywood.com", 400 | "mutiny.in", 401 | ] 402 | 403 | 404 | def count_num_tokens(ser_part): 405 | from transformers import AutoTokenizer 406 | 407 | tokenizer = AutoTokenizer.from_pretrained("TheBloke/Llama-2-13B-fp16") 408 | sum_is: int = sum( 409 | [ 410 | len(input_ids) 411 | for input_ids in tokenizer(ser_part.tolist(), add_special_tokens=False)[ 412 | "input_ids" 413 | ] 414 | ] 415 | ) 416 | return sum_is 417 | 418 | 419 | def write_realnews_partition(rn_data, *, corpus_dir: FileMetadata, rn_name: str): 420 | with Timer(f"Writing partition {rn_name}"): 421 | corpus_partition_dir: FileMetadata = corpus_dir.subdir_in_dir( 422 | rn_name, return_metadata=True 423 | ) 424 | print(f"{rn_name} length: {len(rn_data)}") 425 | futs = [] 426 | for rn_part_i, rn_part in ProgressBar.iter( 427 | enumerate(rn_data.stream(stream_as=DataLayout.PANDAS, batch_size=100_000)), 428 | total=math.ceil(len(rn_data) / 100_000), 429 | ): 430 | rn_part = rn_part.reset_index(drop=True) 431 | fpath: str = corpus_partition_dir.file_in_dir( 432 | f"realnews-{rn_name}-part-{StringUtil.pad_zeros(rn_part_i)}.parquet" 433 | ) 434 | futs.append(run_concurrent(rn_part.to_parquet, fpath)) 435 | wait(futs, progress=True) 436 | print( 437 | f'Done creating {rn_name} partition, final data is at: "{corpus_partition_dir.path}"' 438 | ) 439 | 440 | counts: pd.Series = rn_data["text"].map_partitions(count_num_tokens).compute() 441 | print(f"{rn_name} corpus has {round(counts.sum() / 1e9, 2)} billion tokens") 442 | 443 | 444 | def create_realnews(): 445 | corpus_dir: FileMetadata = FileMetadata.of(f"{CORPUS_DIR}/data/realnews/") 446 | corpus_dir.mkdir() 447 | if len(corpus_dir.list()) == 0: 448 | raise SystemError( 449 | f'Expected RealNews to be in folder "{corpus_dir.path}". ' 450 | f"Please download the data file realnews.jsonl to this directory " 451 | f"(to get this data you need to submit the form at https://github.com/rowanz/grover/tree/master/realnews)" 452 | ) 453 | 454 | with Timer("Reading and splitting realnews.jsonl"): 455 | source: str = corpus_dir.file_in_dir("realnews.jsonl") 456 | all_dfs: List[pd.DataFrame] = [] 457 | buf = [] 458 | df_pbar = ProgressBar.of(unit="file") 459 | row_pbar = ProgressBar.of(unit="rows", miniters=10_000) 460 | row_idx = 0 461 | with io.open(source, "rb") as inp: 462 | for line in inp: 463 | buf.append(orjson.loads(line)) 464 | buf[-1]["idx"] = row_idx 465 | row_idx += 1 466 | row_pbar.update(1) 467 | if len(buf) == 100_000: 468 | all_dfs.append(pd.DataFrame(buf)) 469 | df_pbar.update(1) 470 | buf = [] 471 | row_pbar.success() 472 | all_dfs.append(pd.DataFrame(buf)) 473 | df_pbar.update(1) 474 | buf = [] 475 | gc.collect() 476 | 477 | corpus_split_dir: FileMetadata = corpus_dir.subdir_in_dir( 478 | "split", return_metadata=True 479 | ) 480 | futs = [] 481 | for df_part_i, df_part in enumerate(all_dfs): 482 | dest: str = corpus_split_dir.file_in_dir( 483 | f"realnews-part-{StringUtil.pad_zeros(df_part_i)}.parquet" 484 | ) 485 | futs.append( 486 | run_concurrent( 487 | df_part.to_parquet, 488 | dest, 489 | ) 490 | ) 491 | print(df_part_i) 492 | accumulate(futs, progress=dict(desc="Writing", unit="file")) 493 | 494 | with Timer("Reading split files"): 495 | realnews_data_schema: Dict = { 496 | "idx": "index", 497 | "title": "object", 498 | "text": "text", 499 | "summary": "object", 500 | "authors": "categorical", 501 | "publish_date": "object", 502 | "status": "categorical", 503 | "url": "categorical", 504 | "domain": "categorical", 505 | "warc_date": "object", 506 | "split": "categorical", 507 | } 508 | realnews = Reader.of( 509 | "parquet", 510 | data_schema=realnews_data_schema, 511 | ).read( 512 | FileMetadata.of( 513 | corpus_split_dir.path, 514 | file_format="parquet", 515 | ), 516 | read_as=DataLayout.DASK, 517 | ) 518 | 519 | realnews_india = realnews.query( 520 | f"domain in {REALNEWS_INDIAN_NEWS_DOMAINS}" 521 | ).persist(wait=True) 522 | write_realnews_partition( 523 | realnews_india, corpus_dir=corpus_dir, rn_name="realnews-india" 524 | ) 525 | 526 | realnews_regional = realnews.query( 527 | f"domain in {REALNEWS_REGIONAL_NEWS_DOMAINS}" 528 | ).persist(wait=True) 529 | write_realnews_partition( 530 | realnews_regional, corpus_dir=corpus_dir, rn_name="realnews-regional" 531 | ) 532 | 533 | realnews_dominant = realnews.query( 534 | f"domain not in {REALNEWS_REGIONAL_NEWS_DOMAINS}" 535 | ).persist(wait=True) 536 | write_realnews_partition( 537 | realnews_dominant, corpus_dir=corpus_dir, rn_name="realnews-dominant" 538 | ) 539 | 540 | 541 | def create_cmu_movies(): 542 | corpus_dir: FileMetadata = FileMetadata.of(f"{CORPUS_DIR}/data/cmu_movies/") 543 | corpus_dir.mkdir() 544 | if len(corpus_dir.list()) == 0: 545 | raise SystemError( 546 | f'Expected CMU Movies to be in folder "{corpus_dir.path}". ' 547 | f"Please download the data from https://www.cs.cmu.edu/~ark/personas/ and extract it. " 548 | f'You should get the folder "MovieSummaries".' 549 | ) 550 | with Timer("Reading and merging plot_summaries.txt and movie.metadata.tsv"): 551 | movie_plots: pd.DataFrame = pd.read_csv( 552 | corpus_dir.subdir_in_dir( 553 | "MovieSummaries", return_metadata=True 554 | ).file_in_dir("plot_summaries.txt"), 555 | sep="\t", 556 | header=None, 557 | names=[ 558 | "wiki_movie_id", 559 | "plot_summary", 560 | ], 561 | ) 562 | movie_meta: pd.DataFrame = pd.read_csv( 563 | corpus_dir.subdir_in_dir( 564 | "MovieSummaries", return_metadata=True 565 | ).file_in_dir("movie.metadata.tsv"), 566 | sep="\t", 567 | header=None, 568 | names=[ 569 | "wiki_movie_id", 570 | "freebase_movie_id", 571 | "title", 572 | "release_date", 573 | "box_office_revenue", 574 | "runtime", 575 | "languages", 576 | "countries", 577 | "genres", 578 | ], 579 | ) 580 | movies: pd.DataFrame = ( 581 | movie_meta.merge(movie_plots, on="wiki_movie_id") 582 | .reset_index(drop=True) 583 | .rename(columns=dict(plot_summary="text", wiki_movie_id="idx")) 584 | ) 585 | corpus_raw_text_dir: FileMetadata = corpus_dir.subdir_in_dir( 586 | "raw-text", return_metadata=True 587 | ) 588 | movies.to_parquet(corpus_raw_text_dir.file_in_dir("cmu-movie-summary.parquet")) 589 | print( 590 | f'Done creating CMU Moveis corpus, final data is at: "{corpus_raw_text_dir.path}"' 591 | ) 592 | 593 | def cmu_movies_count_num_tokens(df_path): 594 | df_part = Reader.of( 595 | "parquet", 596 | data_schema={ 597 | "wiki_movie_id": "index", 598 | "plot_summary": "text", 599 | }, 600 | ).read(df_path, raw=True) 601 | ser_part = df_part["plot_summary"] 602 | from transformers import AutoTokenizer 603 | 604 | tokenizer = AutoTokenizer.from_pretrained("TheBloke/Llama-2-13B-fp16") 605 | sum_is: int = sum( 606 | [ 607 | len(input_ids) 608 | for input_ids in tokenizer( 609 | ser_part.tolist(), add_special_tokens=False 610 | )["input_ids"] 611 | ] 612 | ) 613 | return sum_is 614 | 615 | counts: List[int] = accumulate( 616 | [ 617 | run_parallel_ray( 618 | cmu_movies_count_num_tokens, 619 | df_path=df_path, 620 | ) 621 | for df_path in FileMetadata.of( 622 | corpus_raw_text_dir.path, 623 | file_glob="*.parquet", 624 | ).list() 625 | ], 626 | progress=True, 627 | ) 628 | print(f"CMU Movies corpus has {round(sum(counts) / 1e6, 2)} million tokens") 629 | 630 | 631 | if __name__ == "__main__": 632 | create_amazon_products() 633 | gc.collect() 634 | create_realnews() 635 | gc.collect() 636 | create_cmu_movies() 637 | gc.collect() 638 | -------------------------------------------------------------------------------- /src/synthesizrr/data.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | import pandas as pd 4 | from datasets import load_dataset as hf_load_dataset 5 | from fmcore.data.writer import Writer 6 | from fmcore.framework.dl.torch import * 7 | from fmcore.framework.task_data import Dataset, Datasets, DataSplit 8 | from sklearn.model_selection import train_test_split 9 | 10 | DATA_DIR: str = "" # TODO: fill this out! 11 | DATA_DIR: str = "" # TODO: fill this out! 12 | 13 | DATA_DIR: str = "" # TODO: fill this out! 14 | 15 | 16 | class SynthesizRRDataset(Registry): 17 | name: ClassVar[str] 18 | task: ClassVar[Task] 19 | data_schema: ClassVar[Dict[str, MLType]] 20 | 21 | @classmethod 22 | def _registry_keys(cls) -> Optional[Union[List[Any], Any]]: 23 | return cls.name 24 | 25 | @classmethod 26 | def get(cls, name: str) -> Type["SynthesizRRDataset"]: 27 | return cls.get_subclass(name) 28 | 29 | @classproperty 30 | def dataset_dir(cls) -> FileMetadata: 31 | dataset_dir: FileMetadata = FileMetadata.of(f"{DATA_DIR}/{cls.name}/") 32 | dataset_dir.mkdir() 33 | return dataset_dir 34 | 35 | @classproperty 36 | def train_path(cls) -> str: 37 | return cls.dataset_dir.file_in_dir(f"{cls.name}_train.parquet") 38 | 39 | @classproperty 40 | def validation_path(cls) -> str: 41 | return cls.dataset_dir.file_in_dir(f"{cls.name}_validation.parquet") 42 | 43 | @classproperty 44 | def test_path(cls) -> str: 45 | return cls.dataset_dir.file_in_dir(f"{cls.name}_test.parquet") 46 | 47 | @classproperty 48 | def unsupervised_path(cls) -> str: 49 | return cls.dataset_dir.file_in_dir(f"{cls.name}_unsupervised.parquet") 50 | 51 | @classproperty 52 | def schema_path(cls) -> str: 53 | return cls.dataset_dir.file_in_dir(f"{cls.name}_schema.json") 54 | 55 | @classproperty 56 | def unsupervised_schema_path(cls) -> str: 57 | return cls.dataset_dir.file_in_dir(f"{cls.name}_unsupervised_schema.json") 58 | 59 | @classproperty 60 | def supervised_data_schema(cls) -> MLTypeSchema: 61 | return copy.deepcopy(MLType.convert_values(cls.data_schema)) 62 | 63 | @classproperty 64 | def unsupervised_data_schema(cls) -> MLTypeSchema: 65 | data_schema: Dict[str, MLType] = copy.deepcopy( 66 | MLType.convert_values(cls.data_schema) 67 | ) 68 | data_schema: MLTypeSchema = { 69 | col: mltype 70 | for col, mltype in data_schema.items() 71 | if mltype not in GROUND_TRUTH_ML_TYPES 72 | } 73 | return data_schema 74 | 75 | @classmethod 76 | def raw_train(cls) -> Optional[pd.DataFrame]: 77 | return None 78 | 79 | @classmethod 80 | def has_train(cls) -> bool: 81 | return "raw_train" in cls.__dict__ 82 | 83 | @classmethod 84 | def raw_validation(cls) -> Optional[pd.DataFrame]: 85 | return None 86 | 87 | @classmethod 88 | def has_validation(cls) -> bool: 89 | return "raw_validation" in cls.__dict__ 90 | 91 | @classmethod 92 | def raw_test(cls) -> Optional[pd.DataFrame]: 93 | return None 94 | 95 | @classmethod 96 | def has_test(cls) -> bool: 97 | return "raw_test" in cls.__dict__ 98 | 99 | @classmethod 100 | def raw_unsupervised(cls) -> Optional[pd.DataFrame]: 101 | return None 102 | 103 | @classmethod 104 | def has_unsupervised(cls) -> bool: 105 | return "raw_unsupervised" in cls.__dict__ 106 | 107 | @classmethod 108 | def decode_labels(cls, row) -> pd.Series: 109 | raise NotImplementedError() 110 | 111 | @classmethod 112 | def has_decode_labels(cls) -> bool: 113 | return "decode_labels" in cls.__dict__ 114 | 115 | @classmethod 116 | def setup(cls): 117 | if cls.has_train(): 118 | Writer.of("json").write( 119 | data=cls.supervised_data_schema, 120 | destination=cls.schema_path, 121 | overwrite=True, 122 | ) 123 | idx_col: str = only_item( 124 | [ 125 | col 126 | for col, mltype in cls.supervised_data_schema.items() 127 | if mltype is MLType.INDEX 128 | ] 129 | ) 130 | train: pd.DataFrame = cls.raw_train() 131 | if cls.has_decode_labels(): 132 | train: pd.DataFrame = train.apply(cls.decode_labels, axis=1) 133 | if len(train) != train[idx_col].nunique(): 134 | raise ValueError( 135 | f'Expected unique index column "{idx_col}" for split="train".' 136 | ) 137 | train.reset_index(drop=True).to_parquet(cls.train_path) 138 | if cls.has_validation(): 139 | idx_col: str = only_item( 140 | [ 141 | col 142 | for col, mltype in cls.supervised_data_schema.items() 143 | if mltype is MLType.INDEX 144 | ] 145 | ) 146 | validation: pd.DataFrame = cls.raw_validation() 147 | if cls.has_decode_labels(): 148 | validation: pd.DataFrame = validation.apply(cls.decode_labels, axis=1) 149 | if len(validation) != validation[idx_col].nunique(): 150 | raise ValueError( 151 | f'Expected unique index column "{idx_col}" for split="validation".' 152 | ) 153 | validation.reset_index(drop=True).to_parquet(cls.validation_path) 154 | if cls.has_test(): 155 | idx_col: str = only_item( 156 | [ 157 | col 158 | for col, mltype in cls.supervised_data_schema.items() 159 | if mltype is MLType.INDEX 160 | ] 161 | ) 162 | test: pd.DataFrame = cls.raw_test() 163 | if cls.has_decode_labels(): 164 | test: pd.DataFrame = test.apply(cls.decode_labels, axis=1) 165 | if len(test) != test[idx_col].nunique(): 166 | raise ValueError( 167 | f'Expected unique index column "{idx_col}" for split="test".' 168 | ) 169 | test.reset_index(drop=True).to_parquet(cls.test_path) 170 | if cls.has_unsupervised(): 171 | idx_col: str = only_item( 172 | [ 173 | col 174 | for col, mltype in cls.unsupervised_data_schema.items() 175 | if mltype is MLType.INDEX 176 | ] 177 | ) 178 | Writer.of("json").write( 179 | data=cls.unsupervised_data_schema, 180 | destination=cls.unsupervised_schema_path, 181 | overwrite=True, 182 | ) 183 | unsupervised: pd.DataFrame = cls.raw_unsupervised() 184 | if len(unsupervised) != unsupervised[idx_col].nunique(): 185 | raise ValueError( 186 | f'Expected unique index column "{idx_col}" for split="unsupervised".' 187 | ) 188 | unsupervised.reset_index(drop=True).to_parquet(cls.unsupervised_path) 189 | 190 | @classproperty 191 | def train(cls) -> Optional[Dataset]: 192 | return cls.datasets.train 193 | 194 | @classmethod 195 | @safe_validate_arguments 196 | def create_seed_set( 197 | cls, 198 | seed_size: int, 199 | *, 200 | data_split: DataSplit = DataSplit.TRAIN, 201 | random_state: int = 42, 202 | stratify_on_ground_truth: bool = False, 203 | ) -> Dataset: 204 | dataset: Dataset = cls.datasets[data_split].read(read_as=DataLayout.PANDAS) 205 | dataset_df: pd.DataFrame = dataset.data.pandas() 206 | gt_col: str = only_key(dataset.data_schema.ground_truths_schema) 207 | if stratify_on_ground_truth: 208 | _, seed_dataset_df = train_test_split( 209 | dataset_df, 210 | test_size=seed_size, 211 | random_state=random_state, 212 | stratify=dataset_df[gt_col], 213 | ) 214 | else: 215 | _, seed_dataset_df = train_test_split( 216 | dataset_df, 217 | test_size=seed_size, 218 | random_state=random_state, 219 | ) 220 | return dataset.update_params(data=seed_dataset_df) 221 | 222 | @classproperty 223 | def validation(cls) -> Optional[Dataset]: 224 | return cls.datasets.validation 225 | 226 | @classproperty 227 | def test(cls) -> Optional[Dataset]: 228 | return cls.datasets.test 229 | 230 | @classproperty 231 | def unsupervised(cls) -> Optional[Dataset]: 232 | return cls.datasets.unsupervised 233 | 234 | @classproperty 235 | def datasets(cls) -> Datasets: 236 | datasets: Dict[str, Dataset] = {} 237 | if cls.has_train(): 238 | datasets[DataSplit.TRAIN] = Dataset.of( 239 | data_split=DataSplit.TRAIN, 240 | task=cls.task, 241 | data=FileMetadata.of(cls.train_path), 242 | data_schema=cls.supervised_data_schema, 243 | ) 244 | if cls.has_validation(): 245 | datasets[DataSplit.VALIDATION] = Dataset.of( 246 | data_split=DataSplit.VALIDATION, 247 | task=cls.task, 248 | data=FileMetadata.of(cls.validation_path), 249 | data_schema=cls.supervised_data_schema, 250 | ) 251 | if cls.has_test(): 252 | datasets[DataSplit.TEST] = Dataset.of( 253 | data_split=DataSplit.TEST, 254 | task=cls.task, 255 | data=FileMetadata.of(cls.test_path), 256 | data_schema=cls.supervised_data_schema, 257 | ) 258 | if cls.has_unsupervised(): 259 | datasets[DataSplit.UNSUPERVISED] = Dataset.of( 260 | data_split=DataSplit.UNSUPERVISED, 261 | task=cls.task, 262 | data=FileMetadata.of(cls.unsupervised_path), 263 | data_schema=cls.unsupervised_data_schema, 264 | ) 265 | return Datasets.of(**datasets) 266 | 267 | @classmethod 268 | def setup_datasets(cls): 269 | for SynthesizRRDatasetSubclass in cls.subclasses(): 270 | assert issubclass(SynthesizRRDatasetSubclass, SynthesizRRDataset) 271 | SynthesizRRDatasetSubclass.setup() 272 | 273 | @classproperty 274 | def label_verbalizer(cls) -> Optional[Dict[str, str]]: 275 | return None 276 | 277 | 278 | class HyperpartisanNewsDataset(SynthesizRRDataset): 279 | name = "hyperpartisan_news" 280 | task = Task.BINARY_CLASSIFICATION 281 | data_schema = dict( 282 | id=MLType.INDEX, 283 | text=MLType.TEXT, 284 | label_text=MLType.GROUND_TRUTH, 285 | ) 286 | 287 | @classmethod 288 | def raw_train(cls) -> Optional[pd.DataFrame]: 289 | return ( 290 | hf_load_dataset("zapsdcn/hyperpartisan_news", split="train") 291 | .to_pandas() 292 | .rename(columns={"label": "label_text"}) 293 | ) 294 | 295 | @classmethod 296 | def raw_validation(cls) -> Optional[pd.DataFrame]: 297 | return ( 298 | hf_load_dataset("zapsdcn/hyperpartisan_news", split="validation") 299 | .to_pandas() 300 | .rename(columns={"label": "label_text"}) 301 | ) 302 | 303 | @classmethod 304 | def raw_test(cls) -> Optional[pd.DataFrame]: 305 | return ( 306 | hf_load_dataset("zapsdcn/hyperpartisan_news", split="test") 307 | .to_pandas() 308 | .rename(columns={"label": "label_text"}) 309 | ) 310 | 311 | @classproperty 312 | def label_verbalizer(cls) -> Optional[Dict[str, str]]: 313 | return { 314 | "true": "using harsh political language, using a mocking tone and toxic commentary", 315 | "false": "using neutral language, using a reasonable tone and politically correct commentary", 316 | } 317 | 318 | 319 | class AGNewsDataset(SynthesizRRDataset): 320 | name = "ag_news" 321 | task = Task.MULTI_CLASS_CLASSIFICATION 322 | data_schema = dict( 323 | id=MLType.INDEX, 324 | text=MLType.TEXT, 325 | # headline=MLType.TEXT, 326 | label_text=MLType.GROUND_TRUTH, 327 | ) 328 | 329 | @classproperty 330 | def label_verbalizer(cls) -> Optional[Dict[str, str]]: 331 | return { 332 | "Business": "about companies, industries, markets, trade, investments, entrepreneurship, economic policies, and other business-related developments", 333 | "World": "about international news, such as politics, diplomacy, conflicts, global events, international relations, human rights issues, and significant global trends", 334 | "Sci/Tech": "about scientific discoveries, technological advancements, innovations, research breakthroughs", 335 | "Sports": "related to coverage of professional sports leagues, major tournaments, athletes, teams, match results, player transfers, coaching changes, sports-related controversies", 336 | } 337 | 338 | @classmethod 339 | def raw_train(cls) -> Optional[pd.DataFrame]: 340 | return hf_load_dataset( 341 | "zapsdcn/ag", 342 | split="train", 343 | features=ds.Features( 344 | { 345 | "label": ds.Value("int64"), 346 | "text": ds.Value("string"), 347 | "headline": ds.Value("string"), 348 | "id": ds.Value("string"), 349 | } 350 | ), 351 | ).to_pandas() 352 | 353 | @classmethod 354 | def raw_validation(cls) -> Optional[pd.DataFrame]: 355 | return hf_load_dataset( 356 | "zapsdcn/ag", 357 | split="validation", 358 | features=ds.Features( 359 | { 360 | "label": ds.Value("int64"), 361 | "text": ds.Value("string"), 362 | "headline": ds.Value("string"), 363 | "id": ds.Value("string"), 364 | } 365 | ), 366 | ).to_pandas() 367 | 368 | @classmethod 369 | def raw_test(cls) -> Optional[pd.DataFrame]: 370 | test_df: pd.DataFrame = hf_load_dataset( 371 | "zapsdcn/ag", 372 | split="test", 373 | features=ds.Features( 374 | { 375 | "label": ds.Value("int64"), 376 | "text": ds.Value("string"), 377 | "headline": ds.Value("string"), 378 | "id": ds.Value("string"), 379 | } 380 | ), 381 | ).to_pandas() 382 | test_df = ( 383 | test_df.drop(["id"], axis=1) 384 | .reset_index(drop=True) 385 | .reset_index() 386 | .rename(columns={"index": "id"}) 387 | ) 388 | test_df["id"] = "idtest" + test_df["id"].astype(str) 389 | return test_df 390 | 391 | @classmethod 392 | def decode_labels(cls, row) -> pd.Series: 393 | ## Ref: https://www.kaggle.com/datasets/amananandrai/ag-news-classification-dataset 394 | lb_decoder = { 395 | 1: "World", 396 | 2: "Sports", 397 | 3: "Business", 398 | 4: "Sci/Tech", 399 | } 400 | row["label_text"] = lb_decoder[row["label"]] 401 | return row 402 | 403 | 404 | class AmazonReviewsPolarity(SynthesizRRDataset): 405 | name = "amazon-polarity" 406 | task = Task.BINARY_CLASSIFICATION 407 | data_schema = dict( 408 | idx=MLType.INDEX, 409 | # title=MLType.TEXT, 410 | text=MLType.TEXT, 411 | label_text=MLType.GROUND_TRUTH, 412 | ) 413 | 414 | @classproperty 415 | def label_verbalizer(cls) -> Optional[Dict[str, str]]: 416 | return { 417 | "positive": "what the reviewer liked about the product, how the reviewer found it easy to use the product, or the reviewer's positive experience with the product", 418 | "negative": "what the reviewer disliked about the product, how the reviewer found it challenging to use the product, or the reviewer's negative experience with the product", 419 | } 420 | 421 | @classmethod 422 | def raw_train(cls) -> Optional[pd.DataFrame]: 423 | return pd.read_parquet( 424 | f"{DATA_DIR}/data/amazon-polarity-mini/amazon-polarity-mini_train.parquet" 425 | ).rename(columns={"index": "idx", "content": "text"}) 426 | 427 | @classmethod 428 | def raw_validation(cls) -> Optional[pd.DataFrame]: 429 | return pd.read_parquet( 430 | f"{DATA_DIR}/data/amazon-polarity-mini/amazon-polarity-mini_validation.parquet" 431 | ).rename(columns={"index": "idx", "content": "text"}) 432 | 433 | @classmethod 434 | def raw_test(cls) -> Optional[pd.DataFrame]: 435 | return pd.read_parquet( 436 | f"{DATA_DIR}/data/amazon-polarity-mini/amazon-polarity-mini_test.parquet" 437 | ).rename(columns={"index": "idx", "content": "text"}) 438 | 439 | @staticmethod 440 | def create_dataset(): 441 | from datasets import load_dataset as hf_load_dataset 442 | 443 | amazon_polarity_train = ( 444 | hf_load_dataset("amazon_polarity", split="train") 445 | .to_pandas() 446 | .reset_index(drop=True) 447 | .reset_index() 448 | .rename(columns={"index": "idx"}) 449 | ) 450 | amazon_polarity_test = ( 451 | hf_load_dataset("amazon_polarity", split="test") 452 | .to_pandas() 453 | .reset_index(drop=True) 454 | .reset_index() 455 | .rename(columns={"index": "idx"}) 456 | ) 457 | amazon_polarity_train["label_text"] = amazon_polarity_train["label"].map( 458 | { 459 | 1: "positive", 460 | 0: "negative", 461 | } 462 | ) 463 | amazon_polarity_test["label_text"] = amazon_polarity_test["label"].map( 464 | { 465 | 1: "positive", 466 | 0: "negative", 467 | } 468 | ) 469 | amazon_polarity_mini_train = pd.concat( 470 | [ 471 | amazon_polarity_train.query('label_text == "positive"') 472 | .sample( 473 | n=72000 // 2, 474 | random_state=42, 475 | ) 476 | .reset_index(drop=True), 477 | amazon_polarity_train.query('label_text == "negative"') 478 | .sample( 479 | n=72000 // 2, 480 | random_state=42, 481 | ) 482 | .reset_index(drop=True), 483 | ] 484 | ).reset_index(drop=True) 485 | amazon_polarity_mini_test = pd.concat( 486 | [ 487 | amazon_polarity_test.query('label_text == "positive"') 488 | .sample( 489 | n=40_000 // 2, 490 | random_state=42, 491 | ) 492 | .reset_index(drop=True), 493 | amazon_polarity_test.query('label_text == "negative"') 494 | .sample( 495 | n=40_000 // 2, 496 | random_state=42, 497 | ) 498 | .reset_index(drop=True), 499 | ] 500 | ).reset_index(drop=True) 501 | amazon_polarity_train_non_mini = amazon_polarity_train.query( 502 | f"idx not in {amazon_polarity_mini_train['idx'].tolist()}" 503 | ) 504 | 505 | amazon_polarity_mini_validation = pd.concat( 506 | [ 507 | amazon_polarity_train_non_mini.query('label_text == "positive"') 508 | .sample( 509 | n=3600 // 2, 510 | random_state=42, 511 | ) 512 | .reset_index(drop=True), 513 | amazon_polarity_train_non_mini.query('label_text == "negative"') 514 | .sample( 515 | n=3600 // 2, 516 | random_state=42, 517 | ) 518 | .reset_index(drop=True), 519 | ] 520 | ).reset_index(drop=True) 521 | FileMetadata.of(f"{DATA_DIR}/data/amazon-polarity-mini/").mkdir() 522 | amazon_polarity_mini_train.to_parquet( 523 | f"{DATA_DIR}/data/amazon-polarity-mini/amazon-polarity-mini_train.parquet" 524 | ) 525 | amazon_polarity_mini_validation.to_parquet( 526 | f"{DATA_DIR}/data/amazon-polarity-mini/amazon-polarity-mini_validation.parquet" 527 | ) 528 | amazon_polarity_mini_test.to_parquet( 529 | f"{DATA_DIR}/data/amazon-polarity-mini/amazon-polarity-mini_test.parquet" 530 | ) 531 | 532 | 533 | class AmazonReviewsProductCategory(SynthesizRRDataset): 534 | name = "amazon-reviews-category" 535 | task = Task.MULTI_CLASS_CLASSIFICATION 536 | data_schema = dict( 537 | idx=MLType.INDEX, 538 | # asin=MLType.CATEGORICAL, 539 | # product_name=MLType.TEXT, 540 | # product_type=MLType.CATEGORICAL, 541 | # helpful=MLType.CATEGORICAL, 542 | # rating=MLType.CATEGORICAL, 543 | # title=MLType.TEXT, 544 | # date=MLType.CATEGORICAL, 545 | # reviewer=MLType.TEXT, 546 | # reviewer_location=MLType.CATEGORICAL, 547 | text=MLType.TEXT, 548 | label_text=MLType.GROUND_TRUTH, 549 | # sentiment=MLType.CATEGORICAL, 550 | ) 551 | 552 | @classproperty 553 | def label_verbalizer(cls) -> Optional[Dict[str, str]]: 554 | return { 555 | "magazines": "magazines or periodicals covering various topics", 556 | "camera_photo": "photography gear including cameras, lenses, accessories, or photo editing tools", 557 | "office_products": "office supplies or equipment for professional and home office setups", 558 | "kitchen": "kitchenware, appliances, or culinary tools for cooking and dining", 559 | "cell_phones_service": "cell phone service accessories or service plans for communication and connectivity", 560 | "computer_video_games": "computers, gaming consoles, video games, or related accessories", 561 | "grocery_and_gourmet_food": "groceries, fruits and vegetables, gourmet treats, or specialty food items", 562 | "tools_hardware": "tools, hardware, or equipment for DIY projects and home repairs", 563 | "automotive": "auto parts, accessories, or tools for vehicle maintenance and enhancements", 564 | "music_album": "music albums spanning various genres and artists", 565 | "health_and_personal_care": "healthcare products, personal care items, or wellness essentials", 566 | "electronics": "electronic devices, gadgets, personal tech, or home electronics", 567 | "outdoor_living": "products for outdoor activities, gardening, or patio living", 568 | "video": "movies, TV shows, and documentaries spanning various genres and artists", 569 | "apparel": "clothing including casual wear, formal attire, seasonal outfits, activewear, or fashion accessories for men, women, and children", 570 | "toys_games": "fun or educational toys and games for kids of all ages", 571 | "sports_outdoors": "products for various sports and outdoor activities", 572 | "books": "books in various genres and formats", 573 | "software": "computer software for productivity or gaming covering either personal or professional needs", 574 | "baby": "baby essentials, gear, or toys for infants and toddlers", 575 | "musical_and_instruments": "musical instruments, accessories, or music production equipment", 576 | "beauty": "beauty products, cosmetics, or skincare essentials, makeup, hair care, fragrances, or grooming essentials", 577 | "jewelry_and_watches": "watches or jewelry pieces such as necklaces, bracelets, earrings, or rings, crafted in precious metals or adorned with gemstones for special occasions", 578 | } 579 | 580 | @classmethod 581 | def raw_train(cls) -> Optional[pd.DataFrame]: 582 | return pd.read_parquet( 583 | f"{DATA_DIR}/data/amazon-reviews-category/amazon-reviews-category-train.parquet" 584 | ).rename( 585 | columns={ 586 | "unique_id": "idx", 587 | "review_text": "text", 588 | "product_category": "label_text", 589 | } 590 | ) 591 | 592 | @classmethod 593 | def raw_validation(cls) -> Optional[pd.DataFrame]: 594 | return pd.read_parquet( 595 | f"{DATA_DIR}/data/amazon-reviews-category/amazon-reviews-category-validation.parquet" 596 | ).rename( 597 | columns={ 598 | "unique_id": "idx", 599 | "review_text": "text", 600 | "product_category": "label_text", 601 | } 602 | ) 603 | 604 | @classmethod 605 | def raw_test(cls) -> Optional[pd.DataFrame]: 606 | return pd.read_parquet( 607 | f"{DATA_DIR}/data/amazon-reviews-category/amazon-reviews-category-test.parquet" 608 | ).rename( 609 | columns={ 610 | "unique_id": "idx", 611 | "review_text": "text", 612 | "product_category": "label_text", 613 | } 614 | ) 615 | 616 | @staticmethod 617 | def create_dataset(): 618 | from bs4 import BeautifulSoup 619 | 620 | dataset_dir: FileMetadata = FileMetadata.of( 621 | f"{DATA_DIR}/raw-data/amazon-reviews-category/sorted_data/" 622 | ).mkdir(return_metadata=True) 623 | if len(dataset_dir.list()) == 0: 624 | raise SystemError( 625 | f'Expected Amazon Reviews Category data to be in folder "{dataset_dir.path}". ' 626 | f"Please download and unzip the data from " 627 | f"https://www.cs.jhu.edu/~mdredze/datasets/sentiment/unprocessed.tar.gz" 628 | ) 629 | 630 | def parse_review(text) -> pd.DataFrame: 631 | df = [] 632 | for review_BS in BeautifulSoup(text).find_all("review"): 633 | d = {} 634 | for child in review_BS.children: 635 | k = child.name 636 | if k is not None: 637 | v = child.text.strip() 638 | if k in ["product_type", "unique_id"]: ## Allow list 639 | d.setdefault(k, []) 640 | if isinstance(d.get(k), list): 641 | d[k].append(v) 642 | else: 643 | if k in d: 644 | raise ValueError(f'"{k}" key already exists') 645 | d[k] = v 646 | if "unique_id" in d: 647 | d["unique_id"] = "-".join(d["unique_id"]) 648 | if "product_type" in d: 649 | d["product_type"] = ",".join(set(d["product_type"])) 650 | if len(d) > 0: 651 | df.append(d) 652 | return pd.DataFrame(df) 653 | 654 | dfs = [] 655 | from pathlib import Path 656 | 657 | for fpath in dataset_dir.list(only_subdirs=True): 658 | category = Path(fpath).stem 659 | neg = FileSystemUtil.get_file_str( 660 | str(Path(fpath) / "negative.review"), 661 | encoding="cp1252", 662 | raise_error=True, 663 | ) 664 | neg_df = parse_review(neg) 665 | neg_df["product_category"] = category 666 | neg_df["sentiment"] = "negative" 667 | dfs.append(neg_df) 668 | 669 | pos = FileSystemUtil.get_file_str( 670 | str(Path(fpath) / "positive.review"), 671 | ) 672 | pos_df = parse_review(pos) 673 | pos_df["product_category"] = category 674 | pos_df["sentiment"] = "positive" 675 | 676 | dfs.append(pos_df) 677 | reviews = pd.concat(dfs).reset_index(drop=True) 678 | reviews["product_category"] = reviews["product_category"].replace( 679 | { 680 | "grocery": "grocery_and_gourmet_food", 681 | "gourmet_food": "grocery_and_gourmet_food", 682 | } 683 | ) 684 | attrprompt_cats = { 685 | "magazines", 686 | "camera_photo", 687 | "office_products", 688 | "kitchen", 689 | "cell_phones_service", 690 | "computer_video_games", 691 | "grocery_and_gourmet_food", 692 | "tools_hardware", 693 | "automotive", 694 | "music_album", 695 | "health_and_personal_care", 696 | "electronics", 697 | "outdoor_living", 698 | "video", 699 | "apparel", 700 | "toys_games", 701 | "sports_outdoors", 702 | "books", 703 | "software", 704 | "baby", 705 | "musical_and_instruments", 706 | "beauty", 707 | "jewelry_and_watches", 708 | } 709 | reviews = reviews.query( 710 | f"product_category in {list(attrprompt_cats)}" 711 | ).reset_index(drop=True) 712 | assert set(reviews["product_category"]) == attrprompt_cats 713 | assert reviews["unique_id"].nunique() == len(reviews) 714 | # reviews['product_category'].value_counts() 715 | reviews_train = reviews.sample(n=30_000, random_state=42).reset_index(drop=True) 716 | print(reviews_train.shape) 717 | 718 | reviews_test = ( 719 | reviews.query(f"unique_id not in {reviews_train['unique_id'].tolist()}") 720 | .sample(n=2_400, random_state=42) 721 | .reset_index(drop=True) 722 | ) 723 | print(reviews_test.shape) 724 | 725 | reviews_validation = reviews.query( 726 | f"unique_id not in {reviews_train['unique_id'].tolist() + reviews_test['unique_id'].tolist()}" 727 | ).reset_index(drop=True) 728 | print(reviews_validation.shape) 729 | 730 | FileMetadata.of(f"{DATA_DIR}/data/amazon-reviews-category/").mkdir() 731 | reviews_train.to_parquet( 732 | f"{DATA_DIR}/data/amazon-reviews-category/amazon-reviews-category-train.parquet" 733 | ) 734 | reviews_validation.to_parquet( 735 | f"{DATA_DIR}/data/amazon-reviews-category/amazon-reviews-category-validation.parquet" 736 | ) 737 | reviews_test.to_parquet( 738 | f"{DATA_DIR}/data/amazon-reviews-category/amazon-reviews-category-test.parquet" 739 | ) 740 | 741 | 742 | class AmazonHumorousProductQuestions(SynthesizRRDataset): 743 | name = "amazon-humor" 744 | task = Task.BINARY_CLASSIFICATION 745 | data_schema = dict( 746 | idx=MLType.INDEX, 747 | text=MLType.TEXT, 748 | # product_description=MLType.TEXT, 749 | # image_url=MLType.URL, 750 | label_text=MLType.GROUND_TRUTH, 751 | ) 752 | 753 | @classproperty 754 | def label_verbalizer(cls) -> Optional[Dict[str, str]]: 755 | return { 756 | "non_humorous": "solemn", 757 | "humorous": "humorous", 758 | } 759 | 760 | @classmethod 761 | def raw_train(cls) -> Optional[pd.DataFrame]: 762 | return pd.read_parquet( 763 | f"{DATA_DIR}/data/amazon-humor/amazon-humor_train.parquet" 764 | ).rename(columns={"question": "text"}) 765 | 766 | @classmethod 767 | def raw_validation(cls) -> Optional[pd.DataFrame]: 768 | return pd.read_parquet( 769 | f"{DATA_DIR}/data/amazon-humor/amazon-humor_validation.parquet" 770 | ).rename(columns={"question": "text"}) 771 | 772 | @classmethod 773 | def raw_test(cls) -> Optional[pd.DataFrame]: 774 | return pd.read_parquet( 775 | f"{DATA_DIR}/data/amazon-humor/amazon-humor_test.parquet" 776 | ).rename(columns={"question": "text"}) 777 | 778 | @staticmethod 779 | def create_dataset(): 780 | dataset_dir: FileMetadata = FileMetadata.of( 781 | f"{DATA_DIR}/raw-data/amazon-humor/" 782 | ).mkdir(return_metadata=True) 783 | if len(dataset_dir.list()) == 0: 784 | raise SystemError( 785 | f'Expected Amazon Humor data to be in folder "{dataset_dir.path}". ' 786 | f"Please download the data files from https://registry.opendata.aws/humor-detection/" 787 | ) 788 | 789 | humor_pos = ( 790 | pd.read_csv(dataset_dir.file_in_dir("Humorous.csv")) 791 | .reset_index(drop=True) 792 | .reset_index() 793 | .rename(columns={"index": "idx"}) 794 | ) 795 | humor_pos["label_text"] = humor_pos["label"].map( 796 | { 797 | 1: "humorous", 798 | 0: "non_humorous", 799 | } 800 | ) 801 | 802 | humor_neg = ( 803 | pd.read_csv(dataset_dir.file_in_dir("Non-humorous-biased.csv")) 804 | .reset_index(drop=True) 805 | .reset_index() 806 | .rename(columns={"index": "idx"}) 807 | ) 808 | humor_neg["label_text"] = humor_neg["label"].map( 809 | { 810 | 1: "humorous", 811 | 0: "non_humorous", 812 | } 813 | ) 814 | 815 | humor_pos_train = humor_pos.sample(n=15_000 // 2, random_state=42).reset_index( 816 | drop=True 817 | ) 818 | # print(f'humor_pos_train: {len(humor_pos_train)}') 819 | 820 | humor_pos_validation = ( 821 | humor_pos.query(f"idx not in {humor_pos_train['idx'].tolist()}") 822 | .sample(n=1142 // 2, random_state=42) 823 | .reset_index(drop=True) 824 | ) 825 | # print(f'humor_pos_validation: {len(humor_pos_validation)}') 826 | 827 | humor_pos_test = humor_pos.query( 828 | f"idx not in {humor_pos_train['idx'].tolist() + humor_pos_validation['idx'].tolist()}" 829 | ).reset_index(drop=True) 830 | # print(f'humor_pos_test: {len(humor_pos_test)}') 831 | 832 | humor_neg_train = humor_neg.sample(n=15_000 // 2, random_state=42).reset_index( 833 | drop=True 834 | ) 835 | # print(f'humor_neg_train: {len(humor_neg_train)}') 836 | 837 | humor_neg_validation = ( 838 | humor_neg.query(f"idx not in {humor_neg_train['idx'].tolist()}") 839 | .sample(n=1142 // 2, random_state=42) 840 | .reset_index(drop=True) 841 | ) 842 | # print(f'humor_neg_validation: {len(humor_neg_validation)}') 843 | 844 | humor_neg_test = humor_neg.query( 845 | f"idx not in {humor_neg_train['idx'].tolist() + humor_neg_validation['idx'].tolist()}" 846 | ).reset_index(drop=True) 847 | # print(f'humor_neg_test: {len(humor_neg_test)}') 848 | 849 | humor_train = ( 850 | pd.concat([humor_pos_train, humor_neg_train]) 851 | .sample(frac=1, random_state=42) 852 | .reset_index(drop=True) 853 | ) 854 | # print(f'humor_train: {len(humor_train)}') 855 | # display(humor_train.head(3)) 856 | 857 | humor_validation = ( 858 | pd.concat([humor_pos_validation, humor_neg_validation]) 859 | .sample(frac=1, random_state=42) 860 | .reset_index(drop=True) 861 | ) 862 | # print(f'humor_validation: {len(humor_validation)}') 863 | # display(humor_validation.head(3)) 864 | 865 | humor_test = ( 866 | pd.concat([humor_pos_test, humor_neg_test]) 867 | .sample(frac=1, random_state=42) 868 | .reset_index(drop=True) 869 | ) 870 | # print(f'humor_test: {len(humor_test)}') 871 | # display(humor_test.head(3)) 872 | humor_train["idx"] = ( 873 | humor_train["idx"].astype(str) + "-" + humor_train["label_text"].astype(str) 874 | ) 875 | humor_validation["idx"] = ( 876 | humor_validation["idx"].astype(str) 877 | + "-" 878 | + humor_validation["label_text"].astype(str) 879 | ) 880 | humor_test["idx"] = ( 881 | humor_test["idx"].astype(str) + "-" + humor_test["label_text"].astype(str) 882 | ) 883 | 884 | FileMetadata.of(f"{DATA_DIR}/data/amazon-humor/").mkdir() 885 | humor_train.to_parquet( 886 | f"{DATA_DIR}/data/amazon-humor/amazon-humor_train.parquet" 887 | ) 888 | humor_validation.to_parquet( 889 | f"{DATA_DIR}/data/amazon-humor/amazon-humor_validation.parquet" 890 | ) 891 | humor_test.to_parquet(f"{DATA_DIR}/data/amazon-humor/amazon-humor_test.parquet") 892 | 893 | 894 | class ToiHeadlinesDataset(SynthesizRRDataset): 895 | name = "toi_headlines" 896 | task = Task.MULTI_CLASS_CLASSIFICATION 897 | data_schema = dict( 898 | idx=MLType.INDEX, 899 | text=MLType.TEXT, 900 | # publish_date=MLType.TEXT, 901 | # headline_category=MLType.CATEGORICAL, 902 | # headline_text_len=MLType.INT, 903 | label_text=MLType.GROUND_TRUTH, 904 | ) 905 | 906 | @classmethod 907 | def raw_train(cls) -> Optional[pd.DataFrame]: 908 | return pd.read_parquet( 909 | f"{DATA_DIR}/data/toi_headlines/toi_headlines_train.parquet" 910 | ).rename(columns={"headline_text": "text", "headline_root": "label_text"}) 911 | 912 | @classmethod 913 | def raw_validation(cls) -> Optional[pd.DataFrame]: 914 | return pd.read_parquet( 915 | f"{DATA_DIR}/data/toi_headlines/toi_headlines_validation.parquet" 916 | ).rename(columns={"headline_text": "text", "headline_root": "label_text"}) 917 | 918 | @classmethod 919 | def raw_test(cls) -> Optional[pd.DataFrame]: 920 | return pd.read_parquet( 921 | f"{DATA_DIR}/data/toi_headlines/toi_headlines_test.parquet" 922 | ).rename(columns={"headline_text": "text", "headline_root": "label_text"}) 923 | 924 | @classproperty 925 | def label_verbalizer(cls) -> Optional[Dict[str, str]]: 926 | raise NotImplementedError() 927 | 928 | @staticmethod 929 | def create_dataset(): 930 | label_space: List[str] = [ 931 | "sports", 932 | "life-style", 933 | "education", 934 | "entertainment", 935 | "business", 936 | "city", 937 | "environment", 938 | "tech", 939 | "elections", 940 | "world", 941 | ] 942 | dataset_dir: FileMetadata = FileMetadata.of( 943 | f"{DATA_DIR}/raw-data/toi_headlines/" 944 | ).mkdir(return_metadata=True) 945 | if len(dataset_dir.list()) == 0: 946 | raise SystemError( 947 | f'Expected ToI Headlines data to be in folder "{dataset_dir.path}". ' 948 | f"Please download the CSV file from " 949 | f"https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/DPQMQH" 950 | ) 951 | df = pd.read_csv(dataset_dir.file_in_dir("india-news-headlines.csv")) 952 | df["headline_root"] = df["headline_category"].apply( 953 | lambda x: x.strip().removeprefix("home.").split(".")[0] 954 | ) 955 | df["headline_text_len"] = df["headline_text"].apply(len) 956 | df = ( 957 | df.query("headline_text_len >= 40") 958 | .reset_index(drop=True) 959 | .reset_index() 960 | .rename(columns=dict(index="idx")) 961 | ) 962 | full_idxs: np.ndarray = sample_idxs_match_distribution( 963 | source=df["headline_root"], 964 | target=pd.Series([lb for lb in label_space]), ## Balanced 965 | n=None, 966 | seed=42, 967 | ) 968 | full = df.loc[full_idxs].sample(frac=1, random_state=42).reset_index(drop=True) 969 | train = full.loc[ 970 | sample_idxs_match_distribution( 971 | full["headline_root"], 972 | target=pd.Series([lb for lb in label_space]), ## Balanced 973 | n=52_000, 974 | seed=42, 975 | ) 976 | ].reset_index(drop=True) 977 | remaining_wo_train = full.query( 978 | f"idx not in {train['idx'].tolist()}" 979 | ).reset_index(drop=True) 980 | test = remaining_wo_train.loc[ 981 | sample_idxs_match_distribution( 982 | remaining_wo_train["headline_root"], 983 | target=pd.Series([lb for lb in label_space]), ## Balanced 984 | n=10_000, 985 | seed=42, 986 | ) 987 | ].reset_index(drop=True) 988 | validation = full.query( 989 | f"idx not in {train['idx'].tolist() + test['idx'].tolist()}" 990 | ).reset_index(drop=True) 991 | train["idx"] = train["idx"].apply(lambda x: f"train-{x}") 992 | test["idx"] = test["idx"].apply(lambda x: f"test-{x}") 993 | validation["idx"] = validation["idx"].apply(lambda x: f"validation-{x}") 994 | print(train["headline_root"].value_counts()) 995 | print(validation["headline_root"].value_counts()) 996 | print(test["headline_root"].value_counts()) 997 | FileMetadata.of(f"{DATA_DIR}/data/toi_headlines/").mkdir() 998 | train.to_parquet(f"{DATA_DIR}/data/toi_headlines/toi_headlines_train.parquet") 999 | validation.to_parquet( 1000 | f"{DATA_DIR}/data/toi_headlines/toi_headlines_validation.parquet" 1001 | ) 1002 | test.to_parquet(f"{DATA_DIR}/data/toi_headlines/toi_headlines_test.parquet") 1003 | 1004 | 1005 | if __name__ == "__main__": 1006 | ToiHeadlinesDataset.create_dataset() 1007 | AmazonReviewsProductCategory.create_dataset() 1008 | AmazonReviewsPolarity.create_dataset() 1009 | AmazonHumorousProductQuestions.create_dataset() 1010 | ## Copy schema and files to be accessible to all workers: 1011 | HyperpartisanNewsDataset.setup() 1012 | AGNewsDataset.setup() 1013 | ToiHeadlinesDataset.setup() 1014 | AmazonReviewsProductCategory.setup() 1015 | AmazonReviewsPolarity.setup() 1016 | AmazonHumorousProductQuestions.setup() 1017 | -------------------------------------------------------------------------------- /src/synthesizrr/main.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | from fmcore.framework import Tracker 4 | 5 | from synthesizrr.common import ( 6 | RESULTS_DIR, 7 | Corpus, 8 | DatasetName, 9 | Experiment, 10 | ModelName, 11 | Retriever, 12 | ) 13 | from synthesizrr.driver import run_chain 14 | 15 | TRACKER = Tracker.of('log', path='~/synthesizrr_run.log') ## Execution outputs will get logged to this file. 16 | TRACKER = Tracker.of( 17 | "log", path="~/synthesizrr_run.log" 18 | ) ## Execution outputs will get logged to this file. 19 | "log", path="~/synthesizrr_run.log" 20 | ) ## Execution outputs will get logged to this file. 21 | BACKGROUND: bool = False 22 | CART_FRAC: Optional[float] = 0.83 ## Make None to Cartography filtering. 23 | 24 | if __name__ == "__main__": 25 | """ 26 | _ _ _ _ 27 | | || | _ _ _ __ ___ _ _ _ __ __ _ _ _ | |_ (_) ___ __ _ _ _ 28 | | __ || || || '_ \/ -_)| '_|| '_ \/ _` || '_|| _|| |(_-