├── NOTICE
├── img
└── high-level-diagram.jpg
├── src
└── synthesizrr
│ ├── __init__.py
│ ├── driver.py
│ ├── corpus.py
│ ├── data.py
│ └── main.py
├── CODE_OF_CONDUCT.md
├── .gitignore
├── .github
└── workflows
│ ├── linting.yml
│ ├── tests.yml
│ └── release.yml
├── pyproject.toml
├── CONTRIBUTING.md
├── README.md
├── requirements.txt
└── LICENSE
/NOTICE:
--------------------------------------------------------------------------------
1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 |
--------------------------------------------------------------------------------
/img/high-level-diagram.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/synthesizrr/HEAD/img/high-level-diagram.jpg
--------------------------------------------------------------------------------
/src/synthesizrr/__init__.py:
--------------------------------------------------------------------------------
1 | ## Import in dependency order:
2 | import synthesizrr.data
3 | import synthesizrr.common
4 | import synthesizrr.generation
5 | import synthesizrr.metrics
6 | import synthesizrr.driver
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | ## Code of Conduct
2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
4 | opensource-codeofconduct@amazon.com with any additional questions or comments.
5 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *~
2 | *#
3 | *.pyc
4 | *.DS_Store
5 |
6 | __pycache__/
7 | *.py[cod]
8 | *$py.class
9 | *.egg-info/
10 | /.coverage
11 | /.coverage.*
12 | /.cache
13 | /.pytest_cache
14 | /.mypy_cache
15 | .idea/*
16 | **/.ipynb_checkpoints/*
17 | /build
18 | /doc/_apidoc/
19 | *.swp
20 | bandit_report_code.txt
21 | bandit_report.json
22 |
--------------------------------------------------------------------------------
/.github/workflows/linting.yml:
--------------------------------------------------------------------------------
1 | name: Linting
2 |
3 | on: [push, pull_request]
4 |
5 | jobs:
6 | ruff-formatter:
7 | runs-on: ubuntu-latest
8 | steps:
9 | - uses: actions/checkout@v4
10 |
11 | - uses: astral-sh/ruff-action@v3
12 | with:
13 | version: '0.9.2'
14 | args: format --check
15 | src: './src'
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: tests
2 |
3 | on:
4 | push:
5 | branches: [ "main" ]
6 | pull_request:
7 | branches: [ "main" ]
8 |
9 | jobs:
10 | pytest:
11 | runs-on: ubuntu-latest
12 |
13 | steps:
14 | - name: Check out repo
15 | uses: actions/checkout@v3
16 |
17 | - name: Set up Python
18 | uses: actions/setup-python@v4
19 | with:
20 | python-version: "3.11"
21 |
22 | - name: Install test requirements
23 | # Install package dependencies:
24 | run: |
25 | python -m pip install --upgrade pip
26 | python -m pip install --upgrade uv
27 | python -m uv pip install --upgrade pytest
28 | python -m uv pip install -e .
29 |
30 | - name: Run tests
31 | run: |
32 | pytest tests/
33 |
--------------------------------------------------------------------------------
/.github/workflows/release.yml:
--------------------------------------------------------------------------------
1 | name: release
2 |
3 | on:
4 | release:
5 | types:
6 | - published # Published through GitHub UI: https://github.com/amazon-science/synthesizrr/releases/new
7 |
8 | jobs:
9 | pypi:
10 | runs-on: ubuntu-latest
11 | if: >
12 | ${{ github.event.workflow_run.conclusion == 'success' &&
13 | github.event.workflow_run.head_branch == 'main' }}
14 | steps:
15 | - name: Checkout Repository
16 | uses: actions/checkout@v3
17 | with:
18 | fetch-depth: 0 # Required for Git versioning
19 |
20 | - name: Set up Python
21 | uses: actions/setup-python@v4
22 | with:
23 | python-version: "3.11"
24 |
25 | - name: Install Hatch & Twine
26 | run: |
27 | python -m pip install --upgrade pip
28 | python -m pip install uv
29 | python -m uv pip install hatch twine
30 |
31 | - name: Verify Version from Git Tag
32 | run: hatch version
33 |
34 | - name: Build Package
35 | run: hatch build
36 |
37 | - name: Publish to Test PyPI
38 | env:
39 | TWINE_USERNAME: "__token__"
40 | TWINE_PASSWORD: ${{ secrets.TEST_PYPI_API_TOKEN }}
41 | run: twine upload --repository testpypi dist/*
42 |
43 | - name: Publish to PyPI
44 | env:
45 | TWINE_USERNAME: "__token__"
46 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
47 | run: twine upload --repository pypi dist/*
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["hatchling", "hatch-vcs"]
3 | build-backend = "hatchling.build"
4 |
5 | [project]
6 | name = "synthesizrr"
7 | dynamic = ["version"]
8 | authors = [
9 | { name = "Abhishek Divekar", email = "adivekar@utexas.edu" }
10 | ]
11 | description = "Synthesizing realistic and diverse text-datasets from augmented LLMs."
12 | readme = "README.md"
13 | requires-python = ">=3.11"
14 | classifiers = [
15 | "Programming Language :: Python :: 3",
16 | "Operating System :: OS Independent",
17 | ]
18 | license-files = ["LICENSE"]
19 | dependencies = [
20 | "fmcore[all]",
21 | ]
22 |
23 | [tool.hatch.version]
24 | source = "vcs"
25 |
26 | [tool.pytest.ini_options]
27 | pythonpath = ["src"]
28 |
29 | [tool.ruff]
30 | line-length = 110
31 | fix = true
32 | force-exclude = true
33 | extend-exclude = [
34 | "__init__.py",
35 | ]
36 |
37 | [tool.ruff.lint]
38 | fixable = [
39 | "I", # Add all rules under isort linter: https://docs.astral.sh/ruff/rules/#isort-i
40 | "W", # Add all rules under whitespace: https://docs.astral.sh/ruff/rules/#warning-w
41 | "E401", # multiple-imports-on-one-line: https://docs.astral.sh/ruff/rules/multiple-imports-on-one-line/
42 | "E713", # not-in-test: https://docs.astral.sh/ruff/rules/not-in-test/
43 | "E721", # type-comparison: https://docs.astral.sh/ruff/rules/type-comparison/
44 | "E722", # bare-except: https://docs.astral.sh/ruff/rules/bare-except/
45 | "F401", # unused-import: https://docs.astral.sh/ruff/rules/unused-import/
46 | "F541", # f-string-missing-placeholders: https://docs.astral.sh/ruff/rules/f-string-missing-placeholders/
47 | "F811", # redefined-while-unused: https://docs.astral.sh/ruff/rules/redefined-while-unused/
48 | "F841", # unused-variable: https://docs.astral.sh/ruff/rules/unused-variable/
49 | ]
50 | ignore = [
51 | ## Ignored because it makes the code too verbose:
52 | "E731", # lambda-assignment: https://docs.astral.sh/ruff/rules/lambda-assignment/
53 | "E741", # ambiguous-variable-name: https://docs.astral.sh/ruff/rules/ambiguous-variable-name/
54 |
55 | ## Ignored because of bad interaction with `from typing import *`
56 | "F405", # undefined-local-with-import-star-usage: https://docs.astral.sh/ruff/rules/undefined-local-with-import-star-usage/
57 | "F403", # undefined-local-with-import-star: https://docs.astral.sh/ruff/rules/undefined-local-with-import-star/
58 | ]
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing Guidelines
2 |
3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional
4 | documentation, we greatly value feedback and contributions from our community.
5 |
6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary
7 | information to effectively respond to your bug report or contribution.
8 |
9 |
10 | ## Reporting Bugs/Feature Requests
11 |
12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features.
13 |
14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already
15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful:
16 |
17 | * A reproducible test case or series of steps
18 | * The version of our code being used
19 | * Any modifications you've made relevant to the bug
20 | * Anything unusual about your environment or deployment
21 |
22 |
23 | ## Contributing via Pull Requests
24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:
25 |
26 | 1. You are working against the latest source on the *main* branch.
27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.
28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted.
29 |
30 | To send us a pull request, please:
31 |
32 | 1. Fork the repository.
33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
34 | 3. Ensure local tests pass.
35 | 4. Commit to your fork using clear commit messages.
36 | 5. Send us a pull request, answering any default questions in the pull request interface.
37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation.
38 |
39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and
40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/).
41 |
42 |
43 | ## Finding contributions to work on
44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start.
45 |
46 |
47 | ## Code of Conduct
48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
50 | opensource-codeofconduct@amazon.com with any additional questions or comments.
51 |
52 |
53 | ## Security issue notifications
54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.
55 |
56 |
57 | ## Licensing
58 |
59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.
60 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SynthesizRR: Generating Diverse Datasets with Retrieval Augmentation
2 |
3 | This repository contains the implementation of the paper "SynthesizRR: Generating Diverse Datasets with Retrieval Augmentation" (https://arxiv.org/abs/2405.10040)
4 |
5 |
6 | Our proposed approach is below. Refer Algorithm 1 in the paper for details: https://arxiv.org/abs/2405.10040
7 |
8 | 
9 |
10 | ## Installing dependencies
11 |
12 | We recommend installing required dependencies in a new Conda environment using the commands below.
13 |
14 | These commands were tested to work on `Deep Learning AMI GPU PyTorch 1.13.1 (Amazon Linux 2) 20230221` from AWS.
15 |
16 | Install dependencies:
17 | ```commandline
18 | conda create -n synthesizrr python=3.11.8 --yes
19 | conda activate synthesizrr
20 | pip install uv ## For super-fast installation
21 |
22 | uv pip install -r requirements.txt
23 |
24 | uv pip install "spacy==3.7.4" "spacy-transformers==1.3.5"
25 | uv pip install "setuptools==69.5.1"
26 |
27 | python -m spacy download en_core_web_lg
28 | python -c "import nltk; nltk.download('punkt');"
29 | ```
30 |
31 | ## Code structure
32 |
33 | `synthesizrr/base/` contains utility functions and classes.
34 |
35 | `synthesizrr/expts/` contains code to reproduce the experiments.
36 |
37 | ## Running the code
38 | 1. Setup `DATA_DIR`:
39 | - Download the datasets into a local folder `DATA_DIR`.
40 | - Inside `synthesizrr/expt/data.py`, set the variable `DATA_DIR` (marked TODO) to the above folder.
41 |
42 | 2. Setup `CORPUS_DIR`:
43 | - Download the corpora into a folder `CORPUS_DIR`.
44 | - We recommend using S3 for this since the corpora are large.
45 | - Inside `synthesizrr/expt/corpus.py`, set the variable `CORPUS_DIR` (marked TODO) to the above folder.
46 |
47 | 3. Setup `RESULTS_DIR`:
48 | - Inside `synthesizrr/expt/common.py`, set the variable `RESULTS_DIR` (marked with TODO) to a different folder. Intermediate datasets and metrics will be saved here.
49 | - We recommend using S3 for this since the file-paths are long.
50 |
51 | 4. Start a Ray cluster:
52 | - On the Ray head node, run: `ray start --head`
53 | - On the Ray worker nodes, run `ray start --address='
:6379'`
54 | - At the top of the files `data.py`, `corpus.py`, `main.py`, add the following to connect to the Ray cluster:
55 | ```commandline
56 | import synthesizrr
57 | import ray
58 | from ray.util.dask import ray_dask_get, enable_dask_on_ray, disable_dask_on_ray
59 | from pprint import pprint
60 | pprint(ray.init(
61 | address='ray://:10001', ## MODIFY THIS
62 | ignore_reinit_error=True,
63 | _temp_dir=str('/tmp/ray/'),
64 | runtime_env={"py_modules": [
65 | synthesizrr,
66 | ]},
67 | ))
68 | enable_dask_on_ray()
69 | pprint(ray.cluster_resources()) ## Shows you number of cpus and gpus to make sure it is setup properly.
70 | ```
71 |
72 | 5. After modifying the code to set `DATA_DIR`, `CORPUS_DIR` and `RESULTS_DIR`, and starting the Ray cluster, run the following:
73 | - First, run `cd synthesizrr/expts/ && python3 data.py` to create the datasets. (You will need to download certain datasets to `DATA_DIR` folder beforehand).
74 | - Next, run `cd synthesizrr/expts/ && python3 corpus.py` to create the corpora (**warning**, this step needs a lot of compute! Make sure you setup the Ray cluster and use a big machine with at least a few hundred GB of RAM as the head node).
75 | - Finally, run the file `cd synthesizrr/expts/ && python3 main.py` to reproduce the experiments.
76 |
77 |
78 | ## Security
79 |
80 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information.
81 |
82 | ## License
83 |
84 | This project is licensed under the Apache-2.0 License.
85 |
86 | ## Citing
87 |
88 | If you use or refer to this code in another publication, please cite it using the Bibtex below:
89 |
90 | ```bibtex
91 | @misc{divekar2024synthesizrr,
92 | title={SynthesizRR: Generating Diverse Datasets with Retrieval Augmentation},
93 | author={Abhishek Divekar and Greg Durrett},
94 | year={2024},
95 | eprint={2405.10040},
96 | archivePrefix={arXiv}
97 | }
98 | ```
99 |
100 | ## Acknowledgements
101 | The compute infrastructure used for these experiments was financially supported by the Amazon Central Machine Learning department.
102 |
103 | The following people contributed to the design or implemented smaller components in this codebase:
104 | - [Gaurav Manchanda](https://in.linkedin.com/in/gauravmanchanda)
105 | - [Siba Rajendran](https://www.linkedin.com/in/siba-rajendran-920135156/)
106 | - [Vijit Malik](https://scholar.google.com/citations?user=noW8sb8AAAAJ&hl=en)
107 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==2.1.0
2 | accelerate @ git+https://github.com/huggingface/accelerate@b7fa2fa956f40e0b6f650d5eb1764680bf3fd8f7
3 | aim==3.20.1
4 | aim-ui==3.20.1
5 | aimrecords==0.0.7
6 | aimrocks==0.4.0
7 | aiobotocore==2.13.0
8 | aiofiles==23.2.1
9 | aiohttp==3.10.2
10 | aiohttp-cors==0.7.0
11 | aioitertools==0.11.0
12 | aiorwlock==1.4.0
13 | aiosignal==1.3.1
14 | alembic==1.13.1
15 | altair==5.3.0
16 | anyio==4.4.0
17 | argon2-cffi==23.1.0
18 | argon2-cffi-bindings==21.2.0
19 | arrow==1.3.0
20 | art==6.2
21 | asttokens==2.4.1
22 | async-lru==2.0.4
23 | attrs==23.2.0
24 | babel==2.15.0
25 | base58==2.0.1
26 | beautifulsoup4==4.12.3
27 | bitsandbytes==0.43.1
28 | bleach==6.1.0
29 | blessed==1.20.0
30 | blinker==1.8.2
31 | bokeh==3.4.1
32 | boto3==1.34.106
33 | botocore==1.34.106
34 | brotli==1.1.0
35 | cachetools==5.3.3
36 | certifi==2024.7.4
37 | cffi==1.16.0
38 | charset-normalizer==3.3.2
39 | click==8.1.7
40 | cloudpickle==3.0.0
41 | cmake==3.29.3
42 | colorcet==3.1.0
43 | colorful==0.5.6
44 | comm==0.2.2
45 | contourpy==1.2.1
46 | cramjam==2.8.3
47 | cryptography==42.0.8
48 | cycler==0.12.1
49 | dask==2024.2.0
50 | datasets==2.2.1
51 | debugpy==1.8.1
52 | decorator==5.1.1
53 | deepspeed==0.14.2
54 | defusedxml==0.7.1
55 | dill==0.3.8
56 | distlib==0.3.8
57 | distributed==2024.2.0
58 | docker-pycreds==0.4.0
59 | einops==0.8.0
60 | et-xmlfile==1.1.0
61 | evaluate==0.4.2
62 | executing==2.0.1
63 | faiss-cpu==1.8.0
64 | fastapi==0.109.1
65 | fastjsonschema==2.19.1
66 | fastparquet==2024.5.0
67 | filelock==3.14.0
68 | fonttools==4.53.0
69 | fqdn==1.5.1
70 | frozenlist==1.4.1
71 | fsspec==2024.6.0
72 | gitdb==4.0.11
73 | gitpython==3.1.43
74 | google-api-core==2.19.0
75 | google-auth==2.29.0
76 | googleapis-common-protos==1.63.1
77 | gpustat==1.1.1
78 | greenlet==3.0.3
79 | grpcio==1.64.1
80 | h11==0.14.0
81 | hjson==3.1.0
82 | holoviews==1.18.3
83 | httpcore==1.0.5
84 | httptools==0.6.1
85 | httpx==0.27.0
86 | huggingface-hub @ git+https://github.com/huggingface/huggingface_hub@919ce7d0ca281574a26cfa73cf242def95ac0119
87 | hvplot==0.10.0
88 | idna==3.7
89 | imageio==2.34.1
90 | importlib-metadata==7.1.0
91 | ipykernel==6.29.4
92 | ipython==8.25.0
93 | ipywidgets==8.1.3
94 | isoduration==20.11.0
95 | jedi==0.19.1
96 | jinja2==3.1.4
97 | jmespath==1.0.1
98 | joblib==1.4.2
99 | json5==0.9.25
100 | jsonpointer==2.4
101 | jsonschema==4.22.0
102 | jsonschema-specifications==2023.12.1
103 | jupyter==1.0.0
104 | jupyter-client==8.6.2
105 | jupyter-console==6.6.3
106 | jupyter-core==5.7.2
107 | jupyter-events==0.10.0
108 | jupyter-lsp==2.2.5
109 | jupyter-server==2.14.1
110 | jupyter-server-terminals==0.5.3
111 | jupyterlab==4.2.1
112 | jupyterlab-pygments==0.3.0
113 | jupyterlab-server==2.27.2
114 | jupyterlab-widgets==3.0.11
115 | kiwisolver==1.4.5
116 | linkify-it-py==2.0.3
117 | lit==18.1.6
118 | locket==1.0.0
119 | lz4==4.3.3
120 | mako==1.3.5
121 | markdown==3.6
122 | markdown-it-py==3.0.0
123 | markupsafe==2.1.5
124 | matplotlib==3.9.0
125 | matplotlib-inline==0.1.7
126 | mauve-text==0.3.0
127 | mdit-py-plugins==0.4.1
128 | mdurl==0.1.2
129 | mistune==3.0.2
130 | mpmath==1.3.0
131 | msgpack==1.0.8
132 | multidict==6.0.5
133 | multiprocess==0.70.16
134 | nbclient==0.10.0
135 | nbconvert==7.16.4
136 | nbformat==5.10.4
137 | nest-asyncio==1.6.0
138 | networkx==3.3
139 | ninja==1.11.1.1
140 | nltk==3.8.2
141 | notebook==7.2.0
142 | notebook-shim==0.2.4
143 | numpy==1.26.4
144 | nvidia-cublas-cu11==11.10.3.66
145 | nvidia-cuda-cupti-cu11==11.7.101
146 | nvidia-cuda-nvrtc-cu11==11.7.99
147 | nvidia-cuda-runtime-cu11==11.7.99
148 | nvidia-cudnn-cu11==8.5.0.96
149 | nvidia-cufft-cu11==10.9.0.58
150 | nvidia-curand-cu11==10.2.10.91
151 | nvidia-cusolver-cu11==11.4.0.1
152 | nvidia-cusparse-cu11==11.7.4.91
153 | nvidia-ml-py==12.535.161
154 | nvidia-nccl-cu11==2.14.3
155 | nvidia-nvtx-cu11==11.7.91
156 | nvitop==1.3.2
157 | opencensus==0.11.4
158 | opencensus-context==0.1.3
159 | openpyxl==3.1.3
160 | orjson==3.10.3
161 | overrides==7.7.0
162 | packaging==24.0
163 | pandas==1.5.3
164 | pandocfilters==1.5.1
165 | panel==1.4.4
166 | param==2.1.0
167 | parso==0.8.4
168 | partd==1.4.2
169 | patsy==0.5.6
170 | pexpect==4.9.0
171 | pillow==10.3.0
172 | pip==24.0
173 | platformdirs==4.2.2
174 | plotly==5.22.0
175 | plotly-express==0.4.1
176 | prometheus-client==0.20.0
177 | prompt-toolkit==3.0.46
178 | proto-plus==1.23.0
179 | protobuf==4.25.3
180 | psutil==5.9.8
181 | ptyprocess==0.7.0
182 | pure-eval==0.2.2
183 | py-cpuinfo==9.0.0
184 | py-spy==0.3.14
185 | pyarrow==16.1.0
186 | pyarrow-hotfix==0.6
187 | pyasn1==0.6.0
188 | pyasn1-modules==0.4.0
189 | pycparser==2.22
190 | pydantic==1.10.15
191 | pydeck==0.9.1
192 | pygments==2.18.0
193 | pynvml==11.5.0
194 | pyparsing==3.1.2
195 | python-dateutil==2.9.0.post0
196 | python-dotenv==1.0.1
197 | python-json-logger==2.0.7
198 | pytz==2024.1
199 | pyviz-comms==3.0.2
200 | pyyaml==6.0.1
201 | pyzmq==26.0.3
202 | qtconsole==5.5.2
203 | qtpy==2.4.1
204 | ray==2.9.2
205 | referencing==0.35.1
206 | regex==2024.5.15
207 | requests==2.32.3
208 | responses==0.18.0
209 | restrictedpython==7.1
210 | rfc3339-validator==0.1.4
211 | rfc3986-validator==0.1.1
212 | rich==13.7.1
213 | rpds-py==0.18.1
214 | rsa==4.9
215 | s3fs==2024.6.0
216 | s3transfer==0.10.1
217 | safetensors==0.4.3
218 | scikit-learn==1.5.0
219 | scipy==1.13.1
220 | seaborn==0.13.2
221 | send2trash==1.8.3
222 | sentence-transformers==3.0.0
223 | sentencepiece==0.2.0
224 | sentry-sdk==2.8.0
225 | setproctitle==1.3.3
226 | setuptools==70.0.0
227 | six==1.16.0
228 | smart-open==7.0.4
229 | smmap==5.0.1
230 | sniffio==1.3.1
231 | sortedcontainers==2.4.0
232 | soupsieve==2.5
233 | sqlalchemy==2.0.30
234 | stack-data==0.6.3
235 | starlette==0.36.2
236 | statsmodels==0.14.2
237 | streamlit==1.37.0
238 | sympy==1.12.1
239 | tabulate==0.9.0
240 | tblib==3.0.0
241 | tenacity==8.3.0
242 | tensorboard==2.16.2
243 | tensorboard-data-server==0.7.2
244 | tensorboardx==2.6.2.2
245 | termcolor==2.4.0
246 | terminado==0.18.1
247 | threadpoolctl==3.5.0
248 | tiktoken==0.7.0
249 | tinycss2==1.3.0
250 | tokenizers==0.19.1
251 | toml==0.10.2
252 | toolz==0.12.1
253 | torch==2.0.1
254 | tornado==6.4.1
255 | tqdm==4.66.4
256 | traitlets==5.14.3
257 | transformers==4.41.1
258 | triton==2.0.0
259 | types-python-dateutil==2.9.0.20240316
260 | typing-extensions==4.12.1
261 | uc-micro-py==1.0.3
262 | uri-template==1.3.0
263 | urllib3==2.2.2
264 | uv==0.2.6
265 | uvicorn==0.30.1
266 | uvloop==0.19.0
267 | virtualenv==20.26.2
268 | wandb==0.17.0
269 | watchdog==4.0.1
270 | watchfiles==0.22.0
271 | wcwidth==0.2.13
272 | webcolors==1.13
273 | webencodings==0.5.1
274 | websocket-client==1.8.0
275 | websockets==12.0
276 | werkzeug==3.0.3
277 | wheel==0.43.0
278 | widgetsnbextension==4.0.11
279 | wrapt==1.16.0
280 | xlrd==2.0.1
281 | xlsxwriter==3.2.0
282 | xxhash==3.4.1
283 | xyzservices==2024.4.0
284 | yarl==1.9.4
285 | zict==3.0.0
286 | zipp==3.19.2
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
--------------------------------------------------------------------------------
/src/synthesizrr/driver.py:
--------------------------------------------------------------------------------
1 | from typing import *
2 |
3 | import ray
4 | from fmcore.framework import *
5 | from fmcore.framework.dl.torch import *
6 | from fmcore.util import *
7 |
8 | from synthesizrr.common import (
9 | DEFAULT_TEMPERATURE,
10 | DEFAULT_TOP_P,
11 | LABEL_OVERALL,
12 | DatasetFilterParams,
13 | Experiment,
14 | MetricName,
15 | Student,
16 | )
17 | from synthesizrr.generation import (
18 | Corpus,
19 | CreateFewGenDatasets,
20 | CreateSeedSet,
21 | CreateSynthesizRRDatasets,
22 | DatasetName,
23 | EmbedCorpus,
24 | FewGen,
25 | ModelName,
26 | RetrieveFromSeedSet,
27 | Retriever,
28 | SynthesizRR,
29 | )
30 | from synthesizrr.metrics import (
31 | FewGenTextGenMetrics,
32 | GoldDatasetMetrics,
33 | SynthesizRRTextGenMetrics,
34 | )
35 |
36 |
37 | def get_wf(expt: Experiment, metrics: bool) -> Chain:
38 | return {
39 | ## Gold:
40 | (Experiment.Gold, True): Chain.of(
41 | GoldDatasetMetrics,
42 | ),
43 | ## Gold workflow without metrics does not exist.
44 | ## SynthesizRR:
45 | (Experiment.SynthesizRR, False): Chain.of(
46 | EmbedCorpus,
47 | CreateSeedSet,
48 | RetrieveFromSeedSet,
49 | CreateSynthesizRRDatasets,
50 | SynthesizRR,
51 | # SynthesizRRTextGenMetrics,
52 | ),
53 | (Experiment.SynthesizRR, True): Chain.of(
54 | EmbedCorpus,
55 | CreateSeedSet,
56 | RetrieveFromSeedSet,
57 | CreateSynthesizRRDatasets,
58 | SynthesizRR,
59 | SynthesizRRTextGenMetrics,
60 | ),
61 | ## FewGen:
62 | (Experiment.FewGen, False): Chain.of(
63 | CreateSeedSet,
64 | CreateFewGenDatasets,
65 | FewGen,
66 | # FewGenTextGenMetrics,
67 | ),
68 | (Experiment.FewGen, True): Chain.of(
69 | CreateSeedSet,
70 | CreateFewGenDatasets,
71 | FewGen,
72 | FewGenTextGenMetrics,
73 | ),
74 | }[(expt, metrics)]
75 |
76 |
77 | @safe_validate_arguments
78 | def run_chain(
79 | *,
80 | expt: Experiment,
81 | results_dir: FileMetadata,
82 | notifier: Optional[Notifier],
83 | tracker: Optional[Tracker],
84 | background: bool,
85 | step_wait: confloat(ge=0.0) = 30, ## To avoid AWS creds error when running many in parallel
86 | pause: confloat(ge=0.0) = 3,
87 | dataset_name: DatasetName,
88 | model_name: Optional[ModelName] = None,
89 | num_samples_per_label: Optional[conint(ge=10)] = None,
90 | seed_type: Optional[Literal["generated", "train_set"]] = None,
91 | seed_size: Optional[conint(ge=1)] = None,
92 | seed_set_stratify_on_ground_truth: bool = True,
93 | seed_generation_params: Optional[Dict] = None,
94 | top_p: confloat(ge=0.0, le=1.0) = DEFAULT_TOP_P,
95 | temperature: confloat(ge=0.0, le=1e6) = DEFAULT_TEMPERATURE,
96 | icl_and_prompt_template: Optional[Dict[str, str]] = None,
97 | label_verbalizer: Optional[Dict[str, str]] = None,
98 | num_shots_list: Optional[List[conint(ge=0)]] = None,
99 | metrics_overall_num_samples_per_label: conint(ge=10),
100 | metrics_other_label_num_samples_per_label: Optional[conint(ge=10)] = None,
101 | corpus: Optional[Corpus] = None,
102 | retriever: Optional[Retriever] = None,
103 | icl_type: Literal["retrieved", "curated", "seed"] = "retrieved",
104 | retrieval_top_k: conint(ge=1) = 500,
105 | retr_icl_top_ks: Tuple[conint(ge=1), ...] = (1, 2),
106 | retr_icl_distance_range: Tuple[float, float] = (0.5, 0.9),
107 | retr_icl_token_range: Optional[Tuple[conint(ge=1), conint(ge=1)]] = None,
108 | synthesizrr_top_k_range: Optional[range] = None,
109 | synthesizrr_distance_range: Tuple[float, float] = (0.4, 0.9), ## (0.4, 0.9)
110 | llm_batch_size: Optional[conint(ge=1)] = None,
111 | llm_submission_batch_size: conint(ge=1) = 24,
112 | llm_tracking_batch_size: Optional[conint(ge=1)] = None,
113 | llm_num_concurrent_preds: conint(ge=1) = 6,
114 | llm_num_models: Optional[conint(ge=1)] = None,
115 | llm_resources_per_model: Optional[
116 | Dict[Literal["cpu", "gpu"], Union[confloat(ge=0.0, lt=1.0), conint(ge=0)]]
117 | ] = None,
118 | llm_load_balancing_strategy: LoadBalancingStrategy = LoadBalancingStrategy.LEAST_USED,
119 | llm_evaluation_timeout: confloat(ge=0.0, allow_inf_nan=True) = math.inf,
120 | text_gens_parser: Optional[Callable] = None,
121 | text_gens_parser_type: Literal["default", "rejection"] = "rejection",
122 | filter_params: DatasetFilterParams = DatasetFilterParams(filter_type="none"),
123 | metrics_calc_overall: bool = True,
124 | metrics_calc_labels: bool = False,
125 | metrics_max_parallel: conint(ge=1) = 2,
126 | metrics_override_row_count: bool = False,
127 | dataset_cartography_student: Student = Student.DistilBERT,
128 | dataset_cartography_text: Literal["context", "generations"] = "generations",
129 | metrics_to_evaluate: Optional[Tuple[MetricName, ...]] = (
130 | MetricName.RowCount,
131 | MetricName.TextLength,
132 | MetricName.EntityCount,
133 | MetricName.SelfBLEU,
134 | MetricName.Mauve,
135 | MetricName.StudentDistilBERT_AttrPromptTable13,
136 | MetricName.StudentDeBERTaV3Large,
137 | # MetricName.StudentHPOTinyBert,
138 | # MetricName.StudentHPOMiniLM,
139 | # MetricName.StudentHPODistilBERT,
140 | # MetricName.StudentHPOBERT,
141 | # MetricName.StudentHPODeBERTaV3Base,
142 | # MetricName.StudentHPODeBERTaV3Large,
143 | MetricName.LabelPreservation,
144 | MetricName.StudentDatasetCartography,
145 | MetricName.SaveFilteredDataset,
146 | ),
147 | metrics_label_distribution: Literal["balanced", "train_set"] = "train_set",
148 | dry_run: bool = False,
149 | verbosity: conint(ge=0) = 2,
150 | cart_frac: Optional[confloat(gt=0.0)] = None, # 0.83,
151 | ):
152 | if cart_frac is not None:
153 | filter_params = dict(
154 | filter_type="cartography",
155 | cartography_apply="label",
156 | cartography_confidence_frac=("top", cart_frac),
157 | )
158 | metrics_overall_num_samples_per_label: int = int(cart_frac * metrics_overall_num_samples_per_label)
159 |
160 | if expt is Experiment.SynthesizRR:
161 | assert corpus is not None
162 | assert retriever is not None
163 |
164 | if expt is not Experiment.Gold:
165 | assert model_name is not None
166 | assert num_samples_per_label is not None
167 | assert seed_type is not None
168 |
169 | if llm_num_models is None:
170 | ray_num_gpus: int = int(ray.cluster_resources()["GPU"])
171 | llm_num_gpus: int = model_name.llm_resources_per_model().get("gpu", 0)
172 | assert llm_num_gpus >= 0
173 | if llm_num_gpus > 0:
174 | ## LLaMa etc.
175 | llm_num_models: Optional[int] = ray_num_gpus // llm_num_gpus
176 | else:
177 | ## Claude, ChatGPT, etc.
178 | llm_num_models: Optional[int] = None
179 | llm_batch_size: int = get_default(
180 | llm_batch_size,
181 | model_name.llm_batch_size(dataset_name=dataset_name, expt=expt),
182 | )
183 | llm_resources_per_model: Dict[
184 | Literal["cpu", "gpu"], Union[confloat(ge=0.0, lt=1.0), conint(ge=0)]
185 | ] = get_default(
186 | llm_resources_per_model,
187 | model_name.llm_resources_per_model(),
188 | )
189 | metrics_num_samples_per_label: Dict[str, int] = {
190 | LABEL_OVERALL: metrics_overall_num_samples_per_label,
191 | }
192 |
193 | label_verbalizer: Dict[str, str] = get_default(label_verbalizer, dataset_name.label_verbalizer())
194 | if metrics_other_label_num_samples_per_label is not None:
195 | for label_text in label_verbalizer.keys():
196 | metrics_num_samples_per_label[label_text] = metrics_other_label_num_samples_per_label
197 |
198 | if expt is Experiment.Gold:
199 | text_gens_parser: Optional[Callable] = None
200 | elif text_gens_parser is None:
201 | if text_gens_parser_type == "default":
202 | text_gens_parser: Callable = dataset_name.text_gens_parser()
203 | elif text_gens_parser_type == "rejection":
204 | text_gens_parser: Callable = dataset_name.text_gens_parser_rejection(expt=expt)
205 | else:
206 | raise not_impl("text_gens_parser_type", text_gens_parser_type)
207 |
208 | exn_input = dict(
209 | results_dir=results_dir,
210 | dataset_name=dataset_name,
211 | label_verbalizer=label_verbalizer,
212 | seed_type=seed_type,
213 | seed_size=get_default(seed_size, dataset_name.seed_size()),
214 | seed_set_stratify_on_ground_truth=seed_set_stratify_on_ground_truth,
215 | seed_generation_params=seed_generation_params,
216 | model_name=model_name,
217 | num_shots_list=get_default(
218 | num_shots_list,
219 | {
220 | Experiment.Gold: [None],
221 | Experiment.SynthesizRR: [0, 3],
222 | Experiment.FewGen: [0, 32],
223 | }[expt],
224 | ),
225 | num_samples_per_label=num_samples_per_label,
226 | top_p=top_p,
227 | temperature=temperature,
228 | **get_default(icl_and_prompt_template, {}),
229 | **(
230 | {
231 | Experiment.Gold: lambda: dict(
232 | dataset_cartography_student=dataset_cartography_student,
233 | ),
234 | Experiment.FewGen: lambda: dict(
235 | fewgen_max_tokens=dataset_name.max_num_tokens(), ## Max number of output tokens
236 | dataset_cartography_student=dataset_cartography_student,
237 | text_gens_parser=text_gens_parser,
238 | filter_params=filter_params,
239 | ),
240 | Experiment.SynthesizRR: lambda: dict(
241 | synthesizrr_max_tokens=dataset_name.max_num_tokens(), ## Max number of output tokens
242 | corpus=corpus,
243 | corpus_raw_text_dir=corpus.raw_text_dir(),
244 | retriever=retriever,
245 | icl_type=icl_type,
246 | retrieval_top_k=retrieval_top_k,
247 | retr_icl_top_ks=retr_icl_top_ks,
248 | retr_icl_distance_range=retr_icl_distance_range,
249 | retr_icl_token_range=get_default(retr_icl_token_range, corpus.context_token_range()),
250 | synthesizrr_top_k_range=get_default(
251 | synthesizrr_top_k_range, corpus.synthesizrr_top_k_range()
252 | ),
253 | synthesizrr_distance_range=synthesizrr_distance_range,
254 | dataset_cartography_text=dataset_cartography_text,
255 | dataset_cartography_student=dataset_cartography_student,
256 | text_gens_parser=text_gens_parser,
257 | filter_params=filter_params,
258 | ),
259 | }[expt]()
260 | ),
261 | llm_batch_size=llm_batch_size,
262 | llm_submission_batch_size=llm_submission_batch_size,
263 | llm_tracking_batch_size=llm_tracking_batch_size,
264 | llm_resources_per_model=llm_resources_per_model,
265 | llm_num_models=llm_num_models,
266 | llm_num_concurrent_preds=llm_num_concurrent_preds,
267 | llm_load_balancing_strategy=llm_load_balancing_strategy,
268 | llm_evaluation_timeout=llm_evaluation_timeout,
269 | metrics_to_evaluate=get_default(metrics_to_evaluate, []),
270 | metrics_num_samples_per_label=metrics_num_samples_per_label,
271 | metrics_label_distribution=metrics_label_distribution,
272 | metrics_calc_overall=metrics_calc_overall,
273 | metrics_calc_labels=metrics_calc_labels,
274 | metrics_max_parallel=metrics_max_parallel,
275 | metrics_override_row_count=metrics_override_row_count,
276 | )
277 |
278 | wf: Chain = get_wf(expt, metrics=metrics_to_evaluate is not None)
279 | print(set(wf.all_step_inputs(required_only=True)) - set(exn_input.keys()))
280 | if dry_run:
281 | return exn_input
282 | else:
283 | exn = wf.run(
284 | **exn_input,
285 | notifier=notifier,
286 | tracker=tracker,
287 | verbosity=verbosity,
288 | background=background,
289 | step_wait=step_wait,
290 | )
291 | time.sleep(pause)
292 | return exn
293 |
--------------------------------------------------------------------------------
/src/synthesizrr/corpus.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import io
3 | import math
4 | from typing import *
5 |
6 | import orjson
7 | import pandas as pd
8 | from bs4 import BeautifulSoup as BS
9 | from fmcore.constants import DataLayout
10 | from fmcore.data import FileMetadata, Reader, to_sdf
11 | from fmcore.util import (
12 | ProgressBar,
13 | StringUtil,
14 | Timer,
15 | accumulate,
16 | as_list,
17 | as_tuple,
18 | get_default,
19 | is_null,
20 | run_concurrent,
21 | run_parallel_ray,
22 | wait,
23 | whitespace_normalize,
24 | )
25 |
26 | CORPUS_DIR: str = "" # TODO: fill this out! Recommended to use S3.
27 |
28 |
29 | def create_amazon_products():
30 | corpus_dir: FileMetadata = FileMetadata.of(
31 | f"{CORPUS_DIR}/data/amazon-reviews/2018/meta/"
32 | )
33 | corpus_dir.mkdir()
34 | if len(corpus_dir.list()) == 0:
35 | raise SystemError(
36 | f'Expected Amazon Products metadata to be in folder "{corpus_dir.path}". '
37 | f"Please download the data file All_Amazon_Meta.json.gz and unzip to this directory "
38 | f"(to get this data you need to submit the form at https://nijianmo.github.io/amazon/index.html#complete-data)"
39 | )
40 |
41 | source: str = corpus_dir.file_in_dir("All_Amazon_Meta.json")
42 | all_dfs: List[pd.DataFrame] = []
43 | buf = []
44 | df_pbar = ProgressBar.of(unit="file")
45 | row_pbar = ProgressBar.of(unit="rows", miniters=10_000)
46 | with io.open(source, "rb") as inp:
47 | for line in inp:
48 | buf.append(orjson.loads(line))
49 | row_pbar.update(1)
50 | if len(buf) == 100_000:
51 | all_dfs.append(pd.DataFrame(buf))
52 | df_pbar.update(1)
53 | buf = []
54 | row_pbar.success()
55 | all_dfs.append(pd.DataFrame(buf))
56 | df_pbar.update(1)
57 | # futs.append(run_concurrent(
58 | # write_df,
59 | # df=all_dfs[-1],
60 | # n=df_pbar.pbar.n,
61 | # ))
62 | buf = []
63 | gc.collect()
64 |
65 | # fpaths = accumulate(futs, progress=dict(desc='Writing'))
66 |
67 | def _convert_row(row):
68 | if not is_null(row["category"]):
69 | row["category"] = as_tuple(row["category"])
70 | if not is_null(row["description"]):
71 | row["description"] = as_tuple(row["description"])
72 | if not is_null(row["also_buy"]):
73 | row["also_buy"] = as_tuple(row["also_buy"])
74 | if not is_null(row["image"]):
75 | row["image"] = as_tuple(row["image"])
76 | if not is_null(row["feature"]):
77 | row["feature"] = as_tuple(row["feature"])
78 | if not is_null(row["also_view"]):
79 | row["also_view"] = as_tuple(row["also_view"])
80 | if not is_null(row["rank"]):
81 | row["rank"] = as_tuple(row["rank"])
82 | if is_null(row["details"]):
83 | row["details"] = {}
84 | return row
85 |
86 | def _convert_df(df_part):
87 | return (
88 | to_sdf(df_part)
89 | .to_layout(DataLayout.LIST_OF_DICT)
90 | .apply(_convert_row, axis=1)
91 | .pandas()
92 | )
93 |
94 | corpus_split_dir: FileMetadata = corpus_dir.subdir_in_dir(
95 | "split", return_metadata=True
96 | )
97 | futs = []
98 | for df_part_i, df_part in enumerate(all_dfs):
99 | df_part = _convert_df(df_part)
100 | dest: str = corpus_split_dir.file_in_dir(
101 | f"amazon-reviews-2018-meta-part-{StringUtil.pad_zeros(df_part_i)}.parquet"
102 | )
103 | futs.append(
104 | run_concurrent(
105 | df_part.to_parquet,
106 | dest,
107 | )
108 | )
109 | print(df_part_i)
110 | accumulate(futs, progress=dict(desc="Writing", unit="file"))
111 |
112 | with Timer():
113 | prods = Reader.of(
114 | "parquet",
115 | data_schema={
116 | "asin": "object",
117 | # "also_buy": 'object',
118 | # "also_view": 'object',
119 | "title": "object",
120 | "description": "object",
121 | "brand": "object",
122 | "category": "object",
123 | "date": "object",
124 | # "details": 'object',
125 | "feature": "object",
126 | "fit": "object",
127 | # "image": 'object',
128 | "main_cat": "object",
129 | "price": "object",
130 | # "rank": 'object',
131 | # "similar_item": 'object',
132 | # "tech1": 'object',
133 | # "tech2": 'object',
134 | },
135 | ).read(
136 | corpus_split_dir,
137 | read_as=DataLayout.PANDAS,
138 | )
139 | prods = prods.drop_duplicates("asin").persist(wait=True)
140 |
141 | corpus_raw_text_dir: FileMetadata = corpus_dir.subdir_in_dir(
142 | "raw-text", return_metadata=True
143 | )
144 | with Timer():
145 |
146 | def create_product_text(row):
147 | product_text: str = ""
148 | for col in ["title", "description"]:
149 | for text in as_list(row[col]):
150 | product_text += f"- {get_default(text, '')}
"
151 | product_text: str = BS(product_text).get_text(separator="\n")
152 | for i in range(1, 10):
153 | product_text: str = product_text.replace(f"{i}.", f"{i}. ")
154 | product_text: str = whitespace_normalize(product_text)
155 | return product_text
156 |
157 | def set_product_text(_prods_part):
158 | _prods_part["product_text"] = _prods_part.apply(create_product_text, axis=1)
159 | return _prods_part
160 |
161 | prods_list = []
162 | for prods_part_i, prods_part in ProgressBar.iter(
163 | enumerate(prods.stream(stream_as=DataLayout.PANDAS, batch_size=10_000)),
164 | total=math.ceil(len(prods) / 10_000),
165 | ):
166 | prods_list.append(run_parallel_ray(set_product_text, prods_part))
167 | prods_list: List[pd.DataFrame] = accumulate(prods_list, progress=True)
168 |
169 | futs = []
170 | for prods_part_i, prods_part in ProgressBar.iter(
171 | enumerate(prods_list),
172 | total=len(prods_list),
173 | ):
174 | prods_part = prods_part.reset_index(drop=True)
175 | prods_part["asin"] = prods_part["asin"].astype(str)
176 | fpath: str = corpus_raw_text_dir.file_in_dir(
177 | f"amazon-products-2018-raw-text-part-{StringUtil.pad_zeros(prods_part_i)}.parquet"
178 | )
179 | futs.append(run_concurrent(prods_part.to_parquet, fpath))
180 | wait(futs, progress=True)
181 | print(
182 | f'Done creating Amazon Products corpus, final data is at: "{corpus_raw_text_dir.path}"'
183 | )
184 |
185 | def amazon_products_count_num_tokens(df_path):
186 | df_part = Reader.of(
187 | "parquet",
188 | data_schema={
189 | "asin": "index",
190 | "title": "text",
191 | "description": "text",
192 | },
193 | ).read(df_path, raw=True)
194 | ser_part = (
195 | df_part["title"].fillna("").astype(str)
196 | + " "
197 | + df_part["description"].apply(lambda x: " ".join(x)).astype(str)
198 | )
199 | from transformers import AutoTokenizer
200 |
201 | tokenizer = AutoTokenizer.from_pretrained("TheBloke/Llama-2-13B-fp16")
202 | sum_is: int = sum(
203 | [
204 | len(input_ids)
205 | for input_ids in tokenizer(
206 | ser_part.tolist(), add_special_tokens=False
207 | )["input_ids"]
208 | ]
209 | )
210 | return sum_is
211 |
212 | counts: List[int] = accumulate(
213 | [
214 | run_parallel_ray(
215 | amazon_products_count_num_tokens,
216 | df_path=df_path,
217 | )
218 | for df_path in FileMetadata.of(
219 | corpus_raw_text_dir.path,
220 | file_glob="*.parquet",
221 | ).list()
222 | ],
223 | progress=True,
224 | )
225 | print(
226 | f"Amazon Products corpus has {round(sum(counts) / 1e9, 2)} billion tokens"
227 | )
228 |
229 |
230 | REALNEWS_REGIONAL_NEWS_DOMAINS: List[str] = [
231 | "mid-day.com",
232 | "financialexpress.com",
233 | "thenationonlineng.net",
234 | "livemint.com",
235 | "hindustantimes.com",
236 | "vanguardngr.com",
237 | "capitalfm.co.ke",
238 | "straitstimes.com",
239 | "indianexpress.com",
240 | "nation.com.pk",
241 | "jamaica-gleaner.com",
242 | "trend.az",
243 | "stabroeknews.com",
244 | "dawn.com",
245 | "emirates247.com",
246 | "mangalorean.com",
247 | "vccircle.com",
248 | "thisdaylive.com",
249 | "gulfnews.com",
250 | "tribune.com.pk",
251 | "arabnews.com",
252 | "pakobserver.net",
253 | "nation.co.ke",
254 | "eurasiareview.com",
255 | "thedailystar.net",
256 | "deccanchronicle.com",
257 | "jewishpress.com",
258 | "app.com.pk",
259 | "err.ee",
260 | "lankabusinessonline.com",
261 | "koreatimes.co.kr",
262 | "newera.com.na",
263 | "ticotimes.net",
264 | "codewit.com",
265 | "sunnewsonline.com",
266 | "afaqs.com",
267 | "ameinfo.com",
268 | "malaysiakini.com",
269 | "ynetnews.com",
270 | "palestinechronicle.com",
271 | "zmescience.com",
272 | "cyprus-mail.com",
273 | "colombiareports.com",
274 | "arabtimesonline.com",
275 | "bollywoodhungama.com",
276 | "pattayamail.com",
277 | "insightcrime.org",
278 | "medianewsline.com",
279 | "dailytimes.com.pk",
280 | "chinadigitaltimes.net",
281 | "saudigazette.com.sa",
282 | "newsday.co.zw",
283 | "sunstar.com.ph",
284 | "nehandaradio.com",
285 | "freemalaysiatoday.com",
286 | "onlanka.com",
287 | "thezimbabwemail.com",
288 | "theeastafrican.co.ke",
289 | "thecitizen.co.tz",
290 | "lusakatimes.com",
291 | "orissadiary.com",
292 | "aljazeera.com",
293 | "tehrantimes.com",
294 | "theborneopost.com",
295 | "morungexpress.com",
296 | "monitor.co.ug",
297 | "countercurrents.org",
298 | "businessworld.in",
299 | "governancenow.com",
300 | "itweb.co.za",
301 | "972mag.com",
302 | "memeburn.com",
303 | "themediaonline.co.za",
304 | "koimoi.com",
305 | "caribbean360.com",
306 | "yalibnan.com",
307 | "milligazette.com",
308 | "thefrontierpost.com",
309 | "kuwaittimes.net",
310 | "somalilandpress.com",
311 | "thestkittsnevisobserver.com",
312 | "news24.com",
313 | "livinginperu.com",
314 | "journal.com.ph",
315 | "bworldonline.com",
316 | "venezuelanalysis.com",
317 | "businessdayonline.com",
318 | "macaudailytimes.com.mo",
319 | "ghanabusinessnews.com",
320 | "trinidadexpress.com",
321 | "pmnewsnigeria.com",
322 | "lankanewspapers.com",
323 | "asiasentinel.com",
324 | "maravipost.com",
325 | "dayafterindia.com",
326 | "defense-update.com",
327 | "antiguaobserver.com",
328 | "newsbytes.ph",
329 | "truthdive.com",
330 | "thehimalayantimes.com",
331 | "standardmedia.co.ke",
332 | "groundviews.org",
333 | "japantoday.com",
334 | "kbc.co.ke",
335 | "mindanews.com",
336 | "thejakartaglobe.com",
337 | "actionforex.com",
338 | "modernghana.com",
339 | "newstodaynet.com",
340 | "centralchronicle.com",
341 | "dalje.com",
342 | "escambray.cu",
343 | "middle-east-online.com",
344 | "theminaretonline.com",
345 | "pakistankakhudahafiz.com",
346 | "meed.com",
347 | "tribwekchron.com",
348 | "thenews.com.pk",
349 | "iafrica.com",
350 | "philstar.com",
351 | "praguepost.com",
352 | "yonhapnews.co.kr",
353 | "china.org.cn",
354 | "rtn.asia",
355 | "nationalturk.com",
356 | "thebraziltimes.com",
357 | "businessdailyafrica.com",
358 | "hku.hk",
359 | "intifada-palestine.com",
360 | "realbollywood.com",
361 | "pak1stanfirst.com",
362 | "mutiny.in",
363 | "mareeg.com",
364 | "paltelegraph.com",
365 | "pakwatan.com",
366 | "mybroadband.co.za",
367 | "african-bulletin.com",
368 | "thedailynewsegypt.com",
369 | "7days.ae",
370 | "dailyforex.com",
371 | "melodika.net",
372 | ]
373 |
374 | REALNEWS_INDIAN_NEWS_DOMAINS: List[str] = [
375 | "mid-day.com",
376 | "financialexpress.com",
377 | "livemint.com",
378 | "hindustantimes.com",
379 | "indianexpress.com",
380 | "mangalorean.com",
381 | "vccircle.com",
382 | "deccanchronicle.com",
383 | "afaqs.com",
384 | "bollywoodhungama.com",
385 | "medianewsline.com",
386 | "orissadiary.com",
387 | "morungexpress.com",
388 | "countercurrents.org",
389 | "businessworld.in",
390 | "governancenow.com",
391 | "koimoi.com",
392 | "milligazette.com",
393 | "dayafterindia.com",
394 | "truthdive.com",
395 | "newstodaynet.com",
396 | "centralchronicle.com",
397 | "dalje.com",
398 | "rtn.asia",
399 | "realbollywood.com",
400 | "mutiny.in",
401 | ]
402 |
403 |
404 | def count_num_tokens(ser_part):
405 | from transformers import AutoTokenizer
406 |
407 | tokenizer = AutoTokenizer.from_pretrained("TheBloke/Llama-2-13B-fp16")
408 | sum_is: int = sum(
409 | [
410 | len(input_ids)
411 | for input_ids in tokenizer(ser_part.tolist(), add_special_tokens=False)[
412 | "input_ids"
413 | ]
414 | ]
415 | )
416 | return sum_is
417 |
418 |
419 | def write_realnews_partition(rn_data, *, corpus_dir: FileMetadata, rn_name: str):
420 | with Timer(f"Writing partition {rn_name}"):
421 | corpus_partition_dir: FileMetadata = corpus_dir.subdir_in_dir(
422 | rn_name, return_metadata=True
423 | )
424 | print(f"{rn_name} length: {len(rn_data)}")
425 | futs = []
426 | for rn_part_i, rn_part in ProgressBar.iter(
427 | enumerate(rn_data.stream(stream_as=DataLayout.PANDAS, batch_size=100_000)),
428 | total=math.ceil(len(rn_data) / 100_000),
429 | ):
430 | rn_part = rn_part.reset_index(drop=True)
431 | fpath: str = corpus_partition_dir.file_in_dir(
432 | f"realnews-{rn_name}-part-{StringUtil.pad_zeros(rn_part_i)}.parquet"
433 | )
434 | futs.append(run_concurrent(rn_part.to_parquet, fpath))
435 | wait(futs, progress=True)
436 | print(
437 | f'Done creating {rn_name} partition, final data is at: "{corpus_partition_dir.path}"'
438 | )
439 |
440 | counts: pd.Series = rn_data["text"].map_partitions(count_num_tokens).compute()
441 | print(f"{rn_name} corpus has {round(counts.sum() / 1e9, 2)} billion tokens")
442 |
443 |
444 | def create_realnews():
445 | corpus_dir: FileMetadata = FileMetadata.of(f"{CORPUS_DIR}/data/realnews/")
446 | corpus_dir.mkdir()
447 | if len(corpus_dir.list()) == 0:
448 | raise SystemError(
449 | f'Expected RealNews to be in folder "{corpus_dir.path}". '
450 | f"Please download the data file realnews.jsonl to this directory "
451 | f"(to get this data you need to submit the form at https://github.com/rowanz/grover/tree/master/realnews)"
452 | )
453 |
454 | with Timer("Reading and splitting realnews.jsonl"):
455 | source: str = corpus_dir.file_in_dir("realnews.jsonl")
456 | all_dfs: List[pd.DataFrame] = []
457 | buf = []
458 | df_pbar = ProgressBar.of(unit="file")
459 | row_pbar = ProgressBar.of(unit="rows", miniters=10_000)
460 | row_idx = 0
461 | with io.open(source, "rb") as inp:
462 | for line in inp:
463 | buf.append(orjson.loads(line))
464 | buf[-1]["idx"] = row_idx
465 | row_idx += 1
466 | row_pbar.update(1)
467 | if len(buf) == 100_000:
468 | all_dfs.append(pd.DataFrame(buf))
469 | df_pbar.update(1)
470 | buf = []
471 | row_pbar.success()
472 | all_dfs.append(pd.DataFrame(buf))
473 | df_pbar.update(1)
474 | buf = []
475 | gc.collect()
476 |
477 | corpus_split_dir: FileMetadata = corpus_dir.subdir_in_dir(
478 | "split", return_metadata=True
479 | )
480 | futs = []
481 | for df_part_i, df_part in enumerate(all_dfs):
482 | dest: str = corpus_split_dir.file_in_dir(
483 | f"realnews-part-{StringUtil.pad_zeros(df_part_i)}.parquet"
484 | )
485 | futs.append(
486 | run_concurrent(
487 | df_part.to_parquet,
488 | dest,
489 | )
490 | )
491 | print(df_part_i)
492 | accumulate(futs, progress=dict(desc="Writing", unit="file"))
493 |
494 | with Timer("Reading split files"):
495 | realnews_data_schema: Dict = {
496 | "idx": "index",
497 | "title": "object",
498 | "text": "text",
499 | "summary": "object",
500 | "authors": "categorical",
501 | "publish_date": "object",
502 | "status": "categorical",
503 | "url": "categorical",
504 | "domain": "categorical",
505 | "warc_date": "object",
506 | "split": "categorical",
507 | }
508 | realnews = Reader.of(
509 | "parquet",
510 | data_schema=realnews_data_schema,
511 | ).read(
512 | FileMetadata.of(
513 | corpus_split_dir.path,
514 | file_format="parquet",
515 | ),
516 | read_as=DataLayout.DASK,
517 | )
518 |
519 | realnews_india = realnews.query(
520 | f"domain in {REALNEWS_INDIAN_NEWS_DOMAINS}"
521 | ).persist(wait=True)
522 | write_realnews_partition(
523 | realnews_india, corpus_dir=corpus_dir, rn_name="realnews-india"
524 | )
525 |
526 | realnews_regional = realnews.query(
527 | f"domain in {REALNEWS_REGIONAL_NEWS_DOMAINS}"
528 | ).persist(wait=True)
529 | write_realnews_partition(
530 | realnews_regional, corpus_dir=corpus_dir, rn_name="realnews-regional"
531 | )
532 |
533 | realnews_dominant = realnews.query(
534 | f"domain not in {REALNEWS_REGIONAL_NEWS_DOMAINS}"
535 | ).persist(wait=True)
536 | write_realnews_partition(
537 | realnews_dominant, corpus_dir=corpus_dir, rn_name="realnews-dominant"
538 | )
539 |
540 |
541 | def create_cmu_movies():
542 | corpus_dir: FileMetadata = FileMetadata.of(f"{CORPUS_DIR}/data/cmu_movies/")
543 | corpus_dir.mkdir()
544 | if len(corpus_dir.list()) == 0:
545 | raise SystemError(
546 | f'Expected CMU Movies to be in folder "{corpus_dir.path}". '
547 | f"Please download the data from https://www.cs.cmu.edu/~ark/personas/ and extract it. "
548 | f'You should get the folder "MovieSummaries".'
549 | )
550 | with Timer("Reading and merging plot_summaries.txt and movie.metadata.tsv"):
551 | movie_plots: pd.DataFrame = pd.read_csv(
552 | corpus_dir.subdir_in_dir(
553 | "MovieSummaries", return_metadata=True
554 | ).file_in_dir("plot_summaries.txt"),
555 | sep="\t",
556 | header=None,
557 | names=[
558 | "wiki_movie_id",
559 | "plot_summary",
560 | ],
561 | )
562 | movie_meta: pd.DataFrame = pd.read_csv(
563 | corpus_dir.subdir_in_dir(
564 | "MovieSummaries", return_metadata=True
565 | ).file_in_dir("movie.metadata.tsv"),
566 | sep="\t",
567 | header=None,
568 | names=[
569 | "wiki_movie_id",
570 | "freebase_movie_id",
571 | "title",
572 | "release_date",
573 | "box_office_revenue",
574 | "runtime",
575 | "languages",
576 | "countries",
577 | "genres",
578 | ],
579 | )
580 | movies: pd.DataFrame = (
581 | movie_meta.merge(movie_plots, on="wiki_movie_id")
582 | .reset_index(drop=True)
583 | .rename(columns=dict(plot_summary="text", wiki_movie_id="idx"))
584 | )
585 | corpus_raw_text_dir: FileMetadata = corpus_dir.subdir_in_dir(
586 | "raw-text", return_metadata=True
587 | )
588 | movies.to_parquet(corpus_raw_text_dir.file_in_dir("cmu-movie-summary.parquet"))
589 | print(
590 | f'Done creating CMU Moveis corpus, final data is at: "{corpus_raw_text_dir.path}"'
591 | )
592 |
593 | def cmu_movies_count_num_tokens(df_path):
594 | df_part = Reader.of(
595 | "parquet",
596 | data_schema={
597 | "wiki_movie_id": "index",
598 | "plot_summary": "text",
599 | },
600 | ).read(df_path, raw=True)
601 | ser_part = df_part["plot_summary"]
602 | from transformers import AutoTokenizer
603 |
604 | tokenizer = AutoTokenizer.from_pretrained("TheBloke/Llama-2-13B-fp16")
605 | sum_is: int = sum(
606 | [
607 | len(input_ids)
608 | for input_ids in tokenizer(
609 | ser_part.tolist(), add_special_tokens=False
610 | )["input_ids"]
611 | ]
612 | )
613 | return sum_is
614 |
615 | counts: List[int] = accumulate(
616 | [
617 | run_parallel_ray(
618 | cmu_movies_count_num_tokens,
619 | df_path=df_path,
620 | )
621 | for df_path in FileMetadata.of(
622 | corpus_raw_text_dir.path,
623 | file_glob="*.parquet",
624 | ).list()
625 | ],
626 | progress=True,
627 | )
628 | print(f"CMU Movies corpus has {round(sum(counts) / 1e6, 2)} million tokens")
629 |
630 |
631 | if __name__ == "__main__":
632 | create_amazon_products()
633 | gc.collect()
634 | create_realnews()
635 | gc.collect()
636 | create_cmu_movies()
637 | gc.collect()
638 |
--------------------------------------------------------------------------------
/src/synthesizrr/data.py:
--------------------------------------------------------------------------------
1 | from typing import *
2 |
3 | import pandas as pd
4 | from datasets import load_dataset as hf_load_dataset
5 | from fmcore.data.writer import Writer
6 | from fmcore.framework.dl.torch import *
7 | from fmcore.framework.task_data import Dataset, Datasets, DataSplit
8 | from sklearn.model_selection import train_test_split
9 |
10 | DATA_DIR: str = "" # TODO: fill this out!
11 | DATA_DIR: str = "" # TODO: fill this out!
12 |
13 | DATA_DIR: str = "" # TODO: fill this out!
14 |
15 |
16 | class SynthesizRRDataset(Registry):
17 | name: ClassVar[str]
18 | task: ClassVar[Task]
19 | data_schema: ClassVar[Dict[str, MLType]]
20 |
21 | @classmethod
22 | def _registry_keys(cls) -> Optional[Union[List[Any], Any]]:
23 | return cls.name
24 |
25 | @classmethod
26 | def get(cls, name: str) -> Type["SynthesizRRDataset"]:
27 | return cls.get_subclass(name)
28 |
29 | @classproperty
30 | def dataset_dir(cls) -> FileMetadata:
31 | dataset_dir: FileMetadata = FileMetadata.of(f"{DATA_DIR}/{cls.name}/")
32 | dataset_dir.mkdir()
33 | return dataset_dir
34 |
35 | @classproperty
36 | def train_path(cls) -> str:
37 | return cls.dataset_dir.file_in_dir(f"{cls.name}_train.parquet")
38 |
39 | @classproperty
40 | def validation_path(cls) -> str:
41 | return cls.dataset_dir.file_in_dir(f"{cls.name}_validation.parquet")
42 |
43 | @classproperty
44 | def test_path(cls) -> str:
45 | return cls.dataset_dir.file_in_dir(f"{cls.name}_test.parquet")
46 |
47 | @classproperty
48 | def unsupervised_path(cls) -> str:
49 | return cls.dataset_dir.file_in_dir(f"{cls.name}_unsupervised.parquet")
50 |
51 | @classproperty
52 | def schema_path(cls) -> str:
53 | return cls.dataset_dir.file_in_dir(f"{cls.name}_schema.json")
54 |
55 | @classproperty
56 | def unsupervised_schema_path(cls) -> str:
57 | return cls.dataset_dir.file_in_dir(f"{cls.name}_unsupervised_schema.json")
58 |
59 | @classproperty
60 | def supervised_data_schema(cls) -> MLTypeSchema:
61 | return copy.deepcopy(MLType.convert_values(cls.data_schema))
62 |
63 | @classproperty
64 | def unsupervised_data_schema(cls) -> MLTypeSchema:
65 | data_schema: Dict[str, MLType] = copy.deepcopy(
66 | MLType.convert_values(cls.data_schema)
67 | )
68 | data_schema: MLTypeSchema = {
69 | col: mltype
70 | for col, mltype in data_schema.items()
71 | if mltype not in GROUND_TRUTH_ML_TYPES
72 | }
73 | return data_schema
74 |
75 | @classmethod
76 | def raw_train(cls) -> Optional[pd.DataFrame]:
77 | return None
78 |
79 | @classmethod
80 | def has_train(cls) -> bool:
81 | return "raw_train" in cls.__dict__
82 |
83 | @classmethod
84 | def raw_validation(cls) -> Optional[pd.DataFrame]:
85 | return None
86 |
87 | @classmethod
88 | def has_validation(cls) -> bool:
89 | return "raw_validation" in cls.__dict__
90 |
91 | @classmethod
92 | def raw_test(cls) -> Optional[pd.DataFrame]:
93 | return None
94 |
95 | @classmethod
96 | def has_test(cls) -> bool:
97 | return "raw_test" in cls.__dict__
98 |
99 | @classmethod
100 | def raw_unsupervised(cls) -> Optional[pd.DataFrame]:
101 | return None
102 |
103 | @classmethod
104 | def has_unsupervised(cls) -> bool:
105 | return "raw_unsupervised" in cls.__dict__
106 |
107 | @classmethod
108 | def decode_labels(cls, row) -> pd.Series:
109 | raise NotImplementedError()
110 |
111 | @classmethod
112 | def has_decode_labels(cls) -> bool:
113 | return "decode_labels" in cls.__dict__
114 |
115 | @classmethod
116 | def setup(cls):
117 | if cls.has_train():
118 | Writer.of("json").write(
119 | data=cls.supervised_data_schema,
120 | destination=cls.schema_path,
121 | overwrite=True,
122 | )
123 | idx_col: str = only_item(
124 | [
125 | col
126 | for col, mltype in cls.supervised_data_schema.items()
127 | if mltype is MLType.INDEX
128 | ]
129 | )
130 | train: pd.DataFrame = cls.raw_train()
131 | if cls.has_decode_labels():
132 | train: pd.DataFrame = train.apply(cls.decode_labels, axis=1)
133 | if len(train) != train[idx_col].nunique():
134 | raise ValueError(
135 | f'Expected unique index column "{idx_col}" for split="train".'
136 | )
137 | train.reset_index(drop=True).to_parquet(cls.train_path)
138 | if cls.has_validation():
139 | idx_col: str = only_item(
140 | [
141 | col
142 | for col, mltype in cls.supervised_data_schema.items()
143 | if mltype is MLType.INDEX
144 | ]
145 | )
146 | validation: pd.DataFrame = cls.raw_validation()
147 | if cls.has_decode_labels():
148 | validation: pd.DataFrame = validation.apply(cls.decode_labels, axis=1)
149 | if len(validation) != validation[idx_col].nunique():
150 | raise ValueError(
151 | f'Expected unique index column "{idx_col}" for split="validation".'
152 | )
153 | validation.reset_index(drop=True).to_parquet(cls.validation_path)
154 | if cls.has_test():
155 | idx_col: str = only_item(
156 | [
157 | col
158 | for col, mltype in cls.supervised_data_schema.items()
159 | if mltype is MLType.INDEX
160 | ]
161 | )
162 | test: pd.DataFrame = cls.raw_test()
163 | if cls.has_decode_labels():
164 | test: pd.DataFrame = test.apply(cls.decode_labels, axis=1)
165 | if len(test) != test[idx_col].nunique():
166 | raise ValueError(
167 | f'Expected unique index column "{idx_col}" for split="test".'
168 | )
169 | test.reset_index(drop=True).to_parquet(cls.test_path)
170 | if cls.has_unsupervised():
171 | idx_col: str = only_item(
172 | [
173 | col
174 | for col, mltype in cls.unsupervised_data_schema.items()
175 | if mltype is MLType.INDEX
176 | ]
177 | )
178 | Writer.of("json").write(
179 | data=cls.unsupervised_data_schema,
180 | destination=cls.unsupervised_schema_path,
181 | overwrite=True,
182 | )
183 | unsupervised: pd.DataFrame = cls.raw_unsupervised()
184 | if len(unsupervised) != unsupervised[idx_col].nunique():
185 | raise ValueError(
186 | f'Expected unique index column "{idx_col}" for split="unsupervised".'
187 | )
188 | unsupervised.reset_index(drop=True).to_parquet(cls.unsupervised_path)
189 |
190 | @classproperty
191 | def train(cls) -> Optional[Dataset]:
192 | return cls.datasets.train
193 |
194 | @classmethod
195 | @safe_validate_arguments
196 | def create_seed_set(
197 | cls,
198 | seed_size: int,
199 | *,
200 | data_split: DataSplit = DataSplit.TRAIN,
201 | random_state: int = 42,
202 | stratify_on_ground_truth: bool = False,
203 | ) -> Dataset:
204 | dataset: Dataset = cls.datasets[data_split].read(read_as=DataLayout.PANDAS)
205 | dataset_df: pd.DataFrame = dataset.data.pandas()
206 | gt_col: str = only_key(dataset.data_schema.ground_truths_schema)
207 | if stratify_on_ground_truth:
208 | _, seed_dataset_df = train_test_split(
209 | dataset_df,
210 | test_size=seed_size,
211 | random_state=random_state,
212 | stratify=dataset_df[gt_col],
213 | )
214 | else:
215 | _, seed_dataset_df = train_test_split(
216 | dataset_df,
217 | test_size=seed_size,
218 | random_state=random_state,
219 | )
220 | return dataset.update_params(data=seed_dataset_df)
221 |
222 | @classproperty
223 | def validation(cls) -> Optional[Dataset]:
224 | return cls.datasets.validation
225 |
226 | @classproperty
227 | def test(cls) -> Optional[Dataset]:
228 | return cls.datasets.test
229 |
230 | @classproperty
231 | def unsupervised(cls) -> Optional[Dataset]:
232 | return cls.datasets.unsupervised
233 |
234 | @classproperty
235 | def datasets(cls) -> Datasets:
236 | datasets: Dict[str, Dataset] = {}
237 | if cls.has_train():
238 | datasets[DataSplit.TRAIN] = Dataset.of(
239 | data_split=DataSplit.TRAIN,
240 | task=cls.task,
241 | data=FileMetadata.of(cls.train_path),
242 | data_schema=cls.supervised_data_schema,
243 | )
244 | if cls.has_validation():
245 | datasets[DataSplit.VALIDATION] = Dataset.of(
246 | data_split=DataSplit.VALIDATION,
247 | task=cls.task,
248 | data=FileMetadata.of(cls.validation_path),
249 | data_schema=cls.supervised_data_schema,
250 | )
251 | if cls.has_test():
252 | datasets[DataSplit.TEST] = Dataset.of(
253 | data_split=DataSplit.TEST,
254 | task=cls.task,
255 | data=FileMetadata.of(cls.test_path),
256 | data_schema=cls.supervised_data_schema,
257 | )
258 | if cls.has_unsupervised():
259 | datasets[DataSplit.UNSUPERVISED] = Dataset.of(
260 | data_split=DataSplit.UNSUPERVISED,
261 | task=cls.task,
262 | data=FileMetadata.of(cls.unsupervised_path),
263 | data_schema=cls.unsupervised_data_schema,
264 | )
265 | return Datasets.of(**datasets)
266 |
267 | @classmethod
268 | def setup_datasets(cls):
269 | for SynthesizRRDatasetSubclass in cls.subclasses():
270 | assert issubclass(SynthesizRRDatasetSubclass, SynthesizRRDataset)
271 | SynthesizRRDatasetSubclass.setup()
272 |
273 | @classproperty
274 | def label_verbalizer(cls) -> Optional[Dict[str, str]]:
275 | return None
276 |
277 |
278 | class HyperpartisanNewsDataset(SynthesizRRDataset):
279 | name = "hyperpartisan_news"
280 | task = Task.BINARY_CLASSIFICATION
281 | data_schema = dict(
282 | id=MLType.INDEX,
283 | text=MLType.TEXT,
284 | label_text=MLType.GROUND_TRUTH,
285 | )
286 |
287 | @classmethod
288 | def raw_train(cls) -> Optional[pd.DataFrame]:
289 | return (
290 | hf_load_dataset("zapsdcn/hyperpartisan_news", split="train")
291 | .to_pandas()
292 | .rename(columns={"label": "label_text"})
293 | )
294 |
295 | @classmethod
296 | def raw_validation(cls) -> Optional[pd.DataFrame]:
297 | return (
298 | hf_load_dataset("zapsdcn/hyperpartisan_news", split="validation")
299 | .to_pandas()
300 | .rename(columns={"label": "label_text"})
301 | )
302 |
303 | @classmethod
304 | def raw_test(cls) -> Optional[pd.DataFrame]:
305 | return (
306 | hf_load_dataset("zapsdcn/hyperpartisan_news", split="test")
307 | .to_pandas()
308 | .rename(columns={"label": "label_text"})
309 | )
310 |
311 | @classproperty
312 | def label_verbalizer(cls) -> Optional[Dict[str, str]]:
313 | return {
314 | "true": "using harsh political language, using a mocking tone and toxic commentary",
315 | "false": "using neutral language, using a reasonable tone and politically correct commentary",
316 | }
317 |
318 |
319 | class AGNewsDataset(SynthesizRRDataset):
320 | name = "ag_news"
321 | task = Task.MULTI_CLASS_CLASSIFICATION
322 | data_schema = dict(
323 | id=MLType.INDEX,
324 | text=MLType.TEXT,
325 | # headline=MLType.TEXT,
326 | label_text=MLType.GROUND_TRUTH,
327 | )
328 |
329 | @classproperty
330 | def label_verbalizer(cls) -> Optional[Dict[str, str]]:
331 | return {
332 | "Business": "about companies, industries, markets, trade, investments, entrepreneurship, economic policies, and other business-related developments",
333 | "World": "about international news, such as politics, diplomacy, conflicts, global events, international relations, human rights issues, and significant global trends",
334 | "Sci/Tech": "about scientific discoveries, technological advancements, innovations, research breakthroughs",
335 | "Sports": "related to coverage of professional sports leagues, major tournaments, athletes, teams, match results, player transfers, coaching changes, sports-related controversies",
336 | }
337 |
338 | @classmethod
339 | def raw_train(cls) -> Optional[pd.DataFrame]:
340 | return hf_load_dataset(
341 | "zapsdcn/ag",
342 | split="train",
343 | features=ds.Features(
344 | {
345 | "label": ds.Value("int64"),
346 | "text": ds.Value("string"),
347 | "headline": ds.Value("string"),
348 | "id": ds.Value("string"),
349 | }
350 | ),
351 | ).to_pandas()
352 |
353 | @classmethod
354 | def raw_validation(cls) -> Optional[pd.DataFrame]:
355 | return hf_load_dataset(
356 | "zapsdcn/ag",
357 | split="validation",
358 | features=ds.Features(
359 | {
360 | "label": ds.Value("int64"),
361 | "text": ds.Value("string"),
362 | "headline": ds.Value("string"),
363 | "id": ds.Value("string"),
364 | }
365 | ),
366 | ).to_pandas()
367 |
368 | @classmethod
369 | def raw_test(cls) -> Optional[pd.DataFrame]:
370 | test_df: pd.DataFrame = hf_load_dataset(
371 | "zapsdcn/ag",
372 | split="test",
373 | features=ds.Features(
374 | {
375 | "label": ds.Value("int64"),
376 | "text": ds.Value("string"),
377 | "headline": ds.Value("string"),
378 | "id": ds.Value("string"),
379 | }
380 | ),
381 | ).to_pandas()
382 | test_df = (
383 | test_df.drop(["id"], axis=1)
384 | .reset_index(drop=True)
385 | .reset_index()
386 | .rename(columns={"index": "id"})
387 | )
388 | test_df["id"] = "idtest" + test_df["id"].astype(str)
389 | return test_df
390 |
391 | @classmethod
392 | def decode_labels(cls, row) -> pd.Series:
393 | ## Ref: https://www.kaggle.com/datasets/amananandrai/ag-news-classification-dataset
394 | lb_decoder = {
395 | 1: "World",
396 | 2: "Sports",
397 | 3: "Business",
398 | 4: "Sci/Tech",
399 | }
400 | row["label_text"] = lb_decoder[row["label"]]
401 | return row
402 |
403 |
404 | class AmazonReviewsPolarity(SynthesizRRDataset):
405 | name = "amazon-polarity"
406 | task = Task.BINARY_CLASSIFICATION
407 | data_schema = dict(
408 | idx=MLType.INDEX,
409 | # title=MLType.TEXT,
410 | text=MLType.TEXT,
411 | label_text=MLType.GROUND_TRUTH,
412 | )
413 |
414 | @classproperty
415 | def label_verbalizer(cls) -> Optional[Dict[str, str]]:
416 | return {
417 | "positive": "what the reviewer liked about the product, how the reviewer found it easy to use the product, or the reviewer's positive experience with the product",
418 | "negative": "what the reviewer disliked about the product, how the reviewer found it challenging to use the product, or the reviewer's negative experience with the product",
419 | }
420 |
421 | @classmethod
422 | def raw_train(cls) -> Optional[pd.DataFrame]:
423 | return pd.read_parquet(
424 | f"{DATA_DIR}/data/amazon-polarity-mini/amazon-polarity-mini_train.parquet"
425 | ).rename(columns={"index": "idx", "content": "text"})
426 |
427 | @classmethod
428 | def raw_validation(cls) -> Optional[pd.DataFrame]:
429 | return pd.read_parquet(
430 | f"{DATA_DIR}/data/amazon-polarity-mini/amazon-polarity-mini_validation.parquet"
431 | ).rename(columns={"index": "idx", "content": "text"})
432 |
433 | @classmethod
434 | def raw_test(cls) -> Optional[pd.DataFrame]:
435 | return pd.read_parquet(
436 | f"{DATA_DIR}/data/amazon-polarity-mini/amazon-polarity-mini_test.parquet"
437 | ).rename(columns={"index": "idx", "content": "text"})
438 |
439 | @staticmethod
440 | def create_dataset():
441 | from datasets import load_dataset as hf_load_dataset
442 |
443 | amazon_polarity_train = (
444 | hf_load_dataset("amazon_polarity", split="train")
445 | .to_pandas()
446 | .reset_index(drop=True)
447 | .reset_index()
448 | .rename(columns={"index": "idx"})
449 | )
450 | amazon_polarity_test = (
451 | hf_load_dataset("amazon_polarity", split="test")
452 | .to_pandas()
453 | .reset_index(drop=True)
454 | .reset_index()
455 | .rename(columns={"index": "idx"})
456 | )
457 | amazon_polarity_train["label_text"] = amazon_polarity_train["label"].map(
458 | {
459 | 1: "positive",
460 | 0: "negative",
461 | }
462 | )
463 | amazon_polarity_test["label_text"] = amazon_polarity_test["label"].map(
464 | {
465 | 1: "positive",
466 | 0: "negative",
467 | }
468 | )
469 | amazon_polarity_mini_train = pd.concat(
470 | [
471 | amazon_polarity_train.query('label_text == "positive"')
472 | .sample(
473 | n=72000 // 2,
474 | random_state=42,
475 | )
476 | .reset_index(drop=True),
477 | amazon_polarity_train.query('label_text == "negative"')
478 | .sample(
479 | n=72000 // 2,
480 | random_state=42,
481 | )
482 | .reset_index(drop=True),
483 | ]
484 | ).reset_index(drop=True)
485 | amazon_polarity_mini_test = pd.concat(
486 | [
487 | amazon_polarity_test.query('label_text == "positive"')
488 | .sample(
489 | n=40_000 // 2,
490 | random_state=42,
491 | )
492 | .reset_index(drop=True),
493 | amazon_polarity_test.query('label_text == "negative"')
494 | .sample(
495 | n=40_000 // 2,
496 | random_state=42,
497 | )
498 | .reset_index(drop=True),
499 | ]
500 | ).reset_index(drop=True)
501 | amazon_polarity_train_non_mini = amazon_polarity_train.query(
502 | f"idx not in {amazon_polarity_mini_train['idx'].tolist()}"
503 | )
504 |
505 | amazon_polarity_mini_validation = pd.concat(
506 | [
507 | amazon_polarity_train_non_mini.query('label_text == "positive"')
508 | .sample(
509 | n=3600 // 2,
510 | random_state=42,
511 | )
512 | .reset_index(drop=True),
513 | amazon_polarity_train_non_mini.query('label_text == "negative"')
514 | .sample(
515 | n=3600 // 2,
516 | random_state=42,
517 | )
518 | .reset_index(drop=True),
519 | ]
520 | ).reset_index(drop=True)
521 | FileMetadata.of(f"{DATA_DIR}/data/amazon-polarity-mini/").mkdir()
522 | amazon_polarity_mini_train.to_parquet(
523 | f"{DATA_DIR}/data/amazon-polarity-mini/amazon-polarity-mini_train.parquet"
524 | )
525 | amazon_polarity_mini_validation.to_parquet(
526 | f"{DATA_DIR}/data/amazon-polarity-mini/amazon-polarity-mini_validation.parquet"
527 | )
528 | amazon_polarity_mini_test.to_parquet(
529 | f"{DATA_DIR}/data/amazon-polarity-mini/amazon-polarity-mini_test.parquet"
530 | )
531 |
532 |
533 | class AmazonReviewsProductCategory(SynthesizRRDataset):
534 | name = "amazon-reviews-category"
535 | task = Task.MULTI_CLASS_CLASSIFICATION
536 | data_schema = dict(
537 | idx=MLType.INDEX,
538 | # asin=MLType.CATEGORICAL,
539 | # product_name=MLType.TEXT,
540 | # product_type=MLType.CATEGORICAL,
541 | # helpful=MLType.CATEGORICAL,
542 | # rating=MLType.CATEGORICAL,
543 | # title=MLType.TEXT,
544 | # date=MLType.CATEGORICAL,
545 | # reviewer=MLType.TEXT,
546 | # reviewer_location=MLType.CATEGORICAL,
547 | text=MLType.TEXT,
548 | label_text=MLType.GROUND_TRUTH,
549 | # sentiment=MLType.CATEGORICAL,
550 | )
551 |
552 | @classproperty
553 | def label_verbalizer(cls) -> Optional[Dict[str, str]]:
554 | return {
555 | "magazines": "magazines or periodicals covering various topics",
556 | "camera_photo": "photography gear including cameras, lenses, accessories, or photo editing tools",
557 | "office_products": "office supplies or equipment for professional and home office setups",
558 | "kitchen": "kitchenware, appliances, or culinary tools for cooking and dining",
559 | "cell_phones_service": "cell phone service accessories or service plans for communication and connectivity",
560 | "computer_video_games": "computers, gaming consoles, video games, or related accessories",
561 | "grocery_and_gourmet_food": "groceries, fruits and vegetables, gourmet treats, or specialty food items",
562 | "tools_hardware": "tools, hardware, or equipment for DIY projects and home repairs",
563 | "automotive": "auto parts, accessories, or tools for vehicle maintenance and enhancements",
564 | "music_album": "music albums spanning various genres and artists",
565 | "health_and_personal_care": "healthcare products, personal care items, or wellness essentials",
566 | "electronics": "electronic devices, gadgets, personal tech, or home electronics",
567 | "outdoor_living": "products for outdoor activities, gardening, or patio living",
568 | "video": "movies, TV shows, and documentaries spanning various genres and artists",
569 | "apparel": "clothing including casual wear, formal attire, seasonal outfits, activewear, or fashion accessories for men, women, and children",
570 | "toys_games": "fun or educational toys and games for kids of all ages",
571 | "sports_outdoors": "products for various sports and outdoor activities",
572 | "books": "books in various genres and formats",
573 | "software": "computer software for productivity or gaming covering either personal or professional needs",
574 | "baby": "baby essentials, gear, or toys for infants and toddlers",
575 | "musical_and_instruments": "musical instruments, accessories, or music production equipment",
576 | "beauty": "beauty products, cosmetics, or skincare essentials, makeup, hair care, fragrances, or grooming essentials",
577 | "jewelry_and_watches": "watches or jewelry pieces such as necklaces, bracelets, earrings, or rings, crafted in precious metals or adorned with gemstones for special occasions",
578 | }
579 |
580 | @classmethod
581 | def raw_train(cls) -> Optional[pd.DataFrame]:
582 | return pd.read_parquet(
583 | f"{DATA_DIR}/data/amazon-reviews-category/amazon-reviews-category-train.parquet"
584 | ).rename(
585 | columns={
586 | "unique_id": "idx",
587 | "review_text": "text",
588 | "product_category": "label_text",
589 | }
590 | )
591 |
592 | @classmethod
593 | def raw_validation(cls) -> Optional[pd.DataFrame]:
594 | return pd.read_parquet(
595 | f"{DATA_DIR}/data/amazon-reviews-category/amazon-reviews-category-validation.parquet"
596 | ).rename(
597 | columns={
598 | "unique_id": "idx",
599 | "review_text": "text",
600 | "product_category": "label_text",
601 | }
602 | )
603 |
604 | @classmethod
605 | def raw_test(cls) -> Optional[pd.DataFrame]:
606 | return pd.read_parquet(
607 | f"{DATA_DIR}/data/amazon-reviews-category/amazon-reviews-category-test.parquet"
608 | ).rename(
609 | columns={
610 | "unique_id": "idx",
611 | "review_text": "text",
612 | "product_category": "label_text",
613 | }
614 | )
615 |
616 | @staticmethod
617 | def create_dataset():
618 | from bs4 import BeautifulSoup
619 |
620 | dataset_dir: FileMetadata = FileMetadata.of(
621 | f"{DATA_DIR}/raw-data/amazon-reviews-category/sorted_data/"
622 | ).mkdir(return_metadata=True)
623 | if len(dataset_dir.list()) == 0:
624 | raise SystemError(
625 | f'Expected Amazon Reviews Category data to be in folder "{dataset_dir.path}". '
626 | f"Please download and unzip the data from "
627 | f"https://www.cs.jhu.edu/~mdredze/datasets/sentiment/unprocessed.tar.gz"
628 | )
629 |
630 | def parse_review(text) -> pd.DataFrame:
631 | df = []
632 | for review_BS in BeautifulSoup(text).find_all("review"):
633 | d = {}
634 | for child in review_BS.children:
635 | k = child.name
636 | if k is not None:
637 | v = child.text.strip()
638 | if k in ["product_type", "unique_id"]: ## Allow list
639 | d.setdefault(k, [])
640 | if isinstance(d.get(k), list):
641 | d[k].append(v)
642 | else:
643 | if k in d:
644 | raise ValueError(f'"{k}" key already exists')
645 | d[k] = v
646 | if "unique_id" in d:
647 | d["unique_id"] = "-".join(d["unique_id"])
648 | if "product_type" in d:
649 | d["product_type"] = ",".join(set(d["product_type"]))
650 | if len(d) > 0:
651 | df.append(d)
652 | return pd.DataFrame(df)
653 |
654 | dfs = []
655 | from pathlib import Path
656 |
657 | for fpath in dataset_dir.list(only_subdirs=True):
658 | category = Path(fpath).stem
659 | neg = FileSystemUtil.get_file_str(
660 | str(Path(fpath) / "negative.review"),
661 | encoding="cp1252",
662 | raise_error=True,
663 | )
664 | neg_df = parse_review(neg)
665 | neg_df["product_category"] = category
666 | neg_df["sentiment"] = "negative"
667 | dfs.append(neg_df)
668 |
669 | pos = FileSystemUtil.get_file_str(
670 | str(Path(fpath) / "positive.review"),
671 | )
672 | pos_df = parse_review(pos)
673 | pos_df["product_category"] = category
674 | pos_df["sentiment"] = "positive"
675 |
676 | dfs.append(pos_df)
677 | reviews = pd.concat(dfs).reset_index(drop=True)
678 | reviews["product_category"] = reviews["product_category"].replace(
679 | {
680 | "grocery": "grocery_and_gourmet_food",
681 | "gourmet_food": "grocery_and_gourmet_food",
682 | }
683 | )
684 | attrprompt_cats = {
685 | "magazines",
686 | "camera_photo",
687 | "office_products",
688 | "kitchen",
689 | "cell_phones_service",
690 | "computer_video_games",
691 | "grocery_and_gourmet_food",
692 | "tools_hardware",
693 | "automotive",
694 | "music_album",
695 | "health_and_personal_care",
696 | "electronics",
697 | "outdoor_living",
698 | "video",
699 | "apparel",
700 | "toys_games",
701 | "sports_outdoors",
702 | "books",
703 | "software",
704 | "baby",
705 | "musical_and_instruments",
706 | "beauty",
707 | "jewelry_and_watches",
708 | }
709 | reviews = reviews.query(
710 | f"product_category in {list(attrprompt_cats)}"
711 | ).reset_index(drop=True)
712 | assert set(reviews["product_category"]) == attrprompt_cats
713 | assert reviews["unique_id"].nunique() == len(reviews)
714 | # reviews['product_category'].value_counts()
715 | reviews_train = reviews.sample(n=30_000, random_state=42).reset_index(drop=True)
716 | print(reviews_train.shape)
717 |
718 | reviews_test = (
719 | reviews.query(f"unique_id not in {reviews_train['unique_id'].tolist()}")
720 | .sample(n=2_400, random_state=42)
721 | .reset_index(drop=True)
722 | )
723 | print(reviews_test.shape)
724 |
725 | reviews_validation = reviews.query(
726 | f"unique_id not in {reviews_train['unique_id'].tolist() + reviews_test['unique_id'].tolist()}"
727 | ).reset_index(drop=True)
728 | print(reviews_validation.shape)
729 |
730 | FileMetadata.of(f"{DATA_DIR}/data/amazon-reviews-category/").mkdir()
731 | reviews_train.to_parquet(
732 | f"{DATA_DIR}/data/amazon-reviews-category/amazon-reviews-category-train.parquet"
733 | )
734 | reviews_validation.to_parquet(
735 | f"{DATA_DIR}/data/amazon-reviews-category/amazon-reviews-category-validation.parquet"
736 | )
737 | reviews_test.to_parquet(
738 | f"{DATA_DIR}/data/amazon-reviews-category/amazon-reviews-category-test.parquet"
739 | )
740 |
741 |
742 | class AmazonHumorousProductQuestions(SynthesizRRDataset):
743 | name = "amazon-humor"
744 | task = Task.BINARY_CLASSIFICATION
745 | data_schema = dict(
746 | idx=MLType.INDEX,
747 | text=MLType.TEXT,
748 | # product_description=MLType.TEXT,
749 | # image_url=MLType.URL,
750 | label_text=MLType.GROUND_TRUTH,
751 | )
752 |
753 | @classproperty
754 | def label_verbalizer(cls) -> Optional[Dict[str, str]]:
755 | return {
756 | "non_humorous": "solemn",
757 | "humorous": "humorous",
758 | }
759 |
760 | @classmethod
761 | def raw_train(cls) -> Optional[pd.DataFrame]:
762 | return pd.read_parquet(
763 | f"{DATA_DIR}/data/amazon-humor/amazon-humor_train.parquet"
764 | ).rename(columns={"question": "text"})
765 |
766 | @classmethod
767 | def raw_validation(cls) -> Optional[pd.DataFrame]:
768 | return pd.read_parquet(
769 | f"{DATA_DIR}/data/amazon-humor/amazon-humor_validation.parquet"
770 | ).rename(columns={"question": "text"})
771 |
772 | @classmethod
773 | def raw_test(cls) -> Optional[pd.DataFrame]:
774 | return pd.read_parquet(
775 | f"{DATA_DIR}/data/amazon-humor/amazon-humor_test.parquet"
776 | ).rename(columns={"question": "text"})
777 |
778 | @staticmethod
779 | def create_dataset():
780 | dataset_dir: FileMetadata = FileMetadata.of(
781 | f"{DATA_DIR}/raw-data/amazon-humor/"
782 | ).mkdir(return_metadata=True)
783 | if len(dataset_dir.list()) == 0:
784 | raise SystemError(
785 | f'Expected Amazon Humor data to be in folder "{dataset_dir.path}". '
786 | f"Please download the data files from https://registry.opendata.aws/humor-detection/"
787 | )
788 |
789 | humor_pos = (
790 | pd.read_csv(dataset_dir.file_in_dir("Humorous.csv"))
791 | .reset_index(drop=True)
792 | .reset_index()
793 | .rename(columns={"index": "idx"})
794 | )
795 | humor_pos["label_text"] = humor_pos["label"].map(
796 | {
797 | 1: "humorous",
798 | 0: "non_humorous",
799 | }
800 | )
801 |
802 | humor_neg = (
803 | pd.read_csv(dataset_dir.file_in_dir("Non-humorous-biased.csv"))
804 | .reset_index(drop=True)
805 | .reset_index()
806 | .rename(columns={"index": "idx"})
807 | )
808 | humor_neg["label_text"] = humor_neg["label"].map(
809 | {
810 | 1: "humorous",
811 | 0: "non_humorous",
812 | }
813 | )
814 |
815 | humor_pos_train = humor_pos.sample(n=15_000 // 2, random_state=42).reset_index(
816 | drop=True
817 | )
818 | # print(f'humor_pos_train: {len(humor_pos_train)}')
819 |
820 | humor_pos_validation = (
821 | humor_pos.query(f"idx not in {humor_pos_train['idx'].tolist()}")
822 | .sample(n=1142 // 2, random_state=42)
823 | .reset_index(drop=True)
824 | )
825 | # print(f'humor_pos_validation: {len(humor_pos_validation)}')
826 |
827 | humor_pos_test = humor_pos.query(
828 | f"idx not in {humor_pos_train['idx'].tolist() + humor_pos_validation['idx'].tolist()}"
829 | ).reset_index(drop=True)
830 | # print(f'humor_pos_test: {len(humor_pos_test)}')
831 |
832 | humor_neg_train = humor_neg.sample(n=15_000 // 2, random_state=42).reset_index(
833 | drop=True
834 | )
835 | # print(f'humor_neg_train: {len(humor_neg_train)}')
836 |
837 | humor_neg_validation = (
838 | humor_neg.query(f"idx not in {humor_neg_train['idx'].tolist()}")
839 | .sample(n=1142 // 2, random_state=42)
840 | .reset_index(drop=True)
841 | )
842 | # print(f'humor_neg_validation: {len(humor_neg_validation)}')
843 |
844 | humor_neg_test = humor_neg.query(
845 | f"idx not in {humor_neg_train['idx'].tolist() + humor_neg_validation['idx'].tolist()}"
846 | ).reset_index(drop=True)
847 | # print(f'humor_neg_test: {len(humor_neg_test)}')
848 |
849 | humor_train = (
850 | pd.concat([humor_pos_train, humor_neg_train])
851 | .sample(frac=1, random_state=42)
852 | .reset_index(drop=True)
853 | )
854 | # print(f'humor_train: {len(humor_train)}')
855 | # display(humor_train.head(3))
856 |
857 | humor_validation = (
858 | pd.concat([humor_pos_validation, humor_neg_validation])
859 | .sample(frac=1, random_state=42)
860 | .reset_index(drop=True)
861 | )
862 | # print(f'humor_validation: {len(humor_validation)}')
863 | # display(humor_validation.head(3))
864 |
865 | humor_test = (
866 | pd.concat([humor_pos_test, humor_neg_test])
867 | .sample(frac=1, random_state=42)
868 | .reset_index(drop=True)
869 | )
870 | # print(f'humor_test: {len(humor_test)}')
871 | # display(humor_test.head(3))
872 | humor_train["idx"] = (
873 | humor_train["idx"].astype(str) + "-" + humor_train["label_text"].astype(str)
874 | )
875 | humor_validation["idx"] = (
876 | humor_validation["idx"].astype(str)
877 | + "-"
878 | + humor_validation["label_text"].astype(str)
879 | )
880 | humor_test["idx"] = (
881 | humor_test["idx"].astype(str) + "-" + humor_test["label_text"].astype(str)
882 | )
883 |
884 | FileMetadata.of(f"{DATA_DIR}/data/amazon-humor/").mkdir()
885 | humor_train.to_parquet(
886 | f"{DATA_DIR}/data/amazon-humor/amazon-humor_train.parquet"
887 | )
888 | humor_validation.to_parquet(
889 | f"{DATA_DIR}/data/amazon-humor/amazon-humor_validation.parquet"
890 | )
891 | humor_test.to_parquet(f"{DATA_DIR}/data/amazon-humor/amazon-humor_test.parquet")
892 |
893 |
894 | class ToiHeadlinesDataset(SynthesizRRDataset):
895 | name = "toi_headlines"
896 | task = Task.MULTI_CLASS_CLASSIFICATION
897 | data_schema = dict(
898 | idx=MLType.INDEX,
899 | text=MLType.TEXT,
900 | # publish_date=MLType.TEXT,
901 | # headline_category=MLType.CATEGORICAL,
902 | # headline_text_len=MLType.INT,
903 | label_text=MLType.GROUND_TRUTH,
904 | )
905 |
906 | @classmethod
907 | def raw_train(cls) -> Optional[pd.DataFrame]:
908 | return pd.read_parquet(
909 | f"{DATA_DIR}/data/toi_headlines/toi_headlines_train.parquet"
910 | ).rename(columns={"headline_text": "text", "headline_root": "label_text"})
911 |
912 | @classmethod
913 | def raw_validation(cls) -> Optional[pd.DataFrame]:
914 | return pd.read_parquet(
915 | f"{DATA_DIR}/data/toi_headlines/toi_headlines_validation.parquet"
916 | ).rename(columns={"headline_text": "text", "headline_root": "label_text"})
917 |
918 | @classmethod
919 | def raw_test(cls) -> Optional[pd.DataFrame]:
920 | return pd.read_parquet(
921 | f"{DATA_DIR}/data/toi_headlines/toi_headlines_test.parquet"
922 | ).rename(columns={"headline_text": "text", "headline_root": "label_text"})
923 |
924 | @classproperty
925 | def label_verbalizer(cls) -> Optional[Dict[str, str]]:
926 | raise NotImplementedError()
927 |
928 | @staticmethod
929 | def create_dataset():
930 | label_space: List[str] = [
931 | "sports",
932 | "life-style",
933 | "education",
934 | "entertainment",
935 | "business",
936 | "city",
937 | "environment",
938 | "tech",
939 | "elections",
940 | "world",
941 | ]
942 | dataset_dir: FileMetadata = FileMetadata.of(
943 | f"{DATA_DIR}/raw-data/toi_headlines/"
944 | ).mkdir(return_metadata=True)
945 | if len(dataset_dir.list()) == 0:
946 | raise SystemError(
947 | f'Expected ToI Headlines data to be in folder "{dataset_dir.path}". '
948 | f"Please download the CSV file from "
949 | f"https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/DPQMQH"
950 | )
951 | df = pd.read_csv(dataset_dir.file_in_dir("india-news-headlines.csv"))
952 | df["headline_root"] = df["headline_category"].apply(
953 | lambda x: x.strip().removeprefix("home.").split(".")[0]
954 | )
955 | df["headline_text_len"] = df["headline_text"].apply(len)
956 | df = (
957 | df.query("headline_text_len >= 40")
958 | .reset_index(drop=True)
959 | .reset_index()
960 | .rename(columns=dict(index="idx"))
961 | )
962 | full_idxs: np.ndarray = sample_idxs_match_distribution(
963 | source=df["headline_root"],
964 | target=pd.Series([lb for lb in label_space]), ## Balanced
965 | n=None,
966 | seed=42,
967 | )
968 | full = df.loc[full_idxs].sample(frac=1, random_state=42).reset_index(drop=True)
969 | train = full.loc[
970 | sample_idxs_match_distribution(
971 | full["headline_root"],
972 | target=pd.Series([lb for lb in label_space]), ## Balanced
973 | n=52_000,
974 | seed=42,
975 | )
976 | ].reset_index(drop=True)
977 | remaining_wo_train = full.query(
978 | f"idx not in {train['idx'].tolist()}"
979 | ).reset_index(drop=True)
980 | test = remaining_wo_train.loc[
981 | sample_idxs_match_distribution(
982 | remaining_wo_train["headline_root"],
983 | target=pd.Series([lb for lb in label_space]), ## Balanced
984 | n=10_000,
985 | seed=42,
986 | )
987 | ].reset_index(drop=True)
988 | validation = full.query(
989 | f"idx not in {train['idx'].tolist() + test['idx'].tolist()}"
990 | ).reset_index(drop=True)
991 | train["idx"] = train["idx"].apply(lambda x: f"train-{x}")
992 | test["idx"] = test["idx"].apply(lambda x: f"test-{x}")
993 | validation["idx"] = validation["idx"].apply(lambda x: f"validation-{x}")
994 | print(train["headline_root"].value_counts())
995 | print(validation["headline_root"].value_counts())
996 | print(test["headline_root"].value_counts())
997 | FileMetadata.of(f"{DATA_DIR}/data/toi_headlines/").mkdir()
998 | train.to_parquet(f"{DATA_DIR}/data/toi_headlines/toi_headlines_train.parquet")
999 | validation.to_parquet(
1000 | f"{DATA_DIR}/data/toi_headlines/toi_headlines_validation.parquet"
1001 | )
1002 | test.to_parquet(f"{DATA_DIR}/data/toi_headlines/toi_headlines_test.parquet")
1003 |
1004 |
1005 | if __name__ == "__main__":
1006 | ToiHeadlinesDataset.create_dataset()
1007 | AmazonReviewsProductCategory.create_dataset()
1008 | AmazonReviewsPolarity.create_dataset()
1009 | AmazonHumorousProductQuestions.create_dataset()
1010 | ## Copy schema and files to be accessible to all workers:
1011 | HyperpartisanNewsDataset.setup()
1012 | AGNewsDataset.setup()
1013 | ToiHeadlinesDataset.setup()
1014 | AmazonReviewsProductCategory.setup()
1015 | AmazonReviewsPolarity.setup()
1016 | AmazonHumorousProductQuestions.setup()
1017 |
--------------------------------------------------------------------------------
/src/synthesizrr/main.py:
--------------------------------------------------------------------------------
1 | from typing import *
2 |
3 | from fmcore.framework import Tracker
4 |
5 | from synthesizrr.common import (
6 | RESULTS_DIR,
7 | Corpus,
8 | DatasetName,
9 | Experiment,
10 | ModelName,
11 | Retriever,
12 | )
13 | from synthesizrr.driver import run_chain
14 |
15 | TRACKER = Tracker.of('log', path='~/synthesizrr_run.log') ## Execution outputs will get logged to this file.
16 | TRACKER = Tracker.of(
17 | "log", path="~/synthesizrr_run.log"
18 | ) ## Execution outputs will get logged to this file.
19 | "log", path="~/synthesizrr_run.log"
20 | ) ## Execution outputs will get logged to this file.
21 | BACKGROUND: bool = False
22 | CART_FRAC: Optional[float] = 0.83 ## Make None to Cartography filtering.
23 |
24 | if __name__ == "__main__":
25 | """
26 | _ _ _ _
27 | | || | _ _ _ __ ___ _ _ _ __ __ _ _ _ | |_ (_) ___ __ _ _ _
28 | | __ || || || '_ \/ -_)| '_|| '_ \/ _` || '_|| _|| |(_- _` || ' \
29 | |_||_| \_, || .__/\___||_| | .__/\__,_||_| \__||_|/__/\__,_||_||_|
30 | |__/ |_| |_|
31 | """
32 | FEWGEN_NUM_SHOTS_LIST = [0, 32]
33 | if (
34 | "fewgen_hyperpartisan_news_llama_2_13b_chat_exn" not in globals()
35 | or fewgen_hyperpartisan_news_llama_2_13b_chat_exn.status is Status.FAILED
36 | ):
37 | fewgen_hyperpartisan_news_llama_2_13b_chat_exn = run_chain(
38 | results_dir=RESULTS_DIR,
39 | expt=Experiment.FewGen,
40 | dataset_name=DatasetName.HyperpartisanNews,
41 | model_name=ModelName.LLaMa_2_13B_Chat,
42 | num_shots_list=FEWGEN_NUM_SHOTS_LIST,
43 | num_samples_per_label=3_000,
44 | seed_type="train_set",
45 | seed_set_stratify_on_ground_truth=False,
46 | llm_num_models=48,
47 | metrics_overall_num_samples_per_label=2_000,
48 | metrics_max_parallel=3,
49 | metrics_label_distribution="train_set",
50 | # metrics_to_evaluate=None,
51 | tracker=TRACKER,
52 | background=BACKGROUND,
53 | verbosity=1,
54 | step_wait=5,
55 | cart_frac=CART_FRAC,
56 | # dry_run=True,
57 | )
58 |
59 | SYNTHESIZRR_NUM_SHOTS_LIST = [0, 3]
60 | if (
61 | "synthesizrr_retr_icl_hyperpartisan_news_llama_2_13b_chat_exn" not in globals()
62 | or synthesizrr_retr_icl_hyperpartisan_news_llama_2_13b_chat_exn.status
63 | is Status.FAILED
64 | ):
65 | synthesizrr_retr_icl_hyperpartisan_news_llama_2_13b_chat_exn = run_chain(
66 | results_dir=RESULTS_DIR,
67 | expt=Experiment.SynthesizRR,
68 | dataset_name=DatasetName.HyperpartisanNews,
69 | model_name=ModelName.LLaMa_2_13B_Chat,
70 | num_shots_list=SYNTHESIZRR_NUM_SHOTS_LIST,
71 | corpus=Corpus.RealNews_Dominant,
72 | retriever=Retriever.Contriever,
73 | num_samples_per_label=3_000,
74 | seed_type="train_set",
75 | seed_set_stratify_on_ground_truth=False,
76 | llm_batch_size=1,
77 | llm_submission_batch_size=12,
78 | llm_num_models=48,
79 | llm_num_concurrent_preds=2,
80 | metrics_overall_num_samples_per_label=2_000,
81 | metrics_max_parallel=3,
82 | metrics_label_distribution="train_set",
83 | # metrics_to_evaluate=None,
84 | tracker=TRACKER,
85 | background=BACKGROUND,
86 | verbosity=1,
87 | step_wait=5,
88 | cart_frac=CART_FRAC,
89 | # dry_run=True,
90 | )
91 |
92 | SYNTHESIZRR_NUM_SHOTS_LIST = [32]
93 | if (
94 | "synthesizrr_no_retr_icl_hyperpartisan_news_llama_2_13b_chat_exn"
95 | not in globals()
96 | or synthesizrr_no_retr_icl_hyperpartisan_news_llama_2_13b_chat_exn.status
97 | is Status.FAILED
98 | ):
99 | synthesizrr_no_retr_icl_hyperpartisan_news_llama_2_13b_chat_exn = run_chain(
100 | results_dir=RESULTS_DIR,
101 | expt=Experiment.SynthesizRR,
102 | dataset_name=DatasetName.HyperpartisanNews,
103 | model_name=ModelName.LLaMa_2_13B_Chat,
104 | num_shots_list=SYNTHESIZRR_NUM_SHOTS_LIST,
105 | corpus=Corpus.RealNews_Dominant,
106 | retriever=Retriever.Contriever,
107 | num_samples_per_label=3_000,
108 | seed_type="train_set",
109 | seed_set_stratify_on_ground_truth=False,
110 | icl_type="seed",
111 | llm_batch_size=1,
112 | llm_submission_batch_size=12,
113 | llm_num_models=48,
114 | llm_num_concurrent_preds=2,
115 | metrics_overall_num_samples_per_label=2_000,
116 | metrics_max_parallel=3,
117 | metrics_label_distribution="train_set",
118 | # metrics_to_evaluate=None,
119 | icl_and_prompt_template=dict(
120 | icl_template="""
121 | Rewritten Article:
122 | {{icl[example_text]}}""",
123 | prompt_template="""
124 | {{icl_examples}}
125 |
126 | News Article:
127 | {{retrieved_context}}
128 |
129 | Rewrite the above news article {label_verbalization}. The rewritten article should be 2 to 3 paragraphs long.
130 | Rewritten Article: """.strip()
131 | + "\n",
132 | ),
133 | tracker=TRACKER,
134 | background=BACKGROUND,
135 | verbosity=1,
136 | step_wait=5,
137 | cart_frac=CART_FRAC,
138 | # dry_run=True,
139 | )
140 |
141 | FEWGEN_NUM_SHOTS_LIST = [0, 32]
142 | if (
143 | "fewgen_hyperpartisan_news_claude_instant_v1_exn" not in globals()
144 | or fewgen_hyperpartisan_news_claude_instant_v1_exn.status is Status.FAILED
145 | ):
146 | fewgen_hyperpartisan_news_claude_instant_v1_exn = run_chain(
147 | results_dir=RESULTS_DIR,
148 | expt=Experiment.FewGen,
149 | dataset_name=DatasetName.HyperpartisanNews,
150 | model_name=ModelName.Claude_Instant_v1,
151 | num_shots_list=FEWGEN_NUM_SHOTS_LIST,
152 | num_samples_per_label=3_000,
153 | seed_type="train_set",
154 | seed_set_stratify_on_ground_truth=False,
155 | llm_batch_size=1,
156 | llm_submission_batch_size=12,
157 | llm_num_models=1,
158 | llm_num_concurrent_preds=6,
159 | metrics_overall_num_samples_per_label=2_000,
160 | metrics_max_parallel=3,
161 | metrics_label_distribution="train_set",
162 | # metrics_to_evaluate=None,
163 | tracker=TRACKER,
164 | background=BACKGROUND,
165 | verbosity=1,
166 | step_wait=5,
167 | cart_frac=CART_FRAC,
168 | # dry_run=True,
169 | )
170 |
171 | SYNTHESIZRR_NUM_SHOTS_LIST = [0, 3]
172 | if (
173 | "synthesizrr_retr_icl_hyperpartisan_news_claude_instant_v1_exn" not in globals()
174 | or synthesizrr_retr_icl_hyperpartisan_news_claude_instant_v1_exn.status
175 | is Status.FAILED
176 | ):
177 | synthesizrr_retr_icl_hyperpartisan_news_claude_instant_v1_exn = run_chain(
178 | results_dir=RESULTS_DIR,
179 | expt=Experiment.SynthesizRR,
180 | dataset_name=DatasetName.HyperpartisanNews,
181 | model_name=ModelName.Claude_Instant_v1,
182 | num_shots_list=SYNTHESIZRR_NUM_SHOTS_LIST,
183 | corpus=Corpus.RealNews_Dominant,
184 | retriever=Retriever.Contriever,
185 | num_samples_per_label=3_000,
186 | seed_type="train_set",
187 | seed_set_stratify_on_ground_truth=False,
188 | llm_batch_size=1,
189 | llm_submission_batch_size=12,
190 | llm_num_models=1,
191 | llm_num_concurrent_preds=6,
192 | metrics_overall_num_samples_per_label=2_000,
193 | metrics_max_parallel=3,
194 | metrics_label_distribution="train_set",
195 | # metrics_to_evaluate=None,
196 | tracker=TRACKER,
197 | background=BACKGROUND,
198 | verbosity=1,
199 | step_wait=5,
200 | cart_frac=CART_FRAC,
201 | # dry_run=True,
202 | )
203 |
204 | SYNTHESIZRR_NUM_SHOTS_LIST = [32]
205 | if (
206 | "synthesizrr_no_retr_icl_hyperpartisan_news_claude_instant_v1_exn"
207 | not in globals()
208 | or synthesizrr_no_retr_icl_hyperpartisan_news_claude_instant_v1_exn.status
209 | is Status.FAILED
210 | ):
211 | synthesizrr_no_retr_icl_hyperpartisan_news_claude_instant_v1_exn = run_chain(
212 | results_dir=RESULTS_DIR,
213 | expt=Experiment.SynthesizRR,
214 | dataset_name=DatasetName.HyperpartisanNews,
215 | model_name=ModelName.Claude_Instant_v1,
216 | num_shots_list=SYNTHESIZRR_NUM_SHOTS_LIST,
217 | corpus=Corpus.RealNews_Dominant,
218 | retriever=Retriever.Contriever,
219 | num_samples_per_label=3_000,
220 | seed_type="train_set",
221 | seed_set_stratify_on_ground_truth=False,
222 | icl_type="seed",
223 | llm_batch_size=1,
224 | llm_submission_batch_size=12,
225 | llm_num_models=1,
226 | llm_num_concurrent_preds=6,
227 | metrics_overall_num_samples_per_label=2_000,
228 | metrics_max_parallel=3,
229 | metrics_label_distribution="train_set",
230 | # metrics_to_evaluate=None,
231 | icl_and_prompt_template=dict(
232 | icl_template="""
233 | Rewritten Article by Assistant:
234 | {{icl[example_text]}}""".strip()
235 | + "\n",
236 | prompt_template="""
237 | Human:
238 | {{icl_examples}}
239 |
240 | News Article:
241 | {{retrieved_context}}
242 |
243 | Rewrite the above news article {label_verbalization}. The rewritten article should be 2 to 3 paragraphs long.
244 | Rewritten Article by Assistant: """.strip()
245 | + "\n",
246 | ),
247 | tracker=TRACKER,
248 | background=BACKGROUND,
249 | verbosity=1,
250 | step_wait=5,
251 | cart_frac=CART_FRAC,
252 | # dry_run=True,
253 | )
254 |
255 | """
256 | _ ___ _ _
257 | /_\ / __| | \| | ___ __ __ __ ___
258 | / _ \| (_ | | .` |/ -_)\ V V /(_-<
259 | /_/ \_\\___| |_|\_|\___| \_/\_/ /__/
260 | """
261 |
262 | FEWGEN_NUM_SHOTS_LIST = [0, 32]
263 | if (
264 | "fewgen_ag_news_llama_2_13b_chat_exn" not in globals()
265 | or fewgen_ag_news_llama_2_13b_chat_exn.status is Status.FAILED
266 | ):
267 | fewgen_ag_news_llama_2_13b_chat_exn = run_chain(
268 | results_dir=RESULTS_DIR,
269 | expt=Experiment.FewGen,
270 | dataset_name=DatasetName.AgNews,
271 | model_name=ModelName.LLaMa_2_13B_Chat,
272 | num_shots_list=FEWGEN_NUM_SHOTS_LIST,
273 | num_samples_per_label=3_000,
274 | seed_type="train_set",
275 | seed_set_stratify_on_ground_truth=False,
276 | llm_num_models=48,
277 | metrics_overall_num_samples_per_label=8_000,
278 | metrics_max_parallel=3,
279 | metrics_label_distribution="train_set",
280 | # metrics_to_evaluate=None,
281 | tracker=TRACKER,
282 | background=BACKGROUND,
283 | verbosity=1,
284 | step_wait=5,
285 | cart_frac=CART_FRAC,
286 | # dry_run=True,
287 | )
288 |
289 | SYNTHESIZRR_NUM_SHOTS_LIST = [0, 3]
290 | if (
291 | "synthesizrr_retr_icl_ag_news_llama_2_13b_chat_exn" not in globals()
292 | or synthesizrr_retr_icl_ag_news_llama_2_13b_chat_exn.status is Status.FAILED
293 | ):
294 | synthesizrr_retr_icl_ag_news_llama_2_13b_chat_exn = run_chain(
295 | results_dir=RESULTS_DIR,
296 | expt=Experiment.SynthesizRR,
297 | dataset_name=DatasetName.AgNews,
298 | model_name=ModelName.LLaMa_2_13B_Chat,
299 | num_shots_list=SYNTHESIZRR_NUM_SHOTS_LIST,
300 | corpus=Corpus.RealNews_Dominant,
301 | retriever=Retriever.Contriever,
302 | num_samples_per_label=3_000,
303 | seed_type="train_set",
304 | seed_set_stratify_on_ground_truth=False,
305 | llm_num_models=48,
306 | metrics_overall_num_samples_per_label=8_000,
307 | metrics_max_parallel=3,
308 | metrics_label_distribution="train_set",
309 | # metrics_to_evaluate=None,
310 | tracker=TRACKER,
311 | background=BACKGROUND,
312 | verbosity=1,
313 | step_wait=5,
314 | cart_frac=CART_FRAC,
315 | # dry_run=True,
316 | )
317 |
318 | SYNTHESIZRR_NUM_SHOTS_LIST = [32]
319 | if (
320 | "synthesizrr_no_retr_icl_ag_news_llama_2_13b_chat_exn" not in globals()
321 | or synthesizrr_no_retr_icl_ag_news_llama_2_13b_chat_exn.status is Status.FAILED
322 | ):
323 | synthesizrr_no_retr_icl_ag_news_llama_2_13b_chat_exn = run_chain(
324 | results_dir=RESULTS_DIR,
325 | expt=Experiment.SynthesizRR,
326 | dataset_name=DatasetName.AgNews,
327 | model_name=ModelName.LLaMa_2_13B_Chat,
328 | num_shots_list=SYNTHESIZRR_NUM_SHOTS_LIST,
329 | corpus=Corpus.RealNews_Dominant,
330 | retriever=Retriever.Contriever,
331 | num_samples_per_label=3_000,
332 | seed_type="train_set",
333 | seed_set_stratify_on_ground_truth=False,
334 | icl_type="seed",
335 | llm_batch_size=1,
336 | llm_submission_batch_size=24,
337 | llm_num_models=48,
338 | llm_num_concurrent_preds=4,
339 | metrics_overall_num_samples_per_label=8_000,
340 | metrics_max_parallel=3,
341 | metrics_label_distribution="train_set",
342 | # metrics_to_evaluate=None,
343 | icl_and_prompt_template=dict(
344 | icl_template="""
345 | Summary: {{icl[example_text]}}""",
346 | prompt_template="""
347 | {{icl_examples}}
348 |
349 | News Article:
350 | {{retrieved_context}}
351 |
352 | Write a summary for the above news article {label_verbalization}. The summary should be one or two short sentences.
353 | Summary: """.strip()
354 | + " ",
355 | ),
356 | tracker=TRACKER,
357 | background=BACKGROUND,
358 | verbosity=1,
359 | step_wait=5,
360 | cart_frac=CART_FRAC,
361 | # dry_run=True,
362 | )
363 |
364 | FEWGEN_NUM_SHOTS_LIST = [0, 32]
365 | if (
366 | "fewgen_ag_news_claude_instant_v1_exn" not in globals()
367 | or fewgen_ag_news_claude_instant_v1_exn.status is Status.FAILED
368 | ):
369 | fewgen_ag_news_claude_instant_v1_exn = run_chain(
370 | results_dir=RESULTS_DIR,
371 | expt=Experiment.FewGen,
372 | dataset_name=DatasetName.AgNews,
373 | model_name=ModelName.Claude_Instant_v1,
374 | num_shots_list=FEWGEN_NUM_SHOTS_LIST,
375 | num_samples_per_label=3_000,
376 | seed_type="train_set",
377 | seed_set_stratify_on_ground_truth=False,
378 | llm_batch_size=1,
379 | llm_submission_batch_size=12,
380 | llm_num_models=1,
381 | llm_num_concurrent_preds=6,
382 | metrics_overall_num_samples_per_label=8_000,
383 | metrics_max_parallel=3,
384 | metrics_label_distribution="train_set",
385 | # metrics_to_evaluate=None,
386 | tracker=TRACKER,
387 | background=BACKGROUND,
388 | verbosity=1,
389 | step_wait=5,
390 | cart_frac=CART_FRAC,
391 | # dry_run=True,
392 | )
393 |
394 | SYNTHESIZRR_NUM_SHOTS_LIST = [32]
395 | if (
396 | "synthesizrr_no_retr_icl_ag_news_claude_instant_v1_exn" not in globals()
397 | or synthesizrr_no_retr_icl_ag_news_claude_instant_v1_exn.status is Status.FAILED
398 | ):
399 | synthesizrr_no_retr_icl_ag_news_claude_instant_v1_exn = run_chain(
400 | results_dir=RESULTS_DIR,
401 | expt=Experiment.SynthesizRR,
402 | dataset_name=DatasetName.AgNews,
403 | model_name=ModelName.Claude_Instant_v1,
404 | num_shots_list=SYNTHESIZRR_NUM_SHOTS_LIST,
405 | corpus=Corpus.RealNews_Dominant,
406 | retriever=Retriever.Contriever,
407 | num_samples_per_label=3_000,
408 | seed_type="train_set",
409 | seed_set_stratify_on_ground_truth=False,
410 | icl_type="seed",
411 | llm_batch_size=1,
412 | llm_submission_batch_size=12,
413 | llm_num_models=1,
414 | llm_num_concurrent_preds=6,
415 | metrics_overall_num_samples_per_label=8_000,
416 | metrics_max_parallel=3,
417 | metrics_label_distribution="train_set",
418 | # metrics_to_evaluate=None,
419 | icl_and_prompt_template=dict(
420 | icl_template="""
421 | Summary by Assistant: {{icl[example_text]}}""".strip()
422 | + " ",
423 | prompt_template="""
424 | Human:
425 | {{icl_examples}}
426 |
427 | News Article:
428 | {{retrieved_context}}
429 |
430 | Write a summary for the above news article {label_verbalization}. The summary should be one or two short sentences.
431 | Summary by Assistant: """.strip()
432 | + " ",
433 | ),
434 | tracker=TRACKER,
435 | background=BACKGROUND,
436 | verbosity=1,
437 | step_wait=5,
438 | cart_frac=CART_FRAC,
439 | # dry_run=True,
440 | )
441 |
442 | SYNTHESIZRR_NUM_SHOTS_LIST = [0, 3]
443 | if (
444 | "synthesizrr_retr_icl_ag_news_claude_instant_v1_exn" not in globals()
445 | or synthesizrr_retr_icl_ag_news_claude_instant_v1_exn.status is Status.FAILED
446 | ):
447 | synthesizrr_retr_icl_ag_news_claude_instant_v1_exn = run_chain(
448 | results_dir=RESULTS_DIR,
449 | expt=Experiment.SynthesizRR,
450 | dataset_name=DatasetName.AgNews,
451 | model_name=ModelName.Claude_Instant_v1,
452 | num_shots_list=SYNTHESIZRR_NUM_SHOTS_LIST,
453 | corpus=Corpus.RealNews_Dominant,
454 | retriever=Retriever.Contriever,
455 | num_samples_per_label=3_000,
456 | seed_type="train_set",
457 | seed_set_stratify_on_ground_truth=False,
458 | llm_batch_size=1,
459 | llm_submission_batch_size=12,
460 | llm_num_models=1,
461 | llm_num_concurrent_preds=8,
462 | metrics_overall_num_samples_per_label=8_000,
463 | metrics_max_parallel=3,
464 | metrics_label_distribution="train_set",
465 | # metrics_to_evaluate=None,
466 | tracker=TRACKER,
467 | background=BACKGROUND,
468 | verbosity=1,
469 | step_wait=5,
470 | cart_frac=CART_FRAC,
471 | # dry_run=True,
472 | )
473 |
474 | """
475 | _____ ___ _ _ _ _ _
476 | |_ _|___ |_ _| | || | ___ __ _ __| || |(_) _ _ ___ ___
477 | | | / _ \ | | | __ |/ -_)/ _` |/ _` || || || ' \ / -_)(_-<
478 | |_| \___/|___| |_||_|\___|\__,_|\__,_||_||_||_||_|\___|/__/
479 | """
480 |
481 | FEWGEN_NUM_SHOTS_LIST = [0, 32]
482 | if (
483 | "fewgen_toi_headlines_llama_2_13b_chat_exn" not in globals()
484 | or fewgen_toi_headlines_llama_2_13b_chat_exn.status is Status.FAILED
485 | ):
486 | fewgen_toi_headlines_llama_2_13b_chat_exn = run_chain(
487 | results_dir=RESULTS_DIR,
488 | expt=Experiment.FewGen,
489 | dataset_name=DatasetName.ToiHeadlines,
490 | model_name=ModelName.LLaMa_2_13B_Chat,
491 | num_shots_list=FEWGEN_NUM_SHOTS_LIST,
492 | num_samples_per_label=2_000,
493 | seed_type="train_set",
494 | seed_set_stratify_on_ground_truth=False,
495 | llm_batch_size=1,
496 | llm_submission_batch_size=12,
497 | llm_num_models=48,
498 | llm_num_concurrent_preds=20,
499 | metrics_overall_num_samples_per_label=8_000,
500 | metrics_max_parallel=3,
501 | metrics_label_distribution="train_set",
502 | # metrics_to_evaluate=None,
503 | tracker=TRACKER,
504 | background=BACKGROUND,
505 | verbosity=2,
506 | step_wait=5,
507 | cart_frac=CART_FRAC,
508 | # dry_run=True,
509 | )
510 |
511 | SYNTHESIZRR_NUM_SHOTS_LIST = [0, 3]
512 | if (
513 | "synthesizrr_retr_icl_toi_headlines_llama_2_13b_chat_exn" not in globals()
514 | or synthesizrr_retr_icl_toi_headlines_llama_2_13b_chat_exn.status
515 | is Status.FAILED
516 | ):
517 | synthesizrr_retr_icl_toi_headlines_llama_2_13b_chat_exn = run_chain(
518 | results_dir=RESULTS_DIR,
519 | expt=Experiment.SynthesizRR,
520 | dataset_name=DatasetName.ToiHeadlines,
521 | model_name=ModelName.LLaMa_2_13B_Chat,
522 | num_shots_list=SYNTHESIZRR_NUM_SHOTS_LIST,
523 | corpus=Corpus.RealNews_India,
524 | retriever=Retriever.Contriever,
525 | num_samples_per_label=2_000,
526 | seed_type="train_set",
527 | seed_set_stratify_on_ground_truth=False,
528 | llm_batch_size=1,
529 | llm_submission_batch_size=20,
530 | llm_tracking_batch_size=100,
531 | llm_num_models=48,
532 | llm_num_concurrent_preds=20,
533 | metrics_overall_num_samples_per_label=8_000,
534 | metrics_max_parallel=3,
535 | metrics_label_distribution="train_set",
536 | # metrics_to_evaluate=None,
537 | tracker=TRACKER,
538 | background=BACKGROUND,
539 | verbosity=2,
540 | step_wait=5,
541 | cart_frac=CART_FRAC,
542 | # dry_run=True,
543 | )
544 |
545 | SYNTHESIZRR_NUM_SHOTS_LIST = [32]
546 | if (
547 | "synthesizrr_no_retr_icl_toi_headlines_llama_2_13b_chat_exn" not in globals()
548 | or synthesizrr_no_retr_icl_toi_headlines_llama_2_13b_chat_exn.status
549 | is Status.FAILED
550 | ):
551 | synthesizrr_no_retr_icl_toi_headlines_llama_2_13b_chat_exn = run_chain(
552 | results_dir=RESULTS_DIR,
553 | expt=Experiment.SynthesizRR,
554 | dataset_name=DatasetName.ToiHeadlines,
555 | model_name=ModelName.LLaMa_2_13B_Chat,
556 | num_shots_list=SYNTHESIZRR_NUM_SHOTS_LIST,
557 | corpus=Corpus.RealNews_India,
558 | retriever=Retriever.Contriever,
559 | num_samples_per_label=2_000,
560 | seed_type="train_set",
561 | seed_set_stratify_on_ground_truth=False,
562 | icl_type="seed",
563 | llm_batch_size=1,
564 | llm_submission_batch_size=16,
565 | llm_num_models=48,
566 | llm_num_concurrent_preds=10,
567 | metrics_overall_num_samples_per_label=8_000,
568 | metrics_max_parallel=3,
569 | metrics_label_distribution="train_set",
570 | # metrics_to_evaluate=None,
571 | icl_and_prompt_template=dict(
572 | icl_template="""
573 | Headline: {{icl[example_text]}}""",
574 | prompt_template="""
575 | {{icl_examples}}
576 |
577 | News Article:
578 | {{retrieved_context}}
579 |
580 | Write a headline for the above news article about {label_verbalization}. The headline should be a single sentence.
581 | Headline: """.strip()
582 | + " ",
583 | ),
584 | tracker=TRACKER,
585 | background=BACKGROUND,
586 | verbosity=2,
587 | step_wait=5,
588 | cart_frac=CART_FRAC,
589 | # dry_run=True,
590 | )
591 |
592 | FEWGEN_NUM_SHOTS_LIST = [0, 32]
593 | if (
594 | "fewgen_toi_headlines_claude_instant_v1_exn" not in globals()
595 | or fewgen_toi_headlines_claude_instant_v1_exn.status is Status.FAILED
596 | ):
597 | fewgen_toi_headlines_claude_instant_v1_exn = run_chain(
598 | results_dir=RESULTS_DIR,
599 | expt=Experiment.FewGen,
600 | dataset_name=DatasetName.ToiHeadlines,
601 | model_name=ModelName.Claude_Instant_v1,
602 | num_shots_list=FEWGEN_NUM_SHOTS_LIST,
603 | num_samples_per_label=2_000,
604 | seed_type="train_set",
605 | seed_set_stratify_on_ground_truth=False,
606 | llm_batch_size=1,
607 | llm_submission_batch_size=12,
608 | llm_num_models=1,
609 | llm_num_concurrent_preds=10,
610 | metrics_overall_num_samples_per_label=8_000,
611 | metrics_max_parallel=3,
612 | metrics_label_distribution="train_set",
613 | # metrics_to_evaluate=None,
614 | tracker=TRACKER,
615 | background=BACKGROUND,
616 | verbosity=1,
617 | step_wait=5,
618 | cart_frac=CART_FRAC,
619 | # dry_run=True,
620 | )
621 |
622 | SYNTHESIZRR_NUM_SHOTS_LIST = [32]
623 | if (
624 | "synthesizrr_no_retr_icl_toi_headlines_claude_instant_v1_exn" not in globals()
625 | or synthesizrr_no_retr_icl_toi_headlines_claude_instant_v1_exn.status
626 | is Status.FAILED
627 | ):
628 | synthesizrr_no_retr_icl_toi_headlines_claude_instant_v1_exn = run_chain(
629 | results_dir=RESULTS_DIR,
630 | expt=Experiment.SynthesizRR,
631 | dataset_name=DatasetName.ToiHeadlines,
632 | model_name=ModelName.Claude_Instant_v1,
633 | num_shots_list=SYNTHESIZRR_NUM_SHOTS_LIST,
634 | corpus=Corpus.RealNews_India,
635 | retriever=Retriever.Contriever,
636 | num_samples_per_label=2_000,
637 | seed_type="train_set",
638 | seed_set_stratify_on_ground_truth=False,
639 | icl_type="seed",
640 | llm_batch_size=1,
641 | llm_submission_batch_size=12,
642 | llm_num_models=1,
643 | llm_num_concurrent_preds=6,
644 | metrics_overall_num_samples_per_label=8_000,
645 | metrics_max_parallel=3,
646 | metrics_label_distribution="train_set",
647 | # metrics_to_evaluate=None,
648 | icl_and_prompt_template=dict(
649 | icl_template="""
650 | Headline by Assistant: {{icl[example_text]}}""".strip()
651 | + " ",
652 | prompt_template="""
653 | Human:
654 | {{icl_examples}}
655 |
656 | News Article:
657 | {{retrieved_context}}
658 |
659 | Write a headline for the above news article about {label_verbalization}. The headline should be a single sentence.
660 | Headline by Assistant: """.strip()
661 | + " ",
662 | ),
663 | tracker=TRACKER,
664 | background=BACKGROUND,
665 | verbosity=2,
666 | step_wait=5,
667 | cart_frac=CART_FRAC,
668 | # dry_run=True,
669 | )
670 |
671 | SYNTHESIZRR_NUM_SHOTS_LIST = [0, 3]
672 | if (
673 | "synthesizrr_retr_icl_toi_headlines_claude_instant_v1_exn" not in globals()
674 | or synthesizrr_retr_icl_toi_headlines_claude_instant_v1_exn.status
675 | is Status.FAILED
676 | ):
677 | synthesizrr_retr_icl_toi_headlines_claude_instant_v1_exn = run_chain(
678 | results_dir=RESULTS_DIR,
679 | expt=Experiment.SynthesizRR,
680 | dataset_name=DatasetName.ToiHeadlines,
681 | model_name=ModelName.Claude_Instant_v1,
682 | num_shots_list=SYNTHESIZRR_NUM_SHOTS_LIST,
683 | corpus=Corpus.RealNews_India,
684 | retriever=Retriever.Contriever,
685 | num_samples_per_label=2_000,
686 | seed_type="train_set",
687 | seed_set_stratify_on_ground_truth=False,
688 | llm_batch_size=1,
689 | llm_submission_batch_size=12,
690 | llm_num_models=1,
691 | llm_num_concurrent_preds=10,
692 | metrics_overall_num_samples_per_label=8_000,
693 | metrics_max_parallel=3,
694 | metrics_label_distribution="train_set",
695 | # metrics_to_evaluate=None,
696 | tracker=TRACKER,
697 | background=BACKGROUND,
698 | verbosity=2,
699 | step_wait=5,
700 | cart_frac=CART_FRAC,
701 | # dry_run=True,
702 | )
703 |
704 | """
705 | ___ _
706 | / __| __ _ | |_ ___ __ _ ___ _ _ _ _
707 | | (__ / _` || _|/ -_)/ _` |/ _ \| '_|| || |
708 | \___|\__,_| \__|\___|\__, |\___/|_| \_, |
709 | |___/ |__/
710 | """
711 | FEWGEN_NUM_SHOTS_LIST = [0, 32]
712 | if (
713 | "fewgen_amazon_reviews_category_llama_2_13b_chat_exn" not in globals()
714 | or fewgen_amazon_reviews_category_llama_2_13b_chat_exn.status is Status.FAILED
715 | ):
716 | fewgen_amazon_reviews_category_llama_2_13b_chat_exn = run_chain(
717 | results_dir=RESULTS_DIR,
718 | expt=Experiment.FewGen,
719 | dataset_name=DatasetName.AmazonReviewsProductCategory,
720 | model_name=ModelName.LLaMa_2_13B_Chat,
721 | num_shots_list=FEWGEN_NUM_SHOTS_LIST,
722 | num_samples_per_label=1_000,
723 | seed_type="train_set",
724 | seed_set_stratify_on_ground_truth=False,
725 | llm_num_models=48,
726 | metrics_overall_num_samples_per_label=8_000,
727 | metrics_max_parallel=3,
728 | metrics_label_distribution="train_set",
729 | # metrics_to_evaluate=None,
730 | tracker=TRACKER,
731 | background=BACKGROUND,
732 | verbosity=1,
733 | step_wait=5,
734 | cart_frac=CART_FRAC,
735 | # dry_run=True,
736 | )
737 |
738 | SYNTHESIZRR_NUM_SHOTS_LIST = [32]
739 | if (
740 | "synthesizrr_no_retr_icl_amazon_reviews_category_llama_2_13b_chat_exn"
741 | not in globals()
742 | or synthesizrr_no_retr_icl_amazon_reviews_category_llama_2_13b_chat_exn.status
743 | is Status.FAILED
744 | ):
745 | synthesizrr_no_retr_icl_amazon_reviews_category_llama_2_13b_chat_exn = (
746 | run_chain(
747 | results_dir=RESULTS_DIR,
748 | expt=Experiment.SynthesizRR,
749 | dataset_name=DatasetName.AmazonReviewsProductCategory,
750 | model_name=ModelName.LLaMa_2_13B_Chat,
751 | num_shots_list=SYNTHESIZRR_NUM_SHOTS_LIST,
752 | corpus=Corpus.AmazonProducts,
753 | retriever=Retriever.Contriever,
754 | num_samples_per_label=1_000,
755 | seed_type="train_set",
756 | seed_set_stratify_on_ground_truth=False,
757 | icl_type="seed",
758 | llm_batch_size=1,
759 | llm_submission_batch_size=12,
760 | llm_num_models=48,
761 | llm_num_concurrent_preds=23,
762 | metrics_overall_num_samples_per_label=8_000,
763 | metrics_max_parallel=3,
764 | metrics_label_distribution="train_set",
765 | # metrics_to_evaluate=None,
766 | icl_and_prompt_template=dict(
767 | icl_template="""
768 | Review: {{icl[example_text]}}""".strip()
769 | + " ",
770 | prompt_template="""
771 | {{icl_examples}}
772 |
773 | Product details:
774 | {{retrieved_context}}
775 |
776 | Write a product review about the above product which is in the category of {label_verbalization}. Include relevant product details which are mentioned above. The review should only be a single short sentence, or a single paragraph of 3 to 4 sentences. Add very minor typos.
777 | Review: """.strip()
778 | + " ",
779 | ),
780 | tracker=TRACKER,
781 | background=BACKGROUND,
782 | verbosity=1,
783 | step_wait=5,
784 | cart_frac=CART_FRAC,
785 | # dry_run=True,
786 | )
787 | )
788 |
789 | SYNTHESIZRR_NUM_SHOTS_LIST = [0, 3]
790 | if (
791 | "synthesizrr_retr_icl_amazon_reviews_category_llama_2_13b_chat_exn"
792 | not in globals()
793 | or synthesizrr_retr_icl_amazon_reviews_category_llama_2_13b_chat_exn.status
794 | is Status.FAILED
795 | ):
796 | synthesizrr_retr_icl_amazon_reviews_category_llama_2_13b_chat_exn = run_chain(
797 | results_dir=RESULTS_DIR,
798 | expt=Experiment.SynthesizRR,
799 | dataset_name=DatasetName.AmazonReviewsProductCategory,
800 | model_name=ModelName.LLaMa_2_13B_Chat,
801 | num_shots_list=SYNTHESIZRR_NUM_SHOTS_LIST,
802 | corpus=Corpus.AmazonProducts,
803 | retriever=Retriever.Contriever,
804 | num_samples_per_label=1_000,
805 | seed_type="train_set",
806 | seed_set_stratify_on_ground_truth=False,
807 | llm_num_models=48,
808 | metrics_overall_num_samples_per_label=8_000,
809 | metrics_max_parallel=3,
810 | metrics_label_distribution="train_set",
811 | # metrics_to_evaluate=None,
812 | tracker=TRACKER,
813 | background=BACKGROUND,
814 | verbosity=1,
815 | step_wait=5,
816 | cart_frac=CART_FRAC,
817 | # dry_run=True,
818 | )
819 |
820 | FEWGEN_NUM_SHOTS_LIST = [0, 32]
821 | if (
822 | "fewgen_amazon_reviews_category_claude_instant_v1_exn" not in globals()
823 | or fewgen_amazon_reviews_category_claude_instant_v1_exn.status is Status.FAILED
824 | ):
825 | fewgen_amazon_reviews_category_claude_instant_v1_exn = run_chain(
826 | results_dir=RESULTS_DIR,
827 | expt=Experiment.FewGen,
828 | dataset_name=DatasetName.AmazonReviewsProductCategory,
829 | model_name=ModelName.Claude_Instant_v1,
830 | num_shots_list=FEWGEN_NUM_SHOTS_LIST,
831 | num_samples_per_label=1_000,
832 | seed_type="train_set",
833 | seed_set_stratify_on_ground_truth=False,
834 | llm_batch_size=1,
835 | llm_submission_batch_size=12,
836 | llm_num_models=1,
837 | llm_num_concurrent_preds=6,
838 | metrics_overall_num_samples_per_label=8_000,
839 | metrics_max_parallel=3,
840 | metrics_label_distribution="train_set",
841 | # metrics_to_evaluate=None,
842 | tracker=TRACKER,
843 | background=BACKGROUND,
844 | verbosity=1,
845 | step_wait=5,
846 | cart_frac=CART_FRAC,
847 | # dry_run=True,
848 | )
849 |
850 | SYNTHESIZRR_NUM_SHOTS_LIST = [32]
851 | if (
852 | "synthesizrr_no_retr_icl_amazon_reviews_category_claude_instant_v1_exn"
853 | not in globals()
854 | or synthesizrr_no_retr_icl_amazon_reviews_category_claude_instant_v1_exn.status
855 | is Status.FAILED
856 | ):
857 | synthesizrr_no_retr_icl_amazon_reviews_category_claude_instant_v1_exn = (
858 | run_chain(
859 | results_dir=RESULTS_DIR,
860 | expt=Experiment.SynthesizRR,
861 | dataset_name=DatasetName.AmazonReviewsProductCategory,
862 | model_name=ModelName.Claude_Instant_v1,
863 | num_shots_list=SYNTHESIZRR_NUM_SHOTS_LIST,
864 | corpus=Corpus.AmazonProducts,
865 | retriever=Retriever.Contriever,
866 | num_samples_per_label=1_000,
867 | seed_type="train_set",
868 | seed_set_stratify_on_ground_truth=False,
869 | icl_type="seed",
870 | llm_batch_size=1,
871 | llm_submission_batch_size=12,
872 | llm_num_models=1,
873 | llm_num_concurrent_preds=6,
874 | metrics_overall_num_samples_per_label=8_000,
875 | metrics_max_parallel=3,
876 | metrics_label_distribution="train_set",
877 | # metrics_to_evaluate=None,
878 | icl_and_prompt_template=dict(
879 | icl_template="""
880 | Review by Assistant: {{icl[example_text]}}""".strip()
881 | + " ",
882 | prompt_template="""
883 | Human:
884 | {{icl_examples}}
885 |
886 | Product details:
887 | {{retrieved_context}}
888 |
889 | Write a product review about the above product which is in the category of {label_verbalization}. Include relevant product details which are mentioned above. The review should only be a single short sentence, or a single paragraph of 3 to 4 sentences. Add very minor typos.
890 | Review by Assistant: """.strip()
891 | + " ",
892 | ),
893 | tracker=TRACKER,
894 | background=BACKGROUND,
895 | verbosity=1,
896 | step_wait=5,
897 | cart_frac=CART_FRAC,
898 | # dry_run=True,
899 | )
900 | )
901 |
902 | SYNTHESIZRR_NUM_SHOTS_LIST = [0, 3]
903 | if (
904 | "synthesizrr_retr_icl_amazon_reviews_category_claude_instant_v1_exn"
905 | not in globals()
906 | or synthesizrr_retr_icl_amazon_reviews_category_claude_instant_v1_exn.status
907 | is Status.FAILED
908 | ):
909 | synthesizrr_retr_icl_amazon_reviews_category_claude_instant_v1_exn = run_chain(
910 | results_dir=RESULTS_DIR,
911 | expt=Experiment.SynthesizRR,
912 | dataset_name=DatasetName.AmazonReviewsProductCategory,
913 | model_name=ModelName.Claude_Instant_v1,
914 | num_shots_list=SYNTHESIZRR_NUM_SHOTS_LIST,
915 | corpus=Corpus.AmazonProducts,
916 | retriever=Retriever.Contriever,
917 | num_samples_per_label=1_000,
918 | seed_type="train_set",
919 | seed_set_stratify_on_ground_truth=False,
920 | llm_batch_size=1,
921 | llm_submission_batch_size=12,
922 | llm_num_models=1,
923 | llm_num_concurrent_preds=6,
924 | metrics_overall_num_samples_per_label=8_000,
925 | metrics_max_parallel=3,
926 | metrics_label_distribution="train_set",
927 | # metrics_to_evaluate=None,
928 | tracker=TRACKER,
929 | background=BACKGROUND,
930 | verbosity=1,
931 | step_wait=5,
932 | cart_frac=CART_FRAC,
933 | # dry_run=True,
934 | )
935 |
936 | """
937 | _ _
938 | | || | _ _ _ __ ___ _ _
939 | | __ || || || ' \ / _ \| '_|
940 | |_||_| \_,_||_|_|_|\___/|_|
941 | """
942 |
943 | FEWGEN_NUM_SHOTS_LIST = [0, 32]
944 | if (
945 | "fewgen_amazon_humor_llama_2_13b_chat_exn" not in globals()
946 | or fewgen_amazon_humor_llama_2_13b_chat_exn.status is Status.FAILED
947 | ):
948 | fewgen_amazon_humor_llama_2_13b_chat_exn = run_chain(
949 | results_dir=RESULTS_DIR,
950 | expt=Experiment.FewGen,
951 | dataset_name=DatasetName.AmazonHumorousProductQuestions,
952 | model_name=ModelName.LLaMa_2_13B_Chat,
953 | num_shots_list=FEWGEN_NUM_SHOTS_LIST,
954 | num_samples_per_label=5_000,
955 | seed_type="train_set",
956 | seed_set_stratify_on_ground_truth=False,
957 | llm_num_models=48,
958 | metrics_overall_num_samples_per_label=2_000,
959 | metrics_max_parallel=3,
960 | metrics_label_distribution="train_set",
961 | # metrics_to_evaluate=None,
962 | tracker=TRACKER,
963 | background=BACKGROUND,
964 | verbosity=1,
965 | step_wait=5,
966 | cart_frac=CART_FRAC,
967 | # dry_run=True,
968 | )
969 |
970 | SYNTHESIZRR_NUM_SHOTS_LIST = [32]
971 | if (
972 | "synthesizrr_no_retr_icl_amazon_humor_llama_2_13b_chat_exn" not in globals()
973 | or synthesizrr_no_retr_icl_amazon_humor_llama_2_13b_chat_exn.status
974 | is Status.FAILED
975 | ):
976 | synthesizrr_no_retr_icl_amazon_humor_llama_2_13b_chat_exn = run_chain(
977 | results_dir=RESULTS_DIR,
978 | expt=Experiment.SynthesizRR,
979 | dataset_name=DatasetName.AmazonHumorousProductQuestions,
980 | model_name=ModelName.LLaMa_2_13B_Chat,
981 | num_shots_list=SYNTHESIZRR_NUM_SHOTS_LIST,
982 | corpus=Corpus.AmazonProducts,
983 | retriever=Retriever.Contriever,
984 | num_samples_per_label=5_000,
985 | seed_type="train_set",
986 | seed_set_stratify_on_ground_truth=False,
987 | icl_type="seed",
988 | llm_batch_size=1,
989 | llm_submission_batch_size=12,
990 | llm_num_models=48,
991 | llm_num_concurrent_preds=2,
992 | metrics_overall_num_samples_per_label=2_000,
993 | metrics_max_parallel=3,
994 | metrics_label_distribution="train_set",
995 | # metrics_to_evaluate=None,
996 | icl_and_prompt_template=dict(
997 | icl_template="""
998 | Product Question: {{icl[example_text]}}""".strip()
999 | + " ",
1000 | prompt_template="""
1001 | {{icl_examples}}
1002 |
1003 | Product details:
1004 | {{retrieved_context}}
1005 |
1006 | Write a short {label_verbalization} question about the above product on Amazon. Only include the question.
1007 | Product Question: """.strip()
1008 | + " ",
1009 | ),
1010 | tracker=TRACKER,
1011 | background=BACKGROUND,
1012 | verbosity=1,
1013 | step_wait=5,
1014 | cart_frac=CART_FRAC,
1015 | # dry_run=True,
1016 | )
1017 |
1018 | SYNTHESIZRR_NUM_SHOTS_LIST = [0, 3]
1019 | if (
1020 | "synthesizrr_retr_icl_amazon_humor_llama_2_13b_chat_exn" not in globals()
1021 | or synthesizrr_retr_icl_amazon_humor_llama_2_13b_chat_exn.status
1022 | is Status.FAILED
1023 | ):
1024 | synthesizrr_retr_icl_amazon_humor_llama_2_13b_chat_exn = run_chain(
1025 | results_dir=RESULTS_DIR,
1026 | expt=Experiment.SynthesizRR,
1027 | dataset_name=DatasetName.AmazonHumorousProductQuestions,
1028 | model_name=ModelName.LLaMa_2_13B_Chat,
1029 | num_shots_list=SYNTHESIZRR_NUM_SHOTS_LIST,
1030 | corpus=Corpus.AmazonProducts,
1031 | retriever=Retriever.Contriever,
1032 | num_samples_per_label=5_000,
1033 | seed_type="train_set",
1034 | seed_set_stratify_on_ground_truth=False,
1035 | llm_num_models=48,
1036 | metrics_overall_num_samples_per_label=2_000,
1037 | metrics_max_parallel=3,
1038 | metrics_label_distribution="train_set",
1039 | # metrics_to_evaluate=None,
1040 | tracker=TRACKER,
1041 | background=BACKGROUND,
1042 | verbosity=1,
1043 | step_wait=5,
1044 | cart_frac=CART_FRAC,
1045 | # dry_run=True,
1046 | )
1047 |
1048 | FEWGEN_NUM_SHOTS_LIST = [0, 32]
1049 | if (
1050 | "fewgen_amazon_humor_claude_instant_v1_exn" not in globals()
1051 | or fewgen_amazon_humor_claude_instant_v1_exn.status is Status.FAILED
1052 | ):
1053 | fewgen_amazon_humor_claude_instant_v1_exn = run_chain(
1054 | results_dir=RESULTS_DIR,
1055 | expt=Experiment.FewGen,
1056 | dataset_name=DatasetName.AmazonHumorousProductQuestions,
1057 | model_name=ModelName.Claude_Instant_v1,
1058 | num_shots_list=FEWGEN_NUM_SHOTS_LIST,
1059 | num_samples_per_label=5_000,
1060 | seed_type="train_set",
1061 | seed_set_stratify_on_ground_truth=False,
1062 | llm_batch_size=1,
1063 | llm_submission_batch_size=12,
1064 | llm_num_models=1,
1065 | llm_num_concurrent_preds=6,
1066 | metrics_overall_num_samples_per_label=2_000,
1067 | metrics_max_parallel=3,
1068 | metrics_label_distribution="train_set",
1069 | # metrics_to_evaluate=None,
1070 | tracker=TRACKER,
1071 | background=BACKGROUND,
1072 | verbosity=1,
1073 | step_wait=5,
1074 | cart_frac=CART_FRAC,
1075 | # dry_run=True,
1076 | )
1077 |
1078 | SYNTHESIZRR_NUM_SHOTS_LIST = [32]
1079 | if (
1080 | "synthesizrr_no_retr_icl_amazon_humor_claude_instant_v1_exn" not in globals()
1081 | or synthesizrr_no_retr_icl_amazon_humor_claude_instant_v1_exn.status
1082 | is Status.FAILED
1083 | ):
1084 | synthesizrr_no_retr_icl_amazon_humor_claude_instant_v1_exn = run_chain(
1085 | results_dir=RESULTS_DIR,
1086 | expt=Experiment.SynthesizRR,
1087 | dataset_name=DatasetName.AmazonHumorousProductQuestions,
1088 | model_name=ModelName.Claude_Instant_v1,
1089 | num_shots_list=SYNTHESIZRR_NUM_SHOTS_LIST,
1090 | corpus=Corpus.AmazonProducts,
1091 | retriever=Retriever.Contriever,
1092 | num_samples_per_label=5_000,
1093 | seed_type="train_set",
1094 | seed_set_stratify_on_ground_truth=False,
1095 | icl_type="seed",
1096 | llm_batch_size=1,
1097 | llm_submission_batch_size=12,
1098 | llm_num_models=1,
1099 | llm_num_concurrent_preds=6,
1100 | metrics_overall_num_samples_per_label=2_000,
1101 | metrics_max_parallel=3,
1102 | metrics_label_distribution="train_set",
1103 | # metrics_to_evaluate=None,
1104 | icl_and_prompt_template=dict(
1105 | icl_template="""
1106 | Product Question by Assistant: {{icl[example_text]}}""".strip()
1107 | + " ",
1108 | prompt_template="""
1109 | Human:
1110 | {{icl_examples}}
1111 |
1112 | Product details:
1113 | {{retrieved_context}}
1114 |
1115 | Write a short {label_verbalization} question about the above product on Amazon. Only include the question.
1116 | Product Question by Assistant: """.strip()
1117 | + " ",
1118 | ),
1119 | tracker=TRACKER,
1120 | background=BACKGROUND,
1121 | verbosity=1,
1122 | step_wait=5,
1123 | cart_frac=CART_FRAC,
1124 | # dry_run=True,
1125 | )
1126 |
1127 | SYNTHESIZRR_NUM_SHOTS_LIST = [0, 3]
1128 | if (
1129 | "synthesizrr_retr_icl_amazon_humor_claude_instant_v1_exn" not in globals()
1130 | or synthesizrr_retr_icl_amazon_humor_claude_instant_v1_exn.status
1131 | is Status.FAILED
1132 | ):
1133 | synthesizrr_retr_icl_amazon_humor_claude_instant_v1_exn = run_chain(
1134 | results_dir=RESULTS_DIR,
1135 | expt=Experiment.SynthesizRR,
1136 | dataset_name=DatasetName.AmazonHumorousProductQuestions,
1137 | model_name=ModelName.Claude_Instant_v1,
1138 | num_shots_list=SYNTHESIZRR_NUM_SHOTS_LIST,
1139 | corpus=Corpus.AmazonProducts,
1140 | retriever=Retriever.Contriever,
1141 | num_samples_per_label=5_000,
1142 | seed_type="train_set",
1143 | seed_set_stratify_on_ground_truth=False,
1144 | llm_batch_size=1,
1145 | llm_submission_batch_size=12,
1146 | llm_num_models=1,
1147 | llm_num_concurrent_preds=6,
1148 | metrics_overall_num_samples_per_label=2_000,
1149 | metrics_max_parallel=3,
1150 | metrics_label_distribution="train_set",
1151 | # metrics_to_evaluate=None,
1152 | tracker=TRACKER,
1153 | background=BACKGROUND,
1154 | verbosity=1,
1155 | step_wait=5,
1156 | cart_frac=CART_FRAC,
1157 | # dry_run=True,
1158 | )
1159 |
1160 | """
1161 | ___ _ _ _
1162 | | _ \ ___ | | __ _ _ _ (_)| |_ _ _
1163 | | _// _ \| |/ _` || '_|| || _|| || |
1164 | |_| \___/|_|\__,_||_| |_| \__| \_, |
1165 | |__/
1166 | """
1167 |
1168 | FEWGEN_NUM_SHOTS_LIST = [0, 32]
1169 | if (
1170 | "fewgen_amazon_polarity_llama_2_13b_chat_exn" not in globals()
1171 | or fewgen_amazon_polarity_llama_2_13b_chat_exn.status is Status.FAILED
1172 | ):
1173 | fewgen_amazon_polarity_llama_2_13b_chat_exn = run_chain(
1174 | results_dir=RESULTS_DIR,
1175 | expt=Experiment.FewGen,
1176 | dataset_name=DatasetName.AmazonReviewsPolarity,
1177 | model_name=ModelName.LLaMa_2_13B_Chat,
1178 | num_shots_list=FEWGEN_NUM_SHOTS_LIST,
1179 | num_samples_per_label=5_000,
1180 | seed_type="train_set",
1181 | seed_set_stratify_on_ground_truth=False,
1182 | llm_num_models=48,
1183 | metrics_overall_num_samples_per_label=4_000,
1184 | metrics_max_parallel=3,
1185 | metrics_label_distribution="train_set",
1186 | # metrics_to_evaluate=None,
1187 | tracker=TRACKER,
1188 | background=BACKGROUND,
1189 | verbosity=1,
1190 | step_wait=5,
1191 | cart_frac=CART_FRAC,
1192 | # dry_run=True,
1193 | )
1194 |
1195 | SYNTHESIZRR_NUM_SHOTS_LIST = [32]
1196 | if (
1197 | "synthesizrr_no_retr_icl_amazon_polarity_llama_2_13b_chat_exn" not in globals()
1198 | or synthesizrr_no_retr_icl_amazon_polarity_llama_2_13b_chat_exn.status
1199 | is Status.FAILED
1200 | ):
1201 | synthesizrr_no_retr_icl_amazon_polarity_llama_2_13b_chat_exn = run_chain(
1202 | results_dir=RESULTS_DIR,
1203 | expt=Experiment.SynthesizRR,
1204 | dataset_name=DatasetName.AmazonReviewsPolarity,
1205 | model_name=ModelName.LLaMa_2_13B_Chat,
1206 | num_shots_list=SYNTHESIZRR_NUM_SHOTS_LIST,
1207 | corpus=Corpus.AmazonProducts,
1208 | retriever=Retriever.Contriever,
1209 | num_samples_per_label=5_000,
1210 | seed_type="train_set",
1211 | seed_set_stratify_on_ground_truth=False,
1212 | icl_type="seed",
1213 | llm_batch_size=1,
1214 | llm_submission_batch_size=12,
1215 | llm_num_models=48,
1216 | llm_num_concurrent_preds=2,
1217 | metrics_overall_num_samples_per_label=4_000,
1218 | metrics_max_parallel=3,
1219 | metrics_label_distribution="train_set",
1220 | # metrics_to_evaluate=None,
1221 | icl_and_prompt_template=dict(
1222 | icl_template="""
1223 | Review: {{icl[example_text]}}""".strip()
1224 | + " ",
1225 | prompt_template="""
1226 | {{icl_examples}}
1227 |
1228 | Product details:
1229 | {{retrieved_context}}
1230 |
1231 | Write a review about the above product on Amazon which discusses {label_verbalization}. Include relevant product details which are mentioned above. The review should only be a single short sentence, or a single paragraph of 3 to 4 sentences. Add very minor typos.
1232 | Review: """.strip()
1233 | + " ",
1234 | ),
1235 | tracker=TRACKER,
1236 | background=BACKGROUND,
1237 | verbosity=1,
1238 | step_wait=5,
1239 | cart_frac=CART_FRAC,
1240 | # dry_run=True,
1241 | )
1242 |
1243 | SYNTHESIZRR_NUM_SHOTS_LIST = [0, 3]
1244 | if (
1245 | "synthesizrr_retr_icl_amazon_polarity_llama_2_13b_chat_exn" not in globals()
1246 | or synthesizrr_retr_icl_amazon_polarity_llama_2_13b_chat_exn.status
1247 | is Status.FAILED
1248 | ):
1249 | synthesizrr_retr_icl_amazon_polarity_llama_2_13b_chat_exn = run_chain(
1250 | results_dir=RESULTS_DIR,
1251 | expt=Experiment.SynthesizRR,
1252 | dataset_name=DatasetName.AmazonReviewsPolarity,
1253 | model_name=ModelName.LLaMa_2_13B_Chat,
1254 | num_shots_list=SYNTHESIZRR_NUM_SHOTS_LIST,
1255 | corpus=Corpus.AmazonProducts,
1256 | retriever=Retriever.Contriever,
1257 | num_samples_per_label=5_000,
1258 | seed_type="train_set",
1259 | seed_set_stratify_on_ground_truth=False,
1260 | llm_num_models=48,
1261 | metrics_overall_num_samples_per_label=4_000,
1262 | metrics_max_parallel=3,
1263 | metrics_label_distribution="train_set",
1264 | # metrics_to_evaluate=None,
1265 | tracker=TRACKER,
1266 | background=BACKGROUND,
1267 | verbosity=1,
1268 | step_wait=5,
1269 | cart_frac=CART_FRAC,
1270 | # dry_run=True,
1271 | )
1272 |
1273 | FEWGEN_NUM_SHOTS_LIST = [0, 32]
1274 | if (
1275 | "fewgen_amazon_polarity_claude_instant_v1_exn" not in globals()
1276 | or fewgen_amazon_polarity_claude_instant_v1_exn.status is Status.FAILED
1277 | ):
1278 | fewgen_amazon_polarity_claude_instant_v1_exn = run_chain(
1279 | results_dir=RESULTS_DIR,
1280 | expt=Experiment.FewGen,
1281 | dataset_name=DatasetName.AmazonReviewsPolarity,
1282 | model_name=ModelName.Claude_Instant_v1,
1283 | num_shots_list=FEWGEN_NUM_SHOTS_LIST,
1284 | num_samples_per_label=5_000,
1285 | seed_type="train_set",
1286 | seed_set_stratify_on_ground_truth=False,
1287 | llm_batch_size=1,
1288 | llm_submission_batch_size=12,
1289 | llm_num_models=1,
1290 | llm_num_concurrent_preds=6,
1291 | metrics_overall_num_samples_per_label=4_000,
1292 | metrics_max_parallel=3,
1293 | metrics_label_distribution="train_set",
1294 | # metrics_to_evaluate=None,
1295 | tracker=TRACKER,
1296 | background=BACKGROUND,
1297 | verbosity=1,
1298 | step_wait=5,
1299 | cart_frac=CART_FRAC,
1300 | # dry_run=True,
1301 | )
1302 |
1303 | SYNTHESIZRR_NUM_SHOTS_LIST = [32]
1304 | if (
1305 | "synthesizrr_no_retr_icl_amazon_polarity_claude_instant_v1_exn" not in globals()
1306 | or synthesizrr_no_retr_icl_amazon_polarity_claude_instant_v1_exn.status
1307 | is Status.FAILED
1308 | ):
1309 | synthesizrr_no_retr_icl_amazon_polarity_claude_instant_v1_exn = run_chain(
1310 | results_dir=RESULTS_DIR,
1311 | expt=Experiment.SynthesizRR,
1312 | dataset_name=DatasetName.AmazonReviewsPolarity,
1313 | model_name=ModelName.Claude_Instant_v1,
1314 | num_shots_list=SYNTHESIZRR_NUM_SHOTS_LIST,
1315 | corpus=Corpus.AmazonProducts,
1316 | retriever=Retriever.Contriever,
1317 | num_samples_per_label=5_000,
1318 | seed_type="train_set",
1319 | seed_set_stratify_on_ground_truth=False,
1320 | icl_type="seed",
1321 | llm_batch_size=1,
1322 | llm_submission_batch_size=12,
1323 | llm_num_models=1,
1324 | llm_num_concurrent_preds=6,
1325 | metrics_overall_num_samples_per_label=4_000,
1326 | metrics_max_parallel=3,
1327 | metrics_label_distribution="train_set",
1328 | # metrics_to_evaluate=None,
1329 | icl_and_prompt_template=dict(
1330 | icl_template="""
1331 | Review by Assistant: {{icl[example_text]}}""".strip()
1332 | + " ",
1333 | prompt_template="""
1334 | Human:
1335 | {{icl_examples}}
1336 |
1337 | Product details:
1338 | {{retrieved_context}}
1339 |
1340 | Write a review about the above product on Amazon which discusses {label_verbalization}. Include relevant product details which are mentioned above. The review should only be a single short sentence, or a single paragraph of 3 to 4 sentences. Add very minor typos.
1341 | Review by Assistant: """.strip()
1342 | + " ",
1343 | ),
1344 | tracker=TRACKER,
1345 | background=BACKGROUND,
1346 | verbosity=1,
1347 | step_wait=5,
1348 | cart_frac=CART_FRAC,
1349 | # dry_run=True,
1350 | )
1351 |
1352 | SYNTHESIZRR_NUM_SHOTS_LIST = [0, 3]
1353 | if (
1354 | "synthesizrr_retr_icl_amazon_polarity_claude_instant_v1_exn" not in globals()
1355 | or synthesizrr_retr_icl_amazon_polarity_claude_instant_v1_exn.status
1356 | is Status.FAILED
1357 | ):
1358 | synthesizrr_retr_icl_amazon_polarity_claude_instant_v1_exn = run_chain(
1359 | results_dir=RESULTS_DIR,
1360 | expt=Experiment.SynthesizRR,
1361 | dataset_name=DatasetName.AmazonReviewsPolarity,
1362 | model_name=ModelName.Claude_Instant_v1,
1363 | num_shots_list=SYNTHESIZRR_NUM_SHOTS_LIST,
1364 | corpus=Corpus.AmazonProducts,
1365 | retriever=Retriever.Contriever,
1366 | num_samples_per_label=5_000,
1367 | seed_type="train_set",
1368 | seed_set_stratify_on_ground_truth=False,
1369 | llm_batch_size=1,
1370 | llm_submission_batch_size=12,
1371 | llm_num_models=1,
1372 | llm_num_concurrent_preds=6,
1373 | metrics_overall_num_samples_per_label=4_000,
1374 | metrics_max_parallel=3,
1375 | metrics_label_distribution="train_set",
1376 | # metrics_to_evaluate=None,
1377 | tracker=TRACKER,
1378 | background=BACKGROUND,
1379 | verbosity=1,
1380 | step_wait=5,
1381 | cart_frac=CART_FRAC,
1382 | )
1383 |
--------------------------------------------------------------------------------