├── .flake8
├── .gitignore
├── .pre-commit-config.yaml
├── .vscode
└── settings.json
├── CITATION.cff
├── Dockerfile.streamlit
├── LICENSE
├── Makefile
├── README.md
├── data
├── .gitignore
└── derived
│ ├── generations
│ └── math_df.csv
│ ├── mn_general_student_queries.csv
│ └── outcome_df.csv
├── figures
├── dhf-logo-vector-blue.svg
├── dhf-poster-logo.png
├── faithfulness_by_guidance.pdf
├── guidance_autometrics.pdf
├── mean_faithfulness.pdf
├── mean_relevance.pdf
├── mean_relevance2.pdf
├── pairwise_ranks.pdf
├── pairwise_ranks_poster.png
├── rank_distribution.pdf
├── relevance_rank_comparison.pdf
├── relevance_x_faithfulness.pdf
├── rori-rag-final-5.png
├── rori-rag-final.drawio
├── slides_groundedness_v_relevance.png
├── slides_human_groundedness.png
├── slides_kf1.png
├── slides_kf1_hist.png
├── slides_kf1_hist_low.png
├── slides_kf1_hist_low_high.png
├── slides_kf1_no_ir.png
├── slides_preference.png
├── system-diagram-poster.png
└── system-diagram.svg
├── notebooks
├── EvaluateRetrievalImpact.ipynb
├── MathNationDataExploration.ipynb
├── MetricTesting.ipynb
├── Qualtrics.ipynb
└── SurveyDataAnalysis.ipynb
├── poetry.lock
├── poetry.toml
├── pyproject.toml
├── src
├── experiment
│ ├── auth.py
│ ├── completion_utils.py
│ ├── generate.py
│ ├── guidance_conditions.py
│ ├── metrics.py
│ ├── qualtrics.py
│ └── tokenize.py
├── rag
│ ├── __init__.py
│ ├── embedding_utils.py
│ ├── gpf_utils.py
│ ├── logit_bias.py
│ ├── misconceptions.py
│ ├── prompt_utils.py
│ ├── prompts
│ │ ├── __init__.py
│ │ ├── hints.py
│ │ └── mathqa.py
│ ├── resources
│ │ ├── dolma_stopwords.json
│ │ └── misconceptions.ndjson
│ ├── retrieval.py
│ └── retrieval_strategies.py
└── streamlit
│ ├── Annotation_tool.py
│ └── pages
│ └── Relevance.py
└── tests
└── unit
├── conftest.py
├── rag
├── conftest.py
├── test_embedding_utils.py
├── test_gpf_utils.py
├── test_logit_bias.py
├── test_misconceptions.py
├── test_prompt_utils.py
├── test_retrieval.py
└── test_retrieval_strategies.py
├── test_auth.py
├── test_completion_utils.py
├── test_generate.py
├── test_metrics.py
├── test_qualtrics.py
└── test_tokenize.py
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | exclude = venv
3 | ignore = E203, E501, W503
4 | max-line-length = 120
5 |
6 | # E203: Whitespace before ':'
7 | # E501: Line too long
8 | # W503: https://github.com/psf/black/issues/52 and https://black.readthedocs.io/en/stable/guides/using_black_with_other_tools.html#flake8
9 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Project-specific
2 | .DS_Store
3 | .streamlit/secrets.toml
4 | .env
5 |
6 | # Byte-compiled / optimized / DLL files
7 | __pycache__/
8 | *.py[cod]
9 | *$py.class
10 |
11 | # C extensions
12 | *.so
13 |
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | share/python-wheels/
29 | *.egg-info/
30 | .installed.cfg
31 | *.egg
32 | MANIFEST
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .nox/
48 | .coverage
49 | .coverage.*
50 | .cache
51 | nosetests.xml
52 | coverage.xml
53 | *.cover
54 | *.py,cover
55 | .hypothesis/
56 | .pytest_cache/
57 | cover/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 | db.sqlite3-journal
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | .pybuilder/
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # IPython
87 | profile_default/
88 | ipython_config.py
89 |
90 | # pyenv
91 | # For a library or package, you might want to ignore these files since the code is
92 | # intended to run in multiple environments; otherwise, check them in:
93 | # .python-version
94 |
95 | # pipenv
96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
99 | # install all needed dependencies.
100 | #Pipfile.lock
101 |
102 | # poetry
103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
104 | # This is especially recommended for binary packages to ensure reproducibility, and is more
105 | # commonly ignored for libraries.
106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
107 | #poetry.lock
108 |
109 | # pdm
110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
111 | #pdm.lock
112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
113 | # in version control.
114 | # https://pdm.fming.dev/#use-with-ide
115 | .pdm.toml
116 |
117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
118 | __pypackages__/
119 |
120 | # Celery stuff
121 | celerybeat-schedule
122 | celerybeat.pid
123 |
124 | # SageMath parsed files
125 | *.sage.py
126 |
127 | # Environments
128 | .env
129 | .venv
130 | env/
131 | venv/
132 | ENV/
133 | env.bak/
134 | venv.bak/
135 |
136 | # Spyder project settings
137 | .spyderproject
138 | .spyproject
139 |
140 | # Rope project settings
141 | .ropeproject
142 |
143 | # mkdocs documentation
144 | /site
145 |
146 | # mypy
147 | .mypy_cache/
148 | .dmypy.json
149 | dmypy.json
150 |
151 | # Pyre type checker
152 | .pyre/
153 |
154 | # pytype static type analyzer
155 | .pytype/
156 |
157 | # Cython debug symbols
158 | cython_debug/
159 |
160 | # PyCharm
161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
163 | # and can be added to the global gitignore or merged into this file. For a more nuclear
164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
165 | #.idea/
166 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v4.6.0
4 | hooks:
5 | - id: trailing-whitespace
6 | - id: end-of-file-fixer
7 | - id: check-yaml
8 | - repo: https://github.com/pycqa/isort
9 | rev: 5.13.2
10 | hooks:
11 | - id: isort
12 | - repo: https://github.com/asottile/add-trailing-comma
13 | rev: v3.1.0
14 | hooks:
15 | - id: add-trailing-comma
16 | args: [--py36-plus]
17 | - repo: https://github.com/asottile/pyupgrade
18 | rev: v3.17.0
19 | hooks:
20 | - id: pyupgrade
21 | args: [--py37-plus]
22 | - repo: https://github.com/psf/black-pre-commit-mirror
23 | rev: 24.8.0
24 | hooks:
25 | - id: black
26 | language_version: python3.9
27 | - repo: https://github.com/nbQA-dev/nbQA
28 | rev: 1.8.7
29 | hooks:
30 | - id: nbqa-black
31 | - id: nbqa-pyupgrade
32 | args: ["--py37-plus"]
33 | - id: nbqa-isort
34 | args: ["--float-to-top"]
35 | - repo: https://github.com/PyCQA/flake8
36 | rev: 7.1.1
37 | hooks:
38 | - id: flake8
39 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python.testing.pytestArgs": [
3 | "tests"
4 | ],
5 | "python.testing.unittestEnabled": false,
6 | "python.testing.pytestEnabled": true
7 | }
8 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | message: "If you use this software, please cite the paper as below."
3 | authors:
4 | - family-names: Levonian
5 | given-names: Zachary
6 | orcid: https://orcid.org/0000-0002-8932-1489
7 | - family-names: Li
8 | given-names: Chenglu
9 | - family-names: Zhu
10 | given-names: Wangda
11 | - family-names: Gade
12 | given-names: Anoushka
13 | - family-names: Henkel
14 | given-names: Owen
15 | - family-names: Postle
16 | given-names: Millie-Ellen
17 | - family-names: Xing
18 | given-names: Wanli
19 | date-released: 2023-10-04
20 | repository-code: "https://github.com/DigitalHarborFoundation/rag-for-math-qa"
21 | preferred-citation:
22 | type: conference-paper
23 | title: "Retrieval-augmented Generation to Improve Math Question-Answering: Trade-offs Between Groundedness and Human Preference"
24 | abstract: "For middle-school math students, interactive question-answering (QA) with tutors is an effective way to learn. The flexibility and emergent capabilities of generative large language models (LLMs) has led to a surge of interest in automating portions of the tutoring process - including interactive QA to support conceptual discussion of mathematical concepts. However, LLM responses to math questions can be incorrect or mismatched to the educational context - such as being misaligned with a school's curriculum. One potential solution is retrieval-augmented generation (RAG), which involves incorporating a vetted external knowledge source in the LLM prompt to increase response quality. In this paper, we designed prompts that retrieve and use content from a high-quality open-source math textbook to generate responses to real student questions. We evaluate the efficacy of this RAG system for middle-school algebra and geometry QA by administering a multi-condition survey, finding that humans prefer responses generated using RAG, but not when responses are too grounded in the textbook content. We argue that while RAG is able to improve response quality, designers of math QA systems must consider trade-offs between generating responses preferred by students and responses closely matched to specific educational resources."
25 | doi: 10.48550/arXiv.2310.03184
26 | year: 2023
27 | conference:
28 | name: "NeurIPS'23 Workshop on Generative AI for Education (GAIED)"
29 | city: "New Orleans"
30 | country: "US"
31 | date-start: "2023-12-15"
32 | date-end: "2023-12-15"
33 | authors:
34 | - family-names: Levonian
35 | given-names: Zachary
36 | orcid: https://orcid.org/0000-0002-8932-1489
37 | - family-names: Li
38 | given-names: Chenglu
39 | - family-names: Zhu
40 | given-names: Wangda
41 | - family-names: Gade
42 | given-names: Anoushka
43 | - family-names: Henkel
44 | given-names: Owen
45 | - family-names: Postle
46 | given-names: Millie-Ellen
47 | - family-names: Xing
48 | given-names: Wanli
49 |
50 |
51 |
--------------------------------------------------------------------------------
/Dockerfile.streamlit:
--------------------------------------------------------------------------------
1 | FROM python:3.10
2 |
3 | RUN python -m pip install --upgrade pip
4 | RUN curl -sSL https://install.python-poetry.org | python -
5 |
6 | WORKDIR /usr/app
7 |
8 | COPY ./src ./src
9 | COPY ./tests ./tests
10 |
11 | # note: README.md is required for Poetry
12 | COPY README.md README.md
13 | COPY pyproject.toml pyproject.toml
14 | COPY poetry.lock poetry.lock
15 | RUN pip install .
16 |
17 | EXPOSE 8502
18 | ENTRYPOINT ["streamlit", "run", "src/streamlit/Annotation_tool.py", "--server.port=8502", "--server.address=0.0.0.0"]
19 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Zachary Levonian
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: help install ensure-poetry install-precommits test run-streamlit build-docker run-docker remove-docker
2 |
3 | export PATH := $(HOME)/.local/bin:$(PATH)
4 |
5 | tt:
6 | @which poetry
7 | @poetry install
8 |
9 | help:
10 | @echo "Relevant targets are 'install' and 'test'."
11 |
12 | install:
13 | @$(MAKE) ensure-poetry
14 | @$(MAKE) install-precommits
15 | @poetry build
16 |
17 | ensure-poetry:
18 | @# see issue: https://stackoverflow.com/questions/77019756/make-not-finding-executable-added-to-path-in-makefile
19 | @if ! command -v poetry &> /dev/null; then \
20 | echo "Installing poetry"; \
21 | curl -sSL https://install.python-poetry.org | python - ; \
22 | echo "Poetry installed, but you might need to update your PATH before make will detect it."; \
23 | fi
24 | @poetry install
25 |
26 | install-precommits:
27 | @poetry run pre-commit autoupdate
28 | @poetry run pre-commit install --overwrite --install-hooks
29 |
30 | jupyter:
31 | poetry run jupyter lab
32 |
33 | test:
34 | @poetry run pytest --cov=src --cov-report term-missing
35 |
36 | build-docker:
37 | @poetry build
38 | @docker build -t streamlit -f Dockerfile.streamlit .
39 |
40 | run-docker:
41 | @$(MAKE) remove-docker
42 | @docker run \
43 | --name streamlit_container \
44 | -p 8502:8502 \
45 | -v ./.streamlit/secrets.toml:/usr/app/.streamlit/secrets.toml \
46 | streamlit:latest
47 |
48 | remove-docker:
49 | @if docker ps -q --filter "name=streamlit_container" | grep -q .; then \
50 | echo "Stopping streamlit container"; \
51 | docker stop streamlit_container; \
52 | fi
53 | @if docker ps -a -q --filter "name=streamlit_container" | grep -q .; then \
54 | echo "Removing streamlit container"; \
55 | docker remove --volumes streamlit_container; \
56 | fi
57 |
58 | run-streamlit:
59 | @streamlit run src/streamlit/Annotation_tool.py --
60 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Retrieval-augmented generation to improve math question-answering: trade-offs between groundedness and human preference
2 |
3 | [](https://arxiv.org/abs/2310.03184)
4 | [](https://github.com/DigitalHarborFoundation/rag-for-math-qa/blob/main/LICENSE)
5 |
6 |
7 | This repository contains analysis code, prompts, surveys, figures, and data for the paper "Retrieval-augmented generation to improve math question-answering: trade-offs between groundedness and human preference".
8 |
9 | This repository forks the [`llm-math-education`](https://github.com/DigitalHarborFoundation/llm-math-education) package.
10 |
11 | Cite [the paper](https://arxiv.org/abs/2310.03184) using the CITATION.cff file and dropdown:
12 |
13 | >Zachary Levonian, Chenglu Li, Wangda Zhu, Anoushka Gade, Owen Henkel, Millie-Ellen Postle, and Wanli Xing. 2023. Retrieval-augmented Generation to Improve Math Question-Answering: Trade-offs Between Groundedness and Human Preference. In _NeurIPS’23 Workshop on Generative AI for Education (GAIED)_, New Orleans, USA. DOI:https://doi.org/10.48550/arXiv.2310.03184
14 |
15 | ## Development
16 |
17 | Primary code contributor:
18 |
19 | - Zachary Levonian ()
20 |
21 | ## Local development setup
22 |
23 | This project uses `make` and `Poetry` to manage and install dependencies.
24 |
25 | On Windows, you'll need to use WSL and maybe make some other changes.
26 |
27 | ### Python development
28 |
29 | Use `make install` to install all needed dependencies (including the pre-commit hooks and Poetry).
30 |
31 | You'll probably need to manually add Poetry to your PATH, e.g. by updating your `.bashrc` (or relevant equivalent):
32 |
33 | ```bash
34 | export PATH="$HOME/.local/bin:$PATH"
35 | ```
36 |
37 | ### Run tests
38 |
39 | ```bash
40 | make test
41 | ```
42 |
43 | ### Run Jupyter Lab
44 |
45 | ```bash
46 | make jupyter
47 | ```
48 |
49 | Which really just runs `poetry run jupyter lab`, so feel free to customize your Jupyter experience.
50 |
51 | ### Other useful commands
52 |
53 | - `poetry run ` - Run the given command, e.g. `poetry run pytest` invokes the tests.
54 | - `poetry add ` - Add the given package as a dependency. Use flag `-G dev` to add it as a development dependency.
55 |
56 | ## Other notes
57 |
58 | ### Poster figures
59 |
60 | Some logos are present in the posters directory.
61 |
62 | The Digital Harbor Foundation logo was created using [`rsvg-convert`](https://man.archlinux.org/man/rsvg-convert.1.en), installed via `brew`.
63 |
64 | I manually adjusted the source svg (dhf-logo-vector-blue.svg) to use DHF blue (#0091c9) rather than black (#010101).
65 |
66 | ```bash
67 | brew install librsvg
68 | rsvg-convert -d 150 -p 150 -h 2in figures/dhf-logo-vector-blue.svg > figures/dhf-poster-logo.png
69 | ```
70 |
71 | Converting the system diagram (converted from draw.io as an SVG, with embedded fonts):
72 |
73 | ```bash
74 | rsvg-convert -d 150 -p 150 -h 4in figures/system-diagram.svg > figures/system-diagram-poster.png
75 | ```
76 |
--------------------------------------------------------------------------------
/data/.gitignore:
--------------------------------------------------------------------------------
1 | raw/*
2 | derived/*
3 |
--------------------------------------------------------------------------------
/data/derived/mn_general_student_queries.csv:
--------------------------------------------------------------------------------
1 | post_id,subject_name,post_content,is_respondable_query
2 | 19856,Algebra 1,how do you a graph function rule?!,general
3 | 114676,Algebra 1,what is the quadratic formula?,general
4 | 124734,Algebra 1,what is monomial,general
5 | 199138,Algebra 1,Can someone give me a factoring example I don't get it and I re watched the videos and I still don't get it.,general
6 | 285849,Algebra 1,if all the x values in a function are differnt its a function right?,general
7 | 294570,Algebra 1,Give me some examples of an equation that is not a linear function!,general
8 | 394330,Algebra 1,"what are regression and median-fit lines?
9 | [Continued:] I really don't get them.(it) (that)",general
10 | 503676,Algebra 1,"how do you oslve with negative exponents (for example 4^-9)
11 | [Continued:] *solve",general
12 | 591424,Algebra 1,How do you find the parent function of a graph?,general
13 | 645067,Algebra 1,after you distribute a negative does it disappear or do you keep it there?,general
14 | 674022,Algebra 1,I always forget the difference between commutative and associative. Does anybody know a way to make me remember?,general
15 | 726510,Algebra 1,How do you find the radius?,general
16 | 859909,Algebra 1,How do you find the domain and range that are reasonable for a certain function?,general
17 | 1105388,Pre-Algebra,Hey guys! What is the quotient rule??,general
18 | 1140336,Algebra 1,how do i solve square root functions,general
19 | 1179906,Pre-Algebra,How do I multiply fractions???????,general
20 | 1254378,Algebra 1,"where can I find completing the square?
21 | ",general
22 | 1420227,Algebra 1,How does standard deivation work?,general
23 | 1423442,Pre-Algebra,How do you multiply fractions?!?!?,general
24 | 1487901,Algebra 1,"how far can a polynomial go
25 | ",general
26 | 1790360,Algebra 1,What is vertex form and how do you solve for it?,general
27 | 1857118,Algebra 1,what is the difference between communtative property of addition and addition property of equality,general
28 | 1900270,Algebra 1,"Just checking, If I have a problem like x to the 4th in parentheses and there's a exponent out of the parentheses lets say 2 would you multiply the exponent in it by the one out making it x to the 8th power?",general
29 | 1928752,Algebra 1,"Can I get the steps for factoring quadratics
30 | ",general
31 | 1942173,Algebra 1,What is the domain and range? How do I find it?,general
32 | 2002640,Algebra 1,"Can someone help me with what an irrational number is vs a rational number
33 | ",general
34 | 2015759,Algebra 1,"I need help with this question:
35 |
36 | An entertainment firm offers several DJ choices and light shows that range in price based on the rental time period. The DJ's cost between $219.00 and $369.00 per night and the light shows cost between $159.00 and $309.00 per night. If you are booking both a DJ and a light show, write a compound inequality that represents the possible total amount you would pay, x.",problem
37 | 2132595,Algebra 1,What is a leading coefficient?,general
38 | 2153141,Algebra 1,wait... so if im using pemdas how does that work with fractions ?,general
39 | 2166131,Algebra 1,"I need a standard way of recognizing polynomials and their degrees, please help me figure
40 | it out. ",general
41 | 2207350,Algebra 1,"Does anyone know how to find the domain and range of a problem like f(x̄)=x̄²+8?
42 | ","general,problem"
43 | 2311169,Algebra 1,What is the difference between commutative and associative property?,general
44 | 2353255,Algebra 1,"I know that this has been a consistent question but I have a test on line of best fit. I've been given multiple different answers on the algebra wall and some make sense while others don't. My teacher never gave me an equation so I don't think I'm supposed to be using one, so what is the best way to find the accurate line of best fit. ",general
45 | 2389599,Algebra 1,what is the difference between recursive formula and explicit form.,general
46 | 2400730,Algebra 1,how do you know if a graph is an absolute value graph?,general
47 | 2547127,Algebra 1,i need help on how to graph quadratic funtions,general
48 | 2571397,Algebra 1,could i recieve help with turning a trinomial into grouping,general
49 | 2660369,Geometry,Is supplementary angles are always adjacent,general
50 | 2672704,Algebra 1,"i have a problem is this equation linear?
51 | 7x + y + 3 = y
52 | if its linear or not linear can somebody help me on how you can tell if its linear or not","general,problem"
53 | 2675038,Algebra 1,Can function Notation be negative?,general
54 | 2689932,Algebra 1,I don't understand how to find x if there are two x's,general
55 | 2690581,Algebra 1,How do you know if a number is a constant?,general
56 | 2727327,Geometry,How do you translate a shape,general
57 | 2750298,Algebra 1,What is a function notation?,general
58 | 2921833,Algebra 1,How do i graph an inequality with a less than or equal to sign?,general
59 | 2928270,Geometry,Why does the sine of an angle equal the sine of its supplement?,general
60 | 2933099,Algebra 1,Is a rational and irrational number always irrational?,general
61 | 2965705,Geometry,How do I add line segments again??,general
62 | 3033408,Algebra 1,How do you find the zeros of a function by graphing?,general
63 | 3034048,Geometry,I dont get how to get the length of the sides without knowing the perimeter.,general
64 | 3069503,Algebra 1,How can you use polynomials to find the perimeter of a square for a example someone giving you a side length of (3s-5),"general,problem"
65 | 3120570,Geometry,"I know that sin A=opposite/hypotenuse, cos A=adjacent/hypotenuse, and tan A=opposite/adjacent. Is there any other ratios like those?",general
66 | 3293299,Geometry,"I'm still a little lost on how you find midpoints on graphs. I'm currently on Topic 4 of Section 1, and I'm watching the video, but I'm still a little confuse on how they find it. Here is the question. If anyone knows the steps to finding the midpoint on a graph, thank you! ",general
67 | 3311347,Geometry,Can someone explain like interior and exterior angles? thanks,general
68 | 3322367,Geometry,"how do i find the slope of a line in slope intercept form.
69 | ",general
70 |
--------------------------------------------------------------------------------
/figures/dhf-logo-vector-blue.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
7 |
8 |
9 |
12 |
13 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
24 |
28 |
31 |
33 |
34 |
37 |
39 |
41 |
44 |
45 |
46 |
47 |
50 |
52 |
59 |
60 |
61 |
62 |
--------------------------------------------------------------------------------
/figures/dhf-poster-logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/dhf-poster-logo.png
--------------------------------------------------------------------------------
/figures/faithfulness_by_guidance.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/faithfulness_by_guidance.pdf
--------------------------------------------------------------------------------
/figures/guidance_autometrics.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/guidance_autometrics.pdf
--------------------------------------------------------------------------------
/figures/mean_faithfulness.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/mean_faithfulness.pdf
--------------------------------------------------------------------------------
/figures/mean_relevance.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/mean_relevance.pdf
--------------------------------------------------------------------------------
/figures/mean_relevance2.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/mean_relevance2.pdf
--------------------------------------------------------------------------------
/figures/pairwise_ranks.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/pairwise_ranks.pdf
--------------------------------------------------------------------------------
/figures/pairwise_ranks_poster.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/pairwise_ranks_poster.png
--------------------------------------------------------------------------------
/figures/rank_distribution.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/rank_distribution.pdf
--------------------------------------------------------------------------------
/figures/relevance_rank_comparison.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/relevance_rank_comparison.pdf
--------------------------------------------------------------------------------
/figures/relevance_x_faithfulness.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/relevance_x_faithfulness.pdf
--------------------------------------------------------------------------------
/figures/rori-rag-final-5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/rori-rag-final-5.png
--------------------------------------------------------------------------------
/figures/rori-rag-final.drawio:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
--------------------------------------------------------------------------------
/figures/slides_groundedness_v_relevance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/slides_groundedness_v_relevance.png
--------------------------------------------------------------------------------
/figures/slides_human_groundedness.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/slides_human_groundedness.png
--------------------------------------------------------------------------------
/figures/slides_kf1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/slides_kf1.png
--------------------------------------------------------------------------------
/figures/slides_kf1_hist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/slides_kf1_hist.png
--------------------------------------------------------------------------------
/figures/slides_kf1_hist_low.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/slides_kf1_hist_low.png
--------------------------------------------------------------------------------
/figures/slides_kf1_hist_low_high.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/slides_kf1_hist_low_high.png
--------------------------------------------------------------------------------
/figures/slides_kf1_no_ir.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/slides_kf1_no_ir.png
--------------------------------------------------------------------------------
/figures/slides_preference.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/slides_preference.png
--------------------------------------------------------------------------------
/figures/system-diagram-poster.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/system-diagram-poster.png
--------------------------------------------------------------------------------
/notebooks/MathNationDataExploration.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "f05760e9-f206-4fc7-b515-cf8516fb05bc",
6 | "metadata": {},
7 | "source": [
8 | "MathNation Data Exploration\n",
9 | "===\n",
10 | "Exploring a sample of anonymized MathNation data."
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 1,
16 | "id": "6cb45b53-f77e-48e3-a197-d02df2349542",
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "from datetime import datetime\n",
21 | "from pathlib import Path\n",
22 | "\n",
23 | "import matplotlib.pyplot as plt\n",
24 | "import numpy as np\n",
25 | "import pandas as pd"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": 2,
31 | "id": "8378ffc6-9752-4c30-b5dd-6cb60c26561a",
32 | "metadata": {},
33 | "outputs": [],
34 | "source": [
35 | "data_dir = Path(\"../data\")\n",
36 | "assert data_dir.exists()"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": 4,
42 | "id": "350228a7-5f41-413b-917b-a37ffda3ef0c",
43 | "metadata": {},
44 | "outputs": [
45 | {
46 | "data": {
47 | "text/plain": [
48 | "(152844, 9)"
49 | ]
50 | },
51 | "execution_count": 4,
52 | "metadata": {},
53 | "output_type": "execute_result"
54 | }
55 | ],
56 | "source": [
57 | "mn_discussion_filepath = data_dir / \"raw\" / \"math_nation\" / \"mn_discussion_20230914.csv\"\n",
58 | "assert mn_discussion_filepath.exists()\n",
59 | "mn_df = pd.read_csv(mn_discussion_filepath)\n",
60 | "mn_df.shape"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": 5,
66 | "id": "44dfb03d-902a-490a-ac85-9840d9f11dad",
67 | "metadata": {},
68 | "outputs": [
69 | {
70 | "data": {
71 | "text/html": [
72 | "\n",
73 | "\n",
86 | "
\n",
87 | " \n",
88 | " \n",
89 | " \n",
90 | " reply_id \n",
91 | " reply_user \n",
92 | " reply_content \n",
93 | " reply_ts_created \n",
94 | " post_id \n",
95 | " subject_name \n",
96 | " post_content \n",
97 | " post_user \n",
98 | " post_ts_created \n",
99 | " \n",
100 | " \n",
101 | " \n",
102 | " \n",
103 | " 89529 \n",
104 | " 2468990 \n",
105 | " 1412702 \n",
106 | " Examples \n",
107 | " 2018-03-13 17:18:45 \n",
108 | " 2468982 \n",
109 | " Algebra 1 \n",
110 | " degree and terms of a polynomial?\\n \n",
111 | " 7395540 \n",
112 | " 2018-03-13 17:15:02 \n",
113 | " \n",
114 | " \n",
115 | " 145191 \n",
116 | " 974764 \n",
117 | " 667659 \n",
118 | " MAFS Section 1 video 6 \n",
119 | " 2015-10-13 21:25:03 \n",
120 | " 974748 \n",
121 | " Algebra 1 \n",
122 | " Is there any videos on multiplying and dividin... \n",
123 | " 2252406 \n",
124 | " 2015-10-13 21:23:07 \n",
125 | " \n",
126 | " \n",
127 | " 62185 \n",
128 | " 1172038 \n",
129 | " 681091 \n",
130 | " Jack, lets see! Lets multiply those to see if ... \n",
131 | " 2016-01-30 21:49:40 \n",
132 | " 1172013 \n",
133 | " Algebra 1 \n",
134 | " Help! solve each equation and check for soluti... \n",
135 | " 3584741 \n",
136 | " 2016-01-30 21:40:17 \n",
137 | " \n",
138 | " \n",
139 | "
\n",
140 | "
"
141 | ],
142 | "text/plain": [
143 | " reply_id reply_user \\\n",
144 | "89529 2468990 1412702 \n",
145 | "145191 974764 667659 \n",
146 | "62185 1172038 681091 \n",
147 | "\n",
148 | " reply_content \\\n",
149 | "89529 Examples \n",
150 | "145191 MAFS Section 1 video 6 \n",
151 | "62185 Jack, lets see! Lets multiply those to see if ... \n",
152 | "\n",
153 | " reply_ts_created post_id subject_name \\\n",
154 | "89529 2018-03-13 17:18:45 2468982 Algebra 1 \n",
155 | "145191 2015-10-13 21:25:03 974748 Algebra 1 \n",
156 | "62185 2016-01-30 21:49:40 1172013 Algebra 1 \n",
157 | "\n",
158 | " post_content post_user \\\n",
159 | "89529 degree and terms of a polynomial?\\n 7395540 \n",
160 | "145191 Is there any videos on multiplying and dividin... 2252406 \n",
161 | "62185 Help! solve each equation and check for soluti... 3584741 \n",
162 | "\n",
163 | " post_ts_created \n",
164 | "89529 2018-03-13 17:15:02 \n",
165 | "145191 2015-10-13 21:23:07 \n",
166 | "62185 2016-01-30 21:40:17 "
167 | ]
168 | },
169 | "execution_count": 5,
170 | "metadata": {},
171 | "output_type": "execute_result"
172 | }
173 | ],
174 | "source": [
175 | "mn_df.sample(3)"
176 | ]
177 | },
178 | {
179 | "cell_type": "code",
180 | "execution_count": 7,
181 | "id": "f6e3d1b6-f26a-4f17-9b6b-1182d8b9fbe7",
182 | "metadata": {},
183 | "outputs": [
184 | {
185 | "data": {
186 | "text/plain": [
187 | "subject_name\n",
188 | "Algebra 1 15943\n",
189 | "Geometry 133\n",
190 | "Pre-Algebra 24\n",
191 | "8th Grade Math 1\n",
192 | "6th Grade Math 1\n",
193 | "Name: count, dtype: int64"
194 | ]
195 | },
196 | "execution_count": 7,
197 | "metadata": {},
198 | "output_type": "execute_result"
199 | }
200 | ],
201 | "source": [
202 | "mn_df.groupby(\"post_id\").subject_name.head(1).value_counts()"
203 | ]
204 | },
205 | {
206 | "cell_type": "code",
207 | "execution_count": 8,
208 | "id": "84d289fe-8a65-47f3-ad6d-bafc6e79c923",
209 | "metadata": {},
210 | "outputs": [],
211 | "source": [
212 | "# for row in mn_df[mn_df.subject_name == \"Pre-Algebra\"].drop_duplicates(subset=\"post_id\", keep=\"first\").itertuples():\n",
213 | "# print(row.post_content)"
214 | ]
215 | },
216 | {
217 | "cell_type": "code",
218 | "execution_count": 9,
219 | "id": "00686dc9-36eb-434d-b6e6-4526ff366002",
220 | "metadata": {},
221 | "outputs": [],
222 | "source": [
223 | "# manually extracted student queries that are relevant for testing\n",
224 | "prealgebra_student_queries = [\n",
225 | " \"What is the quotient rule??\",\n",
226 | " \"How do I multiply fractions???????\",\n",
227 | " \"How do you multiply fractions?!?!?\",\n",
228 | "]"
229 | ]
230 | },
231 | {
232 | "cell_type": "code",
233 | "execution_count": 9,
234 | "id": "a0b194a0-9c06-49f4-8ba6-e97fc246a181",
235 | "metadata": {},
236 | "outputs": [
237 | {
238 | "data": {
239 | "text/plain": [
240 | "0.08228791454477705"
241 | ]
242 | },
243 | "execution_count": 9,
244 | "metadata": {},
245 | "output_type": "execute_result"
246 | }
247 | ],
248 | "source": [
249 | "first_reply_df = mn_df.sort_values(by=[\"post_id\", \"reply_ts_created\"]).drop_duplicates(subset=\"post_id\", keep=\"first\")\n",
250 | "is_extended_post = first_reply_df.reply_user == first_reply_df.post_user\n",
251 | "is_extended_post.sum() / len(first_reply_df)"
252 | ]
253 | },
254 | {
255 | "cell_type": "code",
256 | "execution_count": 10,
257 | "id": "6c1a08d9-0501-42c1-b19d-e088cf0c9dba",
258 | "metadata": {},
259 | "outputs": [
260 | {
261 | "data": {
262 | "text/html": [
263 | "\n",
264 | "\n",
277 | "
\n",
278 | " \n",
279 | " \n",
280 | " \n",
281 | " reply_id \n",
282 | " reply_user \n",
283 | " reply_content \n",
284 | " reply_ts_created \n",
285 | " post_id \n",
286 | " subject_name \n",
287 | " post_content \n",
288 | " post_user \n",
289 | " post_ts_created \n",
290 | " \n",
291 | " \n",
292 | " \n",
293 | " \n",
294 | " 84643 \n",
295 | " 1867 \n",
296 | " 945440 \n",
297 | " someone please help\\n \n",
298 | " 2013-10-22 23:35:31 \n",
299 | " 1865 \n",
300 | " Algebra 1 \n",
301 | " NaN \n",
302 | " 945440 \n",
303 | " 2013-10-22 23:35:14 \n",
304 | " \n",
305 | " \n",
306 | " 2633 \n",
307 | " 2065 \n",
308 | " 939329 \n",
309 | " Can somebody help me please \n",
310 | " 2013-10-23 01:08:46 \n",
311 | " 2045 \n",
312 | " Algebra 1 \n",
313 | " If 9a+6b+8c=−3 ,\\n\\nwhat is 54a+48c+36b? \n",
314 | " 939329 \n",
315 | " 2013-10-23 01:03:49 \n",
316 | " \n",
317 | " \n",
318 | " 50640 \n",
319 | " 3277 \n",
320 | " 572466 \n",
321 | " need help \n",
322 | " 2013-10-26 17:33:16 \n",
323 | " 3276 \n",
324 | " Algebra 1 \n",
325 | " \\ny + 2x = −1\\n3y − x = \n",
326 | " 572466 \n",
327 | " 2013-10-26 17:32:44 \n",
328 | " \n",
329 | " \n",
330 | " 37363 \n",
331 | " 3487 \n",
332 | " 1032512 \n",
333 | " Any takers? \n",
334 | " 2013-10-28 01:49:26 \n",
335 | " 3485 \n",
336 | " Algebra 1 \n",
337 | " Challenge problem! Suppose the polynomial:\\n\\n... \n",
338 | " 1032512 \n",
339 | " 2013-10-28 01:27:00 \n",
340 | " \n",
341 | " \n",
342 | " 6423 \n",
343 | " 3733 \n",
344 | " 524723 \n",
345 | " how would you solve this?\\n \n",
346 | " 2013-10-28 22:52:32 \n",
347 | " 3732 \n",
348 | " Algebra 1 \n",
349 | " NaN \n",
350 | " 524723 \n",
351 | " 2013-10-28 22:52:14 \n",
352 | " \n",
353 | " \n",
354 | " ... \n",
355 | " ... \n",
356 | " ... \n",
357 | " ... \n",
358 | " ... \n",
359 | " ... \n",
360 | " ... \n",
361 | " ... \n",
362 | " ... \n",
363 | " ... \n",
364 | " \n",
365 | " \n",
366 | " 1252 \n",
367 | " 3305954 \n",
368 | " 4599973 \n",
369 | " *part \n",
370 | " 2021-09-11 21:17:51 \n",
371 | " 3305953 \n",
372 | " Geometry \n",
373 | " May someone please help me with let b and c \n",
374 | " 4599973 \n",
375 | " 2021-09-11 21:17:40 \n",
376 | " \n",
377 | " \n",
378 | " 26794 \n",
379 | " 3306558 \n",
380 | " 5197522 \n",
381 | " i think it is but i dont now\\n \n",
382 | " 2021-09-13 20:59:33 \n",
383 | " 3306554 \n",
384 | " Algebra 1 \n",
385 | " If the question has a<25 is a<=25 the sa... \n",
386 | " 5197522 \n",
387 | " 2021-09-13 20:58:48 \n",
388 | " \n",
389 | " \n",
390 | " 74748 \n",
391 | " 3308394 \n",
392 | " 5223632 \n",
393 | " here is the paper \n",
394 | " 2021-09-15 23:05:46 \n",
395 | " 3308391 \n",
396 | " Algebra 1 \n",
397 | " for the first box I got x>0 and x<6 \n",
398 | " 5223632 \n",
399 | " 2021-09-15 23:05:04 \n",
400 | " \n",
401 | " \n",
402 | " 36997 \n",
403 | " 3312002 \n",
404 | " 4519295 \n",
405 | " anybody? \n",
406 | " 2021-09-21 23:51:27 \n",
407 | " 3311997 \n",
408 | " Algebra 1 \n",
409 | " help \n",
410 | " 4519295 \n",
411 | " 2021-09-21 23:49:09 \n",
412 | " \n",
413 | " \n",
414 | " 41571 \n",
415 | " 3320636 \n",
416 | " 4807952 \n",
417 | " How do I do this, \n",
418 | " 2021-10-12 23:10:24 \n",
419 | " 3320632 \n",
420 | " Algebra 1 \n",
421 | " Can somebody help me? \n",
422 | " 4807952 \n",
423 | " 2021-10-12 23:08:45 \n",
424 | " \n",
425 | " \n",
426 | "
\n",
427 | "
1325 rows × 9 columns
\n",
428 | "
"
429 | ],
430 | "text/plain": [
431 | " reply_id reply_user reply_content \\\n",
432 | "84643 1867 945440 someone please help\\n \n",
433 | "2633 2065 939329 Can somebody help me please \n",
434 | "50640 3277 572466 need help \n",
435 | "37363 3487 1032512 Any takers? \n",
436 | "6423 3733 524723 how would you solve this?\\n \n",
437 | "... ... ... ... \n",
438 | "1252 3305954 4599973 *part \n",
439 | "26794 3306558 5197522 i think it is but i dont now\\n \n",
440 | "74748 3308394 5223632 here is the paper \n",
441 | "36997 3312002 4519295 anybody? \n",
442 | "41571 3320636 4807952 How do I do this, \n",
443 | "\n",
444 | " reply_ts_created post_id subject_name \\\n",
445 | "84643 2013-10-22 23:35:31 1865 Algebra 1 \n",
446 | "2633 2013-10-23 01:08:46 2045 Algebra 1 \n",
447 | "50640 2013-10-26 17:33:16 3276 Algebra 1 \n",
448 | "37363 2013-10-28 01:49:26 3485 Algebra 1 \n",
449 | "6423 2013-10-28 22:52:32 3732 Algebra 1 \n",
450 | "... ... ... ... \n",
451 | "1252 2021-09-11 21:17:51 3305953 Geometry \n",
452 | "26794 2021-09-13 20:59:33 3306554 Algebra 1 \n",
453 | "74748 2021-09-15 23:05:46 3308391 Algebra 1 \n",
454 | "36997 2021-09-21 23:51:27 3311997 Algebra 1 \n",
455 | "41571 2021-10-12 23:10:24 3320632 Algebra 1 \n",
456 | "\n",
457 | " post_content post_user \\\n",
458 | "84643 NaN 945440 \n",
459 | "2633 If 9a+6b+8c=−3 ,\\n\\nwhat is 54a+48c+36b? 939329 \n",
460 | "50640 \\ny + 2x = −1\\n3y − x = 572466 \n",
461 | "37363 Challenge problem! Suppose the polynomial:\\n\\n... 1032512 \n",
462 | "6423 NaN 524723 \n",
463 | "... ... ... \n",
464 | "1252 May someone please help me with let b and c 4599973 \n",
465 | "26794 If the question has a<25 is a<=25 the sa... 5197522 \n",
466 | "74748 for the first box I got x>0 and x<6 5223632 \n",
467 | "36997 help 4519295 \n",
468 | "41571 Can somebody help me? 4807952 \n",
469 | "\n",
470 | " post_ts_created \n",
471 | "84643 2013-10-22 23:35:14 \n",
472 | "2633 2013-10-23 01:03:49 \n",
473 | "50640 2013-10-26 17:32:44 \n",
474 | "37363 2013-10-28 01:27:00 \n",
475 | "6423 2013-10-28 22:52:14 \n",
476 | "... ... \n",
477 | "1252 2021-09-11 21:17:40 \n",
478 | "26794 2021-09-13 20:58:48 \n",
479 | "74748 2021-09-15 23:05:04 \n",
480 | "36997 2021-09-21 23:49:09 \n",
481 | "41571 2021-10-12 23:08:45 \n",
482 | "\n",
483 | "[1325 rows x 9 columns]"
484 | ]
485 | },
486 | "execution_count": 10,
487 | "metadata": {},
488 | "output_type": "execute_result"
489 | }
490 | ],
491 | "source": [
492 | "first_reply_df[is_extended_post]"
493 | ]
494 | },
495 | {
496 | "cell_type": "code",
497 | "execution_count": 11,
498 | "id": "abc41f40-32e4-43ef-8bf7-6d8c05c64e7c",
499 | "metadata": {},
500 | "outputs": [
501 | {
502 | "data": {
503 | "text/plain": [
504 | "16102"
505 | ]
506 | },
507 | "execution_count": 11,
508 | "metadata": {},
509 | "output_type": "execute_result"
510 | }
511 | ],
512 | "source": [
513 | "posts = []\n",
514 | "for post_id, group in mn_df.sort_values(by=[\"post_id\", \"reply_ts_created\"]).groupby(\"post_id\"):\n",
515 | " post_content = str(group.iloc[0].post_content)\n",
516 | " post_user = group.iloc[0].post_user\n",
517 | " for row in group.itertuples():\n",
518 | " if row.reply_user != post_user:\n",
519 | " break # reply from non-OP\n",
520 | " else:\n",
521 | " # continuation of the post in a reply\n",
522 | " post_content += \"\\n[Continued:] \" + str(row.reply_content)\n",
523 | " prev_row = row\n",
524 | " posts.append(\n",
525 | " {\n",
526 | " \"post_id\": post_id,\n",
527 | " \"post_user\": post_user,\n",
528 | " \"subject_name\": group.iloc[0].subject_name,\n",
529 | " \"post_ts_created\": group.iloc[0].post_ts_created,\n",
530 | " \"post_content\": post_content,\n",
531 | " }\n",
532 | " )\n",
533 | "len(posts)"
534 | ]
535 | },
536 | {
537 | "cell_type": "code",
538 | "execution_count": 12,
539 | "id": "77538bb6-0129-40cd-bf40-8c1280ca81d4",
540 | "metadata": {},
541 | "outputs": [
542 | {
543 | "data": {
544 | "text/html": [
545 | "\n",
546 | "\n",
559 | "
\n",
560 | " \n",
561 | " \n",
562 | " \n",
563 | " post_id \n",
564 | " post_user \n",
565 | " subject_name \n",
566 | " post_ts_created \n",
567 | " post_content \n",
568 | " \n",
569 | " \n",
570 | " \n",
571 | " \n",
572 | " 3777 \n",
573 | " 706733 \n",
574 | " 1218779 \n",
575 | " Algebra 1 \n",
576 | " 2015-04-09 21:51:42 \n",
577 | " what is 5349 to the third power equal???? HELP... \n",
578 | " \n",
579 | " \n",
580 | " 11824 \n",
581 | " 2261468 \n",
582 | " 3440812 \n",
583 | " Algebra 1 \n",
584 | " 2017-10-31 13:47:47 \n",
585 | " =| =] =) =} ^o^ ^0^ ^@^ ^u^ all mojys are not ... \n",
586 | " \n",
587 | " \n",
588 | " 3957 \n",
589 | " 739247 \n",
590 | " 2647012 \n",
591 | " Algebra 1 \n",
592 | " 2015-04-18 15:51:10 \n",
593 | " Melissa who's after ya? \n",
594 | " \n",
595 | " \n",
596 | "
\n",
597 | "
"
598 | ],
599 | "text/plain": [
600 | " post_id post_user subject_name post_ts_created \\\n",
601 | "3777 706733 1218779 Algebra 1 2015-04-09 21:51:42 \n",
602 | "11824 2261468 3440812 Algebra 1 2017-10-31 13:47:47 \n",
603 | "3957 739247 2647012 Algebra 1 2015-04-18 15:51:10 \n",
604 | "\n",
605 | " post_content \n",
606 | "3777 what is 5349 to the third power equal???? HELP... \n",
607 | "11824 =| =] =) =} ^o^ ^0^ ^@^ ^u^ all mojys are not ... \n",
608 | "3957 Melissa who's after ya? "
609 | ]
610 | },
611 | "execution_count": 12,
612 | "metadata": {},
613 | "output_type": "execute_result"
614 | }
615 | ],
616 | "source": [
617 | "pd.DataFrame(posts).sample(3)"
618 | ]
619 | },
620 | {
621 | "cell_type": "code",
622 | "execution_count": 44,
623 | "id": "4be3b675-fec3-4d09-b1ce-d1c2fba0c378",
624 | "metadata": {},
625 | "outputs": [],
626 | "source": [
627 | "pd.DataFrame(posts).to_csv(data_dir / \"derived\" / \"mn_student_queries_raw3.csv\")"
628 | ]
629 | },
630 | {
631 | "cell_type": "markdown",
632 | "id": "1431a6e0-477c-402e-938b-9d8083e201fc",
633 | "metadata": {},
634 | "source": [
635 | "Original sample used for annotation was this:\n",
636 | "\n",
637 | "```\n",
638 | "sdf = mn_df[mn_df.subject_name.isin([\"Algebra 1\", \"Geometry\"])]\n",
639 | "sdf = sdf.drop_duplicates(subset=\"post_id\", keep=\"first\")\n",
640 | "sdf[[\"post_id\", \"subject_name\", \"post_content\", \"post_user\", \"post_ts_created\"]].to_csv(data_dir / \"derived\" / \"mn_student_queries_raw.csv\")\n",
641 | "```\n",
642 | "\n",
643 | "But I threw those annotations away."
644 | ]
645 | },
646 | {
647 | "cell_type": "markdown",
648 | "id": "f4c123a5-2c44-4935-b218-dcc4e50cb509",
649 | "metadata": {},
650 | "source": [
651 | "### Loading annotated data"
652 | ]
653 | },
654 | {
655 | "cell_type": "code",
656 | "execution_count": 46,
657 | "id": "27bbdd84-b403-4620-baa5-48e8e0f5122a",
658 | "metadata": {},
659 | "outputs": [
660 | {
661 | "data": {
662 | "text/plain": [
663 | "(553, 8)"
664 | ]
665 | },
666 | "execution_count": 46,
667 | "metadata": {},
668 | "output_type": "execute_result"
669 | }
670 | ],
671 | "source": [
672 | "adf = pd.read_csv(data_dir / \"derived\" / \"mn_student_queries_raw2_annotated.csv\")\n",
673 | "adf.shape"
674 | ]
675 | },
676 | {
677 | "cell_type": "code",
678 | "execution_count": 47,
679 | "id": "ec091d67-9c49-4880-96cd-510e5f665f6b",
680 | "metadata": {},
681 | "outputs": [
682 | {
683 | "data": {
684 | "text/html": [
685 | "\n",
686 | "\n",
699 | "
\n",
700 | " \n",
701 | " \n",
702 | " \n",
703 | " index \n",
704 | " post_id \n",
705 | " post_user \n",
706 | " subject_name \n",
707 | " post_ts_created \n",
708 | " post_content \n",
709 | " is_respondable_query \n",
710 | " notes \n",
711 | " \n",
712 | " \n",
713 | " \n",
714 | " \n",
715 | " 2 \n",
716 | " 2 \n",
717 | " 19856 \n",
718 | " 848987 \n",
719 | " Algebra 1 \n",
720 | " 2013-11-16 23:25:43 \n",
721 | " how do you a graph function rule?! \n",
722 | " general \n",
723 | " NaN \n",
724 | " \n",
725 | " \n",
726 | " 462 \n",
727 | " 462 \n",
728 | " 3119582 \n",
729 | " 3073615 \n",
730 | " Geometry \n",
731 | " 2020-03-31 19:38:40 \n",
732 | " is anybody there\\n \n",
733 | " NaN \n",
734 | " NaN \n",
735 | " \n",
736 | " \n",
737 | " 405 \n",
738 | " 405 \n",
739 | " 2966682 \n",
740 | " 8261996 \n",
741 | " Geometry \n",
742 | " 2019-09-11 21:16:05 \n",
743 | " Wouldn't it be the same thing?? \n",
744 | " NaN \n",
745 | " NaN \n",
746 | " \n",
747 | " \n",
748 | "
\n",
749 | "
"
750 | ],
751 | "text/plain": [
752 | " index post_id post_user subject_name post_ts_created \\\n",
753 | "2 2 19856 848987 Algebra 1 2013-11-16 23:25:43 \n",
754 | "462 462 3119582 3073615 Geometry 2020-03-31 19:38:40 \n",
755 | "405 405 2966682 8261996 Geometry 2019-09-11 21:16:05 \n",
756 | "\n",
757 | " post_content is_respondable_query notes \n",
758 | "2 how do you a graph function rule?! general NaN \n",
759 | "462 is anybody there\\n NaN NaN \n",
760 | "405 Wouldn't it be the same thing?? NaN NaN "
761 | ]
762 | },
763 | "execution_count": 47,
764 | "metadata": {},
765 | "output_type": "execute_result"
766 | }
767 | ],
768 | "source": [
769 | "adf.sample(n=3)"
770 | ]
771 | },
772 | {
773 | "cell_type": "code",
774 | "execution_count": 48,
775 | "id": "787a7ee0-3f39-4de1-af8d-b1ce0d4fb7eb",
776 | "metadata": {},
777 | "outputs": [
778 | {
779 | "data": {
780 | "text/plain": [
781 | "is_respondable_query\n",
782 | "problem 62\n",
783 | "general 52\n",
784 | "confirm 8\n",
785 | "resource request 4\n",
786 | "general,problem 3\n",
787 | "advice 2\n",
788 | "wrong but don't know why 1\n",
789 | "clarify question 1\n",
790 | "stuck 1\n",
791 | "Name: count, dtype: int64"
792 | ]
793 | },
794 | "execution_count": 48,
795 | "metadata": {},
796 | "output_type": "execute_result"
797 | }
798 | ],
799 | "source": [
800 | "adf.is_respondable_query.value_counts()"
801 | ]
802 | },
803 | {
804 | "cell_type": "code",
805 | "execution_count": 53,
806 | "id": "495fb9fc-46d1-4f18-bf88-9729b26bafb7",
807 | "metadata": {},
808 | "outputs": [
809 | {
810 | "data": {
811 | "text/html": [
812 | "\n",
813 | "\n",
826 | "
\n",
827 | " \n",
828 | " \n",
829 | " subject_name \n",
830 | " 6th Grade Math \n",
831 | " Algebra 1 \n",
832 | " Geometry \n",
833 | " Pre-Algebra \n",
834 | " All \n",
835 | " \n",
836 | " \n",
837 | " is_respondable_query \n",
838 | " \n",
839 | " \n",
840 | " \n",
841 | " \n",
842 | " \n",
843 | " \n",
844 | " \n",
845 | " \n",
846 | " \n",
847 | " All \n",
848 | " 1 \n",
849 | " 111 \n",
850 | " 19 \n",
851 | " 3 \n",
852 | " 134 \n",
853 | " \n",
854 | " \n",
855 | " problem \n",
856 | " 1 \n",
857 | " 54 \n",
858 | " 7 \n",
859 | " 0 \n",
860 | " 62 \n",
861 | " \n",
862 | " \n",
863 | " general \n",
864 | " 0 \n",
865 | " 40 \n",
866 | " 9 \n",
867 | " 3 \n",
868 | " 52 \n",
869 | " \n",
870 | " \n",
871 | " confirm \n",
872 | " 0 \n",
873 | " 6 \n",
874 | " 2 \n",
875 | " 0 \n",
876 | " 8 \n",
877 | " \n",
878 | " \n",
879 | " resource request \n",
880 | " 0 \n",
881 | " 3 \n",
882 | " 1 \n",
883 | " 0 \n",
884 | " 4 \n",
885 | " \n",
886 | " \n",
887 | " general,problem \n",
888 | " 0 \n",
889 | " 3 \n",
890 | " 0 \n",
891 | " 0 \n",
892 | " 3 \n",
893 | " \n",
894 | " \n",
895 | " advice \n",
896 | " 0 \n",
897 | " 2 \n",
898 | " 0 \n",
899 | " 0 \n",
900 | " 2 \n",
901 | " \n",
902 | " \n",
903 | " clarify question \n",
904 | " 0 \n",
905 | " 1 \n",
906 | " 0 \n",
907 | " 0 \n",
908 | " 1 \n",
909 | " \n",
910 | " \n",
911 | " stuck \n",
912 | " 0 \n",
913 | " 1 \n",
914 | " 0 \n",
915 | " 0 \n",
916 | " 1 \n",
917 | " \n",
918 | " \n",
919 | " wrong but don't know why \n",
920 | " 0 \n",
921 | " 1 \n",
922 | " 0 \n",
923 | " 0 \n",
924 | " 1 \n",
925 | " \n",
926 | " \n",
927 | "
\n",
928 | "
"
929 | ],
930 | "text/plain": [
931 | "subject_name 6th Grade Math Algebra 1 Geometry Pre-Algebra \\\n",
932 | "is_respondable_query \n",
933 | "All 1 111 19 3 \n",
934 | "problem 1 54 7 0 \n",
935 | "general 0 40 9 3 \n",
936 | "confirm 0 6 2 0 \n",
937 | "resource request 0 3 1 0 \n",
938 | "general,problem 0 3 0 0 \n",
939 | "advice 0 2 0 0 \n",
940 | "clarify question 0 1 0 0 \n",
941 | "stuck 0 1 0 0 \n",
942 | "wrong but don't know why 0 1 0 0 \n",
943 | "\n",
944 | "subject_name All \n",
945 | "is_respondable_query \n",
946 | "All 134 \n",
947 | "problem 62 \n",
948 | "general 52 \n",
949 | "confirm 8 \n",
950 | "resource request 4 \n",
951 | "general,problem 3 \n",
952 | "advice 2 \n",
953 | "clarify question 1 \n",
954 | "stuck 1 \n",
955 | "wrong but don't know why 1 "
956 | ]
957 | },
958 | "execution_count": 53,
959 | "metadata": {},
960 | "output_type": "execute_result"
961 | }
962 | ],
963 | "source": [
964 | "pd.crosstab(adf.subject_name, adf.is_respondable_query, margins=True).T.sort_values(by=\"All\", ascending=False)"
965 | ]
966 | },
967 | {
968 | "cell_type": "code",
969 | "execution_count": 57,
970 | "id": "8c65b4a6-42a5-42d5-ab88-19a74edf0596",
971 | "metadata": {},
972 | "outputs": [
973 | {
974 | "data": {
975 | "text/html": [
976 | "\n",
977 | "\n",
990 | "
\n",
991 | " \n",
992 | " \n",
993 | " \n",
994 | " post_id \n",
995 | " subject_name \n",
996 | " post_content \n",
997 | " is_respondable_query \n",
998 | " \n",
999 | " \n",
1000 | " \n",
1001 | " \n",
1002 | " 29 \n",
1003 | " 199138 \n",
1004 | " Algebra 1 \n",
1005 | " Can someone give me a factoring example I don... \n",
1006 | " general \n",
1007 | " \n",
1008 | " \n",
1009 | " 138 \n",
1010 | " 1105388 \n",
1011 | " Pre-Algebra \n",
1012 | " Hey guys! What is the quotient rule?? \n",
1013 | " general \n",
1014 | " \n",
1015 | " \n",
1016 | " 110 \n",
1017 | " 859909 \n",
1018 | " Algebra 1 \n",
1019 | " How do you find the domain and range that are ... \n",
1020 | " general \n",
1021 | " \n",
1022 | " \n",
1023 | "
\n",
1024 | "
"
1025 | ],
1026 | "text/plain": [
1027 | " post_id subject_name post_content \\\n",
1028 | "29 199138 Algebra 1 Can someone give me a factoring example I don... \n",
1029 | "138 1105388 Pre-Algebra Hey guys! What is the quotient rule?? \n",
1030 | "110 859909 Algebra 1 How do you find the domain and range that are ... \n",
1031 | "\n",
1032 | " is_respondable_query \n",
1033 | "29 general \n",
1034 | "138 general \n",
1035 | "110 general "
1036 | ]
1037 | },
1038 | "execution_count": 57,
1039 | "metadata": {},
1040 | "output_type": "execute_result"
1041 | }
1042 | ],
1043 | "source": [
1044 | "general_df = adf[adf.is_respondable_query.map(lambda s: pd.notna(s) and \"general\" in s)][\n",
1045 | " [\"post_id\", \"subject_name\", \"post_content\", \"is_respondable_query\"]\n",
1046 | "]\n",
1047 | "general_df.sample(n=3)"
1048 | ]
1049 | },
1050 | {
1051 | "cell_type": "code",
1052 | "execution_count": 58,
1053 | "id": "b137f401-9593-4c2e-9f7a-4aff448bb077",
1054 | "metadata": {},
1055 | "outputs": [],
1056 | "source": [
1057 | "general_df.to_csv(data_dir / \"derived\" / \"mn_general_student_queries.csv\", index=False)"
1058 | ]
1059 | },
1060 | {
1061 | "cell_type": "code",
1062 | "execution_count": 62,
1063 | "id": "cf10ce6d-1c63-4792-8f42-9dda5a0e0f9c",
1064 | "metadata": {},
1065 | "outputs": [
1066 | {
1067 | "name": "stdout",
1068 | "output_type": "stream",
1069 | "text": [
1070 | "MathNation queries: (55, 4)\n"
1071 | ]
1072 | },
1073 | {
1074 | "data": {
1075 | "text/html": [
1076 | "\n",
1077 | "\n",
1090 | "
\n",
1091 | " \n",
1092 | " \n",
1093 | " \n",
1094 | " post_id \n",
1095 | " subject_name \n",
1096 | " post_content \n",
1097 | " is_respondable_query \n",
1098 | " \n",
1099 | " \n",
1100 | " \n",
1101 | " \n",
1102 | " 38 \n",
1103 | " 2672704 \n",
1104 | " Algebra 1 \n",
1105 | " i have a problem is this equation linear?\\n7x ... \n",
1106 | " general,problem \n",
1107 | " \n",
1108 | " \n",
1109 | " 41 \n",
1110 | " 2690581 \n",
1111 | " Algebra 1 \n",
1112 | " How do you know if a number is a constant? \n",
1113 | " general \n",
1114 | " \n",
1115 | " \n",
1116 | " 20 \n",
1117 | " 1790360 \n",
1118 | " Algebra 1 \n",
1119 | " What is vertex form and how do you solve for it? \n",
1120 | " general \n",
1121 | " \n",
1122 | " \n",
1123 | "
\n",
1124 | "
"
1125 | ],
1126 | "text/plain": [
1127 | " post_id subject_name post_content \\\n",
1128 | "38 2672704 Algebra 1 i have a problem is this equation linear?\\n7x ... \n",
1129 | "41 2690581 Algebra 1 How do you know if a number is a constant? \n",
1130 | "20 1790360 Algebra 1 What is vertex form and how do you solve for it? \n",
1131 | "\n",
1132 | " is_respondable_query \n",
1133 | "38 general,problem \n",
1134 | "41 general \n",
1135 | "20 general "
1136 | ]
1137 | },
1138 | "execution_count": 62,
1139 | "metadata": {},
1140 | "output_type": "execute_result"
1141 | }
1142 | ],
1143 | "source": [
1144 | "# load the mathnation query data\n",
1145 | "mn_general_student_queries_filepath = data_dir / \"derived\" / \"mn_general_student_queries.csv\"\n",
1146 | "query_df = pd.read_csv(mn_general_student_queries_filepath)\n",
1147 | "print(f\"MathNation queries: {query_df.shape}\")\n",
1148 | "query_df.sample(n=3)"
1149 | ]
1150 | },
1151 | {
1152 | "cell_type": "code",
1153 | "execution_count": null,
1154 | "id": "05c868c2-6288-4173-afb4-1dabe9cd3e4b",
1155 | "metadata": {},
1156 | "outputs": [],
1157 | "source": []
1158 | },
1159 | {
1160 | "cell_type": "code",
1161 | "execution_count": null,
1162 | "id": "87f3cab2-cdc8-49e3-acb5-378dc32857bc",
1163 | "metadata": {},
1164 | "outputs": [],
1165 | "source": []
1166 | },
1167 | {
1168 | "cell_type": "code",
1169 | "execution_count": null,
1170 | "id": "c3abf13e-9498-46c3-b16b-0ac69b411ed0",
1171 | "metadata": {},
1172 | "outputs": [],
1173 | "source": []
1174 | }
1175 | ],
1176 | "metadata": {
1177 | "kernelspec": {
1178 | "display_name": "llm-math-education",
1179 | "language": "python",
1180 | "name": "llm-math-education"
1181 | },
1182 | "language_info": {
1183 | "codemirror_mode": {
1184 | "name": "ipython",
1185 | "version": 3
1186 | },
1187 | "file_extension": ".py",
1188 | "mimetype": "text/x-python",
1189 | "name": "python",
1190 | "nbconvert_exporter": "python",
1191 | "pygments_lexer": "ipython3",
1192 | "version": "3.10.11"
1193 | }
1194 | },
1195 | "nbformat": 4,
1196 | "nbformat_minor": 5
1197 | }
1198 |
--------------------------------------------------------------------------------
/notebooks/MetricTesting.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "a3b12196-aa39-4532-87ce-3f6f41ec411e",
6 | "metadata": {},
7 | "source": [
8 | "Metric testing\n",
9 | "==="
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 1,
15 | "id": "4379991e-c03a-434d-9daa-314507f635df",
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "import datasets\n",
20 | "import evaluate\n",
21 | "\n",
22 | "from experiment import metrics"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 2,
28 | "id": "47955103-9b48-4bd9-88e3-1a3999c429ab",
29 | "metadata": {},
30 | "outputs": [
31 | {
32 | "name": "stdout",
33 | "output_type": "stream",
34 | "text": [
35 | "INFO:tensorflow:Reading checkpoint /Users/zacharylevonian/.cache/huggingface/datasets/downloads/extracted/0b5f615fbe4df81a585448a4e6f47b4bb3af737cc290a4d96effa6ef1840ea73/bleurt-base-512.\n",
36 | "INFO:tensorflow:Config file found, reading.\n",
37 | "INFO:tensorflow:Will load checkpoint bert_custom\n",
38 | "INFO:tensorflow:Loads full paths and checks that files exists.\n",
39 | "INFO:tensorflow:... name:bert_custom\n",
40 | "INFO:tensorflow:... vocab_file:vocab.txt\n",
41 | "INFO:tensorflow:... bert_config_file:bert_config.json\n",
42 | "INFO:tensorflow:... do_lower_case:True\n",
43 | "INFO:tensorflow:... max_seq_length:512\n",
44 | "INFO:tensorflow:Creating BLEURT scorer.\n",
45 | "INFO:tensorflow:Creating WordPiece tokenizer.\n",
46 | "INFO:tensorflow:WordPiece tokenizer instantiated.\n",
47 | "INFO:tensorflow:Creating Eager Mode predictor.\n",
48 | "INFO:tensorflow:Loading model.\n",
49 | "INFO:tensorflow:BLEURT initialized.\n"
50 | ]
51 | },
52 | {
53 | "name": "stderr",
54 | "output_type": "stream",
55 | "text": [
56 | "INFO:tensorflow:BLEURT initialized.\n"
57 | ]
58 | }
59 | ],
60 | "source": [
61 | "# https://github.com/huggingface/evaluate/issues/428\n",
62 | "bleurt = evaluate.load(\n",
63 | " \"bleurt\", \"bleurt-base-512\", module_type=\"metric\", download_config=datasets.DownloadConfig(use_etag=False)\n",
64 | ")"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": 3,
70 | "id": "6d6faff9-ea7b-4339-b689-251b280ce6a3",
71 | "metadata": {},
72 | "outputs": [
73 | {
74 | "data": {
75 | "text/plain": [
76 | "{'scores': [1.175394892692566, -1.1031553745269775]}"
77 | ]
78 | },
79 | "execution_count": 3,
80 | "metadata": {},
81 | "output_type": "execute_result"
82 | }
83 | ],
84 | "source": [
85 | "predictions = [\"hello there\", \"general skywalker\"]\n",
86 | "references = [\"hello there\", \"general kenobi\"]\n",
87 | "bleurt.compute(predictions=predictions, references=references)"
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": 4,
93 | "id": "f4a07f69-0f92-4056-aee3-a7c98f27d8d5",
94 | "metadata": {},
95 | "outputs": [
96 | {
97 | "data": {
98 | "text/plain": [
99 | "0.4088791608810425"
100 | ]
101 | },
102 | "execution_count": 4,
103 | "metadata": {},
104 | "output_type": "execute_result"
105 | }
106 | ],
107 | "source": [
108 | "def compute_bleurt(passages: list[str], generation: str):\n",
109 | " references = passages + [\n",
110 | " \"\\n\".join(passages),\n",
111 | " ]\n",
112 | " predictions = [\n",
113 | " generation,\n",
114 | " ] * (len(passages) + 1)\n",
115 | " scores = bleurt.compute(predictions=predictions, references=references)[\"scores\"]\n",
116 | " return max(scores)\n",
117 | "\n",
118 | "\n",
119 | "compute_bleurt([\"The alphabet is 26 letters long.\", \"Math is not so easy.\"], \"The English alphabet is 26 letters long.\")"
120 | ]
121 | },
122 | {
123 | "cell_type": "code",
124 | "execution_count": null,
125 | "id": "a4daae3f-e042-4f94-a4b0-c9e93aa876dd",
126 | "metadata": {},
127 | "outputs": [],
128 | "source": []
129 | },
130 | {
131 | "cell_type": "code",
132 | "execution_count": 5,
133 | "id": "e19b8934-8eb7-43c2-ab12-8725de26d0fb",
134 | "metadata": {},
135 | "outputs": [],
136 | "source": [
137 | "bert_score = metrics.get_bertscore_metric_object()"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": 7,
143 | "id": "6b9bd28b-60d3-4e9d-919a-ce8604baef7e",
144 | "metadata": {},
145 | "outputs": [
146 | {
147 | "name": "stderr",
148 | "output_type": "stream",
149 | "text": [
150 | "Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']\n",
151 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
152 | ]
153 | },
154 | {
155 | "data": {
156 | "text/plain": [
157 | "{'precision': [0.9999998807907104, 0.9180971384048462],\n",
158 | " 'recall': [0.9999998807907104, 0.8901697397232056],\n",
159 | " 'f1': [0.9999998807907104, 0.9039177894592285],\n",
160 | " 'hashcode': 'roberta-large_L17_no-idf_version=0.3.12(hug_trans=4.33.1)'}"
161 | ]
162 | },
163 | "execution_count": 7,
164 | "metadata": {},
165 | "output_type": "execute_result"
166 | }
167 | ],
168 | "source": [
169 | "predictions = [\"hello there\", \"general skywalker\"]\n",
170 | "references = [\"hello there\", \"general kenobi\"]\n",
171 | "bert_score.compute(predictions=predictions, references=references, lang=\"en\")"
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "execution_count": 9,
177 | "id": "1776095e-f9df-4628-8736-49fbfb38275e",
178 | "metadata": {},
179 | "outputs": [
180 | {
181 | "data": {
182 | "text/plain": [
183 | "{'precision': [0.8900542259216309, 0.9747406840324402],\n",
184 | " 'recall': [0.8820334672927856, 0.9553087949752808],\n",
185 | " 'f1': [0.8860256671905518, 0.9649269580841064],\n",
186 | " 'hashcode': 'roberta-large_L17_no-idf_version=0.3.12(hug_trans=4.33.1)'}"
187 | ]
188 | },
189 | "execution_count": 9,
190 | "metadata": {},
191 | "output_type": "execute_result"
192 | }
193 | ],
194 | "source": [
195 | "# must match counts\n",
196 | "bert_score.compute(\n",
197 | " predictions=[\"This is a test.\"] * 2,\n",
198 | " references=[\"Two reference sentences.\", \"Second is a test sentence.\"],\n",
199 | " lang=\"en\",\n",
200 | ")"
201 | ]
202 | },
203 | {
204 | "cell_type": "code",
205 | "execution_count": null,
206 | "id": "136a9fa6-743a-4dcc-ac54-e6287574cbd1",
207 | "metadata": {},
208 | "outputs": [],
209 | "source": []
210 | }
211 | ],
212 | "metadata": {
213 | "kernelspec": {
214 | "display_name": "Python 3 (ipykernel)",
215 | "language": "python",
216 | "name": "python3"
217 | },
218 | "language_info": {
219 | "codemirror_mode": {
220 | "name": "ipython",
221 | "version": 3
222 | },
223 | "file_extension": ".py",
224 | "mimetype": "text/x-python",
225 | "name": "python",
226 | "nbconvert_exporter": "python",
227 | "pygments_lexer": "ipython3",
228 | "version": "3.10.11"
229 | }
230 | },
231 | "nbformat": 4,
232 | "nbformat_minor": 5
233 | }
234 |
--------------------------------------------------------------------------------
/notebooks/Qualtrics.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "1fec94f6-6847-4198-b6ba-ec45ea561b8c",
6 | "metadata": {},
7 | "source": [
8 | "# Qualtrics survey creation"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": 96,
14 | "id": "1b24f48a-aa65-44a0-84d7-e125da064978",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import html\n",
19 | "import json\n",
20 | "import re\n",
21 | "from pathlib import Path\n",
22 | "\n",
23 | "import pandas as pd"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": 97,
29 | "id": "8a8c6b72-7478-4843-bab5-887137394024",
30 | "metadata": {},
31 | "outputs": [],
32 | "source": [
33 | "data_dir = Path(\"../data\")\n",
34 | "assert data_dir.exists()\n",
35 | "figures_dir = Path(\"../figures\")\n",
36 | "figures_dir.mkdir(exist_ok=True)\n",
37 | "assert figures_dir.exists()"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 98,
43 | "id": "7dfa556b-fce1-4deb-97a6-25c9a2b0340c",
44 | "metadata": {},
45 | "outputs": [],
46 | "source": [
47 | "template_filepath = data_dir / \"raw\" / \"qualtrics\" / \"Rori_ranking_annotations_-_template.qsf\"\n",
48 | "with open(template_filepath) as infile:\n",
49 | " survey_text = infile.read()\n",
50 | " assert json.loads(survey_text)"
51 | ]
52 | },
53 | {
54 | "cell_type": "markdown",
55 | "id": "fbe34bd3-16fb-4a7d-aa58-1494991e7a9e",
56 | "metadata": {},
57 | "source": [
58 | "Keys to fill:\n",
59 | "\n",
60 | " - RoriSurveyId\n",
61 | " - Response{1,2,3}Q*\n",
62 | " - QueryTextQ*\n",
63 | " - DocumentQ*\n"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": 99,
69 | "id": "56091b06-eff6-4325-bad0-1abd1069d71e",
70 | "metadata": {},
71 | "outputs": [],
72 | "source": [
73 | "response_keys = [\"Response1Q\", \"Response2Q\", \"Response3Q\"]\n",
74 | "query_text_key = \"QueryTextQ\"\n",
75 | "document_key = \"DocumentQ\"\n",
76 | "\n",
77 | "# validate expected keys\n",
78 | "expected_survey_size = 15\n",
79 | "for key in response_keys + [query_text_key, document_key]:\n",
80 | " for i in range(1, expected_survey_size + 1):\n",
81 | " qkey = key + str(i)\n",
82 | " assert qkey in survey_text, qkey"
83 | ]
84 | },
85 | {
86 | "cell_type": "code",
87 | "execution_count": 100,
88 | "id": "5efb33ad-6cf4-4903-86dc-641fa4da9ae3",
89 | "metadata": {},
90 | "outputs": [
91 | {
92 | "name": "stdout",
93 | "output_type": "stream",
94 | "text": [
95 | "\n",
96 | "Display\":\"Response1Q2\"},\"\n"
97 | ]
98 | },
99 | {
100 | "data": {
101 | "text/plain": [
102 | "['Response1Q2']"
103 | ]
104 | },
105 | "execution_count": 100,
106 | "metadata": {},
107 | "output_type": "execute_result"
108 | }
109 | ],
110 | "source": [
111 | "for result in re.finditer(\"Response1Q2(?![0-9])\", survey_text):\n",
112 | " print(result)\n",
113 | " ind = result.span()[0]\n",
114 | " print(survey_text[ind - 10 : ind + 15])\n",
115 | "re.findall(\"Response1Q2(?![0-9])\", survey_text)"
116 | ]
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": 101,
121 | "id": "887a2aaf-b173-46ff-82cd-93d1e63775af",
122 | "metadata": {},
123 | "outputs": [],
124 | "source": [
125 | "def convert_text(text, use_br=True):\n",
126 | " text = html.escape(text.replace(\"\\\\\", \"/\"))\n",
127 | " # text = \"\" + \"<\\\\/p>
\".join(text.split(\"\\n\")) + \"<\\\\/p>\"\n",
128 | " if use_br:\n",
129 | " text = \"
\" + \"
\".join(text.split(\"\\n\")) + \"
\"\n",
130 | " else:\n",
131 | " text = \"\" + \"
\".join(text.split(\"\\n\")) + \"
\"\n",
132 | " return text\n",
133 | "\n",
134 | "\n",
135 | "expected_survey_size = 15\n",
136 | "for i in range(1, expected_survey_size + 1):\n",
137 | " r1 = \"R1 Multi-line string\\n\\nSeveral bits here are normal:\\n - 1\\n - 2\\n - 3\"\n",
138 | " r2 = r\"R2 Single line string, with some maybe-problematic characters: /\\!@#$%^&*()_-+\"\n",
139 | " r3 = \"R3\"\n",
140 | " responses = [r1, r2, r3]\n",
141 | " query = f\"Test query for page {i}\"\n",
142 | " document = \"Paragraph 1: text goes here\\nParagraph 2: Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.\"\n",
143 | " for key, response in zip(response_keys, responses):\n",
144 | " qkey = key + str(i)\n",
145 | " text = convert_text(response, use_br=False)\n",
146 | " survey_text = survey_text.replace(qkey, text, 1)\n",
147 | " survey_text = survey_text.replace(query_text_key + str(i), convert_text(query), 1)\n",
148 | " survey_text = survey_text.replace(document_key + str(i), convert_text(document), 1)"
149 | ]
150 | },
151 | {
152 | "cell_type": "code",
153 | "execution_count": 102,
154 | "id": "9dfe6092-dc16-41a7-8f63-6bebb43c24b1",
155 | "metadata": {},
156 | "outputs": [
157 | {
158 | "data": {
159 | "text/plain": [
160 | "'lti-line string
Several bit'"
161 | ]
162 | },
163 | "execution_count": 102,
164 | "metadata": {},
165 | "output_type": "execute_result"
166 | }
167 | ],
168 | "source": [
169 | "ind = 67956\n",
170 | "band = 20\n",
171 | "survey_text[ind - band : ind + band]"
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "execution_count": 103,
177 | "id": "733f3f68-1a55-45cb-86e7-78744834e0a3",
178 | "metadata": {},
179 | "outputs": [],
180 | "source": [
181 | "# verify that we've created valid JSON\n",
182 | "survey_text = json.dumps(json.loads(survey_text))"
183 | ]
184 | },
185 | {
186 | "cell_type": "code",
187 | "execution_count": 104,
188 | "id": "626257d8-0dfd-447d-82b3-7bfd8cd4d265",
189 | "metadata": {},
190 | "outputs": [],
191 | "source": [
192 | "survey_dir = data_dir / \"derived\" / \"qualtrics\"\n",
193 | "survey_dir.mkdir(exist_ok=True)\n",
194 | "survey_filepath = survey_dir / \"Rori_ranking_annotations_-_survey1.qsf\"\n",
195 | "with open(survey_filepath, \"w\") as outfile:\n",
196 | " outfile.write(survey_text)"
197 | ]
198 | },
199 | {
200 | "cell_type": "code",
201 | "execution_count": null,
202 | "id": "bece161b-aefd-4cd4-8f04-001b61622674",
203 | "metadata": {},
204 | "outputs": [],
205 | "source": []
206 | }
207 | ],
208 | "metadata": {
209 | "kernelspec": {
210 | "display_name": "Python 3 (ipykernel)",
211 | "language": "python",
212 | "name": "python3"
213 | },
214 | "language_info": {
215 | "codemirror_mode": {
216 | "name": "ipython",
217 | "version": 3
218 | },
219 | "file_extension": ".py",
220 | "mimetype": "text/x-python",
221 | "name": "python",
222 | "nbconvert_exporter": "python",
223 | "pygments_lexer": "ipython3",
224 | "version": "3.10.11"
225 | }
226 | },
227 | "nbformat": 4,
228 | "nbformat_minor": 5
229 | }
230 |
--------------------------------------------------------------------------------
/poetry.toml:
--------------------------------------------------------------------------------
1 | [virtualenvs]
2 | create = true
3 | in-project = true
4 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "rag-for-math-qa"
3 | version = "0.2.0"
4 | description = "Experimentation/analysis code and scripts"
5 | authors = [
6 | "Zachary Levonian "
7 | ]
8 | license = "MIT"
9 | readme = "README.md"
10 | packages = [{include = "experiment", from = "src"}, {include = "rag", from = "src"}]
11 | repository = "https://github.com/DigitalHarborFoundation/rag-for-math-qa.git"
12 |
13 | [tool.poetry.dependencies]
14 | python = ">=3.10,<3.12"
15 | poetry = "1.6.1"
16 | streamlit = "^1.23.1"
17 | pandas = "^2.0.2"
18 | scipy = "^1.11.0"
19 | argon2-cffi = "^21.3.0"
20 | sqlalchemy = "^2.0.20"
21 | psycopg2-binary = "^2.9.7"
22 | numpy = "^1.24.3"
23 | spacy = "^3.6.1"
24 | statsmodels = "^0.14.0"
25 | tabulate = "^0.9.0"
26 | scikit-learn = "^1.3.0"
27 | openai = "^0.28.0"
28 | tiktoken = "^0.4.0"
29 | python-dotenv = "^1.0.0"
30 | evaluate = "^0.4.0"
31 | bert-score = "^0.3.13"
32 | bleurt = {git = "https://github.com/google-research/bleurt.git"}
33 | tensorflow = {version = "^2.13.0" }
34 | tensorflow-macos = { version = "^2.13.0", platform = "darwin", markers = "platform_machine=='arm64'" }
35 | tensorflow-intel = { version = "^2.13.0", platform = "win32" }
36 | tensorflow-cpu = [
37 | { version = "^2.13.0", platform = "linux", markers = "platform_machine!='arm64' and platform_machine!='aarch64'" },
38 | { version = "^2.13.0", platform = "darwin", markers = "platform_machine!='arm64' and platform_machine!='aarch64'" },]
39 | tensorflow-cpu-aws = { version = "^2.13.0", platform = "linux", markers = "platform_machine=='arm64' or platform_machine=='aarch64'" }
40 | # https://github.com/tensorflow/tensorflow/blob/adb39b04e9cb116df4659a7e2de9eea27e62f25c/tensorflow/tools/pip_package/setup.py#L107-L108
41 | # https://github.com/python-poetry/poetry/issues/8271#issuecomment-1697740447
42 | tensorflow-io-gcs-filesystem = [
43 | { version = ">= 0.23.1", markers = "platform_machine!='arm64' or platform_system!='Darwin'" },
44 | { version = "< 0.32.0", markers = "platform_system == 'Windows'" }
45 | ]
46 | datasets = "2.10.0"
47 | krippendorff = "^0.6.0"
48 | skrub = {git = "https://github.com/skrub-data/skrub.git"}
49 |
50 | [tool.poetry.group.dev.dependencies]
51 | jupyter = "^1.0.0"
52 | matplotlib = "^3.7.1"
53 | black = "^22.12.0"
54 | isort = "^5.12"
55 | flake8 = "^6.0.0"
56 | nbqa = "^1.6.0"
57 | pre-commit = "^2.21.0"
58 | pytest = "^7.2.1"
59 | pytest-cov = "^4.0.0"
60 | jupyterlab = "^4.0.2"
61 |
62 | [build-system]
63 | requires = ["poetry-core"]
64 | build-backend = "poetry.core.masonry.api"
65 |
66 | [tool.black]
67 | line-length = 120
68 | include = '\.pyi?$'
69 | exclude = '''
70 | /(
71 | .eggs # exclude a few common directories in the
72 | | .git # root of the project
73 | | .github
74 | | .gitignore
75 | | .hg
76 | | .mypy_cache
77 | | .tox
78 | | .venv
79 | | venv
80 | | _build
81 | | buck-out
82 | | build
83 | | ci
84 | | data
85 | | dist
86 | | docs
87 | | docsrc
88 | )/
89 | '''
90 |
91 | [tool.isort]
92 | profile = "black"
93 | line_length = 79
94 | multi_line_output = 3
95 | include_trailing_comma = true
96 | virtual_env = "venv"
97 |
--------------------------------------------------------------------------------
/src/experiment/auth.py:
--------------------------------------------------------------------------------
1 | def generate_auth_token() -> str:
2 | import binascii
3 | import os
4 |
5 | auth_token = binascii.b2a_hex(os.urandom(16)).decode("ascii")
6 | return auth_token
7 |
8 |
9 | def cast_unicode(s: bytes | str, encoding: str) -> str:
10 | if isinstance(s, bytes):
11 | return s.decode(encoding, "replace")
12 | return s
13 |
14 |
15 | def passwd_hash(passphrase: str) -> str:
16 | from argon2 import PasswordHasher
17 |
18 | ph = PasswordHasher(
19 | memory_cost=10240,
20 | time_cost=10,
21 | parallelism=8,
22 | )
23 | h = ph.hash(passphrase)
24 |
25 | return ":".join(("argon2", cast_unicode(h, "ascii")))
26 |
27 |
28 | def passwd_check(hashed_passphrase: str, passphrase: str) -> bool:
29 | # modification of source provided with a BSD 3-Clause License, Copyright (c) 2015-, Jupyter Development Team
30 | # from notebook.auth.security
31 | assert hashed_passphrase.startswith("argon2:")
32 | import argon2
33 | import argon2.exceptions
34 |
35 | ph = argon2.PasswordHasher()
36 | try:
37 | return ph.verify(hashed_passphrase[7:], passphrase)
38 | except argon2.exceptions.VerificationError:
39 | return False
40 |
--------------------------------------------------------------------------------
/src/experiment/completion_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import time
3 |
4 | import openai
5 |
6 | logger = logging.getLogger(__name__)
7 |
8 |
9 | def get_completion_noraise(
10 | messages: list,
11 | sleep: float = 0.1,
12 | should_log_successful: bool = False,
13 | **kwargs,
14 | ) -> str | None:
15 | """Function wrapper that swallows exceptions, intended to use with multiprocessing.
16 |
17 | Returns:
18 | str | None: The completion, or None if an exception was raised.
19 | """
20 | try:
21 | generation = get_completion_with_wait(messages, sleep=sleep, **kwargs)
22 | if should_log_successful:
23 | logger.info("Successful completion.")
24 | return generation
25 | except Exception as ex:
26 | logger.warning(f"get_completion_noraise returning None due to {type(ex).__name__} error: {ex}")
27 | return None
28 |
29 |
30 | def get_completion_with_wait(messages: list, sleep: float = 0.1, **kwargs) -> str:
31 | generation = get_completion_with_retries(messages, **kwargs)
32 | if sleep:
33 | time.sleep(sleep) # being a bit polite on repeated api calls
34 | return generation
35 |
36 |
37 | def get_completion_with_retries(
38 | messages: list,
39 | max_attempts: int = 3,
40 | sleep_time_between_attempts: float = 5,
41 | **kwargs,
42 | ) -> str:
43 | """Could use a library for this, but let's keep it simple.
44 |
45 | Args:
46 | messages (list): _description_
47 | max_attempts (int, optional): Defaults to 3.
48 | sleep_time (float, optional): Defaults to 5 (seconds).
49 |
50 | Returns:
51 | str: The completion
52 | """
53 | n_attempts = 0
54 | while n_attempts < max_attempts:
55 | n_attempts += 1
56 | try:
57 | return get_completion(messages, **kwargs)
58 | except Exception as ex:
59 | logger.warning(f"Failure on attempt {n_attempts} / {max_attempts}: {type(ex).__name__} {ex}")
60 | if n_attempts == max_attempts:
61 | raise ex
62 | time.sleep(sleep_time_between_attempts * n_attempts)
63 | raise ValueError(
64 | f"Exceeded max attempts ({max_attempts}), base sleep interval {sleep_time_between_attempts}s; this error indicates an unexpected logical flow",
65 | )
66 |
67 |
68 | def get_completion(messages: list, model_name: str = "gpt-3.5-turbo-0613", request_timeout: float = 20) -> str:
69 | completion = openai.ChatCompletion.create(
70 | model=model_name,
71 | messages=messages,
72 | request_timeout=request_timeout,
73 | )
74 | assistant_message = completion["choices"][0]["message"]["content"]
75 | return assistant_message
76 |
--------------------------------------------------------------------------------
/src/experiment/generate.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import json
3 | import multiprocessing as mp
4 | import time
5 | from pathlib import Path
6 | from typing import Callable
7 |
8 | from experiment import completion_utils
9 |
10 |
11 | def is_valid_generation(generation: dict):
12 | return "messages" in generation and "generation" in generation
13 |
14 |
15 | class GenerationCorpus:
16 | def __init__(self, output_dir: Path, key: str, overwrite: bool = False):
17 | output_dir.mkdir(exist_ok=True)
18 | self.output_dir = output_dir
19 | self.key = key
20 | self.generation_filepath = output_dir / f"{key}_generations.ndjson"
21 | self.generations = []
22 | if not overwrite:
23 | self.load()
24 |
25 | def load(self):
26 | if self.generation_filepath.exists():
27 | with open(self.generation_filepath) as infile:
28 | for line in infile:
29 | d = json.loads(line)
30 | self.generations.append(d)
31 |
32 | def overwrite(self):
33 | self._save_generations(self.generations, write_mode="w")
34 |
35 | def filter_generations(
36 | self,
37 | should_include_func: Callable = is_valid_generation,
38 | should_remove_func: Callable | None = None,
39 | ) -> int:
40 | filtered_generations = []
41 | for generation in self.generations:
42 | if should_include_func(generation) and (should_remove_func is None or not should_remove_func(generation)):
43 | filtered_generations.append(generation)
44 | n_removed = len(self.generations) - len(filtered_generations)
45 | self.generations = filtered_generations
46 | return n_removed
47 |
48 | def _save_generation(self, metadata: dict):
49 | with open(self.generation_filepath, "a") as outfile:
50 | outfile.write(json.dumps(metadata) + "\n")
51 |
52 | def _save_generations(self, metadata_list: list[dict], write_mode: str = "a"):
53 | with open(self.generation_filepath, write_mode) as outfile:
54 | for metadata in metadata_list:
55 | outfile.write(json.dumps(metadata) + "\n")
56 |
57 | def is_already_generated(
58 | self,
59 | messages: list,
60 | metadata: dict | None,
61 | exclude_keys: set[str] = {"generation", "messages"},
62 | ) -> bool:
63 | """Determine if a generation was already created for this set of messages (and corresponding metadata).
64 |
65 | Args:
66 | messages (list): Message list to pass to the OpenAI API.
67 | metadata (dict | None): Optional metadata associated with the generation.
68 | exclude_keys (set[str], optional): Metadata keys to ignore when determining if this is a duplicate. Defaults to {"generation", "messages"}.
69 |
70 | Returns:
71 | bool: True if already in self.generations, False otherwise.
72 | """
73 | if metadata is None:
74 | metadata = {}
75 | for generation in self.generations:
76 | assert "messages" in generation
77 | assert "generation" in generation
78 | if generation["messages"] == messages and generation["generation"] is not None:
79 | is_metadata_match = True
80 | for key, value in metadata.items():
81 | if key in exclude_keys:
82 | continue
83 | if key not in generation or generation[key] != value:
84 | is_metadata_match = False
85 | if is_metadata_match:
86 | return True
87 | return False
88 |
89 | def generate(
90 | self,
91 | messages: list,
92 | metadata: dict | None,
93 | sleep: float | None = 0.1,
94 | ) -> bool:
95 | """Generate a new ChatCompletion.
96 |
97 | Args:
98 | messages (list): List of messages.
99 | metadata (dict | None): Metadata to save with the completion.
100 | sleep (float | None, optional): Time to wait after this request, in seconds. Defaults to 0.1.
101 |
102 | Returns:
103 | bool: True if a new generation was created and saved, False otherwise.
104 | """
105 | if metadata is None:
106 | metadata = {}
107 | if self.is_already_generated(messages, metadata):
108 | return False
109 | generation = completion_utils.get_completion_with_retries(messages)
110 | metadata["messages"] = messages
111 | metadata["generation"] = generation
112 | self.generations.append(metadata)
113 | self._save_generation(metadata)
114 | if sleep:
115 | time.sleep(sleep) # being a bit polite on repeated api calls
116 | return True
117 |
118 | def batch_filter_not_already_generated(self, metadata_list: list[dict]) -> list[dict]:
119 | metadata_to_process = []
120 | for metadata in metadata_list:
121 | if "messages" not in metadata:
122 | raise ValueError("Expected 'messages' in all provided metadata.")
123 | if not self.is_already_generated(metadata["messages"], metadata):
124 | metadata_to_process.append(metadata)
125 | return metadata_to_process
126 |
127 | def get_nonmatching_generations(
128 | self,
129 | metadata_list: list[dict],
130 | exclude_keys: set[str] = {"generation", "messages"},
131 | should_remove_nonmatching: bool = False,
132 | ) -> list[dict]:
133 | nonmatching_generations = []
134 | nonmatching_inds = []
135 | for i, generation in enumerate(self.generations):
136 | generation_match_found = False
137 | for metadata in metadata_list:
138 | is_metadata_match = True
139 | for key, value in metadata.items():
140 | if key in exclude_keys:
141 | continue
142 | if key not in generation or generation[key] != value:
143 | is_metadata_match = False
144 | break
145 | if is_metadata_match:
146 | generation_match_found = True
147 | break
148 | if not generation_match_found:
149 | nonmatching_generations.append(generation)
150 | nonmatching_inds.append(i)
151 | if should_remove_nonmatching:
152 | for ind in sorted(nonmatching_inds, reverse=True):
153 | self.generations.pop(ind)
154 | return nonmatching_generations
155 |
156 | def batch_generate(
157 | self,
158 | metadata_list: list[dict],
159 | n_processes: int = 4,
160 | sleep: float = 0.1,
161 | completion_func: Callable = completion_utils.get_completion_noraise,
162 | **kwargs,
163 | ) -> int:
164 | """_summary_
165 |
166 | Args:
167 | metadata_list (list[dict]): List of metadata dictionaries, that must each include a 'messages' key with a list of messages.
168 | n_processes (int, optional): # processes to spawn in the pool. Defaults to 4.
169 | sleep (float, optional): Time to sleep between requests IN EACH PROCESS. Defaults to 0.1.
170 | completion_func (Callable, optional): Function to use to produce generations from a list of messages. Defaults to completion_utils.get_completion_noraise.
171 | Other keyword args are passed to completion_func
172 |
173 | Raises:
174 | ValueError: If 'messages' is not included in one of the metadata dicts.
175 |
176 | Returns:
177 | int: Number of new generations. Note this MAY imply generations failed if < len(metadata_list), but only if no metadata were already generated.
178 | """
179 | metadata_to_process = self.batch_filter_not_already_generated(metadata_list)
180 | if len(metadata_to_process) == 0:
181 | return 0
182 | get_completion_func = functools.partial(completion_func, sleep=sleep, **kwargs)
183 | with mp.Pool(processes=n_processes) as pool:
184 | message_lists = (md["messages"] for md in metadata_to_process)
185 | results = pool.map(get_completion_func, message_lists)
186 | assert len(results) == len(metadata_to_process)
187 | metadata_completed = []
188 | for metadata, result in zip(metadata_to_process, results):
189 | if result is not None:
190 | metadata["generation"] = result
191 | metadata_completed.append(metadata)
192 | if len(metadata_completed) > 0:
193 | self.generations.extend(metadata_completed)
194 | self._save_generations(metadata_completed)
195 | return len(metadata_completed)
196 |
--------------------------------------------------------------------------------
/src/experiment/guidance_conditions.py:
--------------------------------------------------------------------------------
1 | none = [
2 | {
3 | "role": "system",
4 | "content": """You are going to act as a mathematics tutor for a 13 year old student who is in grade 8 or 9 and lives in Ghana.
5 | You will be encouraging and factual.
6 | Prefer simple, short responses.
7 | If the student says something inappropriate or off topic you will say you can only focus on mathematics and ask them if they have any math-related follow-up questions.""",
8 | },
9 | ]
10 |
11 | low = [
12 | {
13 | "role": "system",
14 | "content": """You are going to act as a mathematics tutor for a 13 year old student who is in grade 8 or 9 and lives in Ghana.
15 | You will be encouraging and factual.
16 |
17 | Only if it is relevant, examples and language from the section below may be helpful to format your response:
18 | ===
19 | {rori_microlesson_texts}
20 | {openstax_subsection_texts}
21 | ===
22 |
23 | Prefer simple, short responses.
24 | If the student says something inappropriate or off topic you will say you can only focus on mathematics and ask them if they have any math-related follow-up questions.""",
25 | },
26 | ]
27 |
28 | medium = [
29 | {
30 | "role": "system",
31 | "content": """You are going to act as a mathematics tutor for a 13 year old student who is in grade 8 or 9 and lives in Ghana.
32 | You will be encouraging and factual.
33 |
34 | Use examples and language from the section below to format your response:
35 | ===
36 | {rori_microlesson_texts}
37 | {openstax_subsection_texts}
38 | ===
39 |
40 | Prefer simple, short responses.
41 | If the student says something inappropriate or off topic you will say you can only focus on mathematics and ask them if they have any math-related follow-up questions.""",
42 | },
43 | ]
44 |
45 | high = [
46 | {
47 | "role": "system",
48 | "content": """You are going to act as a mathematics tutor for a 13 year old student who is in grade 8 or 9.
49 | This student lives in Ghana or Nigeria.
50 | You will be encouraging and factual.
51 | Prefer simple, short responses based on the textbook.
52 | If the student says something inappropriate or off topic you will say you can only focus on mathematics and ask them if they have any math-related follow-up questions.""",
53 | },
54 | {
55 | "role": "user",
56 | "content": """Answer the following question: {user_query}
57 |
58 | Reference content from this textbook section in your response:
59 | {openstax_subsection_texts}
60 |
61 | End your response by relating the question to an example or definition in the textbook.""",
62 | },
63 | ]
64 |
65 | extract_relevant = [
66 | {
67 | "role": "user",
68 | "content": """Given a middle-school math student's question, you will identify the most relevant section from a textbook.
69 |
70 | Student question: {user_query}
71 |
72 | Repeat the student's question and then repeat in full the most relevant paragraph from my math textbook. If none of them seem relevant, take a deep breath and output the most relevant. Don't say anything else.
73 |
74 | Textbook paragraphs:
75 |
76 | {rori_microlesson_texts}
77 | {openstax_subsection_texts}""",
78 | },
79 | ]
80 |
81 | # TODO consider if these guidance conditions could include their own dbinfo settings as well
82 | guidance_condition_messages_map = {
83 | "none": none,
84 | "low": low,
85 | # "medium": medium,
86 | "high": high,
87 | "extract_relevant": extract_relevant,
88 | }
89 | guidance_condition_name_list = [key for key in guidance_condition_messages_map.keys()]
90 |
--------------------------------------------------------------------------------
/src/experiment/metrics.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import evaluate
4 |
5 | import experiment.tokenize
6 |
7 |
8 | def compute_macro_f1(passages: list[str], generation: str, discount_text: str | None = None) -> float:
9 | """Returns the max F1 across all the passages.
10 | Depending on arguments, this can be Knowledge F1 or just F1.
11 |
12 | SQuAD paper (http://arxiv.org/abs/1606.05250):
13 | "This metric measures the average overlap between the prediction and ground truth answer.
14 | We treat the prediction and ground truth as bags of tokens, and compute their F1.
15 | We take the maximum F1 over all of the ground truth answers for a given question, and then average over all of the questions."
16 |
17 | K-F1++ (https://aclanthology.org/2023.findings-acl.60):
18 | "Knowledge-F1 (K-F1) ... calculates the unigram overlap between the response and a knowledge snippet K,
19 | providing a verbatim measure of grounding to the input source.
20 | We propose K-F1++, a variant of K-F1,
21 | that captures only the novel information in the generated response and discounts any lexical alignment to the question:
22 | it calculates the unigram overlap between the response and K,
23 | after subtracting any tokens appearing in the question from the response."
24 | To use K-F1++, pass in the text to ignore to discount_text.
25 | """
26 | # first create the tokenization function to use
27 | get_tokens = functools.partial(experiment.tokenize.get_tokens, lower=True, remove_nonalphanumeric_tokens=True)
28 | generation_tokens = set(get_tokens(generation))
29 | if discount_text:
30 | discount_tokens = set(get_tokens(discount_text))
31 | generation_tokens -= discount_tokens
32 | n_predicted_tokens = len(generation_tokens)
33 | if n_predicted_tokens == 0:
34 | raise ValueError("Expected generation to be non-empty.")
35 | f1_scores = []
36 | for passage in passages:
37 | passage_tokens = set(get_tokens(passage))
38 | if discount_text:
39 | passage_tokens -= discount_tokens
40 | n_ground_truth_tokens = len(passage_tokens)
41 | if n_ground_truth_tokens == 0:
42 | continue
43 | n_correct_tokens = len(passage_tokens & generation_tokens)
44 | precision = n_correct_tokens / n_predicted_tokens
45 | recall = n_correct_tokens / n_ground_truth_tokens
46 | if precision + recall == 0:
47 | f1 = 0
48 | else:
49 | f1 = 2 * (precision * recall) / (precision + recall)
50 | f1_scores.append(f1)
51 | if len(f1_scores) == 0:
52 | raise ValueError("No non-empty passages.")
53 | max_f1 = max(f1_scores)
54 | return max_f1
55 |
56 |
57 | @functools.cache
58 | def get_bertscore_metric_object() -> evaluate.EvaluationModule:
59 | """
60 | See https://huggingface.co/spaces/evaluate-metric/bertscore
61 | """
62 | return evaluate.load("bertscore")
63 |
64 |
65 | @functools.cache
66 | def get_bleurt_metric_object(checkpoint: str = "bleurt-20") -> evaluate.EvaluationModule:
67 | """See https://huggingface.co/spaces/evaluate-metric/bleurt
68 |
69 | According to https://github.com/google-research/bleurt, the current recommended checkpoint is BLEURT-20.
70 | TODO use that checkpoiont??
71 |
72 | Args:
73 | checkpoint (str, optional): bleurt-base-512, bleurt-large-512, etc. Defaults to "bleurt-20".
74 |
75 | Returns:
76 | evaluate.EvaluationModule: The metric object
77 | """
78 | return evaluate.load(
79 | "bleurt",
80 | checkpoint=checkpoint,
81 | module_type="metric",
82 | # download_config=datasets.DownloadConfig(use_etag=False),
83 | )
84 |
85 |
86 | def compute_bertscore(passages: list[str], generation: str):
87 | bert_score = get_bertscore_metric_object()
88 | references = passages
89 | predictions = [generation] * len(references)
90 | score_dict = bert_score.compute(predictions=predictions, references=references, lang="en")
91 | return max(score_dict["f1"])
92 |
93 |
94 | def compute_bleurt(passages: list[str], generation: str, compare_to_combined_passage: bool = False):
95 | bleurt = get_bleurt_metric_object()
96 | references = passages
97 | if compare_to_combined_passage:
98 | references += ["\n".join(passages)]
99 | predictions = [generation] * len(references)
100 | scores = bleurt.compute(predictions=predictions, references=references)["scores"]
101 | return max(scores)
102 |
103 |
104 | def compute_bleurt_batch(passages_list: list[list[str]], generation_list: list[str]):
105 | # TODO we can make this faster by implementing this
106 | assert len(passages_list) == len(generation_list)
107 | raise NotImplementedError("Batch scoring not yet implemented.")
108 |
--------------------------------------------------------------------------------
/src/experiment/qualtrics.py:
--------------------------------------------------------------------------------
1 | import html
2 | import json
3 | import logging
4 | import math
5 | from datetime import datetime
6 | from pathlib import Path
7 |
8 | import numpy as np
9 | import pandas as pd
10 |
11 | OVERFLOW_RESPONSE = "No more responses to annotate"
12 | OVERFLOW_DOCUMENT = "n/a"
13 | OVERFLOW_QUERY = "No more student queries in this survey, just click through to the end."
14 |
15 | survey_id_key = "RoriSurveyId"
16 | response_keys = ["Response1Q", "Response2Q", "Response3Q"]
17 | query_text_key = "QueryTextQ"
18 | document_key = "DocumentQ"
19 |
20 |
21 | def chunker(seq, size):
22 | return (seq[pos : pos + size] for pos in range(0, len(seq), size))
23 |
24 |
25 | def get_template(template_filepath: Path):
26 | with open(template_filepath) as infile:
27 | survey_text = infile.read()
28 | validate_template_survey_text(survey_text)
29 | return survey_text
30 |
31 |
32 | def validate_template_survey_text(survey_text, expected_survey_size: int = 15):
33 | assert json.loads(survey_text)
34 |
35 | # validate expected keys
36 | for key in response_keys + [query_text_key, document_key]:
37 | for i in range(1, expected_survey_size + 1):
38 | qkey = key + str(i)
39 | assert qkey in survey_text, qkey
40 |
41 |
42 | def convert_text(text, use_br=True):
43 | # not sure this replace is necessary, but I was having some issues with escaping
44 | text = text.replace("\\", "/")
45 | text = html.escape(text)
46 | if use_br:
47 | text = "" + "
".join(text.split("\n")) + "
"
48 | else:
49 | text = "" + "
".join(text.split("\n")) + "
"
50 | return text
51 |
52 |
53 | def create_surveys(
54 | df: pd.DataFrame,
55 | template_survey_text: str,
56 | survey_dir: Path,
57 | survey_size: int = 15,
58 | rng: np.random.Generator = None,
59 | ) -> pd.DataFrame:
60 | """Assumed columns:
61 | - generation
62 | - query
63 | - document
64 | """
65 | if rng is None:
66 | rng = np.random.default_rng()
67 | rows = []
68 | for query, group in df.groupby("query"):
69 | assert len(group) == 3, "Qualtrics survey hard-coded to accept 3 responses."
70 | assert group["generation"].nunique() == 3, "Generations/responses should be unique."
71 | for key in ["document"]:
72 | n_unique = group[key].nunique()
73 | if n_unique != 1:
74 | logging.warning(f"Expected 1 unique value in column {key}, found {n_unique}")
75 | # shuffle rows
76 | group = group.sample(frac=1, random_state=rng)
77 | # build new data structure
78 | responses = []
79 | metas = []
80 | for i in range(3):
81 | row = group.iloc[i]
82 | response = row.generation
83 | responses.append(response)
84 | meta = row.to_dict()
85 | del meta["document"]
86 | del meta["generation"]
87 | del meta["query"]
88 | metas.append(meta)
89 | row = [query, group.iloc[0]["document"], *responses, *metas]
90 | rows.append(row)
91 | # randomize row order
92 | rng.shuffle(rows)
93 |
94 | survey_dir.mkdir(exist_ok=True)
95 | for i, survey_rows in enumerate(chunker(rows, survey_size)):
96 | survey_id = f"s_{datetime.now().strftime('%Y%m%d')}_{i+1}/{math.ceil(len(rows) / survey_size)}"
97 | survey_text = create_survey(survey_id, survey_rows, template_survey_text, survey_size)
98 | survey_filepath = survey_dir / f"Rori_ranking_annotations_-_survey{i}.qsf"
99 | with open(survey_filepath, "w") as outfile:
100 | outfile.write(survey_text)
101 | for row in survey_rows:
102 | row.append(survey_id)
103 |
104 | # build survey_df from rows
105 | survey_df = pd.DataFrame(
106 | rows,
107 | columns=[
108 | "query",
109 | "document",
110 | "response1",
111 | "response2",
112 | "response3",
113 | "response1_meta",
114 | "response2_meta",
115 | "response3_meta",
116 | "survey_id",
117 | ],
118 | )
119 | return survey_df
120 |
121 |
122 | def create_survey(survey_id: str, rows: list, template_survey_text: str, survey_size: int = 15) -> str:
123 | survey_text = template_survey_text
124 | survey_text = survey_text.replace(survey_id_key, survey_id)
125 | for i in range(1, survey_size + 1):
126 | if i - 1 < len(rows):
127 | row = rows[i - 1]
128 | query, document, r1, r2, r3, _, _, _ = row
129 | else:
130 | query = OVERFLOW_QUERY
131 | document = OVERFLOW_DOCUMENT
132 | r1, r2, r3 = OVERFLOW_RESPONSE, OVERFLOW_RESPONSE, OVERFLOW_RESPONSE
133 | responses = [r1, r2, r3]
134 | for key, response in zip(response_keys, responses):
135 | qkey = key + str(i)
136 | text = convert_text(response, use_br=False)
137 | survey_text = survey_text.replace(qkey, text, 1)
138 | survey_text = survey_text.replace(query_text_key + str(i), convert_text(query), 1)
139 | survey_text = survey_text.replace(document_key + str(i), convert_text(document), 1)
140 | # verify that we've created valid JSON
141 | survey_text = json.dumps(json.loads(survey_text))
142 | return survey_text
143 |
--------------------------------------------------------------------------------
/src/experiment/tokenize.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import re
3 |
4 | from spacy.lang.en import English
5 |
6 |
7 | @functools.cache
8 | def get_spacy_english():
9 | nlp = English()
10 | return nlp
11 |
12 |
13 | def get_tokens(string_to_tokenize: str, lower: bool = True, remove_nonalphanumeric_tokens: bool = False) -> list[str]:
14 | nlp = get_spacy_english()
15 | doc = nlp.tokenizer(string_to_tokenize)
16 | if lower:
17 | tokens = [t.text.lower() for t in doc]
18 | else:
19 | tokens = [t.text for t in doc]
20 | if remove_nonalphanumeric_tokens:
21 | tokens = [token for token in tokens if re.match("\\w+", token)]
22 | return tokens
23 |
--------------------------------------------------------------------------------
/src/rag/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/src/rag/__init__.py
--------------------------------------------------------------------------------
/src/rag/embedding_utils.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import numpy as np
4 | import openai
5 | import tiktoken
6 |
7 | EMBEDDING_DIM = 1536
8 | MAX_TOKENS_PER_REQUEST = 8191
9 | EMBEDDING_MODEL = "text-embedding-ada-002"
10 |
11 |
12 | def get_token_counts(text_list: list[str]) -> list[int]:
13 | """Given a list of texts, returns a list of the same length with the number of tokens from the EMBEDDING_MODEL tokenizer.
14 |
15 | Args:
16 | text_list (list[str]): Texts to tokenize.
17 |
18 | Returns:
19 | list[int]: Token counts corresponding to text_list.
20 | """
21 | tokenizer = tiktoken.encoding_for_model(EMBEDDING_MODEL)
22 | token_counts = []
23 | for string in text_list:
24 | token_counts.append(len(tokenizer.encode(string)))
25 | return token_counts
26 |
27 |
28 | def get_openai_embeddings(
29 | texts: list[str],
30 | embedding_model: str = EMBEDDING_MODEL,
31 | ) -> list[np.array]:
32 | """Given the list of texts, query the openai Embedding API, returning the embeddings in a list of numpy arrays.
33 | Note: calls through to a cahced function.
34 |
35 | Args:
36 | texts (list[str]): List of texts to embed.
37 | embedding_model (str, optional): Embedding model to use. Defaults to EMBEDDING_MODEL.
38 |
39 | Returns:
40 | list[np.array]: Embeddings, in the same order as the given texts.
41 | """
42 | return get_openai_embeddings_cached(tuple(texts), embedding_model=embedding_model)
43 |
44 |
45 | @functools.lru_cache(maxsize=512, typed=True)
46 | def get_openai_embeddings_cached(
47 | texts: tuple[str],
48 | embedding_model: str = EMBEDDING_MODEL,
49 | ) -> list[np.array]:
50 | result = openai.Embedding.create(input=texts, engine=embedding_model)
51 | embedding_list = [np.array(d["embedding"]) for d in result.data]
52 | return embedding_list
53 |
54 |
55 | def batch_embed_texts(
56 | input_text_list: list[str],
57 | n_tokens_list: list[int],
58 | ) -> list[np.array]:
59 | """Embed the given texts, respecting the API max tokens limit given MAX_TOKENS_PER_REQUEST.
60 |
61 | Args:
62 | input_text_list (list[str]): Texts to embed.
63 | n_tokens_list (list[int]): As returned by `get_token_counts`
64 |
65 | Returns:
66 | list[np.array]: List of embeddings, stored in numpy arrays.
67 | """
68 | curr_batch_token_count = 0
69 | texts = []
70 | embedding_list = []
71 | for text, n_tokens in zip(input_text_list, n_tokens_list):
72 | if curr_batch_token_count + n_tokens > MAX_TOKENS_PER_REQUEST:
73 | embedding_list.extend(get_openai_embeddings(texts, EMBEDDING_MODEL))
74 | texts = [text]
75 | curr_batch_token_count = 0
76 | else:
77 | texts.append(text)
78 | curr_batch_token_count += n_tokens
79 | if len(texts) > 0:
80 | embedding_list.extend(get_openai_embeddings(texts, EMBEDDING_MODEL))
81 | return embedding_list
82 |
--------------------------------------------------------------------------------
/src/rag/gpf_utils.py:
--------------------------------------------------------------------------------
1 | def get_gpd_codes(lesson_code):
2 | """
3 | The Global Proficiency Framework (GPF) uses GPD codes.
4 | This is a slight variant of GPD codes that includes the grade before the GPD code.
5 |
6 | Structure is: G....
7 |
8 | e.g. G9.N5.1.3.1 has:
9 | grade 9
10 | domain N
11 | construct N5
12 | subconstruct N5.1
13 | skill N5.1.3
14 | index 1
15 | """
16 | tokens = lesson_code.split(".")
17 | grade = int(tokens[0][1])
18 | index = int(tokens[-1])
19 |
20 | skill = ".".join(tokens[1:-1])
21 |
22 | domain = tokens[1][0]
23 | construct = tokens[1]
24 | subconstruct = tokens[1] + "." + tokens[2]
25 | return grade, domain, construct, subconstruct, skill, index
26 |
--------------------------------------------------------------------------------
/src/rag/logit_bias.py:
--------------------------------------------------------------------------------
1 | # Utilities associated with logit_bias
2 | # docs: https://platform.openai.com/docs/api-reference/chat/create#logit_bias
3 | # help doc: https://help.openai.com/en/articles/5247780-using-logit-bias-to-define-token-probability
4 | # logit_bias takes at most 300 tokens: https://aidungeon.medium.com/controlling-gpt-3-with-logit-bias-55866d593292
5 | import functools
6 | import importlib.resources
7 | import json
8 | import re
9 | import string as string_utils
10 | from collections import Counter
11 |
12 | import tiktoken
13 |
14 | from rag import resources
15 |
16 |
17 | @functools.cache
18 | def get_tokenizer(model_name: str = "gpt-3.5-turbo") -> tiktoken.Encoding:
19 | """Get the tokenizer. Cached.
20 |
21 | Args:
22 | model_name (str, optional): The model tokenizer to load. Defaults to "gpt-3.5-turbo".
23 |
24 | Returns:
25 | tiktoken.Encoding: The tiktoken/OpenAI tokenizer.
26 | """
27 | tokenizer = tiktoken.encoding_for_model(model_name)
28 | return tokenizer
29 |
30 |
31 | @functools.cache
32 | def get_stopword_tokens():
33 | """Cached version of `load_stopword_tokens`."""
34 | return load_stopword_tokens()
35 |
36 |
37 | def load_stopword_tokens() -> set[int]:
38 | resource_filepath = importlib.resources.files(resources) / "dolma_stopwords.json"
39 | with resource_filepath.open("r") as infile:
40 | stopwords_dict = json.load(infile)
41 | return set(stopwords_dict["stopword_tokens"])
42 |
43 |
44 | def create_stopword_token_set_from_word_list(word_list: list[str]) -> set[int]:
45 | """Create a set of stopword tokens from the given list of stop words.
46 | Used to create the default stopword resource loaded by `load_stopword_tokens`.
47 |
48 | Args:
49 | word_list (list[str]): List of words to include in the stopword set.
50 |
51 | Returns:
52 | set[int]: Set of stopword tokens.
53 | """
54 | tokenizer = get_tokenizer()
55 | stopword_tokens = set()
56 | stopwords = (
57 | word_list
58 | + list(map(str.lower, word_list))
59 | + list(map(str.upper, word_list))
60 | + list(map(str.capitalize, word_list))
61 | + list(string_utils.whitespace)
62 | + list(string_utils.punctuation)
63 | )
64 | for word in stopwords:
65 | for char in string_utils.whitespace + string_utils.punctuation:
66 | for string in [word, char + word, word + char]:
67 | tokens = tokenizer.encode(string)
68 | if len(tokens) == 1:
69 | stopword_tokens.add(tokens[0])
70 | return stopword_tokens
71 |
72 |
73 | def get_nonstopword_tokens(text: str) -> list[int]:
74 | tokenizer = get_tokenizer()
75 | stopword_tokens = get_stopword_tokens()
76 | tokens = tokenizer.encode(text)
77 | tokens = [
78 | token
79 | for token in tokens
80 | if token not in stopword_tokens and re.fullmatch("[ A-Za-z]+", tokenizer.decode([token]))
81 | ]
82 | return tokens
83 |
84 |
85 | def get_logit_bias(
86 | tokens: list[int],
87 | min_count: int = 2,
88 | n_tokens: int | None = None,
89 | max_tokens: int = 50,
90 | min_bias: float = 1.0,
91 | max_bias: float = 5.0,
92 | ) -> dict[int, float]:
93 | """Given a list of tokens, create a corresponding logit_bias dictionary.
94 |
95 | Roughly, identifies the most frequent max_tokens tokens, stopping at n_tokens if provided,
96 | that occur at least min_count times. Bias values are assigned based on frequency, in the range min_bias to max_bias.
97 | The most frequent token will always have a weight of max_bias in the resulting logit_bias.
98 | Bias defaults are generally inspired by this doc: https://help.openai.com/en/articles/5247780-using-logit-bias-to-define-token-probability
99 |
100 | Args:
101 | tokens (list[int]): The list of tokens, e.g. after having stopword tokens removed.
102 | min_count (int, optional): Defaults to 2.
103 | n_tokens (int | None, optional): Defaults to None.
104 | max_tokens (int, optional): Defaults to 50.
105 | min_bias (float, optional): Defaults to 1.0.
106 | max_bias (float, optional): Defaults to 5.0.
107 |
108 | Returns:
109 | dict[int, float]: The logit_bias dict that can be passed to the logit_bias parameter accepted by the OpenAI API.
110 | """
111 | if len(tokens) == 0:
112 | return {}
113 | logit_bias = {}
114 | c = Counter(tokens).most_common(max_tokens)
115 | max_count = c[0][1] # count of most-frequently-occurring token
116 | if max_count >= min_count:
117 | for token, count in c:
118 | if count < min_count:
119 | continue
120 | bias = min_bias + (max_bias - min_bias) * (count / max_count)
121 | logit_bias[token] = bias
122 | if n_tokens is not None and len(logit_bias) >= n_tokens:
123 | break
124 | return logit_bias
125 |
126 |
127 | def get_logit_bias_from_slot(
128 | recent_slot_fill_dict: list[dict[str, str]],
129 | include: list[str] | None = None,
130 | exclude: list[str] = [],
131 | **kwargs,
132 | ) -> dict[int, float]:
133 | """Given texts that fill one or more slots, create an appropriate logit_bias.
134 | This is probably not a very reasonable way to do to this.
135 |
136 | Args:
137 | recent_slot_fill_dict (list[dict[str, str]]): See `prompt_utils.PromptManager`.
138 | include (list[str] | None, optional): Slot texts to consider. Defaults to None, meaning all slots are included.
139 | exclude (list[str], optional): Slot texts to ignore. Defaults to [].
140 |
141 | Returns:
142 | dict[int, float]: logit_bias
143 | """
144 | texts = []
145 | for slot_fill_dict in recent_slot_fill_dict:
146 | for key, value in slot_fill_dict.items():
147 | if (include is None or key in include) and key not in exclude:
148 | texts.append(value)
149 | text = "\n".join(texts)
150 | tokens = get_nonstopword_tokens(text)
151 | return get_logit_bias(tokens, **kwargs)
152 |
--------------------------------------------------------------------------------
/src/rag/misconceptions.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import importlib.resources
3 | import json
4 | from operator import itemgetter
5 |
6 | from rag import resources
7 |
8 |
9 | @functools.cache
10 | def get_misconception_list():
11 | """Cached version of `load_misconception_list`."""
12 | return load_misconception_list()
13 |
14 |
15 | def load_misconception_list() -> list[dict[str, str]]:
16 | """Misconceptions resource generated in the `MisconceptionData.ipynb` notebook.
17 |
18 | Returns:
19 | list[dict[str, str]]: List of misconceptions.
20 | """
21 | resource_filepath = importlib.resources.files(resources) / "misconceptions.ndjson"
22 | with resource_filepath.open("r") as infile:
23 | misconception_list = json.load(infile)
24 | return misconception_list
25 |
26 |
27 | @functools.cache
28 | def get_misconceptions_string() -> str:
29 | """Intended for use in prompts, this is a semi-colon delimited list of common math misconceptions.
30 |
31 | Returns:
32 | str: List of misconceptions that is approximately 1000 tokens (see notebook).
33 | """
34 | misconception_list = get_misconception_list()
35 | misconception_list.sort(key=itemgetter("Topic", "ID"))
36 | descriptions = []
37 | for row in misconception_list:
38 | if not row["ID"].startswith("MaE"):
39 | continue
40 | description = row["Misconception"]
41 | description = description.split(".")[0]
42 | s = description.replace(";", ",").replace("\n", " ").strip()
43 | descriptions.append(s)
44 | misconception_string = "; ".join(descriptions)
45 | return misconception_string
46 |
--------------------------------------------------------------------------------
/src/rag/prompt_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import re
4 |
5 | from rag import embedding_utils, retrieval_strategies
6 |
7 | VALID_ROLES: list[str] = ["user", "assistant", "system"]
8 |
9 |
10 | class PromptSelector:
11 | """PromptSelector provides utilities to enumerate and choose prompts.
12 |
13 | Prompts are stored in dictionaries in the `prompts` modules.
14 | """
15 |
16 | def __init__(self, intro_prompt_dict: dict):
17 | self.intro_prompt_dict = intro_prompt_dict
18 | self.pretty_name_to_id_map = {
19 | t[1]["pretty_name"] if "pretty_name" in t[1] else f"Prompt {i}": t[0]
20 | for i, t in enumerate(self.intro_prompt_dict)
21 | }
22 |
23 | def get_intro_prompt_pretty_names(self):
24 | pretty_name_list = []
25 | for i, prompt_info in enumerate(self.intro_prompt_dict.values()):
26 | pretty_name = prompt_info["pretty_name"] if "pretty_name" in prompt_info else f"Prompt {i}"
27 | pretty_name_list.append(pretty_name)
28 | return pretty_name_list
29 |
30 | def get_intro_prompt_message_lists(self) -> list[dict[str, str]]:
31 | message_lists = []
32 | for prompt_info in self.intro_prompt_dict.values():
33 | message_lists.append(prompt_info["messages"])
34 | return message_lists
35 |
36 | def get_default_intro_prompt(self) -> dict[str]:
37 | return self.intro_prompt_dict[next(iter(self.intro_prompt_dict.keys()))]
38 |
39 | def convert_conversation_to_string(messages):
40 | conversation_string = ""
41 | for message in messages:
42 | conversation_string += message["role"].upper() + ":\n"
43 | conversation_string += message["content"] + "\n"
44 | return conversation_string
45 |
46 | def convert_string_to_conversation(
47 | conversation_string: str,
48 | ) -> list[dict[str, str]]:
49 | """Given a string representing a conversation, convert into the expected messages list format.
50 |
51 | Follows a pretty basic convention, defined in this implementation.
52 |
53 | Args:
54 | conversation_string (str): String representing a conversation.
55 |
56 | Returns:
57 | list[dict[str, str]]: List of messages, each with a "role" and "content".
58 | """
59 | messages = []
60 | message = {
61 | "content": "",
62 | }
63 | for line in conversation_string.split("\n"):
64 | possible_role = line[:-1].lower()
65 | if possible_role in VALID_ROLES:
66 | if "role" in message:
67 | message["content"] = message["content"].strip()
68 | messages.append(message)
69 | message = {
70 | "content": "",
71 | }
72 | message["role"] = possible_role
73 | else:
74 | message["content"] += line + "\n"
75 | if "role" in message:
76 | message["content"] = message["content"].strip()
77 | messages.append(message)
78 | return messages
79 |
80 |
81 | class PromptManager:
82 | """Stores prompts and generates message lists for passing to the OpenAI API."""
83 |
84 | def __init__(self):
85 | self.intro_messages: list[dict[str, str]] = []
86 | self.retrieval_strategy: retrieval_strategies.RetrievalStrategy = retrieval_strategies.NoRetrievalStrategy()
87 | self.stored_messages: list[dict[str, str]] = []
88 | self.most_recent_slot_fill_dict: dict[str, str] = {}
89 | self.recent_slot_fill_dict: list[dict[str, str]] = []
90 |
91 | def set_intro_messages(self, intro_messages: list[dict[str, str]]) -> PromptManager:
92 | self.intro_messages = intro_messages
93 | return self
94 |
95 | def set_retrieval_strategy(
96 | self,
97 | retrieval_strategy: retrieval_strategies.RetrievalStrategy,
98 | ) -> PromptManager:
99 | self.retrieval_strategy = retrieval_strategy
100 | return self
101 |
102 | def get_retrieval_strategy(self) -> retrieval_strategies.RetrievalStrategy:
103 | return self.retrieval_strategy
104 |
105 | def add_stored_message(self, message: dict[str, str]) -> PromptManager:
106 | self.stored_messages.append(message)
107 | return self
108 |
109 | def clear_stored_messages(self) -> PromptManager:
110 | self.stored_messages.clear()
111 | return self
112 |
113 | def build_query(
114 | self,
115 | user_query: str | None = None,
116 | previous_messages: list[dict[str, str]] | None = None,
117 | query_for_retrieval_context: str | None = None,
118 | ) -> list[dict[str, str]]:
119 | """Given a user_query (or the `intro_messages` set on this PromptManager), build a set of messages to pass to the OpenAI API.
120 |
121 | Args:
122 | user_query (str | None, optional): If provided, will construct a new user message from this user. Defaults to None.
123 | previous_messages (list[dict[str, str]] | None, optional): If provided, will continue a conversation. Defaults to None.
124 | query_for_retrieval_context (str | None, optional): If provided, this is used for any RetrievalStrategies that require querying. Defaults to None, meaning the user_query or the most recent user message will be used.
125 |
126 | Raises:
127 | KeyError: If the given RetrievalStrategy doesn't fill all the identified slots in the prompts.
128 |
129 | Returns:
130 | list[dict[str, str]]: List of messages, to pass to the OpenAI API.
131 | """
132 | if previous_messages is None:
133 | previous_messages = self.stored_messages
134 | if len(previous_messages) == 0:
135 | # this is a new query
136 | messages = [message.copy() for message in self.intro_messages]
137 | self.stored_messages.extend(messages)
138 | else:
139 | # not a new query,
140 | # so include the previous messages as context
141 | messages = [message.copy() for message in previous_messages]
142 | if user_query is not None:
143 | user_message = {
144 | "role": "user",
145 | "content": user_query,
146 | }
147 | messages.append(user_message)
148 | self.stored_messages.append(user_message)
149 |
150 | should_remove_user_query_message = False
151 | if query_for_retrieval_context is None:
152 | query_for_retrieval_context = ""
153 | for message in messages[::-1]:
154 | expected_slots = PromptManager.identify_slots(message["content"])
155 | if len(expected_slots) > 0:
156 | slot_fill_dict = self.retrieval_strategy.do_retrieval(
157 | expected_slots,
158 | query_for_retrieval_context,
159 | messages,
160 | )
161 | self.most_recent_slot_fill_dict = slot_fill_dict
162 | self.recent_slot_fill_dict.append(slot_fill_dict)
163 | assert len(slot_fill_dict) == len(
164 | expected_slots,
165 | ), "Unexpected fill provided."
166 | if "user_query" in slot_fill_dict and user_query is not None:
167 | # special case: fill user_query slots with the current user_query
168 | slot_fill_dict["user_query"] = user_query
169 | should_remove_user_query_message = True
170 | try:
171 | message["content"] = message["content"].format(**slot_fill_dict)
172 | except KeyError:
173 | raise KeyError(
174 | f"Failed to fill {expected_slots} with {slot_fill_dict}.",
175 | )
176 | else:
177 | self.recent_slot_fill_dict.append({})
178 | if query_for_retrieval_context == "" and message["role"] == "user":
179 | # use as retrieval context the most recent user message
180 | # TODO rethink this, providing a more flexible way to specify the retrieval context
181 | query_for_retrieval_context = message["content"]
182 | self.recent_slot_fill_dict = self.recent_slot_fill_dict[::-1]
183 | if should_remove_user_query_message:
184 | self.stored_messages.pop()
185 | assert messages[-1]["content"] == user_query
186 | messages = messages[:-1]
187 | return messages
188 |
189 | def compute_stored_token_counts(self) -> int:
190 | token_counts = embedding_utils.get_token_counts(
191 | [message["content"] for message in self.stored_messages],
192 | )
193 | total_token_count = sum(token_counts)
194 | return total_token_count
195 |
196 | def get_recent_slot_fill(self, slot_key: str) -> str | None:
197 | for slot_fill_dict in self.recent_slot_fill_dict:
198 | for key, value in slot_fill_dict.items():
199 | if key == slot_key:
200 | return value
201 | return None
202 |
203 | def identify_slots(prompt_string: str) -> list[str]:
204 | """Uses a regex to identify missing slots in a prompt_string.
205 |
206 | More advanced slot formatting is not supported.
207 |
208 | Args:
209 | prompt_string (str): The prompt itself, with format-style slots to fill e.g. "This is a prompt with a slot: {slot_to_fill}"
210 |
211 | Returns:
212 | list[str]: List of identified slots.
213 | """
214 | expected_slots = re.findall(r"{[^{} ]+}", prompt_string)
215 | return sorted({slot[1:-1] for slot in expected_slots})
216 |
--------------------------------------------------------------------------------
/src/rag/prompts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/src/rag/prompts/__init__.py
--------------------------------------------------------------------------------
/src/rag/prompts/hints.py:
--------------------------------------------------------------------------------
1 | intro_prompts = {
2 | "hint_sequence": {
3 | "pretty_name": "Hint sequence",
4 | "messages": [
5 | {
6 | "role": "system",
7 | "content": """You are an expert mathematics tutor who gives useful hints for middle-school students.
8 |
9 | The following paragraphs are examples of content that may or not be relevant in helping the student write a hint.
10 | {rori_microlesson_texts}
11 | {openstax_subsection_texts}""",
12 | },
13 | {
14 | "role": "user",
15 | "content": """I just received this math lesson:
16 | {lesson}
17 |
18 | Provide a hint for this math question:
19 | {question}
20 |
21 | I answered {incorrect_answer}, which is incorrect. Generate four hints for the correct answer {correct_answer} by following the steps below:
22 |
23 | FIRST: Tell me the goal of the problem.
24 |
25 | SECOND: Tell me what information I need to accomplish the goal.
26 |
27 | THIRD: Tell me what mathematical computation I need to do using the information I have to accomplish the goal.
28 |
29 | FOURTH: Tell me how to do the mathematical computation and show that it results in the correct answer "{correct_answer}".""",
30 | },
31 | ],
32 | },
33 | "slip_correction": {
34 | "pretty_name": "Slip correction",
35 | "messages": [
36 | {
37 | "role": "system",
38 | "content": """You are an expert mathematics tutor who gives useful hints for middle-school students.
39 |
40 | The following paragraphs are examples of content that may or may not be relevant in helping the student write a hint.
41 | {rori_microlesson_texts}
42 | {openstax_subsection_texts}""",
43 | },
44 | {
45 | "role": "user",
46 | "content": """I just received this math lesson:
47 | {lesson}
48 |
49 | Provide a hint for this math question:
50 | {question}
51 |
52 | The correct answer is {correct_answer}, but I answered {incorrect_answer}.
53 | I think I made a small slip-up. What did I do wrong?
54 | Your answer should be one sentence and start with "Remember to".""",
55 | },
56 | ],
57 | },
58 | "misconception": {
59 | "pretty_name": "Misconception-based hint",
60 | "messages": [
61 | {
62 | "role": "system",
63 | "content": """You are an expert mathematics tutor who gives useful hints for middle-school students.
64 |
65 | The following paragraphs are examples of content that may or may not be relevant in helping the student write a hint.
66 | {rori_microlesson_texts}
67 | {openstax_subsection_texts}""",
68 | },
69 | {
70 | "role": "user",
71 | "content": """I want you to give a hint for a student who just answered a maths question incorrectly. Explain that their incorrect answer might be due to a misconception.
72 |
73 | Here are some common misconceptions that lead middle-school math students to get an incorrect answer:
74 | {misconception_string}
75 |
76 | Relevant lesson: {lesson}
77 | Question: {question}
78 | Answer: {answer}
79 | Incorrect Answer: {incorrect_answer}
80 |
81 | Give a hint that identifies the possible misconception that led to the incorrect answer "{incorrect_answer}" rather than the correct answer "{answer}".""",
82 | },
83 | ],
84 | },
85 | "comparative_hint": {
86 | "pretty_name": "Comparative hint",
87 | "messages": [
88 | {
89 | "role": "system",
90 | "content": """You are an expert mathematics tutor who gives useful hints for middle-school students.
91 |
92 | The following paragraphs are examples of content that may or not be relevant in helping the student write a hint.
93 | {rori_microlesson_texts}
94 | {openstax_subsection_texts}""",
95 | },
96 | {
97 | "role": "user",
98 | "content": """Provide a hint for this math question:
99 | {question}
100 |
101 | I answered {incorrect_answer}. I know the correct answer is {correct_answer}, but I need a hint to understand why.
102 |
103 | Here's the relevant lesson for this problem:
104 | {lesson}
105 |
106 | FIRST: Repeat the worked example from the lesson.
107 |
108 | SECOND: Compare my question to the worked example, explaining how it is different.
109 |
110 | THIRD: Give me the steps to solve my question correctly, identifying the correct answer as {correct_answer} in the final step.""",
111 | },
112 | ],
113 | },
114 | }
115 |
116 | misconception_identification = {
117 | "creature_ai": {
118 | "pretty_name": "Nancy Otero's misconception identification prompt",
119 | "messages": [
120 | {
121 | "role": "user",
122 | "content": """I’ll give you a spreadsheet with a list of MaEs. Each MaE has an ID, an explanation of the MaE, and 4 examples of the MaE.
123 | Then I'll show you a student incorrect answer to a math question.
124 | I want you to tell me how many of the {mae_count} MaEs you can identify in the answers, identify as many as you can.
125 | Please process the answer and tell me:
126 | If the answer is correct
127 | If the answer is not correct: how many MaEs can you identify? Which ones and why?
128 | Spreadsheet:
129 | {mae_spreadsheet_string}
130 | """,
131 | },
132 | ],
133 | },
134 | }
135 |
--------------------------------------------------------------------------------
/src/rag/prompts/mathqa.py:
--------------------------------------------------------------------------------
1 | intro_prompts = {
2 | "general_math_qa_intro": {
3 | "pretty_name": "General middle-school math prompt",
4 | "messages": [
5 | {
6 | "role": "system",
7 | "content": """You are going to act as a mathematics tutor for a 13 year old student who is in grade 8 or 9.
8 | This student lives in Ghana or Nigeria.
9 | You will be encouraging and factual.
10 | {rori_microlesson_texts}
11 | {openstax_subsection_texts}
12 | Prefer simple, short responses.
13 | If the student says something inappropriate or off topic you will say you can only focus on mathematics and ask them if they have any math-related follow-up questions.
14 | """,
15 | },
16 | ],
17 | "retrieval_config": { # experimental configuration info, not yet implemented
18 | "rori_microlesson_texts": {
19 | "prefix": "Here is some lesson content that might be relevant:\n",
20 | },
21 | "openstax_subsection_texts": {
22 | "prefix": "Here are some excerpts from a math textbook. If they are relevant to the question, feel free to use language or examples from these excerpts:\n",
23 | },
24 | },
25 | },
26 | "retrieval_reliant_math_qa_intro": {
27 | "pretty_name": "Retrieval-reliant middle-school math prompt",
28 | "messages": [
29 | {
30 | "role": "system",
31 | "content": """You are going to act as a mathematics tutor for a 13 year old student who is in grade 8 or 9.
32 | This student lives in Ghana or Nigeria.
33 | You will be encouraging and factual.
34 |
35 | Use examples and language from the section below to format your response:
36 | ===
37 | {rori_microlesson_texts}
38 | {openstax_subsection_texts}
39 | ===
40 |
41 | Prefer simple, short responses.
42 | If the student says something inappropriate or off topic you will say you can only focus on mathematics and ask them if they have any math-related follow-up questions.
43 | """,
44 | },
45 | ],
46 | },
47 | "instruct_qa": {
48 | # inspired by the instruct-qa QAPromptTemplate: https://github.com/McGill-NLP/instruct-qa/blob/main/instruct_qa/prompt/templates.py
49 | "pretty_name": "McGill's (very general) instruct-qa prompt",
50 | "messages": [
51 | {
52 | "role": "user",
53 | "content": """Please answer the following question given the following passages:
54 | {rori_microlesson_texts}
55 | {openstax_subsection_texts}
56 | Question: {user_query}
57 | Answer: """,
58 | },
59 | ],
60 | },
61 | }
62 |
--------------------------------------------------------------------------------
/src/rag/retrieval.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import collections.abc
4 | from pathlib import Path
5 |
6 | import numpy as np
7 | import pandas as pd
8 | import scipy
9 |
10 | from rag import embedding_utils
11 |
12 |
13 | class RetrievalDb:
14 | """In-memory retrieval helper class.
15 |
16 | When creating new embeddings:
17 | ```python
18 | assert "text_col" in df.columns
19 | retrieval_db = RetrievalDb(Path("embedding_dir"), "db_name", "text_col", df)
20 | retrieval_db.create_embeddings()
21 | retrieval_db.save_df()
22 | ```
23 |
24 | When loading existing embeddings:
25 | ```python
26 | retrieval_db = RetrievalDb(Path("embedding_dir"), "db_name", "text_col")
27 | # `load()` is called during construction if df is not provided
28 | assert "text_col" in retrieval_db.df.columns
29 | ```
30 | """
31 |
32 | def __init__(
33 | self,
34 | embedding_dir: Path,
35 | db_name: str,
36 | embed_col: str,
37 | df: pd.DataFrame | None = None,
38 | n_tokens_col: str = "n_tokens",
39 | ):
40 | self.embedding_dir = embedding_dir
41 | self.db_name = db_name
42 |
43 | self.df_filepath = self.embedding_dir / f"{self.db_name}_df.parquet"
44 | self.embedding_filepath = self.embedding_dir / f"{self.db_name}_embed.npy"
45 |
46 | self.embed_col = embed_col
47 | if df is None:
48 | self.load()
49 | else:
50 | self.df = df
51 | self.normalize_strings()
52 | assert self.embed_col in self.df.columns
53 |
54 | self.n_tokens_col = n_tokens_col
55 | if n_tokens_col not in self.df.columns:
56 | self.compute_token_counts()
57 |
58 | def normalize_strings(self):
59 | self.df[self.embed_col] = self.df[self.embed_col].map(normalize_text)
60 |
61 | def compute_token_counts(self):
62 | token_counts = embedding_utils.get_token_counts(self.df[self.embed_col])
63 | self.df[self.n_tokens_col] = token_counts
64 |
65 | def create_embeddings(self):
66 | embedding_list = embedding_utils.batch_embed_texts(
67 | self.df[self.embed_col],
68 | self.df[self.n_tokens_col],
69 | )
70 | self.embedding_mat = np.concatenate(
71 | [e.reshape(1, -1) for e in embedding_list],
72 | axis=0,
73 | )
74 | np.save(self.embedding_filepath, self.embedding_mat)
75 |
76 | def save_df(self):
77 | self.df.to_parquet(self.df_filepath)
78 |
79 | def load(self):
80 | if not self.df_filepath.exists():
81 | raise ValueError(
82 | f"Trying to load a dataframe from non-existent path: {self.df_filepath}",
83 | )
84 | self.df = pd.read_parquet(self.df_filepath)
85 | self.embedding_mat = np.load(self.embedding_filepath)
86 |
87 | def compute_embedding_distances(self, query_embedding: np.array) -> np.array:
88 | if query_embedding.shape[0] != 1:
89 | query_embedding = query_embedding.reshape(1, -1)
90 | distances = scipy.spatial.distance.cdist(
91 | query_embedding,
92 | self.embedding_mat,
93 | metric="cosine",
94 | )[0]
95 | return distances
96 |
97 | def compute_string_distances(self, query_str: str) -> np.array:
98 | embedding_list = embedding_utils.get_openai_embeddings(
99 | [normalize_text(query_str)],
100 | )
101 | query_embedding = embedding_list[0]
102 | return self.compute_embedding_distances(query_embedding)
103 |
104 | def iterate_query_embeddings(
105 | self,
106 | query_embedding_list: collections.abc.Iterable[np.array],
107 | ) -> collections.abc.Generator[np.array]:
108 | for query_embedding in query_embedding_list:
109 | yield self.compute_embedding_distances(query_embedding)
110 |
111 | def get_top_df(self, distances: np.array, k: int = 5) -> pd.DataFrame:
112 | sort_inds = np.argsort(distances)
113 | top_k_indices = sort_inds[:k]
114 | top_k_scores = distances[top_k_indices]
115 | assert top_k_indices.shape == top_k_scores.shape
116 | return self.df.iloc[top_k_indices]
117 |
118 |
119 | def get_distance_sort_indices(distances: np.array) -> np.array:
120 | return np.argsort(distances)
121 |
122 |
123 | def normalize_text(text: str) -> str:
124 | return text.replace("\n", " ").strip()
125 |
126 |
127 | class DbInfo:
128 | """Wrapper class with info about how retrieved texts should be incorporated in a prompt.
129 |
130 | See `prompt_utils.PromptManager`.
131 | """
132 |
133 | def __init__(
134 | self,
135 | db: RetrievalDb,
136 | max_tokens: int = 1000,
137 | max_texts: int = 1000,
138 | prefix: str = "",
139 | suffix: str = "",
140 | join_string: str = "\n",
141 | use_parent_text: bool = False,
142 | parent_group_cols: list[str] = [],
143 | parent_sort_cols: list[str] = [],
144 | ):
145 | self.db = db
146 | self.max_tokens = max_tokens
147 | self.max_texts = max_texts
148 | self.prefix = prefix
149 | self.suffix = suffix
150 | self.join_string = join_string
151 |
152 | # configure parent retrieval
153 | self.use_parent_text = use_parent_text
154 | self.parent_join_string = "\n"
155 | self.parent_group_cols = parent_group_cols
156 | self.parent_sort_cols = parent_sort_cols
157 |
158 | def copy(self, **kwargs) -> DbInfo:
159 | """Create a copy of this DbInfo, overriding the keyword args with new values if provided.
160 |
161 | Returns:
162 | DbInfo: Newly instantiated copy.
163 | """
164 | for expected_key in ["max_tokens", "prefix", "suffix"]:
165 | if expected_key not in kwargs:
166 | kwargs[expected_key] = getattr(self, expected_key)
167 | return DbInfo(self.db, **kwargs)
168 |
169 | def get_fill_string_from_distances(self, distances: np.array) -> str:
170 | """Given distances to the texts within the RetrievalDb, create an appropriate fill string.
171 |
172 | Args:
173 | distances (np.array): Distances, where closer texts in the RetrievalDb are more relevant.
174 |
175 | Returns:
176 | str: The string to include in the prompt.
177 | """
178 | sort_inds = get_distance_sort_indices(distances)
179 | used_inds = set()
180 | texts = []
181 | total_tokens = 0
182 | for ind in sort_inds:
183 | if ind in used_inds:
184 | continue
185 | if self.use_parent_text:
186 | token_budget = self.max_tokens - total_tokens
187 | text, n_tokens, new_used_inds = self.get_parent_text(ind, token_budget)
188 | used_inds.update(new_used_inds)
189 | else:
190 | text, n_tokens = self.get_single_text(ind)
191 | used_inds.add(ind)
192 | if total_tokens + n_tokens > self.max_tokens:
193 | break
194 | total_tokens += n_tokens
195 | texts.append(text)
196 | if len(texts) >= self.max_texts:
197 | break
198 | fill_string = self.prefix + self.join_string.join(texts) + self.suffix
199 | return fill_string
200 |
201 | def get_single_text(self, ind: int):
202 | """Given a index, return the text and corresponding number of tokens from the RetrievalDb.
203 |
204 | Args:
205 | ind (int): _description_
206 | """
207 | row = self.db.df.iloc[ind]
208 | text = row[self.db.embed_col]
209 | n_tokens = row[self.db.n_tokens_col]
210 | return text, n_tokens
211 |
212 | def get_parent_text(self, ind: int, token_budget: int):
213 | """
214 | Intuition of "parent document" retriever is to retrieve for inclusion in a prompt the "parent" document,
215 | similar to including docs on either "side" as additional context.
216 |
217 | Read more: https://python.langchain.com/docs/modules/data_connection/retrievers/parent_document_retriever
218 |
219 | Args:
220 | ind (int): Most semantically relevant index to retrieve parents of.
221 | """
222 | df = self.db.df
223 | row = df.iloc[ind]
224 | and_cond = np.ones(len(df), dtype=bool)
225 | for col in self.parent_group_cols:
226 | cond = df[col] == row[col]
227 | and_cond = np.logical_and(and_cond, cond)
228 | parent = df[and_cond]
229 | if self.parent_sort_cols is not None and len(self.parent_sort_cols) > 1:
230 | parent = parent.sort_values(by=self.parent_sort_cols)
231 | # include a variable amount of context based on the given token_budget
232 | # preference ranking implemented here:
233 | # - all docs
234 | # - up to token_budget docs from target_ind - 0
235 | total_tokens = parent[self.db.n_tokens_col].sum()
236 | if row[self.db.n_tokens_col] > token_budget:
237 | # simple case: NOTHING will fit in the token budget!
238 | return None
239 | elif total_tokens <= token_budget:
240 | # simple case: if all tokens in budget, no extra work
241 | rows = parent
242 | else:
243 | target_ind_val = df.index[ind]
244 | before = parent.loc[:target_ind_val]
245 | # see: https://stackoverflow.com/a/37872823
246 | cumulative_token_counts = before.loc[::-1, self.db.n_tokens_col].cumsum()[::-1]
247 | rows = before[cumulative_token_counts <= token_budget]
248 | assert len(rows) >= 1
249 | new_used_inds = {df.index.get_loc(new_ind) for new_ind in rows.index}
250 | assert ind in new_used_inds
251 | texts = rows[self.db.embed_col]
252 | n_tokens_rows = rows[self.db.n_tokens_col]
253 | text = self.parent_join_string.join(texts)
254 | # note this will underestimate the true number of tokens, due to whatever parent_join_string is
255 | n_tokens = n_tokens_rows.sum()
256 | return text, n_tokens, new_used_inds
257 |
--------------------------------------------------------------------------------
/src/rag/retrieval_strategies.py:
--------------------------------------------------------------------------------
1 | from rag import retrieval
2 |
3 |
4 | class RetrievalStrategy:
5 | """General retrieval strategy interface."""
6 |
7 | def do_retrieval(
8 | self,
9 | expected_slots: list[str],
10 | user_query: str,
11 | previous_messages: list[dict[str, str]] = [],
12 | ):
13 | raise ValueError("Not implemented.")
14 |
15 |
16 | class NoRetrievalStrategy(RetrievalStrategy):
17 | """Fill all expected_slots with the empty string."""
18 |
19 | def do_retrieval(
20 | self,
21 | expected_slots: list[str],
22 | user_query: str,
23 | previous_messages: list[dict[str, str]] = [],
24 | ):
25 | return {expected_slot: "" for expected_slot in expected_slots}
26 |
27 |
28 | class StaticRetrievalStrategy(RetrievalStrategy):
29 | """Fill all expected_slots with a static string."""
30 |
31 | def __init__(self, fill_string: str) -> None:
32 | super().__init__()
33 | self.fill_string = fill_string
34 |
35 | def do_retrieval(
36 | self,
37 | expected_slots: list[str],
38 | user_query: str,
39 | previous_messages: list[dict[str, str]] = [],
40 | ):
41 | return {expected_slot: self.fill_string for expected_slot in expected_slots}
42 |
43 |
44 | class EmbeddingRetrievalStrategy(RetrievalStrategy):
45 | """Fill all expected_slots with up to max_token texts from the retrieval_db.
46 | Uses the user_query to produce an embedding and compute RetrievalDb distances."""
47 |
48 | def __init__(self, db: retrieval.RetrievalDb, max_tokens: int = 2000) -> None:
49 | super().__init__()
50 | self.db: retrieval.RetrievalDb = db
51 | self.max_tokens: int = max_tokens
52 |
53 | def do_retrieval(
54 | self,
55 | expected_slots: list[str],
56 | user_query: str,
57 | previous_messages: list[dict[str, str]] = [],
58 | ):
59 | distances = self.db.compute_string_distances(user_query)
60 | sort_inds = retrieval.get_distance_sort_indices(distances)
61 | texts = []
62 | total_tokens = 0
63 | for ind in sort_inds:
64 | row = self.db.df.iloc[ind]
65 | n_tokens = row[self.db.n_tokens_col]
66 | if total_tokens + n_tokens > self.max_tokens:
67 | break
68 | total_tokens += n_tokens
69 | text = row[self.db.embed_col]
70 | texts.append(text)
71 | fill_string = "\n".join(texts)
72 | return {expected_slot: fill_string for expected_slot in expected_slots}
73 |
74 |
75 | class MappedEmbeddingRetrievalStrategy(RetrievalStrategy):
76 | """Fill all expected_slots based on the entries in slot_map.
77 | If asked to fill a slot not in slot_map, will use nonmatching_fill instead.
78 | slot_map can have either static strings or `retrieval.DbInfo`."""
79 |
80 | def __init__(
81 | self,
82 | slot_map: dict[str, str | retrieval.DbInfo],
83 | nonmatching_fill: str = "",
84 | ) -> None:
85 | super().__init__()
86 | self.slot_map = slot_map
87 | self.nonmatching_fill = nonmatching_fill
88 | self._validate_slot_map()
89 |
90 | def _validate_slot_map(self):
91 | for key, value in self.slot_map.items():
92 | if type(key) is not str:
93 | raise ValueError("Slot map keys must be strings.")
94 | if type(value) not in [str, retrieval.DbInfo]:
95 | raise ValueError("Unexpected type in slot map.")
96 | if type(value) is retrieval.DbInfo:
97 | if not hasattr(value.db, "compute_string_distances"):
98 | raise ValueError(
99 | "Expected a db with a compute_string_distances() method.",
100 | )
101 |
102 | def update_map(self, slot_updates: dict):
103 | self.slot_map.update(slot_updates)
104 | self._validate_slot_map()
105 |
106 | def do_retrieval(
107 | self,
108 | expected_slots: list[str],
109 | user_query: str,
110 | previous_messages: list[dict[str, str]] = [],
111 | ):
112 | fill_string_map = {}
113 | for expected_slot in expected_slots:
114 | if expected_slot in self.slot_map:
115 | db_info = self.slot_map[expected_slot]
116 | if type(db_info) is str:
117 | fill_string = db_info
118 | else:
119 | distances = db_info.db.compute_string_distances(user_query)
120 | fill_string = db_info.get_fill_string_from_distances(distances)
121 | else:
122 | fill_string = self.nonmatching_fill
123 | fill_string_map[expected_slot] = fill_string
124 | return fill_string_map
125 |
--------------------------------------------------------------------------------
/src/streamlit/Annotation_tool.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 |
3 | st.markdown(
4 | """ # Annotation tool
5 |
6 | In-progress annotation tool.
7 | """,
8 | )
9 |
--------------------------------------------------------------------------------
/src/streamlit/pages/Relevance.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 |
3 | st.markdown(
4 | """ # Relevance annotation
5 |
6 | In-progress page to annotate relevance.
7 | """,
8 | )
9 |
--------------------------------------------------------------------------------
/tests/unit/conftest.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import pytest
4 |
5 | from rag import embedding_utils, retrieval
6 |
7 |
8 | def mock_get_openai_embeddings(input_text_list, *args, **kwargs):
9 | return [np.random.random(size=embedding_utils.EMBEDDING_DIM) for _ in input_text_list]
10 |
11 |
12 | @pytest.fixture
13 | def patch_get_openai_embeddings(monkeypatch):
14 | monkeypatch.setattr(
15 | "rag.embedding_utils.get_openai_embeddings",
16 | mock_get_openai_embeddings,
17 | )
18 |
19 |
20 | @pytest.fixture
21 | def retrieval_db_path(tmp_path, patch_get_openai_embeddings):
22 | # Creates a retrieval database that can be used by multiple tests
23 | df = pd.DataFrame(
24 | [
25 | {
26 | "categorical_var": "A",
27 | "group_var": 1,
28 | "text": "Test text 1.",
29 | },
30 | {
31 | "categorical_var": "B",
32 | "group_var": 1,
33 | "text": "Test text 2.",
34 | },
35 | {
36 | "categorical_var": "C",
37 | "group_var": 2,
38 | "text": "Test text 3.",
39 | },
40 | ],
41 | )
42 | db = retrieval.RetrievalDb(tmp_path, "conftestDb", "text", df)
43 | db.create_embeddings()
44 | db.save_df()
45 | return tmp_path
46 |
47 |
48 | @pytest.fixture
49 | def retrieval_db(retrieval_db_path) -> retrieval.RetrievalDb:
50 | db = retrieval.RetrievalDb(retrieval_db_path, "conftestDb", "text")
51 | return db
52 |
--------------------------------------------------------------------------------
/tests/unit/rag/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 |
4 | @pytest.fixture(autouse=True)
5 | def no_openai(monkeypatch):
6 | """Remove Embedding and ChatCompletion creations during all RAG unit tests.
7 |
8 | It's not actually clear to me that this works."""
9 | monkeypatch.delattr("openai.Embedding.create")
10 | monkeypatch.delattr("openai.ChatCompletion.create")
11 |
--------------------------------------------------------------------------------
/tests/unit/rag/test_embedding_utils.py:
--------------------------------------------------------------------------------
1 | from rag import embedding_utils
2 |
3 |
4 | def test_get_token_counts():
5 | n_tokens_list = embedding_utils.get_token_counts(["test", "test "])
6 | assert n_tokens_list == [1, 2]
7 |
8 |
9 | def test_get_openai_embeddings(patch_get_openai_embeddings):
10 | result = embedding_utils.get_openai_embeddings(["test"])
11 | assert len(result) == 1
12 | assert result[0].shape[0] == embedding_utils.EMBEDDING_DIM
13 |
14 |
15 | def test_batch_embed_texts(patch_get_openai_embeddings):
16 | max_tokens = embedding_utils.MAX_TOKENS_PER_REQUEST
17 | input_text_list = ["test"] * (max_tokens + 1)
18 | embedding_list = embedding_utils.batch_embed_texts(
19 | input_text_list,
20 | embedding_utils.get_token_counts(input_text_list),
21 | )
22 | assert len(embedding_list) == len(input_text_list)
23 | assert all(emb.shape[0] == embedding_utils.EMBEDDING_DIM for emb in embedding_list)
24 |
--------------------------------------------------------------------------------
/tests/unit/rag/test_gpf_utils.py:
--------------------------------------------------------------------------------
1 | from rag import gpf_utils
2 |
3 |
4 | def test_get_gpd_codes():
5 | grade, domain, construct, subconstruct, skill, index = gpf_utils.get_gpd_codes(
6 | "G9.N5.1.3.1",
7 | )
8 | assert grade == 9
9 | assert domain == "N"
10 | assert construct == "N5"
11 | assert subconstruct == "N5.1"
12 | assert skill == "N5.1.3"
13 | assert index == 1
14 |
--------------------------------------------------------------------------------
/tests/unit/rag/test_logit_bias.py:
--------------------------------------------------------------------------------
1 | from rag import logit_bias
2 |
3 | THE_TOKEN = 1820 # token for string "the"
4 |
5 |
6 | def test_get_tokenizer():
7 | tokenizer = logit_bias.get_tokenizer()
8 | assert tokenizer.encode("the") == [1820]
9 | assert tokenizer.decode([THE_TOKEN]) == "the"
10 |
11 |
12 | def test_load_stopwords():
13 | stopword_tokens = logit_bias.load_stopword_tokens()
14 | assert len(stopword_tokens) >= 11000
15 | assert THE_TOKEN in stopword_tokens, "The 'the' token (1820) should be in the stopwords."
16 | assert stopword_tokens == logit_bias.get_stopword_tokens()
17 | # test caching
18 | assert logit_bias.get_stopword_tokens() == logit_bias.get_stopword_tokens()
19 |
20 |
21 | def test_get_nonstopword_tokens():
22 | # stopwords
23 | tokens = logit_bias.get_nonstopword_tokens("the was and were thus")
24 | assert len(tokens) == 0
25 | # non-stopwords
26 | tokens = logit_bias.get_nonstopword_tokens("verily osmogorp")
27 | assert len(tokens) >= 5
28 |
29 |
30 | def test_get_logit_bias():
31 | tokens = logit_bias.get_nonstopword_tokens("verily verily verily")
32 | assert len(tokens) == 6
33 | logit_bias_dict = logit_bias.get_logit_bias(tokens, min_count=4)
34 | assert len(logit_bias_dict) == 0
35 | logit_bias_dict = logit_bias.get_logit_bias(tokens, min_count=3)
36 | assert len(logit_bias_dict) == 1
37 | assert any([token in logit_bias_dict for token in tokens])
38 |
39 |
40 | def test_get_logit_bias_from_slot():
41 | recent_slot_fill_dict = [
42 | {
43 | "slot1": "verily verily verily",
44 | "slot2": "the",
45 | },
46 | ]
47 | logit_bias_dict = logit_bias.get_logit_bias_from_slot(recent_slot_fill_dict)
48 | assert 1570 in logit_bias_dict
49 | # TODO make this test cover more cases and be more explanatory
50 |
51 |
52 | def test_create_stopword_token_set_from_word_list():
53 | word_list = ["the"]
54 | stopword_tokens = logit_bias.create_stopword_token_set_from_word_list(word_list)
55 | assert len(stopword_tokens) > 1
56 | assert THE_TOKEN in stopword_tokens
57 |
--------------------------------------------------------------------------------
/tests/unit/rag/test_misconceptions.py:
--------------------------------------------------------------------------------
1 | from rag import misconceptions
2 |
3 |
4 | def test_load_misconception_list():
5 | misconception_list = misconceptions.load_misconception_list()
6 | assert len(misconception_list) >= 50
7 | # assert misconception_list == misconceptions.get_misconception_list() # I'm not sure why this inequality is false
8 | # test caching
9 | assert misconceptions.get_misconception_list() == misconceptions.get_misconception_list()
10 |
11 |
12 | def test_get_misconceptions_string():
13 | misconception_string = misconceptions.get_misconceptions_string()
14 | assert len(misconception_string) >= 1000
15 |
--------------------------------------------------------------------------------
/tests/unit/rag/test_prompt_utils.py:
--------------------------------------------------------------------------------
1 | from rag import prompt_utils, retrieval_strategies
2 |
3 |
4 | def test_conversion():
5 | test_string = """SYSTEM:
6 | System prompt.
7 | Multiple lines.
8 | USER:
9 | User query."""
10 | messages = prompt_utils.PromptSelector.convert_string_to_conversation(test_string)
11 | assert len(messages) == 2
12 | assert messages[0]["role"] == "system"
13 | assert messages[0]["content"] == "System prompt.\nMultiple lines.", messages[0]
14 | assert messages[1]["role"] == "user"
15 | assert messages[1]["content"] == "User query.", messages[1]
16 |
17 |
18 | def test_PromptSelector():
19 | test_prompts = {
20 | "test_prompt_1": {
21 | "pretty_name": "Prompt 1",
22 | "messages": [
23 | {
24 | "role": "system",
25 | "content": "System prompt 1.",
26 | },
27 | ],
28 | },
29 | "test_prompt_2": {
30 | "pretty_name": "Prompt 2",
31 | "messages": [
32 | {
33 | "role": "system",
34 | "content": "System prompt 2.",
35 | },
36 | ],
37 | },
38 | }
39 | pm = prompt_utils.PromptSelector(test_prompts)
40 | assert pm.get_intro_prompt_pretty_names() == ["Prompt 1", "Prompt 2"]
41 | message_lists = pm.get_intro_prompt_message_lists()
42 | assert message_lists[0] == test_prompts["test_prompt_1"]["messages"]
43 |
44 | assert pm.get_default_intro_prompt()["pretty_name"] == "Prompt 1"
45 |
46 |
47 | def test_PromptManager():
48 | pm = prompt_utils.PromptManager()
49 |
50 | # test basic query
51 | messages = pm.build_query("Test")
52 | assert len(messages) == 1
53 | assert messages[0]["content"] == "Test"
54 | assert len(pm.stored_messages) == 1
55 | assert pm.stored_messages[0]["content"] == "Test"
56 | pm.clear_stored_messages()
57 |
58 | # test conversation start with system message
59 | test_intro_messages = [
60 | {
61 | "role": "system",
62 | "content": "System",
63 | },
64 | ]
65 | messages = pm.set_intro_messages(test_intro_messages).build_query("User")
66 | assert len(messages) == 2
67 | assert messages[0]["role"] == "system", messages
68 | assert messages[0]["content"] == "System"
69 | assert messages[1]["role"] == "user"
70 | assert messages[1]["content"] == "User"
71 | pm.clear_stored_messages()
72 |
73 | # test conversation continuation
74 | previous_messages = messages
75 | previous_messages.append(
76 | {
77 | "role": "assistant",
78 | "content": "Assistant",
79 | },
80 | )
81 | messages = pm.set_intro_messages(test_intro_messages).build_query(
82 | "User2",
83 | previous_messages=previous_messages,
84 | )
85 | assert len(messages) == 4
86 | assert messages[2]["content"] == "Assistant"
87 | assert messages[3]["content"] == "User2"
88 |
89 |
90 | def test_PromptManager_retrieval():
91 | test_intro_messages = [
92 | {
93 | "role": "system",
94 | "content": "Test {slot1} {slot2}",
95 | },
96 | ]
97 | retrieval_strategy = retrieval_strategies.StaticRetrievalStrategy("Fill")
98 | pm = prompt_utils.PromptManager().set_intro_messages(test_intro_messages).set_retrieval_strategy(retrieval_strategy)
99 | messages = pm.build_query("User")
100 | assert len(messages) == 2
101 | assert messages[0]["content"] == "Test Fill Fill", messages[0]
102 | assert pm.intro_messages[0]["content"] == "Test {slot1} {slot2}"
103 |
104 |
105 | def test_identify_slots():
106 | slots = prompt_utils.PromptManager.identify_slots("test {test1} {test2} {test3 }")
107 | assert slots == ["test1", "test2"]
108 |
109 | slots = prompt_utils.PromptManager.identify_slots("test {test1} {test1}")
110 | assert slots == ["test1"]
111 |
112 |
113 | def test_user_query_replacement():
114 | test_intro_messages = [
115 | {
116 | "role": "user",
117 | "content": "Question: {user_query}",
118 | },
119 | ]
120 | messages = prompt_utils.PromptManager().set_intro_messages(test_intro_messages).build_query("Test")
121 | assert len(messages) == 1
122 | assert messages[0]["content"] == "Question: Test"
123 |
124 | # this behavior requires exactly the slot name "user_query"
125 | # compare to the below behavior, which produces 2 messages
126 | test_intro_messages = [
127 | {
128 | "role": "user",
129 | "content": "Question: {other_slot}",
130 | },
131 | ]
132 | messages = prompt_utils.PromptManager().set_intro_messages(test_intro_messages).build_query("Test")
133 | assert len(messages) == 2
134 | assert messages[0]["content"] == "Question: "
135 | assert messages[1]["content"] == "Test"
136 |
--------------------------------------------------------------------------------
/tests/unit/rag/test_retrieval.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import pytest
6 |
7 | from rag import embedding_utils, retrieval
8 |
9 |
10 | def test_RetrievalDb(tmp_path, patch_get_openai_embeddings):
11 | df = pd.DataFrame(
12 | [
13 | {
14 | "categorical_var": "A",
15 | "text": "Test text.",
16 | },
17 | {
18 | "categorical_var": "B",
19 | "text": "Test\n\ntext.",
20 | },
21 | ],
22 | )
23 | db = retrieval.RetrievalDb(tmp_path, "testDb", "text", df)
24 | assert not db.embedding_filepath.exists()
25 | assert not db.df_filepath.exists()
26 | db.create_embeddings()
27 | assert db.embedding_filepath.exists()
28 | assert db.embedding_mat.shape == (len(df), embedding_utils.EMBEDDING_DIM)
29 | db.save_df()
30 | assert db.df_filepath.exists()
31 |
32 | db = retrieval.RetrievalDb(tmp_path, "testDb", "text")
33 | assert len(db.df) == len(df)
34 |
35 | distances = db.compute_string_distances("Test query.")
36 | assert len(distances) == len(df)
37 | top_df = db.get_top_df(distances, k=1)
38 | assert len(top_df) == 1
39 |
40 | # test non-existent db loading
41 | with pytest.raises(ValueError):
42 | retrieval.RetrievalDb(tmp_path, "testDb2", "text")
43 |
44 |
45 | def test_DbInfo(retrieval_db):
46 | db_info = retrieval.DbInfo(retrieval_db, max_tokens=1)
47 | assert db_info.max_tokens == 1
48 | db_info2 = db_info.copy()
49 | assert db_info2.max_tokens == 1
50 | db_info3 = db_info.copy(max_tokens=2)
51 | assert db_info3.max_tokens == 2
52 |
53 |
54 | def test_DbInfo_get_fill_string_from_distances(retrieval_db):
55 | test_prefix = "test_prefix"
56 | test_suffix = "test_suffix"
57 | db_info = retrieval.DbInfo(
58 | retrieval_db,
59 | prefix=test_prefix,
60 | suffix=test_suffix,
61 | max_texts=1,
62 | )
63 | distances = np.array([0] + [1] * (len(db_info.db.df) - 1))
64 | assert len(distances) == len(db_info.db.df)
65 | fill_string = db_info.get_fill_string_from_distances(distances)
66 | assert fill_string.startswith(test_prefix)
67 | assert fill_string.endswith(test_suffix)
68 | assert db_info.db.df[db_info.db.embed_col].iloc[0] in fill_string
69 | assert all(
70 | db_info.db.df[db_info.db.embed_col].iloc[1:].map(lambda t: t not in fill_string),
71 | )
72 |
73 |
74 | def test_DbInfo_get_single_text(retrieval_db):
75 | db_info = retrieval.DbInfo(retrieval_db)
76 | text, n_tokens = db_info.get_single_text(0)
77 | assert text == db_info.db.df[db_info.db.embed_col].iloc[0]
78 | assert n_tokens > 0
79 |
80 |
81 | def test_DbInfo_get_parent_text(retrieval_db):
82 | db_info = retrieval.DbInfo(
83 | retrieval_db,
84 | max_texts=1,
85 | use_parent_text=True,
86 | parent_group_cols=["group_var"],
87 | parent_sort_cols=["categorical_var"],
88 | )
89 | # first, verify that with a budget of 0 no texts are returned
90 | assert db_info.get_parent_text(0, 0) is None
91 |
92 | # second, verify normal case
93 | text, n_tokens, used_inds = db_info.get_parent_text(0, 1000)
94 | assert used_inds == {0, 1}
95 | expected_parent_rows = db_info.db.df[db_info.db.df["group_var"] == 1]
96 | expected_text = "\n".join(expected_parent_rows[db_info.db.embed_col])
97 | assert text == expected_text, f"Parent text was {text}"
98 | assert n_tokens == expected_parent_rows[db_info.db.n_tokens_col].sum()
99 |
100 | # here, we retrieve the second entry in a parent text but we still get back the first and second together
101 | distances = np.array([1, 0, 1])
102 | fill_string = db_info.get_fill_string_from_distances(distances)
103 | assert fill_string == expected_text
104 |
105 | # check for non-duplicate retrieval
106 | db_info.max_texts = 2
107 | distances = np.array([0, 0, 1])
108 | fill_string = db_info.get_fill_string_from_distances(distances)
109 | assert len(re.findall(expected_text, fill_string)) == 1, f"Duplicate retrieval in {fill_string}"
110 |
111 | # third, we verify token budget backoff behavior
112 | # here, we give only enough budget to get the target text
113 | token_budget = db_info.db.df[db_info.db.n_tokens_col].iloc[0]
114 | text, n_tokens, used_inds = db_info.get_parent_text(0, token_budget)
115 | assert text == db_info.db.df[db_info.db.embed_col].iloc[0]
116 | assert n_tokens == token_budget
117 | assert used_inds == {0}
118 |
119 | # here, we give enough budget for two texts
120 | token_budget = db_info.db.df[db_info.db.n_tokens_col].iloc[0:1].sum()
121 | text, n_tokens, used_inds = db_info.get_parent_text(0, token_budget)
122 | assert len(used_inds - {0, 1}) == 0
123 |
--------------------------------------------------------------------------------
/tests/unit/rag/test_retrieval_strategies.py:
--------------------------------------------------------------------------------
1 | from rag import retrieval, retrieval_strategies
2 |
3 |
4 | def test_NoRetrievalStrategy():
5 | retriever = retrieval_strategies.NoRetrievalStrategy()
6 | assert retriever.do_retrieval(["test"], "user") == {"test": ""}
7 |
8 |
9 | def test_StaticRetrievalStrategy():
10 | retriever = retrieval_strategies.StaticRetrievalStrategy("TestVal")
11 | assert retriever.do_retrieval(["test"], "user") == {"test": "TestVal"}
12 |
13 |
14 | def test_EmbeddingRetrievalStrategy(retrieval_db_path):
15 | db = retrieval.RetrievalDb(retrieval_db_path, "conftestDb", "text")
16 | retriever = retrieval_strategies.EmbeddingRetrievalStrategy(db, max_tokens=5)
17 | filled_slots = retriever.do_retrieval(["testSlot"], "testQuery")
18 | assert "testSlot" in filled_slots
19 | # note that the actual text retrieved here is random due to the way the retrieval embeddings are generated
20 | assert filled_slots["testSlot"].startswith("Test text"), "No retrieved text."
21 |
22 | # test max_tokens
23 | retriever = retrieval_strategies.EmbeddingRetrievalStrategy(db, max_tokens=0)
24 | filled_slots = retriever.do_retrieval(["testSlot"], "testQuery")
25 | assert filled_slots["testSlot"] == ""
26 |
27 |
28 | def test_MappedEmbeddingRetrievalStrategy(retrieval_db):
29 | slot_map = {
30 | "slot1": "fill1",
31 | "slot2": "fill2",
32 | }
33 | retriever = retrieval_strategies.MappedEmbeddingRetrievalStrategy(
34 | slot_map,
35 | nonmatching_fill="nomatch",
36 | )
37 | filled_slots = retriever.do_retrieval(["slot1", "slot2"], "")
38 | assert slot_map == filled_slots
39 | filled_slots = retriever.do_retrieval(["slot1", "slot2", "slot3"], "")
40 | assert filled_slots["slot3"] == "nomatch"
41 | for slot, slot_fill in slot_map.items():
42 | assert filled_slots[slot] == slot_fill
43 |
44 | slot_map = {
45 | "slot1": "fill1",
46 | "slot2": retrieval.DbInfo(retrieval_db),
47 | }
48 | retriever = retrieval_strategies.MappedEmbeddingRetrievalStrategy(slot_map)
49 | filled_slots = retriever.do_retrieval(["slot1", "slot2"], "")
50 | assert filled_slots["slot1"] == "fill1"
51 | assert filled_slots["slot2"].startswith("Test text")
52 |
53 | retriever.update_map({"slot2": "fill2"})
54 | filled_slots = retriever.do_retrieval(["slot1", "slot2"], "")
55 | assert filled_slots["slot2"] == "fill2"
56 |
--------------------------------------------------------------------------------
/tests/unit/test_auth.py:
--------------------------------------------------------------------------------
1 | from experiment import auth
2 |
3 |
4 | def test_passwd_check():
5 | passphrase = "test"
6 | hashed_passphrase = auth.passwd_hash(passphrase)
7 | assert auth.passwd_check(hashed_passphrase, passphrase)
8 |
--------------------------------------------------------------------------------
/tests/unit/test_completion_utils.py:
--------------------------------------------------------------------------------
1 | from experiment import completion_utils
2 |
3 | COMPLETION_ERROR_STRING = "completion_utils completion error"
4 |
5 |
6 | def mock_get_completion_error(*args, **kwargs):
7 | raise ValueError(COMPLETION_ERROR_STRING)
8 |
9 |
10 | def mock_get_completion_noerror(*args, **kwargs):
11 | return "Test completion"
12 |
13 |
14 | def test_get_completion_noraise(monkeypatch, caplog):
15 | # verify behavior when errors occur
16 | monkeypatch.setattr(
17 | "experiment.completion_utils.get_completion",
18 | mock_get_completion_error,
19 | )
20 | assert completion_utils.get_completion_noraise([], sleep=0, sleep_time_between_attempts=0, max_attempts=2) is None
21 | for record in caplog.records:
22 | assert record.name == completion_utils.logger.name
23 | assert COMPLETION_ERROR_STRING in record.message
24 | assert (
25 | len(caplog.records) == 3
26 | ), "Expected max_attempts error messages from attempts + 1 caught by get_completion_noraise."
27 |
28 | # verify no-error behavior
29 | monkeypatch.setattr(
30 | "experiment.completion_utils.get_completion",
31 | mock_get_completion_noerror,
32 | )
33 | assert (
34 | completion_utils.get_completion_noraise([], sleep=0, sleep_time_between_attempts=0, max_attempts=2)
35 | == "Test completion"
36 | )
37 |
--------------------------------------------------------------------------------
/tests/unit/test_generate.py:
--------------------------------------------------------------------------------
1 | import string
2 |
3 | import numpy as np
4 | import pytest
5 |
6 | from experiment import generate
7 |
8 |
9 | def mock_get_completion(*args, **kwargs):
10 | choices = list(string.ascii_lowercase)
11 | return " ".join(np.random.choice(choices, size=50))
12 |
13 |
14 | def mock_get_completion_error(*args, **kwargs):
15 | raise ValueError(f"Args: {args}")
16 |
17 |
18 | @pytest.fixture
19 | def patch_get_completion(monkeypatch):
20 | monkeypatch.setattr(
21 | "experiment.completion_utils.get_completion",
22 | mock_get_completion,
23 | )
24 |
25 |
26 | def test_GenerationCorpus(tmp_path, patch_get_completion):
27 | corpus = generate.GenerationCorpus(tmp_path, "test_corpus1")
28 | assert len(corpus.generations) == 0
29 | messages = [
30 | {
31 | "role": "user",
32 | "content": "Test query",
33 | },
34 | ]
35 | metadata = {"test_key": "test_value"}
36 | assert corpus.generate(messages, metadata, sleep=0)
37 | assert not corpus.generate(messages, metadata, sleep=0)
38 | assert len(corpus.generations) == 1
39 | assert corpus.generations[0]["test_key"] == "test_value"
40 | assert corpus.generations[0]["generation"] is not None
41 | metadata = metadata.copy() # NOTE: this will fail without a copy()!
42 | metadata["test_key"] = "test_value2"
43 | assert not corpus.is_already_generated(messages, metadata)
44 | assert corpus.generate(messages, metadata, sleep=0)
45 | assert len(corpus.generations) == 2
46 | assert corpus.generations[-1]["test_key"] == "test_value2"
47 |
48 | corpus2 = generate.GenerationCorpus(tmp_path, "test_corpus1")
49 | assert len(corpus2.generations) == 2
50 | assert corpus2.generations[0]["test_key"] == "test_value"
51 | assert corpus2.generations[-1]["test_key"] == "test_value2"
52 |
53 |
54 | def test_GenerationCorpus_batch(tmp_path):
55 | corpus = generate.GenerationCorpus(tmp_path, "test_corpus")
56 | assert len(corpus.generations) == 0
57 | messages = [
58 | {
59 | "role": "user",
60 | "content": "Test query",
61 | },
62 | ]
63 | n_to_generate = 200
64 | metadata_list = [{"test_id": i, "messages": messages} for i in range(n_to_generate)]
65 | n_generated = corpus.batch_generate(
66 | metadata_list,
67 | n_processes=4,
68 | sleep=0,
69 | completion_func=mock_get_completion,
70 | )
71 | assert n_generated == n_to_generate, f"Expected {n_generated} generations, but only produced {n_to_generate}"
72 | assert len(corpus.generations) == n_to_generate
73 | for generation in corpus.generations:
74 | assert generation["generation"] is not None
75 |
76 | assert (
77 | corpus.batch_generate(
78 | metadata_list,
79 | n_processes=4,
80 | sleep=0,
81 | completion_func=mock_get_completion,
82 | fake_kwarg=0, # also test passing a kwarg
83 | )
84 | == 0
85 | )
86 |
87 | # test get_nonmatching_generations
88 | n_generations = len(corpus.generations)
89 | metadata_list = [dict(metadata) for metadata in metadata_list]
90 | popped_metadata = metadata_list.pop(0)
91 | nonmatching_generations = corpus.get_nonmatching_generations(metadata_list)
92 | assert len(nonmatching_generations) == 1
93 | assert nonmatching_generations[0] == popped_metadata
94 | assert popped_metadata in corpus.generations, "Unexpectedly removed from generations."
95 | assert n_generations == len(corpus.generations)
96 | assert len(corpus.get_nonmatching_generations(metadata_list, should_remove_nonmatching=True)) == 1
97 | assert (
98 | len(corpus.generations) == n_generations - 1
99 | ), "Should be one fewer generation after removing non-matching generations."
100 |
101 | # test filter_generations() and overwite()
102 | n_generations = len(corpus.generations)
103 | assert corpus.filter_generations() == 0
104 | assert corpus.filter_generations(lambda _: False) == n_generations
105 | assert len(corpus.generations) == 0
106 | corpus.overwrite()
107 |
108 |
109 | def test_GenerationCorpus_batch_error(tmp_path):
110 | corpus = generate.GenerationCorpus(tmp_path, "test_corpus")
111 | assert len(corpus.generations) == 0
112 | messages = [
113 | {
114 | "role": "user",
115 | "content": "Test query",
116 | },
117 | ]
118 | n_to_generate = 2
119 | metadata_list = [{"test_id": i, "messages": messages} for i in range(n_to_generate)]
120 | with pytest.raises(ValueError):
121 | corpus.batch_generate(
122 | metadata_list,
123 | n_processes=1,
124 | sleep=0,
125 | completion_func=mock_get_completion_error,
126 | )
127 |
--------------------------------------------------------------------------------
/tests/unit/test_metrics.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from experiment import metrics
4 |
5 |
6 | def test_metric_objects():
7 | assert metrics.get_bertscore_metric_object() is not None
8 | assert metrics.get_bleurt_metric_object() is not None
9 |
10 |
11 | def test_compute_bleurt():
12 | passages = ["The alphabet is 26 letters long.", "Math is not so easy."]
13 | generation = "The English alphabet is 26 letters long."
14 | assert metrics.compute_bleurt(passages, generation) >= 0.3
15 |
16 |
17 | def test_compute_bertscore():
18 | passages = ["The alphabet is 26 letters long.", "Math is not so easy."]
19 | generation = "The English alphabet is 26 letters long."
20 | assert metrics.compute_bertscore(passages, generation) >= 0.3
21 |
22 |
23 | def test_compute_macro_f1():
24 | assert metrics.compute_macro_f1(["Test text", "distractor passage"], "test text") == 1
25 | assert metrics.compute_macro_f1(["Test"], "test text") == 2 / 3
26 | assert metrics.compute_macro_f1(["distractor", "distractor passage"], "test text") == 0
27 | with pytest.raises(ValueError):
28 | metrics.compute_macro_f1(["test"], "")
29 | with pytest.raises(ValueError):
30 | metrics.compute_macro_f1([""], "test")
31 |
32 | # K-F1++
33 | assert (
34 | metrics.compute_macro_f1(
35 | ["George Washington"],
36 | "Who was the first president? George Washington",
37 | discount_text="Who was the first president?",
38 | )
39 | == 1
40 | )
41 |
--------------------------------------------------------------------------------
/tests/unit/test_qualtrics.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | from datetime import datetime
3 |
4 | import pandas as pd
5 | import pytest
6 |
7 | from experiment import qualtrics
8 |
9 |
10 | @pytest.fixture(scope="module")
11 | def qualtrics_template_filepath(pytestconfig) -> pathlib.Path:
12 | template_filepath = pytestconfig.rootpath / "data" / "raw" / "qualtrics" / "Rori_ranking_annotations_-_template.qsf"
13 | if not template_filepath.exists():
14 | pytest.skip(f"Qualtrics template at {template_filepath} does not exist.")
15 | return template_filepath
16 |
17 |
18 | def test_create_surveys(qualtrics_template_filepath, tmp_path):
19 | template_survey_text = qualtrics.get_template(qualtrics_template_filepath)
20 | assert template_survey_text is not None
21 |
22 | survey = qualtrics.create_survey("survey_id0", [], template_survey_text)
23 | assert qualtrics.OVERFLOW_RESPONSE in survey
24 |
25 | survey_dir = tmp_path / "test_surveys"
26 | df = pd.DataFrame(
27 | [
28 | ["q1", "d1", "none", "g1", "a"],
29 | ["q1", "d1", "low", "g2", "a"],
30 | ["q1", "d1", "high", "g3", "a"],
31 | ],
32 | columns=["query", "document", "guidance", "generation", "metadata"],
33 | )
34 | survey_df = qualtrics.create_surveys(df, template_survey_text, survey_dir)
35 | row = survey_df.iloc[0]
36 | assert row["query"] == "q1"
37 | assert row["document"] == "d1"
38 | assert set(row[["response1", "response2", "response3"]]) == {"g1", "g2", "g3"}
39 | assert row["survey_id"] == f"s_{datetime.now().strftime('%Y%m%d')}_1/1"
40 |
--------------------------------------------------------------------------------
/tests/unit/test_tokenize.py:
--------------------------------------------------------------------------------
1 | from experiment import tokenize
2 |
3 |
4 | def test_get_tokens():
5 | assert tokenize.get_tokens("Test text", lower=True) == ["test", "text"]
6 | assert tokenize.get_tokens("Test text", lower=False) == ["Test", "text"]
7 |
8 | assert tokenize.get_tokens("Test.", lower=False) == ["Test", "."]
9 | assert tokenize.get_tokens("Test.", lower=False, remove_nonalphanumeric_tokens=True) == ["Test"]
10 |
--------------------------------------------------------------------------------