├── .flake8 ├── .gitignore ├── .pre-commit-config.yaml ├── .vscode └── settings.json ├── CITATION.cff ├── Dockerfile.streamlit ├── LICENSE ├── Makefile ├── README.md ├── data ├── .gitignore └── derived │ ├── generations │ └── math_df.csv │ ├── mn_general_student_queries.csv │ └── outcome_df.csv ├── figures ├── dhf-logo-vector-blue.svg ├── dhf-poster-logo.png ├── faithfulness_by_guidance.pdf ├── guidance_autometrics.pdf ├── mean_faithfulness.pdf ├── mean_relevance.pdf ├── mean_relevance2.pdf ├── pairwise_ranks.pdf ├── pairwise_ranks_poster.png ├── rank_distribution.pdf ├── relevance_rank_comparison.pdf ├── relevance_x_faithfulness.pdf ├── rori-rag-final-5.png ├── rori-rag-final.drawio ├── slides_groundedness_v_relevance.png ├── slides_human_groundedness.png ├── slides_kf1.png ├── slides_kf1_hist.png ├── slides_kf1_hist_low.png ├── slides_kf1_hist_low_high.png ├── slides_kf1_no_ir.png ├── slides_preference.png ├── system-diagram-poster.png └── system-diagram.svg ├── notebooks ├── EvaluateRetrievalImpact.ipynb ├── MathNationDataExploration.ipynb ├── MetricTesting.ipynb ├── Qualtrics.ipynb └── SurveyDataAnalysis.ipynb ├── poetry.lock ├── poetry.toml ├── pyproject.toml ├── src ├── experiment │ ├── auth.py │ ├── completion_utils.py │ ├── generate.py │ ├── guidance_conditions.py │ ├── metrics.py │ ├── qualtrics.py │ └── tokenize.py ├── rag │ ├── __init__.py │ ├── embedding_utils.py │ ├── gpf_utils.py │ ├── logit_bias.py │ ├── misconceptions.py │ ├── prompt_utils.py │ ├── prompts │ │ ├── __init__.py │ │ ├── hints.py │ │ └── mathqa.py │ ├── resources │ │ ├── dolma_stopwords.json │ │ └── misconceptions.ndjson │ ├── retrieval.py │ └── retrieval_strategies.py └── streamlit │ ├── Annotation_tool.py │ └── pages │ └── Relevance.py └── tests └── unit ├── conftest.py ├── rag ├── conftest.py ├── test_embedding_utils.py ├── test_gpf_utils.py ├── test_logit_bias.py ├── test_misconceptions.py ├── test_prompt_utils.py ├── test_retrieval.py └── test_retrieval_strategies.py ├── test_auth.py ├── test_completion_utils.py ├── test_generate.py ├── test_metrics.py ├── test_qualtrics.py └── test_tokenize.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude = venv 3 | ignore = E203, E501, W503 4 | max-line-length = 120 5 | 6 | # E203: Whitespace before ':' 7 | # E501: Line too long 8 | # W503: https://github.com/psf/black/issues/52 and https://black.readthedocs.io/en/stable/guides/using_black_with_other_tools.html#flake8 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Project-specific 2 | .DS_Store 3 | .streamlit/secrets.toml 4 | .env 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # poetry 103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 107 | #poetry.lock 108 | 109 | # pdm 110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 111 | #pdm.lock 112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 113 | # in version control. 114 | # https://pdm.fming.dev/#use-with-ide 115 | .pdm.toml 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | #.idea/ 166 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.6.0 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: end-of-file-fixer 7 | - id: check-yaml 8 | - repo: https://github.com/pycqa/isort 9 | rev: 5.13.2 10 | hooks: 11 | - id: isort 12 | - repo: https://github.com/asottile/add-trailing-comma 13 | rev: v3.1.0 14 | hooks: 15 | - id: add-trailing-comma 16 | args: [--py36-plus] 17 | - repo: https://github.com/asottile/pyupgrade 18 | rev: v3.17.0 19 | hooks: 20 | - id: pyupgrade 21 | args: [--py37-plus] 22 | - repo: https://github.com/psf/black-pre-commit-mirror 23 | rev: 24.8.0 24 | hooks: 25 | - id: black 26 | language_version: python3.9 27 | - repo: https://github.com/nbQA-dev/nbQA 28 | rev: 1.8.7 29 | hooks: 30 | - id: nbqa-black 31 | - id: nbqa-pyupgrade 32 | args: ["--py37-plus"] 33 | - id: nbqa-isort 34 | args: ["--float-to-top"] 35 | - repo: https://github.com/PyCQA/flake8 36 | rev: 7.1.1 37 | hooks: 38 | - id: flake8 39 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.testing.pytestArgs": [ 3 | "tests" 4 | ], 5 | "python.testing.unittestEnabled": false, 6 | "python.testing.pytestEnabled": true 7 | } 8 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite the paper as below." 3 | authors: 4 | - family-names: Levonian 5 | given-names: Zachary 6 | orcid: https://orcid.org/0000-0002-8932-1489 7 | - family-names: Li 8 | given-names: Chenglu 9 | - family-names: Zhu 10 | given-names: Wangda 11 | - family-names: Gade 12 | given-names: Anoushka 13 | - family-names: Henkel 14 | given-names: Owen 15 | - family-names: Postle 16 | given-names: Millie-Ellen 17 | - family-names: Xing 18 | given-names: Wanli 19 | date-released: 2023-10-04 20 | repository-code: "https://github.com/DigitalHarborFoundation/rag-for-math-qa" 21 | preferred-citation: 22 | type: conference-paper 23 | title: "Retrieval-augmented Generation to Improve Math Question-Answering: Trade-offs Between Groundedness and Human Preference" 24 | abstract: "For middle-school math students, interactive question-answering (QA) with tutors is an effective way to learn. The flexibility and emergent capabilities of generative large language models (LLMs) has led to a surge of interest in automating portions of the tutoring process - including interactive QA to support conceptual discussion of mathematical concepts. However, LLM responses to math questions can be incorrect or mismatched to the educational context - such as being misaligned with a school's curriculum. One potential solution is retrieval-augmented generation (RAG), which involves incorporating a vetted external knowledge source in the LLM prompt to increase response quality. In this paper, we designed prompts that retrieve and use content from a high-quality open-source math textbook to generate responses to real student questions. We evaluate the efficacy of this RAG system for middle-school algebra and geometry QA by administering a multi-condition survey, finding that humans prefer responses generated using RAG, but not when responses are too grounded in the textbook content. We argue that while RAG is able to improve response quality, designers of math QA systems must consider trade-offs between generating responses preferred by students and responses closely matched to specific educational resources." 25 | doi: 10.48550/arXiv.2310.03184 26 | year: 2023 27 | conference: 28 | name: "NeurIPS'23 Workshop on Generative AI for Education (GAIED)" 29 | city: "New Orleans" 30 | country: "US" 31 | date-start: "2023-12-15" 32 | date-end: "2023-12-15" 33 | authors: 34 | - family-names: Levonian 35 | given-names: Zachary 36 | orcid: https://orcid.org/0000-0002-8932-1489 37 | - family-names: Li 38 | given-names: Chenglu 39 | - family-names: Zhu 40 | given-names: Wangda 41 | - family-names: Gade 42 | given-names: Anoushka 43 | - family-names: Henkel 44 | given-names: Owen 45 | - family-names: Postle 46 | given-names: Millie-Ellen 47 | - family-names: Xing 48 | given-names: Wanli 49 | 50 | 51 | -------------------------------------------------------------------------------- /Dockerfile.streamlit: -------------------------------------------------------------------------------- 1 | FROM python:3.10 2 | 3 | RUN python -m pip install --upgrade pip 4 | RUN curl -sSL https://install.python-poetry.org | python - 5 | 6 | WORKDIR /usr/app 7 | 8 | COPY ./src ./src 9 | COPY ./tests ./tests 10 | 11 | # note: README.md is required for Poetry 12 | COPY README.md README.md 13 | COPY pyproject.toml pyproject.toml 14 | COPY poetry.lock poetry.lock 15 | RUN pip install . 16 | 17 | EXPOSE 8502 18 | ENTRYPOINT ["streamlit", "run", "src/streamlit/Annotation_tool.py", "--server.port=8502", "--server.address=0.0.0.0"] 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Zachary Levonian 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: help install ensure-poetry install-precommits test run-streamlit build-docker run-docker remove-docker 2 | 3 | export PATH := $(HOME)/.local/bin:$(PATH) 4 | 5 | tt: 6 | @which poetry 7 | @poetry install 8 | 9 | help: 10 | @echo "Relevant targets are 'install' and 'test'." 11 | 12 | install: 13 | @$(MAKE) ensure-poetry 14 | @$(MAKE) install-precommits 15 | @poetry build 16 | 17 | ensure-poetry: 18 | @# see issue: https://stackoverflow.com/questions/77019756/make-not-finding-executable-added-to-path-in-makefile 19 | @if ! command -v poetry &> /dev/null; then \ 20 | echo "Installing poetry"; \ 21 | curl -sSL https://install.python-poetry.org | python - ; \ 22 | echo "Poetry installed, but you might need to update your PATH before make will detect it."; \ 23 | fi 24 | @poetry install 25 | 26 | install-precommits: 27 | @poetry run pre-commit autoupdate 28 | @poetry run pre-commit install --overwrite --install-hooks 29 | 30 | jupyter: 31 | poetry run jupyter lab 32 | 33 | test: 34 | @poetry run pytest --cov=src --cov-report term-missing 35 | 36 | build-docker: 37 | @poetry build 38 | @docker build -t streamlit -f Dockerfile.streamlit . 39 | 40 | run-docker: 41 | @$(MAKE) remove-docker 42 | @docker run \ 43 | --name streamlit_container \ 44 | -p 8502:8502 \ 45 | -v ./.streamlit/secrets.toml:/usr/app/.streamlit/secrets.toml \ 46 | streamlit:latest 47 | 48 | remove-docker: 49 | @if docker ps -q --filter "name=streamlit_container" | grep -q .; then \ 50 | echo "Stopping streamlit container"; \ 51 | docker stop streamlit_container; \ 52 | fi 53 | @if docker ps -a -q --filter "name=streamlit_container" | grep -q .; then \ 54 | echo "Removing streamlit container"; \ 55 | docker remove --volumes streamlit_container; \ 56 | fi 57 | 58 | run-streamlit: 59 | @streamlit run src/streamlit/Annotation_tool.py -- 60 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Retrieval-augmented generation to improve math question-answering: trade-offs between groundedness and human preference 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-2310.03184-b31b1b.svg)](https://arxiv.org/abs/2310.03184) 4 | [![License](https://img.shields.io/github/license/DigitalHarborFoundation/rag-for-math-qa)](https://github.com/DigitalHarborFoundation/rag-for-math-qa/blob/main/LICENSE) 5 | 6 | 7 | This repository contains analysis code, prompts, surveys, figures, and data for the paper "Retrieval-augmented generation to improve math question-answering: trade-offs between groundedness and human preference". 8 | 9 | This repository forks the [`llm-math-education`](https://github.com/DigitalHarborFoundation/llm-math-education) package. 10 | 11 | Cite [the paper](https://arxiv.org/abs/2310.03184) using the CITATION.cff file and dropdown: 12 | 13 | >Zachary Levonian, Chenglu Li, Wangda Zhu, Anoushka Gade, Owen Henkel, Millie-Ellen Postle, and Wanli Xing. 2023. Retrieval-augmented Generation to Improve Math Question-Answering: Trade-offs Between Groundedness and Human Preference. In _NeurIPS’23 Workshop on Generative AI for Education (GAIED)_, New Orleans, USA. DOI:https://doi.org/10.48550/arXiv.2310.03184 14 | 15 | ## Development 16 | 17 | Primary code contributor: 18 | 19 | - Zachary Levonian () 20 | 21 | ## Local development setup 22 | 23 | This project uses `make` and `Poetry` to manage and install dependencies. 24 | 25 | On Windows, you'll need to use WSL and maybe make some other changes. 26 | 27 | ### Python development 28 | 29 | Use `make install` to install all needed dependencies (including the pre-commit hooks and Poetry). 30 | 31 | You'll probably need to manually add Poetry to your PATH, e.g. by updating your `.bashrc` (or relevant equivalent): 32 | 33 | ```bash 34 | export PATH="$HOME/.local/bin:$PATH" 35 | ``` 36 | 37 | ### Run tests 38 | 39 | ```bash 40 | make test 41 | ``` 42 | 43 | ### Run Jupyter Lab 44 | 45 | ```bash 46 | make jupyter 47 | ``` 48 | 49 | Which really just runs `poetry run jupyter lab`, so feel free to customize your Jupyter experience. 50 | 51 | ### Other useful commands 52 | 53 | - `poetry run ` - Run the given command, e.g. `poetry run pytest` invokes the tests. 54 | - `poetry add ` - Add the given package as a dependency. Use flag `-G dev` to add it as a development dependency. 55 | 56 | ## Other notes 57 | 58 | ### Poster figures 59 | 60 | Some logos are present in the posters directory. 61 | 62 | The Digital Harbor Foundation logo was created using [`rsvg-convert`](https://man.archlinux.org/man/rsvg-convert.1.en), installed via `brew`. 63 | 64 | I manually adjusted the source svg (dhf-logo-vector-blue.svg) to use DHF blue (#0091c9) rather than black (#010101). 65 | 66 | ```bash 67 | brew install librsvg 68 | rsvg-convert -d 150 -p 150 -h 2in figures/dhf-logo-vector-blue.svg > figures/dhf-poster-logo.png 69 | ``` 70 | 71 | Converting the system diagram (converted from draw.io as an SVG, with embedded fonts): 72 | 73 | ```bash 74 | rsvg-convert -d 150 -p 150 -h 4in figures/system-diagram.svg > figures/system-diagram-poster.png 75 | ``` 76 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | raw/* 2 | derived/* 3 | -------------------------------------------------------------------------------- /data/derived/mn_general_student_queries.csv: -------------------------------------------------------------------------------- 1 | post_id,subject_name,post_content,is_respondable_query 2 | 19856,Algebra 1,how do you a graph function rule?!,general 3 | 114676,Algebra 1,what is the quadratic formula?,general 4 | 124734,Algebra 1,what is monomial,general 5 | 199138,Algebra 1,Can someone give me a factoring example I don't get it and I re watched the videos and I still don't get it.,general 6 | 285849,Algebra 1,if all the x values in a function are differnt its a function right?,general 7 | 294570,Algebra 1,Give me some examples of an equation that is not a linear function!,general 8 | 394330,Algebra 1,"what are regression and median-fit lines? 9 | [Continued:] I really don't get them.(it) (that)",general 10 | 503676,Algebra 1,"how do you oslve with negative exponents (for example 4^-9) 11 | [Continued:] *solve",general 12 | 591424,Algebra 1,How do you find the parent function of a graph?,general 13 | 645067,Algebra 1,after you distribute a negative does it disappear or do you keep it there?,general 14 | 674022,Algebra 1,I always forget the difference between commutative and associative. Does anybody know a way to make me remember?,general 15 | 726510,Algebra 1,How do you find the radius?,general 16 | 859909,Algebra 1,How do you find the domain and range that are reasonable for a certain function?,general 17 | 1105388,Pre-Algebra,Hey guys! What is the quotient rule??,general 18 | 1140336,Algebra 1,how do i solve square root functions,general 19 | 1179906,Pre-Algebra,How do I multiply fractions???????,general 20 | 1254378,Algebra 1,"where can I find completing the square? 21 | ",general 22 | 1420227,Algebra 1,How does standard deivation work?,general 23 | 1423442,Pre-Algebra,How do you multiply fractions?!?!?,general 24 | 1487901,Algebra 1,"how far can a polynomial go 25 | ",general 26 | 1790360,Algebra 1,What is vertex form and how do you solve for it?,general 27 | 1857118,Algebra 1,what is the difference between communtative property of addition and addition property of equality,general 28 | 1900270,Algebra 1,"Just checking, If I have a problem like x to the 4th in parentheses and there's a exponent out of the parentheses lets say 2 would you multiply the exponent in it by the one out making it x to the 8th power?",general 29 | 1928752,Algebra 1,"Can I get the steps for factoring quadratics 30 | ",general 31 | 1942173,Algebra 1,What is the domain and range? How do I find it?,general 32 | 2002640,Algebra 1,"Can someone help me with what an irrational number is vs a rational number 33 | ",general 34 | 2015759,Algebra 1,"I need help with this question: 35 | 36 | An entertainment firm offers several DJ choices and light shows that range in price based on the rental time period. The DJ's cost between $219.00 and $369.00 per night and the light shows cost between $159.00 and $309.00 per night. If you are booking both a DJ and a light show, write a compound inequality that represents the possible total amount you would pay, x.",problem 37 | 2132595,Algebra 1,What is a leading coefficient?,general 38 | 2153141,Algebra 1,wait... so if im using pemdas how does that work with fractions ?,general 39 | 2166131,Algebra 1,"I need a standard way of recognizing polynomials and their degrees, please help me figure 40 | it out. ",general 41 | 2207350,Algebra 1,"Does anyone know how to find the domain and range of a problem like f(x̄)=x̄²+8? 42 | ","general,problem" 43 | 2311169,Algebra 1,What is the difference between commutative and associative property?,general 44 | 2353255,Algebra 1,"I know that this has been a consistent question but I have a test on line of best fit. I've been given multiple different answers on the algebra wall and some make sense while others don't. My teacher never gave me an equation so I don't think I'm supposed to be using one, so what is the best way to find the accurate line of best fit. ",general 45 | 2389599,Algebra 1,what is the difference between recursive formula and explicit form.,general 46 | 2400730,Algebra 1,how do you know if a graph is an absolute value graph?,general 47 | 2547127,Algebra 1,i need help on how to graph quadratic funtions,general 48 | 2571397,Algebra 1,could i recieve help with turning a trinomial into grouping,general 49 | 2660369,Geometry,Is supplementary angles are always adjacent,general 50 | 2672704,Algebra 1,"i have a problem is this equation linear? 51 | 7x + y + 3 = y 52 | if its linear or not linear can somebody help me on how you can tell if its linear or not","general,problem" 53 | 2675038,Algebra 1,Can function Notation be negative?,general 54 | 2689932,Algebra 1,I don't understand how to find x if there are two x's,general 55 | 2690581,Algebra 1,How do you know if a number is a constant?,general 56 | 2727327,Geometry,How do you translate a shape,general 57 | 2750298,Algebra 1,What is a function notation?,general 58 | 2921833,Algebra 1,How do i graph an inequality with a less than or equal to sign?,general 59 | 2928270,Geometry,Why does the sine of an angle equal the sine of its supplement?,general 60 | 2933099,Algebra 1,Is a rational and irrational number always irrational?,general 61 | 2965705,Geometry,How do I add line segments again??,general 62 | 3033408,Algebra 1,How do you find the zeros of a function by graphing?,general 63 | 3034048,Geometry,I dont get how to get the length of the sides without knowing the perimeter.,general 64 | 3069503,Algebra 1,How can you use polynomials to find the perimeter of a square for a example someone giving you a side length of (3s-5),"general,problem" 65 | 3120570,Geometry,"I know that sin A=opposite/hypotenuse, cos A=adjacent/hypotenuse, and tan A=opposite/adjacent. Is there any other ratios like those?",general 66 | 3293299,Geometry,"I'm still a little lost on how you find midpoints on graphs. I'm currently on Topic 4 of Section 1, and I'm watching the video, but I'm still a little confuse on how they find it. Here is the question. If anyone knows the steps to finding the midpoint on a graph, thank you! ",general 67 | 3311347,Geometry,Can someone explain like interior and exterior angles? thanks,general 68 | 3322367,Geometry,"how do i find the slope of a line in slope intercept form. 69 | ",general 70 | -------------------------------------------------------------------------------- /figures/dhf-logo-vector-blue.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 7 | 8 | 9 | 12 | 13 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 24 | 28 | 31 | 33 | 34 | 37 | 39 | 41 | 44 | 45 | 46 | 47 | 50 | 52 | 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /figures/dhf-poster-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/dhf-poster-logo.png -------------------------------------------------------------------------------- /figures/faithfulness_by_guidance.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/faithfulness_by_guidance.pdf -------------------------------------------------------------------------------- /figures/guidance_autometrics.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/guidance_autometrics.pdf -------------------------------------------------------------------------------- /figures/mean_faithfulness.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/mean_faithfulness.pdf -------------------------------------------------------------------------------- /figures/mean_relevance.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/mean_relevance.pdf -------------------------------------------------------------------------------- /figures/mean_relevance2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/mean_relevance2.pdf -------------------------------------------------------------------------------- /figures/pairwise_ranks.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/pairwise_ranks.pdf -------------------------------------------------------------------------------- /figures/pairwise_ranks_poster.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/pairwise_ranks_poster.png -------------------------------------------------------------------------------- /figures/rank_distribution.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/rank_distribution.pdf -------------------------------------------------------------------------------- /figures/relevance_rank_comparison.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/relevance_rank_comparison.pdf -------------------------------------------------------------------------------- /figures/relevance_x_faithfulness.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/relevance_x_faithfulness.pdf -------------------------------------------------------------------------------- /figures/rori-rag-final-5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/rori-rag-final-5.png -------------------------------------------------------------------------------- /figures/rori-rag-final.drawio: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | -------------------------------------------------------------------------------- /figures/slides_groundedness_v_relevance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/slides_groundedness_v_relevance.png -------------------------------------------------------------------------------- /figures/slides_human_groundedness.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/slides_human_groundedness.png -------------------------------------------------------------------------------- /figures/slides_kf1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/slides_kf1.png -------------------------------------------------------------------------------- /figures/slides_kf1_hist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/slides_kf1_hist.png -------------------------------------------------------------------------------- /figures/slides_kf1_hist_low.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/slides_kf1_hist_low.png -------------------------------------------------------------------------------- /figures/slides_kf1_hist_low_high.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/slides_kf1_hist_low_high.png -------------------------------------------------------------------------------- /figures/slides_kf1_no_ir.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/slides_kf1_no_ir.png -------------------------------------------------------------------------------- /figures/slides_preference.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/slides_preference.png -------------------------------------------------------------------------------- /figures/system-diagram-poster.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/figures/system-diagram-poster.png -------------------------------------------------------------------------------- /notebooks/MathNationDataExploration.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "f05760e9-f206-4fc7-b515-cf8516fb05bc", 6 | "metadata": {}, 7 | "source": [ 8 | "MathNation Data Exploration\n", 9 | "===\n", 10 | "Exploring a sample of anonymized MathNation data." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "id": "6cb45b53-f77e-48e3-a197-d02df2349542", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "from datetime import datetime\n", 21 | "from pathlib import Path\n", 22 | "\n", 23 | "import matplotlib.pyplot as plt\n", 24 | "import numpy as np\n", 25 | "import pandas as pd" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 2, 31 | "id": "8378ffc6-9752-4c30-b5dd-6cb60c26561a", 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "data_dir = Path(\"../data\")\n", 36 | "assert data_dir.exists()" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 4, 42 | "id": "350228a7-5f41-413b-917b-a37ffda3ef0c", 43 | "metadata": {}, 44 | "outputs": [ 45 | { 46 | "data": { 47 | "text/plain": [ 48 | "(152844, 9)" 49 | ] 50 | }, 51 | "execution_count": 4, 52 | "metadata": {}, 53 | "output_type": "execute_result" 54 | } 55 | ], 56 | "source": [ 57 | "mn_discussion_filepath = data_dir / \"raw\" / \"math_nation\" / \"mn_discussion_20230914.csv\"\n", 58 | "assert mn_discussion_filepath.exists()\n", 59 | "mn_df = pd.read_csv(mn_discussion_filepath)\n", 60 | "mn_df.shape" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 5, 66 | "id": "44dfb03d-902a-490a-ac85-9840d9f11dad", 67 | "metadata": {}, 68 | "outputs": [ 69 | { 70 | "data": { 71 | "text/html": [ 72 | "
\n", 73 | "\n", 86 | "\n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | "
reply_idreply_userreply_contentreply_ts_createdpost_idsubject_namepost_contentpost_userpost_ts_created
8952924689901412702Examples2018-03-13 17:18:452468982Algebra 1degree and terms of a polynomial?\\n73955402018-03-13 17:15:02
145191974764667659MAFS Section 1 video 62015-10-13 21:25:03974748Algebra 1Is there any videos on multiplying and dividin...22524062015-10-13 21:23:07
621851172038681091Jack, lets see! Lets multiply those to see if ...2016-01-30 21:49:401172013Algebra 1Help! solve each equation and check for soluti...35847412016-01-30 21:40:17
\n", 140 | "
" 141 | ], 142 | "text/plain": [ 143 | " reply_id reply_user \\\n", 144 | "89529 2468990 1412702 \n", 145 | "145191 974764 667659 \n", 146 | "62185 1172038 681091 \n", 147 | "\n", 148 | " reply_content \\\n", 149 | "89529 Examples \n", 150 | "145191 MAFS Section 1 video 6 \n", 151 | "62185 Jack, lets see! Lets multiply those to see if ... \n", 152 | "\n", 153 | " reply_ts_created post_id subject_name \\\n", 154 | "89529 2018-03-13 17:18:45 2468982 Algebra 1 \n", 155 | "145191 2015-10-13 21:25:03 974748 Algebra 1 \n", 156 | "62185 2016-01-30 21:49:40 1172013 Algebra 1 \n", 157 | "\n", 158 | " post_content post_user \\\n", 159 | "89529 degree and terms of a polynomial?\\n 7395540 \n", 160 | "145191 Is there any videos on multiplying and dividin... 2252406 \n", 161 | "62185 Help! solve each equation and check for soluti... 3584741 \n", 162 | "\n", 163 | " post_ts_created \n", 164 | "89529 2018-03-13 17:15:02 \n", 165 | "145191 2015-10-13 21:23:07 \n", 166 | "62185 2016-01-30 21:40:17 " 167 | ] 168 | }, 169 | "execution_count": 5, 170 | "metadata": {}, 171 | "output_type": "execute_result" 172 | } 173 | ], 174 | "source": [ 175 | "mn_df.sample(3)" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 7, 181 | "id": "f6e3d1b6-f26a-4f17-9b6b-1182d8b9fbe7", 182 | "metadata": {}, 183 | "outputs": [ 184 | { 185 | "data": { 186 | "text/plain": [ 187 | "subject_name\n", 188 | "Algebra 1 15943\n", 189 | "Geometry 133\n", 190 | "Pre-Algebra 24\n", 191 | "8th Grade Math 1\n", 192 | "6th Grade Math 1\n", 193 | "Name: count, dtype: int64" 194 | ] 195 | }, 196 | "execution_count": 7, 197 | "metadata": {}, 198 | "output_type": "execute_result" 199 | } 200 | ], 201 | "source": [ 202 | "mn_df.groupby(\"post_id\").subject_name.head(1).value_counts()" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 8, 208 | "id": "84d289fe-8a65-47f3-ad6d-bafc6e79c923", 209 | "metadata": {}, 210 | "outputs": [], 211 | "source": [ 212 | "# for row in mn_df[mn_df.subject_name == \"Pre-Algebra\"].drop_duplicates(subset=\"post_id\", keep=\"first\").itertuples():\n", 213 | "# print(row.post_content)" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": 9, 219 | "id": "00686dc9-36eb-434d-b6e6-4526ff366002", 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "# manually extracted student queries that are relevant for testing\n", 224 | "prealgebra_student_queries = [\n", 225 | " \"What is the quotient rule??\",\n", 226 | " \"How do I multiply fractions???????\",\n", 227 | " \"How do you multiply fractions?!?!?\",\n", 228 | "]" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": 9, 234 | "id": "a0b194a0-9c06-49f4-8ba6-e97fc246a181", 235 | "metadata": {}, 236 | "outputs": [ 237 | { 238 | "data": { 239 | "text/plain": [ 240 | "0.08228791454477705" 241 | ] 242 | }, 243 | "execution_count": 9, 244 | "metadata": {}, 245 | "output_type": "execute_result" 246 | } 247 | ], 248 | "source": [ 249 | "first_reply_df = mn_df.sort_values(by=[\"post_id\", \"reply_ts_created\"]).drop_duplicates(subset=\"post_id\", keep=\"first\")\n", 250 | "is_extended_post = first_reply_df.reply_user == first_reply_df.post_user\n", 251 | "is_extended_post.sum() / len(first_reply_df)" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": 10, 257 | "id": "6c1a08d9-0501-42c1-b19d-e088cf0c9dba", 258 | "metadata": {}, 259 | "outputs": [ 260 | { 261 | "data": { 262 | "text/html": [ 263 | "
\n", 264 | "\n", 277 | "\n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | "
reply_idreply_userreply_contentreply_ts_createdpost_idsubject_namepost_contentpost_userpost_ts_created
846431867945440someone please help\\n2013-10-22 23:35:311865Algebra 1NaN9454402013-10-22 23:35:14
26332065939329Can somebody help me please2013-10-23 01:08:462045Algebra 1If 9a+6b+8c=−3 ,\\n\\nwhat is 54a+48c+36b?9393292013-10-23 01:03:49
506403277572466need help2013-10-26 17:33:163276Algebra 1\\ny + 2x = −1\\n3y − x =5724662013-10-26 17:32:44
3736334871032512Any takers?2013-10-28 01:49:263485Algebra 1Challenge problem! Suppose the polynomial:\\n\\n...10325122013-10-28 01:27:00
64233733524723how would you solve this?\\n2013-10-28 22:52:323732Algebra 1NaN5247232013-10-28 22:52:14
..............................
125233059544599973*part2021-09-11 21:17:513305953GeometryMay someone please help me with let b and c45999732021-09-11 21:17:40
2679433065585197522i think it is but i dont now\\n2021-09-13 20:59:333306554Algebra 1If the question has a<25 is a<=25 the sa...51975222021-09-13 20:58:48
7474833083945223632here is the paper2021-09-15 23:05:463308391Algebra 1for the first box I got x>0 and x<652236322021-09-15 23:05:04
3699733120024519295anybody?2021-09-21 23:51:273311997Algebra 1help45192952021-09-21 23:49:09
4157133206364807952How do I do this,2021-10-12 23:10:243320632Algebra 1Can somebody help me?48079522021-10-12 23:08:45
\n", 427 | "

1325 rows × 9 columns

\n", 428 | "
" 429 | ], 430 | "text/plain": [ 431 | " reply_id reply_user reply_content \\\n", 432 | "84643 1867 945440 someone please help\\n \n", 433 | "2633 2065 939329 Can somebody help me please \n", 434 | "50640 3277 572466 need help \n", 435 | "37363 3487 1032512 Any takers? \n", 436 | "6423 3733 524723 how would you solve this?\\n \n", 437 | "... ... ... ... \n", 438 | "1252 3305954 4599973 *part \n", 439 | "26794 3306558 5197522 i think it is but i dont now\\n \n", 440 | "74748 3308394 5223632 here is the paper \n", 441 | "36997 3312002 4519295 anybody? \n", 442 | "41571 3320636 4807952 How do I do this, \n", 443 | "\n", 444 | " reply_ts_created post_id subject_name \\\n", 445 | "84643 2013-10-22 23:35:31 1865 Algebra 1 \n", 446 | "2633 2013-10-23 01:08:46 2045 Algebra 1 \n", 447 | "50640 2013-10-26 17:33:16 3276 Algebra 1 \n", 448 | "37363 2013-10-28 01:49:26 3485 Algebra 1 \n", 449 | "6423 2013-10-28 22:52:32 3732 Algebra 1 \n", 450 | "... ... ... ... \n", 451 | "1252 2021-09-11 21:17:51 3305953 Geometry \n", 452 | "26794 2021-09-13 20:59:33 3306554 Algebra 1 \n", 453 | "74748 2021-09-15 23:05:46 3308391 Algebra 1 \n", 454 | "36997 2021-09-21 23:51:27 3311997 Algebra 1 \n", 455 | "41571 2021-10-12 23:10:24 3320632 Algebra 1 \n", 456 | "\n", 457 | " post_content post_user \\\n", 458 | "84643 NaN 945440 \n", 459 | "2633 If 9a+6b+8c=−3 ,\\n\\nwhat is 54a+48c+36b? 939329 \n", 460 | "50640 \\ny + 2x = −1\\n3y − x = 572466 \n", 461 | "37363 Challenge problem! Suppose the polynomial:\\n\\n... 1032512 \n", 462 | "6423 NaN 524723 \n", 463 | "... ... ... \n", 464 | "1252 May someone please help me with let b and c 4599973 \n", 465 | "26794 If the question has a<25 is a<=25 the sa... 5197522 \n", 466 | "74748 for the first box I got x>0 and x<6 5223632 \n", 467 | "36997 help 4519295 \n", 468 | "41571 Can somebody help me? 4807952 \n", 469 | "\n", 470 | " post_ts_created \n", 471 | "84643 2013-10-22 23:35:14 \n", 472 | "2633 2013-10-23 01:03:49 \n", 473 | "50640 2013-10-26 17:32:44 \n", 474 | "37363 2013-10-28 01:27:00 \n", 475 | "6423 2013-10-28 22:52:14 \n", 476 | "... ... \n", 477 | "1252 2021-09-11 21:17:40 \n", 478 | "26794 2021-09-13 20:58:48 \n", 479 | "74748 2021-09-15 23:05:04 \n", 480 | "36997 2021-09-21 23:49:09 \n", 481 | "41571 2021-10-12 23:08:45 \n", 482 | "\n", 483 | "[1325 rows x 9 columns]" 484 | ] 485 | }, 486 | "execution_count": 10, 487 | "metadata": {}, 488 | "output_type": "execute_result" 489 | } 490 | ], 491 | "source": [ 492 | "first_reply_df[is_extended_post]" 493 | ] 494 | }, 495 | { 496 | "cell_type": "code", 497 | "execution_count": 11, 498 | "id": "abc41f40-32e4-43ef-8bf7-6d8c05c64e7c", 499 | "metadata": {}, 500 | "outputs": [ 501 | { 502 | "data": { 503 | "text/plain": [ 504 | "16102" 505 | ] 506 | }, 507 | "execution_count": 11, 508 | "metadata": {}, 509 | "output_type": "execute_result" 510 | } 511 | ], 512 | "source": [ 513 | "posts = []\n", 514 | "for post_id, group in mn_df.sort_values(by=[\"post_id\", \"reply_ts_created\"]).groupby(\"post_id\"):\n", 515 | " post_content = str(group.iloc[0].post_content)\n", 516 | " post_user = group.iloc[0].post_user\n", 517 | " for row in group.itertuples():\n", 518 | " if row.reply_user != post_user:\n", 519 | " break # reply from non-OP\n", 520 | " else:\n", 521 | " # continuation of the post in a reply\n", 522 | " post_content += \"\\n[Continued:] \" + str(row.reply_content)\n", 523 | " prev_row = row\n", 524 | " posts.append(\n", 525 | " {\n", 526 | " \"post_id\": post_id,\n", 527 | " \"post_user\": post_user,\n", 528 | " \"subject_name\": group.iloc[0].subject_name,\n", 529 | " \"post_ts_created\": group.iloc[0].post_ts_created,\n", 530 | " \"post_content\": post_content,\n", 531 | " }\n", 532 | " )\n", 533 | "len(posts)" 534 | ] 535 | }, 536 | { 537 | "cell_type": "code", 538 | "execution_count": 12, 539 | "id": "77538bb6-0129-40cd-bf40-8c1280ca81d4", 540 | "metadata": {}, 541 | "outputs": [ 542 | { 543 | "data": { 544 | "text/html": [ 545 | "
\n", 546 | "\n", 559 | "\n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | " \n", 578 | " \n", 579 | " \n", 580 | " \n", 581 | " \n", 582 | " \n", 583 | " \n", 584 | " \n", 585 | " \n", 586 | " \n", 587 | " \n", 588 | " \n", 589 | " \n", 590 | " \n", 591 | " \n", 592 | " \n", 593 | " \n", 594 | " \n", 595 | " \n", 596 | "
post_idpost_usersubject_namepost_ts_createdpost_content
37777067331218779Algebra 12015-04-09 21:51:42what is 5349 to the third power equal???? HELP...
1182422614683440812Algebra 12017-10-31 13:47:47=| =] =) =} ^o^ ^0^ ^@^ ^u^ all mojys are not ...
39577392472647012Algebra 12015-04-18 15:51:10Melissa who's after ya?
\n", 597 | "
" 598 | ], 599 | "text/plain": [ 600 | " post_id post_user subject_name post_ts_created \\\n", 601 | "3777 706733 1218779 Algebra 1 2015-04-09 21:51:42 \n", 602 | "11824 2261468 3440812 Algebra 1 2017-10-31 13:47:47 \n", 603 | "3957 739247 2647012 Algebra 1 2015-04-18 15:51:10 \n", 604 | "\n", 605 | " post_content \n", 606 | "3777 what is 5349 to the third power equal???? HELP... \n", 607 | "11824 =| =] =) =} ^o^ ^0^ ^@^ ^u^ all mojys are not ... \n", 608 | "3957 Melissa who's after ya? " 609 | ] 610 | }, 611 | "execution_count": 12, 612 | "metadata": {}, 613 | "output_type": "execute_result" 614 | } 615 | ], 616 | "source": [ 617 | "pd.DataFrame(posts).sample(3)" 618 | ] 619 | }, 620 | { 621 | "cell_type": "code", 622 | "execution_count": 44, 623 | "id": "4be3b675-fec3-4d09-b1ce-d1c2fba0c378", 624 | "metadata": {}, 625 | "outputs": [], 626 | "source": [ 627 | "pd.DataFrame(posts).to_csv(data_dir / \"derived\" / \"mn_student_queries_raw3.csv\")" 628 | ] 629 | }, 630 | { 631 | "cell_type": "markdown", 632 | "id": "1431a6e0-477c-402e-938b-9d8083e201fc", 633 | "metadata": {}, 634 | "source": [ 635 | "Original sample used for annotation was this:\n", 636 | "\n", 637 | "```\n", 638 | "sdf = mn_df[mn_df.subject_name.isin([\"Algebra 1\", \"Geometry\"])]\n", 639 | "sdf = sdf.drop_duplicates(subset=\"post_id\", keep=\"first\")\n", 640 | "sdf[[\"post_id\", \"subject_name\", \"post_content\", \"post_user\", \"post_ts_created\"]].to_csv(data_dir / \"derived\" / \"mn_student_queries_raw.csv\")\n", 641 | "```\n", 642 | "\n", 643 | "But I threw those annotations away." 644 | ] 645 | }, 646 | { 647 | "cell_type": "markdown", 648 | "id": "f4c123a5-2c44-4935-b218-dcc4e50cb509", 649 | "metadata": {}, 650 | "source": [ 651 | "### Loading annotated data" 652 | ] 653 | }, 654 | { 655 | "cell_type": "code", 656 | "execution_count": 46, 657 | "id": "27bbdd84-b403-4620-baa5-48e8e0f5122a", 658 | "metadata": {}, 659 | "outputs": [ 660 | { 661 | "data": { 662 | "text/plain": [ 663 | "(553, 8)" 664 | ] 665 | }, 666 | "execution_count": 46, 667 | "metadata": {}, 668 | "output_type": "execute_result" 669 | } 670 | ], 671 | "source": [ 672 | "adf = pd.read_csv(data_dir / \"derived\" / \"mn_student_queries_raw2_annotated.csv\")\n", 673 | "adf.shape" 674 | ] 675 | }, 676 | { 677 | "cell_type": "code", 678 | "execution_count": 47, 679 | "id": "ec091d67-9c49-4880-96cd-510e5f665f6b", 680 | "metadata": {}, 681 | "outputs": [ 682 | { 683 | "data": { 684 | "text/html": [ 685 | "
\n", 686 | "\n", 699 | "\n", 700 | " \n", 701 | " \n", 702 | " \n", 703 | " \n", 704 | " \n", 705 | " \n", 706 | " \n", 707 | " \n", 708 | " \n", 709 | " \n", 710 | " \n", 711 | " \n", 712 | " \n", 713 | " \n", 714 | " \n", 715 | " \n", 716 | " \n", 717 | " \n", 718 | " \n", 719 | " \n", 720 | " \n", 721 | " \n", 722 | " \n", 723 | " \n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | " \n", 740 | " \n", 741 | " \n", 742 | " \n", 743 | " \n", 744 | " \n", 745 | " \n", 746 | " \n", 747 | " \n", 748 | "
indexpost_idpost_usersubject_namepost_ts_createdpost_contentis_respondable_querynotes
2219856848987Algebra 12013-11-16 23:25:43how do you a graph function rule?!generalNaN
46246231195823073615Geometry2020-03-31 19:38:40is anybody there\\nNaNNaN
40540529666828261996Geometry2019-09-11 21:16:05Wouldn't it be the same thing??NaNNaN
\n", 749 | "
" 750 | ], 751 | "text/plain": [ 752 | " index post_id post_user subject_name post_ts_created \\\n", 753 | "2 2 19856 848987 Algebra 1 2013-11-16 23:25:43 \n", 754 | "462 462 3119582 3073615 Geometry 2020-03-31 19:38:40 \n", 755 | "405 405 2966682 8261996 Geometry 2019-09-11 21:16:05 \n", 756 | "\n", 757 | " post_content is_respondable_query notes \n", 758 | "2 how do you a graph function rule?! general NaN \n", 759 | "462 is anybody there\\n NaN NaN \n", 760 | "405 Wouldn't it be the same thing?? NaN NaN " 761 | ] 762 | }, 763 | "execution_count": 47, 764 | "metadata": {}, 765 | "output_type": "execute_result" 766 | } 767 | ], 768 | "source": [ 769 | "adf.sample(n=3)" 770 | ] 771 | }, 772 | { 773 | "cell_type": "code", 774 | "execution_count": 48, 775 | "id": "787a7ee0-3f39-4de1-af8d-b1ce0d4fb7eb", 776 | "metadata": {}, 777 | "outputs": [ 778 | { 779 | "data": { 780 | "text/plain": [ 781 | "is_respondable_query\n", 782 | "problem 62\n", 783 | "general 52\n", 784 | "confirm 8\n", 785 | "resource request 4\n", 786 | "general,problem 3\n", 787 | "advice 2\n", 788 | "wrong but don't know why 1\n", 789 | "clarify question 1\n", 790 | "stuck 1\n", 791 | "Name: count, dtype: int64" 792 | ] 793 | }, 794 | "execution_count": 48, 795 | "metadata": {}, 796 | "output_type": "execute_result" 797 | } 798 | ], 799 | "source": [ 800 | "adf.is_respondable_query.value_counts()" 801 | ] 802 | }, 803 | { 804 | "cell_type": "code", 805 | "execution_count": 53, 806 | "id": "495fb9fc-46d1-4f18-bf88-9729b26bafb7", 807 | "metadata": {}, 808 | "outputs": [ 809 | { 810 | "data": { 811 | "text/html": [ 812 | "
\n", 813 | "\n", 826 | "\n", 827 | " \n", 828 | " \n", 829 | " \n", 830 | " \n", 831 | " \n", 832 | " \n", 833 | " \n", 834 | " \n", 835 | " \n", 836 | " \n", 837 | " \n", 838 | " \n", 839 | " \n", 840 | " \n", 841 | " \n", 842 | " \n", 843 | " \n", 844 | " \n", 845 | " \n", 846 | " \n", 847 | " \n", 848 | " \n", 849 | " \n", 850 | " \n", 851 | " \n", 852 | " \n", 853 | " \n", 854 | " \n", 855 | " \n", 856 | " \n", 857 | " \n", 858 | " \n", 859 | " \n", 860 | " \n", 861 | " \n", 862 | " \n", 863 | " \n", 864 | " \n", 865 | " \n", 866 | " \n", 867 | " \n", 868 | " \n", 869 | " \n", 870 | " \n", 871 | " \n", 872 | " \n", 873 | " \n", 874 | " \n", 875 | " \n", 876 | " \n", 877 | " \n", 878 | " \n", 879 | " \n", 880 | " \n", 881 | " \n", 882 | " \n", 883 | " \n", 884 | " \n", 885 | " \n", 886 | " \n", 887 | " \n", 888 | " \n", 889 | " \n", 890 | " \n", 891 | " \n", 892 | " \n", 893 | " \n", 894 | " \n", 895 | " \n", 896 | " \n", 897 | " \n", 898 | " \n", 899 | " \n", 900 | " \n", 901 | " \n", 902 | " \n", 903 | " \n", 904 | " \n", 905 | " \n", 906 | " \n", 907 | " \n", 908 | " \n", 909 | " \n", 910 | " \n", 911 | " \n", 912 | " \n", 913 | " \n", 914 | " \n", 915 | " \n", 916 | " \n", 917 | " \n", 918 | " \n", 919 | " \n", 920 | " \n", 921 | " \n", 922 | " \n", 923 | " \n", 924 | " \n", 925 | " \n", 926 | " \n", 927 | "
subject_name6th Grade MathAlgebra 1GeometryPre-AlgebraAll
is_respondable_query
All1111193134
problem1547062
general0409352
confirm06208
resource request03104
general,problem03003
advice02002
clarify question01001
stuck01001
wrong but don't know why01001
\n", 928 | "
" 929 | ], 930 | "text/plain": [ 931 | "subject_name 6th Grade Math Algebra 1 Geometry Pre-Algebra \\\n", 932 | "is_respondable_query \n", 933 | "All 1 111 19 3 \n", 934 | "problem 1 54 7 0 \n", 935 | "general 0 40 9 3 \n", 936 | "confirm 0 6 2 0 \n", 937 | "resource request 0 3 1 0 \n", 938 | "general,problem 0 3 0 0 \n", 939 | "advice 0 2 0 0 \n", 940 | "clarify question 0 1 0 0 \n", 941 | "stuck 0 1 0 0 \n", 942 | "wrong but don't know why 0 1 0 0 \n", 943 | "\n", 944 | "subject_name All \n", 945 | "is_respondable_query \n", 946 | "All 134 \n", 947 | "problem 62 \n", 948 | "general 52 \n", 949 | "confirm 8 \n", 950 | "resource request 4 \n", 951 | "general,problem 3 \n", 952 | "advice 2 \n", 953 | "clarify question 1 \n", 954 | "stuck 1 \n", 955 | "wrong but don't know why 1 " 956 | ] 957 | }, 958 | "execution_count": 53, 959 | "metadata": {}, 960 | "output_type": "execute_result" 961 | } 962 | ], 963 | "source": [ 964 | "pd.crosstab(adf.subject_name, adf.is_respondable_query, margins=True).T.sort_values(by=\"All\", ascending=False)" 965 | ] 966 | }, 967 | { 968 | "cell_type": "code", 969 | "execution_count": 57, 970 | "id": "8c65b4a6-42a5-42d5-ab88-19a74edf0596", 971 | "metadata": {}, 972 | "outputs": [ 973 | { 974 | "data": { 975 | "text/html": [ 976 | "
\n", 977 | "\n", 990 | "\n", 991 | " \n", 992 | " \n", 993 | " \n", 994 | " \n", 995 | " \n", 996 | " \n", 997 | " \n", 998 | " \n", 999 | " \n", 1000 | " \n", 1001 | " \n", 1002 | " \n", 1003 | " \n", 1004 | " \n", 1005 | " \n", 1006 | " \n", 1007 | " \n", 1008 | " \n", 1009 | " \n", 1010 | " \n", 1011 | " \n", 1012 | " \n", 1013 | " \n", 1014 | " \n", 1015 | " \n", 1016 | " \n", 1017 | " \n", 1018 | " \n", 1019 | " \n", 1020 | " \n", 1021 | " \n", 1022 | " \n", 1023 | "
post_idsubject_namepost_contentis_respondable_query
29199138Algebra 1Can someone give me a factoring example I don...general
1381105388Pre-AlgebraHey guys! What is the quotient rule??general
110859909Algebra 1How do you find the domain and range that are ...general
\n", 1024 | "
" 1025 | ], 1026 | "text/plain": [ 1027 | " post_id subject_name post_content \\\n", 1028 | "29 199138 Algebra 1 Can someone give me a factoring example I don... \n", 1029 | "138 1105388 Pre-Algebra Hey guys! What is the quotient rule?? \n", 1030 | "110 859909 Algebra 1 How do you find the domain and range that are ... \n", 1031 | "\n", 1032 | " is_respondable_query \n", 1033 | "29 general \n", 1034 | "138 general \n", 1035 | "110 general " 1036 | ] 1037 | }, 1038 | "execution_count": 57, 1039 | "metadata": {}, 1040 | "output_type": "execute_result" 1041 | } 1042 | ], 1043 | "source": [ 1044 | "general_df = adf[adf.is_respondable_query.map(lambda s: pd.notna(s) and \"general\" in s)][\n", 1045 | " [\"post_id\", \"subject_name\", \"post_content\", \"is_respondable_query\"]\n", 1046 | "]\n", 1047 | "general_df.sample(n=3)" 1048 | ] 1049 | }, 1050 | { 1051 | "cell_type": "code", 1052 | "execution_count": 58, 1053 | "id": "b137f401-9593-4c2e-9f7a-4aff448bb077", 1054 | "metadata": {}, 1055 | "outputs": [], 1056 | "source": [ 1057 | "general_df.to_csv(data_dir / \"derived\" / \"mn_general_student_queries.csv\", index=False)" 1058 | ] 1059 | }, 1060 | { 1061 | "cell_type": "code", 1062 | "execution_count": 62, 1063 | "id": "cf10ce6d-1c63-4792-8f42-9dda5a0e0f9c", 1064 | "metadata": {}, 1065 | "outputs": [ 1066 | { 1067 | "name": "stdout", 1068 | "output_type": "stream", 1069 | "text": [ 1070 | "MathNation queries: (55, 4)\n" 1071 | ] 1072 | }, 1073 | { 1074 | "data": { 1075 | "text/html": [ 1076 | "
\n", 1077 | "\n", 1090 | "\n", 1091 | " \n", 1092 | " \n", 1093 | " \n", 1094 | " \n", 1095 | " \n", 1096 | " \n", 1097 | " \n", 1098 | " \n", 1099 | " \n", 1100 | " \n", 1101 | " \n", 1102 | " \n", 1103 | " \n", 1104 | " \n", 1105 | " \n", 1106 | " \n", 1107 | " \n", 1108 | " \n", 1109 | " \n", 1110 | " \n", 1111 | " \n", 1112 | " \n", 1113 | " \n", 1114 | " \n", 1115 | " \n", 1116 | " \n", 1117 | " \n", 1118 | " \n", 1119 | " \n", 1120 | " \n", 1121 | " \n", 1122 | " \n", 1123 | "
post_idsubject_namepost_contentis_respondable_query
382672704Algebra 1i have a problem is this equation linear?\\n7x ...general,problem
412690581Algebra 1How do you know if a number is a constant?general
201790360Algebra 1What is vertex form and how do you solve for it?general
\n", 1124 | "
" 1125 | ], 1126 | "text/plain": [ 1127 | " post_id subject_name post_content \\\n", 1128 | "38 2672704 Algebra 1 i have a problem is this equation linear?\\n7x ... \n", 1129 | "41 2690581 Algebra 1 How do you know if a number is a constant? \n", 1130 | "20 1790360 Algebra 1 What is vertex form and how do you solve for it? \n", 1131 | "\n", 1132 | " is_respondable_query \n", 1133 | "38 general,problem \n", 1134 | "41 general \n", 1135 | "20 general " 1136 | ] 1137 | }, 1138 | "execution_count": 62, 1139 | "metadata": {}, 1140 | "output_type": "execute_result" 1141 | } 1142 | ], 1143 | "source": [ 1144 | "# load the mathnation query data\n", 1145 | "mn_general_student_queries_filepath = data_dir / \"derived\" / \"mn_general_student_queries.csv\"\n", 1146 | "query_df = pd.read_csv(mn_general_student_queries_filepath)\n", 1147 | "print(f\"MathNation queries: {query_df.shape}\")\n", 1148 | "query_df.sample(n=3)" 1149 | ] 1150 | }, 1151 | { 1152 | "cell_type": "code", 1153 | "execution_count": null, 1154 | "id": "05c868c2-6288-4173-afb4-1dabe9cd3e4b", 1155 | "metadata": {}, 1156 | "outputs": [], 1157 | "source": [] 1158 | }, 1159 | { 1160 | "cell_type": "code", 1161 | "execution_count": null, 1162 | "id": "87f3cab2-cdc8-49e3-acb5-378dc32857bc", 1163 | "metadata": {}, 1164 | "outputs": [], 1165 | "source": [] 1166 | }, 1167 | { 1168 | "cell_type": "code", 1169 | "execution_count": null, 1170 | "id": "c3abf13e-9498-46c3-b16b-0ac69b411ed0", 1171 | "metadata": {}, 1172 | "outputs": [], 1173 | "source": [] 1174 | } 1175 | ], 1176 | "metadata": { 1177 | "kernelspec": { 1178 | "display_name": "llm-math-education", 1179 | "language": "python", 1180 | "name": "llm-math-education" 1181 | }, 1182 | "language_info": { 1183 | "codemirror_mode": { 1184 | "name": "ipython", 1185 | "version": 3 1186 | }, 1187 | "file_extension": ".py", 1188 | "mimetype": "text/x-python", 1189 | "name": "python", 1190 | "nbconvert_exporter": "python", 1191 | "pygments_lexer": "ipython3", 1192 | "version": "3.10.11" 1193 | } 1194 | }, 1195 | "nbformat": 4, 1196 | "nbformat_minor": 5 1197 | } 1198 | -------------------------------------------------------------------------------- /notebooks/MetricTesting.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "a3b12196-aa39-4532-87ce-3f6f41ec411e", 6 | "metadata": {}, 7 | "source": [ 8 | "Metric testing\n", 9 | "===" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "id": "4379991e-c03a-434d-9daa-314507f635df", 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import datasets\n", 20 | "import evaluate\n", 21 | "\n", 22 | "from experiment import metrics" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 2, 28 | "id": "47955103-9b48-4bd9-88e3-1a3999c429ab", 29 | "metadata": {}, 30 | "outputs": [ 31 | { 32 | "name": "stdout", 33 | "output_type": "stream", 34 | "text": [ 35 | "INFO:tensorflow:Reading checkpoint /Users/zacharylevonian/.cache/huggingface/datasets/downloads/extracted/0b5f615fbe4df81a585448a4e6f47b4bb3af737cc290a4d96effa6ef1840ea73/bleurt-base-512.\n", 36 | "INFO:tensorflow:Config file found, reading.\n", 37 | "INFO:tensorflow:Will load checkpoint bert_custom\n", 38 | "INFO:tensorflow:Loads full paths and checks that files exists.\n", 39 | "INFO:tensorflow:... name:bert_custom\n", 40 | "INFO:tensorflow:... vocab_file:vocab.txt\n", 41 | "INFO:tensorflow:... bert_config_file:bert_config.json\n", 42 | "INFO:tensorflow:... do_lower_case:True\n", 43 | "INFO:tensorflow:... max_seq_length:512\n", 44 | "INFO:tensorflow:Creating BLEURT scorer.\n", 45 | "INFO:tensorflow:Creating WordPiece tokenizer.\n", 46 | "INFO:tensorflow:WordPiece tokenizer instantiated.\n", 47 | "INFO:tensorflow:Creating Eager Mode predictor.\n", 48 | "INFO:tensorflow:Loading model.\n", 49 | "INFO:tensorflow:BLEURT initialized.\n" 50 | ] 51 | }, 52 | { 53 | "name": "stderr", 54 | "output_type": "stream", 55 | "text": [ 56 | "INFO:tensorflow:BLEURT initialized.\n" 57 | ] 58 | } 59 | ], 60 | "source": [ 61 | "# https://github.com/huggingface/evaluate/issues/428\n", 62 | "bleurt = evaluate.load(\n", 63 | " \"bleurt\", \"bleurt-base-512\", module_type=\"metric\", download_config=datasets.DownloadConfig(use_etag=False)\n", 64 | ")" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 3, 70 | "id": "6d6faff9-ea7b-4339-b689-251b280ce6a3", 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "data": { 75 | "text/plain": [ 76 | "{'scores': [1.175394892692566, -1.1031553745269775]}" 77 | ] 78 | }, 79 | "execution_count": 3, 80 | "metadata": {}, 81 | "output_type": "execute_result" 82 | } 83 | ], 84 | "source": [ 85 | "predictions = [\"hello there\", \"general skywalker\"]\n", 86 | "references = [\"hello there\", \"general kenobi\"]\n", 87 | "bleurt.compute(predictions=predictions, references=references)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 4, 93 | "id": "f4a07f69-0f92-4056-aee3-a7c98f27d8d5", 94 | "metadata": {}, 95 | "outputs": [ 96 | { 97 | "data": { 98 | "text/plain": [ 99 | "0.4088791608810425" 100 | ] 101 | }, 102 | "execution_count": 4, 103 | "metadata": {}, 104 | "output_type": "execute_result" 105 | } 106 | ], 107 | "source": [ 108 | "def compute_bleurt(passages: list[str], generation: str):\n", 109 | " references = passages + [\n", 110 | " \"\\n\".join(passages),\n", 111 | " ]\n", 112 | " predictions = [\n", 113 | " generation,\n", 114 | " ] * (len(passages) + 1)\n", 115 | " scores = bleurt.compute(predictions=predictions, references=references)[\"scores\"]\n", 116 | " return max(scores)\n", 117 | "\n", 118 | "\n", 119 | "compute_bleurt([\"The alphabet is 26 letters long.\", \"Math is not so easy.\"], \"The English alphabet is 26 letters long.\")" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "id": "a4daae3f-e042-4f94-a4b0-c9e93aa876dd", 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 5, 133 | "id": "e19b8934-8eb7-43c2-ab12-8725de26d0fb", 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "bert_score = metrics.get_bertscore_metric_object()" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 7, 143 | "id": "6b9bd28b-60d3-4e9d-919a-ce8604baef7e", 144 | "metadata": {}, 145 | "outputs": [ 146 | { 147 | "name": "stderr", 148 | "output_type": "stream", 149 | "text": [ 150 | "Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']\n", 151 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" 152 | ] 153 | }, 154 | { 155 | "data": { 156 | "text/plain": [ 157 | "{'precision': [0.9999998807907104, 0.9180971384048462],\n", 158 | " 'recall': [0.9999998807907104, 0.8901697397232056],\n", 159 | " 'f1': [0.9999998807907104, 0.9039177894592285],\n", 160 | " 'hashcode': 'roberta-large_L17_no-idf_version=0.3.12(hug_trans=4.33.1)'}" 161 | ] 162 | }, 163 | "execution_count": 7, 164 | "metadata": {}, 165 | "output_type": "execute_result" 166 | } 167 | ], 168 | "source": [ 169 | "predictions = [\"hello there\", \"general skywalker\"]\n", 170 | "references = [\"hello there\", \"general kenobi\"]\n", 171 | "bert_score.compute(predictions=predictions, references=references, lang=\"en\")" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": 9, 177 | "id": "1776095e-f9df-4628-8736-49fbfb38275e", 178 | "metadata": {}, 179 | "outputs": [ 180 | { 181 | "data": { 182 | "text/plain": [ 183 | "{'precision': [0.8900542259216309, 0.9747406840324402],\n", 184 | " 'recall': [0.8820334672927856, 0.9553087949752808],\n", 185 | " 'f1': [0.8860256671905518, 0.9649269580841064],\n", 186 | " 'hashcode': 'roberta-large_L17_no-idf_version=0.3.12(hug_trans=4.33.1)'}" 187 | ] 188 | }, 189 | "execution_count": 9, 190 | "metadata": {}, 191 | "output_type": "execute_result" 192 | } 193 | ], 194 | "source": [ 195 | "# must match counts\n", 196 | "bert_score.compute(\n", 197 | " predictions=[\"This is a test.\"] * 2,\n", 198 | " references=[\"Two reference sentences.\", \"Second is a test sentence.\"],\n", 199 | " lang=\"en\",\n", 200 | ")" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": null, 206 | "id": "136a9fa6-743a-4dcc-ac54-e6287574cbd1", 207 | "metadata": {}, 208 | "outputs": [], 209 | "source": [] 210 | } 211 | ], 212 | "metadata": { 213 | "kernelspec": { 214 | "display_name": "Python 3 (ipykernel)", 215 | "language": "python", 216 | "name": "python3" 217 | }, 218 | "language_info": { 219 | "codemirror_mode": { 220 | "name": "ipython", 221 | "version": 3 222 | }, 223 | "file_extension": ".py", 224 | "mimetype": "text/x-python", 225 | "name": "python", 226 | "nbconvert_exporter": "python", 227 | "pygments_lexer": "ipython3", 228 | "version": "3.10.11" 229 | } 230 | }, 231 | "nbformat": 4, 232 | "nbformat_minor": 5 233 | } 234 | -------------------------------------------------------------------------------- /notebooks/Qualtrics.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "1fec94f6-6847-4198-b6ba-ec45ea561b8c", 6 | "metadata": {}, 7 | "source": [ 8 | "# Qualtrics survey creation" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 96, 14 | "id": "1b24f48a-aa65-44a0-84d7-e125da064978", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import html\n", 19 | "import json\n", 20 | "import re\n", 21 | "from pathlib import Path\n", 22 | "\n", 23 | "import pandas as pd" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 97, 29 | "id": "8a8c6b72-7478-4843-bab5-887137394024", 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "data_dir = Path(\"../data\")\n", 34 | "assert data_dir.exists()\n", 35 | "figures_dir = Path(\"../figures\")\n", 36 | "figures_dir.mkdir(exist_ok=True)\n", 37 | "assert figures_dir.exists()" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 98, 43 | "id": "7dfa556b-fce1-4deb-97a6-25c9a2b0340c", 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "template_filepath = data_dir / \"raw\" / \"qualtrics\" / \"Rori_ranking_annotations_-_template.qsf\"\n", 48 | "with open(template_filepath) as infile:\n", 49 | " survey_text = infile.read()\n", 50 | " assert json.loads(survey_text)" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "id": "fbe34bd3-16fb-4a7d-aa58-1494991e7a9e", 56 | "metadata": {}, 57 | "source": [ 58 | "Keys to fill:\n", 59 | "\n", 60 | " - RoriSurveyId\n", 61 | " - Response{1,2,3}Q*\n", 62 | " - QueryTextQ*\n", 63 | " - DocumentQ*\n" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 99, 69 | "id": "56091b06-eff6-4325-bad0-1abd1069d71e", 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "response_keys = [\"Response1Q\", \"Response2Q\", \"Response3Q\"]\n", 74 | "query_text_key = \"QueryTextQ\"\n", 75 | "document_key = \"DocumentQ\"\n", 76 | "\n", 77 | "# validate expected keys\n", 78 | "expected_survey_size = 15\n", 79 | "for key in response_keys + [query_text_key, document_key]:\n", 80 | " for i in range(1, expected_survey_size + 1):\n", 81 | " qkey = key + str(i)\n", 82 | " assert qkey in survey_text, qkey" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 100, 88 | "id": "5efb33ad-6cf4-4903-86dc-641fa4da9ae3", 89 | "metadata": {}, 90 | "outputs": [ 91 | { 92 | "name": "stdout", 93 | "output_type": "stream", 94 | "text": [ 95 | "\n", 96 | "Display\":\"Response1Q2\"},\"\n" 97 | ] 98 | }, 99 | { 100 | "data": { 101 | "text/plain": [ 102 | "['Response1Q2']" 103 | ] 104 | }, 105 | "execution_count": 100, 106 | "metadata": {}, 107 | "output_type": "execute_result" 108 | } 109 | ], 110 | "source": [ 111 | "for result in re.finditer(\"Response1Q2(?![0-9])\", survey_text):\n", 112 | " print(result)\n", 113 | " ind = result.span()[0]\n", 114 | " print(survey_text[ind - 10 : ind + 15])\n", 115 | "re.findall(\"Response1Q2(?![0-9])\", survey_text)" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 101, 121 | "id": "887a2aaf-b173-46ff-82cd-93d1e63775af", 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "def convert_text(text, use_br=True):\n", 126 | " text = html.escape(text.replace(\"\\\\\", \"/\"))\n", 127 | " # text = \"

\" + \"<\\\\/p>

\".join(text.split(\"\\n\")) + \"<\\\\/p>\"\n", 128 | " if use_br:\n", 129 | " text = \"

\" + \"


\".join(text.split(\"\\n\")) + \"

\"\n", 130 | " else:\n", 131 | " text = \"

\" + \"

\".join(text.split(\"\\n\")) + \"

\"\n", 132 | " return text\n", 133 | "\n", 134 | "\n", 135 | "expected_survey_size = 15\n", 136 | "for i in range(1, expected_survey_size + 1):\n", 137 | " r1 = \"R1 Multi-line string\\n\\nSeveral bits here are normal:\\n - 1\\n - 2\\n - 3\"\n", 138 | " r2 = r\"R2 Single line string, with some maybe-problematic characters: /\\!@#$%^&*()_-+\"\n", 139 | " r3 = \"R3\"\n", 140 | " responses = [r1, r2, r3]\n", 141 | " query = f\"Test query for page {i}\"\n", 142 | " document = \"Paragraph 1: text goes here\\nParagraph 2: Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.\"\n", 143 | " for key, response in zip(response_keys, responses):\n", 144 | " qkey = key + str(i)\n", 145 | " text = convert_text(response, use_br=False)\n", 146 | " survey_text = survey_text.replace(qkey, text, 1)\n", 147 | " survey_text = survey_text.replace(query_text_key + str(i), convert_text(query), 1)\n", 148 | " survey_text = survey_text.replace(document_key + str(i), convert_text(document), 1)" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 102, 154 | "id": "9dfe6092-dc16-41a7-8f63-6bebb43c24b1", 155 | "metadata": {}, 156 | "outputs": [ 157 | { 158 | "data": { 159 | "text/plain": [ 160 | "'lti-line string

Several bit'" 161 | ] 162 | }, 163 | "execution_count": 102, 164 | "metadata": {}, 165 | "output_type": "execute_result" 166 | } 167 | ], 168 | "source": [ 169 | "ind = 67956\n", 170 | "band = 20\n", 171 | "survey_text[ind - band : ind + band]" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": 103, 177 | "id": "733f3f68-1a55-45cb-86e7-78744834e0a3", 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "# verify that we've created valid JSON\n", 182 | "survey_text = json.dumps(json.loads(survey_text))" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 104, 188 | "id": "626257d8-0dfd-447d-82b3-7bfd8cd4d265", 189 | "metadata": {}, 190 | "outputs": [], 191 | "source": [ 192 | "survey_dir = data_dir / \"derived\" / \"qualtrics\"\n", 193 | "survey_dir.mkdir(exist_ok=True)\n", 194 | "survey_filepath = survey_dir / \"Rori_ranking_annotations_-_survey1.qsf\"\n", 195 | "with open(survey_filepath, \"w\") as outfile:\n", 196 | " outfile.write(survey_text)" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": null, 202 | "id": "bece161b-aefd-4cd4-8f04-001b61622674", 203 | "metadata": {}, 204 | "outputs": [], 205 | "source": [] 206 | } 207 | ], 208 | "metadata": { 209 | "kernelspec": { 210 | "display_name": "Python 3 (ipykernel)", 211 | "language": "python", 212 | "name": "python3" 213 | }, 214 | "language_info": { 215 | "codemirror_mode": { 216 | "name": "ipython", 217 | "version": 3 218 | }, 219 | "file_extension": ".py", 220 | "mimetype": "text/x-python", 221 | "name": "python", 222 | "nbconvert_exporter": "python", 223 | "pygments_lexer": "ipython3", 224 | "version": "3.10.11" 225 | } 226 | }, 227 | "nbformat": 4, 228 | "nbformat_minor": 5 229 | } 230 | -------------------------------------------------------------------------------- /poetry.toml: -------------------------------------------------------------------------------- 1 | [virtualenvs] 2 | create = true 3 | in-project = true 4 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "rag-for-math-qa" 3 | version = "0.2.0" 4 | description = "Experimentation/analysis code and scripts" 5 | authors = [ 6 | "Zachary Levonian " 7 | ] 8 | license = "MIT" 9 | readme = "README.md" 10 | packages = [{include = "experiment", from = "src"}, {include = "rag", from = "src"}] 11 | repository = "https://github.com/DigitalHarborFoundation/rag-for-math-qa.git" 12 | 13 | [tool.poetry.dependencies] 14 | python = ">=3.10,<3.12" 15 | poetry = "1.6.1" 16 | streamlit = "^1.23.1" 17 | pandas = "^2.0.2" 18 | scipy = "^1.11.0" 19 | argon2-cffi = "^21.3.0" 20 | sqlalchemy = "^2.0.20" 21 | psycopg2-binary = "^2.9.7" 22 | numpy = "^1.24.3" 23 | spacy = "^3.6.1" 24 | statsmodels = "^0.14.0" 25 | tabulate = "^0.9.0" 26 | scikit-learn = "^1.3.0" 27 | openai = "^0.28.0" 28 | tiktoken = "^0.4.0" 29 | python-dotenv = "^1.0.0" 30 | evaluate = "^0.4.0" 31 | bert-score = "^0.3.13" 32 | bleurt = {git = "https://github.com/google-research/bleurt.git"} 33 | tensorflow = {version = "^2.13.0" } 34 | tensorflow-macos = { version = "^2.13.0", platform = "darwin", markers = "platform_machine=='arm64'" } 35 | tensorflow-intel = { version = "^2.13.0", platform = "win32" } 36 | tensorflow-cpu = [ 37 | { version = "^2.13.0", platform = "linux", markers = "platform_machine!='arm64' and platform_machine!='aarch64'" }, 38 | { version = "^2.13.0", platform = "darwin", markers = "platform_machine!='arm64' and platform_machine!='aarch64'" },] 39 | tensorflow-cpu-aws = { version = "^2.13.0", platform = "linux", markers = "platform_machine=='arm64' or platform_machine=='aarch64'" } 40 | # https://github.com/tensorflow/tensorflow/blob/adb39b04e9cb116df4659a7e2de9eea27e62f25c/tensorflow/tools/pip_package/setup.py#L107-L108 41 | # https://github.com/python-poetry/poetry/issues/8271#issuecomment-1697740447 42 | tensorflow-io-gcs-filesystem = [ 43 | { version = ">= 0.23.1", markers = "platform_machine!='arm64' or platform_system!='Darwin'" }, 44 | { version = "< 0.32.0", markers = "platform_system == 'Windows'" } 45 | ] 46 | datasets = "2.10.0" 47 | krippendorff = "^0.6.0" 48 | skrub = {git = "https://github.com/skrub-data/skrub.git"} 49 | 50 | [tool.poetry.group.dev.dependencies] 51 | jupyter = "^1.0.0" 52 | matplotlib = "^3.7.1" 53 | black = "^22.12.0" 54 | isort = "^5.12" 55 | flake8 = "^6.0.0" 56 | nbqa = "^1.6.0" 57 | pre-commit = "^2.21.0" 58 | pytest = "^7.2.1" 59 | pytest-cov = "^4.0.0" 60 | jupyterlab = "^4.0.2" 61 | 62 | [build-system] 63 | requires = ["poetry-core"] 64 | build-backend = "poetry.core.masonry.api" 65 | 66 | [tool.black] 67 | line-length = 120 68 | include = '\.pyi?$' 69 | exclude = ''' 70 | /( 71 | .eggs # exclude a few common directories in the 72 | | .git # root of the project 73 | | .github 74 | | .gitignore 75 | | .hg 76 | | .mypy_cache 77 | | .tox 78 | | .venv 79 | | venv 80 | | _build 81 | | buck-out 82 | | build 83 | | ci 84 | | data 85 | | dist 86 | | docs 87 | | docsrc 88 | )/ 89 | ''' 90 | 91 | [tool.isort] 92 | profile = "black" 93 | line_length = 79 94 | multi_line_output = 3 95 | include_trailing_comma = true 96 | virtual_env = "venv" 97 | -------------------------------------------------------------------------------- /src/experiment/auth.py: -------------------------------------------------------------------------------- 1 | def generate_auth_token() -> str: 2 | import binascii 3 | import os 4 | 5 | auth_token = binascii.b2a_hex(os.urandom(16)).decode("ascii") 6 | return auth_token 7 | 8 | 9 | def cast_unicode(s: bytes | str, encoding: str) -> str: 10 | if isinstance(s, bytes): 11 | return s.decode(encoding, "replace") 12 | return s 13 | 14 | 15 | def passwd_hash(passphrase: str) -> str: 16 | from argon2 import PasswordHasher 17 | 18 | ph = PasswordHasher( 19 | memory_cost=10240, 20 | time_cost=10, 21 | parallelism=8, 22 | ) 23 | h = ph.hash(passphrase) 24 | 25 | return ":".join(("argon2", cast_unicode(h, "ascii"))) 26 | 27 | 28 | def passwd_check(hashed_passphrase: str, passphrase: str) -> bool: 29 | # modification of source provided with a BSD 3-Clause License, Copyright (c) 2015-, Jupyter Development Team 30 | # from notebook.auth.security 31 | assert hashed_passphrase.startswith("argon2:") 32 | import argon2 33 | import argon2.exceptions 34 | 35 | ph = argon2.PasswordHasher() 36 | try: 37 | return ph.verify(hashed_passphrase[7:], passphrase) 38 | except argon2.exceptions.VerificationError: 39 | return False 40 | -------------------------------------------------------------------------------- /src/experiment/completion_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | 4 | import openai 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | def get_completion_noraise( 10 | messages: list, 11 | sleep: float = 0.1, 12 | should_log_successful: bool = False, 13 | **kwargs, 14 | ) -> str | None: 15 | """Function wrapper that swallows exceptions, intended to use with multiprocessing. 16 | 17 | Returns: 18 | str | None: The completion, or None if an exception was raised. 19 | """ 20 | try: 21 | generation = get_completion_with_wait(messages, sleep=sleep, **kwargs) 22 | if should_log_successful: 23 | logger.info("Successful completion.") 24 | return generation 25 | except Exception as ex: 26 | logger.warning(f"get_completion_noraise returning None due to {type(ex).__name__} error: {ex}") 27 | return None 28 | 29 | 30 | def get_completion_with_wait(messages: list, sleep: float = 0.1, **kwargs) -> str: 31 | generation = get_completion_with_retries(messages, **kwargs) 32 | if sleep: 33 | time.sleep(sleep) # being a bit polite on repeated api calls 34 | return generation 35 | 36 | 37 | def get_completion_with_retries( 38 | messages: list, 39 | max_attempts: int = 3, 40 | sleep_time_between_attempts: float = 5, 41 | **kwargs, 42 | ) -> str: 43 | """Could use a library for this, but let's keep it simple. 44 | 45 | Args: 46 | messages (list): _description_ 47 | max_attempts (int, optional): Defaults to 3. 48 | sleep_time (float, optional): Defaults to 5 (seconds). 49 | 50 | Returns: 51 | str: The completion 52 | """ 53 | n_attempts = 0 54 | while n_attempts < max_attempts: 55 | n_attempts += 1 56 | try: 57 | return get_completion(messages, **kwargs) 58 | except Exception as ex: 59 | logger.warning(f"Failure on attempt {n_attempts} / {max_attempts}: {type(ex).__name__} {ex}") 60 | if n_attempts == max_attempts: 61 | raise ex 62 | time.sleep(sleep_time_between_attempts * n_attempts) 63 | raise ValueError( 64 | f"Exceeded max attempts ({max_attempts}), base sleep interval {sleep_time_between_attempts}s; this error indicates an unexpected logical flow", 65 | ) 66 | 67 | 68 | def get_completion(messages: list, model_name: str = "gpt-3.5-turbo-0613", request_timeout: float = 20) -> str: 69 | completion = openai.ChatCompletion.create( 70 | model=model_name, 71 | messages=messages, 72 | request_timeout=request_timeout, 73 | ) 74 | assistant_message = completion["choices"][0]["message"]["content"] 75 | return assistant_message 76 | -------------------------------------------------------------------------------- /src/experiment/generate.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import json 3 | import multiprocessing as mp 4 | import time 5 | from pathlib import Path 6 | from typing import Callable 7 | 8 | from experiment import completion_utils 9 | 10 | 11 | def is_valid_generation(generation: dict): 12 | return "messages" in generation and "generation" in generation 13 | 14 | 15 | class GenerationCorpus: 16 | def __init__(self, output_dir: Path, key: str, overwrite: bool = False): 17 | output_dir.mkdir(exist_ok=True) 18 | self.output_dir = output_dir 19 | self.key = key 20 | self.generation_filepath = output_dir / f"{key}_generations.ndjson" 21 | self.generations = [] 22 | if not overwrite: 23 | self.load() 24 | 25 | def load(self): 26 | if self.generation_filepath.exists(): 27 | with open(self.generation_filepath) as infile: 28 | for line in infile: 29 | d = json.loads(line) 30 | self.generations.append(d) 31 | 32 | def overwrite(self): 33 | self._save_generations(self.generations, write_mode="w") 34 | 35 | def filter_generations( 36 | self, 37 | should_include_func: Callable = is_valid_generation, 38 | should_remove_func: Callable | None = None, 39 | ) -> int: 40 | filtered_generations = [] 41 | for generation in self.generations: 42 | if should_include_func(generation) and (should_remove_func is None or not should_remove_func(generation)): 43 | filtered_generations.append(generation) 44 | n_removed = len(self.generations) - len(filtered_generations) 45 | self.generations = filtered_generations 46 | return n_removed 47 | 48 | def _save_generation(self, metadata: dict): 49 | with open(self.generation_filepath, "a") as outfile: 50 | outfile.write(json.dumps(metadata) + "\n") 51 | 52 | def _save_generations(self, metadata_list: list[dict], write_mode: str = "a"): 53 | with open(self.generation_filepath, write_mode) as outfile: 54 | for metadata in metadata_list: 55 | outfile.write(json.dumps(metadata) + "\n") 56 | 57 | def is_already_generated( 58 | self, 59 | messages: list, 60 | metadata: dict | None, 61 | exclude_keys: set[str] = {"generation", "messages"}, 62 | ) -> bool: 63 | """Determine if a generation was already created for this set of messages (and corresponding metadata). 64 | 65 | Args: 66 | messages (list): Message list to pass to the OpenAI API. 67 | metadata (dict | None): Optional metadata associated with the generation. 68 | exclude_keys (set[str], optional): Metadata keys to ignore when determining if this is a duplicate. Defaults to {"generation", "messages"}. 69 | 70 | Returns: 71 | bool: True if already in self.generations, False otherwise. 72 | """ 73 | if metadata is None: 74 | metadata = {} 75 | for generation in self.generations: 76 | assert "messages" in generation 77 | assert "generation" in generation 78 | if generation["messages"] == messages and generation["generation"] is not None: 79 | is_metadata_match = True 80 | for key, value in metadata.items(): 81 | if key in exclude_keys: 82 | continue 83 | if key not in generation or generation[key] != value: 84 | is_metadata_match = False 85 | if is_metadata_match: 86 | return True 87 | return False 88 | 89 | def generate( 90 | self, 91 | messages: list, 92 | metadata: dict | None, 93 | sleep: float | None = 0.1, 94 | ) -> bool: 95 | """Generate a new ChatCompletion. 96 | 97 | Args: 98 | messages (list): List of messages. 99 | metadata (dict | None): Metadata to save with the completion. 100 | sleep (float | None, optional): Time to wait after this request, in seconds. Defaults to 0.1. 101 | 102 | Returns: 103 | bool: True if a new generation was created and saved, False otherwise. 104 | """ 105 | if metadata is None: 106 | metadata = {} 107 | if self.is_already_generated(messages, metadata): 108 | return False 109 | generation = completion_utils.get_completion_with_retries(messages) 110 | metadata["messages"] = messages 111 | metadata["generation"] = generation 112 | self.generations.append(metadata) 113 | self._save_generation(metadata) 114 | if sleep: 115 | time.sleep(sleep) # being a bit polite on repeated api calls 116 | return True 117 | 118 | def batch_filter_not_already_generated(self, metadata_list: list[dict]) -> list[dict]: 119 | metadata_to_process = [] 120 | for metadata in metadata_list: 121 | if "messages" not in metadata: 122 | raise ValueError("Expected 'messages' in all provided metadata.") 123 | if not self.is_already_generated(metadata["messages"], metadata): 124 | metadata_to_process.append(metadata) 125 | return metadata_to_process 126 | 127 | def get_nonmatching_generations( 128 | self, 129 | metadata_list: list[dict], 130 | exclude_keys: set[str] = {"generation", "messages"}, 131 | should_remove_nonmatching: bool = False, 132 | ) -> list[dict]: 133 | nonmatching_generations = [] 134 | nonmatching_inds = [] 135 | for i, generation in enumerate(self.generations): 136 | generation_match_found = False 137 | for metadata in metadata_list: 138 | is_metadata_match = True 139 | for key, value in metadata.items(): 140 | if key in exclude_keys: 141 | continue 142 | if key not in generation or generation[key] != value: 143 | is_metadata_match = False 144 | break 145 | if is_metadata_match: 146 | generation_match_found = True 147 | break 148 | if not generation_match_found: 149 | nonmatching_generations.append(generation) 150 | nonmatching_inds.append(i) 151 | if should_remove_nonmatching: 152 | for ind in sorted(nonmatching_inds, reverse=True): 153 | self.generations.pop(ind) 154 | return nonmatching_generations 155 | 156 | def batch_generate( 157 | self, 158 | metadata_list: list[dict], 159 | n_processes: int = 4, 160 | sleep: float = 0.1, 161 | completion_func: Callable = completion_utils.get_completion_noraise, 162 | **kwargs, 163 | ) -> int: 164 | """_summary_ 165 | 166 | Args: 167 | metadata_list (list[dict]): List of metadata dictionaries, that must each include a 'messages' key with a list of messages. 168 | n_processes (int, optional): # processes to spawn in the pool. Defaults to 4. 169 | sleep (float, optional): Time to sleep between requests IN EACH PROCESS. Defaults to 0.1. 170 | completion_func (Callable, optional): Function to use to produce generations from a list of messages. Defaults to completion_utils.get_completion_noraise. 171 | Other keyword args are passed to completion_func 172 | 173 | Raises: 174 | ValueError: If 'messages' is not included in one of the metadata dicts. 175 | 176 | Returns: 177 | int: Number of new generations. Note this MAY imply generations failed if < len(metadata_list), but only if no metadata were already generated. 178 | """ 179 | metadata_to_process = self.batch_filter_not_already_generated(metadata_list) 180 | if len(metadata_to_process) == 0: 181 | return 0 182 | get_completion_func = functools.partial(completion_func, sleep=sleep, **kwargs) 183 | with mp.Pool(processes=n_processes) as pool: 184 | message_lists = (md["messages"] for md in metadata_to_process) 185 | results = pool.map(get_completion_func, message_lists) 186 | assert len(results) == len(metadata_to_process) 187 | metadata_completed = [] 188 | for metadata, result in zip(metadata_to_process, results): 189 | if result is not None: 190 | metadata["generation"] = result 191 | metadata_completed.append(metadata) 192 | if len(metadata_completed) > 0: 193 | self.generations.extend(metadata_completed) 194 | self._save_generations(metadata_completed) 195 | return len(metadata_completed) 196 | -------------------------------------------------------------------------------- /src/experiment/guidance_conditions.py: -------------------------------------------------------------------------------- 1 | none = [ 2 | { 3 | "role": "system", 4 | "content": """You are going to act as a mathematics tutor for a 13 year old student who is in grade 8 or 9 and lives in Ghana. 5 | You will be encouraging and factual. 6 | Prefer simple, short responses. 7 | If the student says something inappropriate or off topic you will say you can only focus on mathematics and ask them if they have any math-related follow-up questions.""", 8 | }, 9 | ] 10 | 11 | low = [ 12 | { 13 | "role": "system", 14 | "content": """You are going to act as a mathematics tutor for a 13 year old student who is in grade 8 or 9 and lives in Ghana. 15 | You will be encouraging and factual. 16 | 17 | Only if it is relevant, examples and language from the section below may be helpful to format your response: 18 | === 19 | {rori_microlesson_texts} 20 | {openstax_subsection_texts} 21 | === 22 | 23 | Prefer simple, short responses. 24 | If the student says something inappropriate or off topic you will say you can only focus on mathematics and ask them if they have any math-related follow-up questions.""", 25 | }, 26 | ] 27 | 28 | medium = [ 29 | { 30 | "role": "system", 31 | "content": """You are going to act as a mathematics tutor for a 13 year old student who is in grade 8 or 9 and lives in Ghana. 32 | You will be encouraging and factual. 33 | 34 | Use examples and language from the section below to format your response: 35 | === 36 | {rori_microlesson_texts} 37 | {openstax_subsection_texts} 38 | === 39 | 40 | Prefer simple, short responses. 41 | If the student says something inappropriate or off topic you will say you can only focus on mathematics and ask them if they have any math-related follow-up questions.""", 42 | }, 43 | ] 44 | 45 | high = [ 46 | { 47 | "role": "system", 48 | "content": """You are going to act as a mathematics tutor for a 13 year old student who is in grade 8 or 9. 49 | This student lives in Ghana or Nigeria. 50 | You will be encouraging and factual. 51 | Prefer simple, short responses based on the textbook. 52 | If the student says something inappropriate or off topic you will say you can only focus on mathematics and ask them if they have any math-related follow-up questions.""", 53 | }, 54 | { 55 | "role": "user", 56 | "content": """Answer the following question: {user_query} 57 | 58 | Reference content from this textbook section in your response: 59 | {openstax_subsection_texts} 60 | 61 | End your response by relating the question to an example or definition in the textbook.""", 62 | }, 63 | ] 64 | 65 | extract_relevant = [ 66 | { 67 | "role": "user", 68 | "content": """Given a middle-school math student's question, you will identify the most relevant section from a textbook. 69 | 70 | Student question: {user_query} 71 | 72 | Repeat the student's question and then repeat in full the most relevant paragraph from my math textbook. If none of them seem relevant, take a deep breath and output the most relevant. Don't say anything else. 73 | 74 | Textbook paragraphs: 75 | 76 | {rori_microlesson_texts} 77 | {openstax_subsection_texts}""", 78 | }, 79 | ] 80 | 81 | # TODO consider if these guidance conditions could include their own dbinfo settings as well 82 | guidance_condition_messages_map = { 83 | "none": none, 84 | "low": low, 85 | # "medium": medium, 86 | "high": high, 87 | "extract_relevant": extract_relevant, 88 | } 89 | guidance_condition_name_list = [key for key in guidance_condition_messages_map.keys()] 90 | -------------------------------------------------------------------------------- /src/experiment/metrics.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import evaluate 4 | 5 | import experiment.tokenize 6 | 7 | 8 | def compute_macro_f1(passages: list[str], generation: str, discount_text: str | None = None) -> float: 9 | """Returns the max F1 across all the passages. 10 | Depending on arguments, this can be Knowledge F1 or just F1. 11 | 12 | SQuAD paper (http://arxiv.org/abs/1606.05250): 13 | "This metric measures the average overlap between the prediction and ground truth answer. 14 | We treat the prediction and ground truth as bags of tokens, and compute their F1. 15 | We take the maximum F1 over all of the ground truth answers for a given question, and then average over all of the questions." 16 | 17 | K-F1++ (https://aclanthology.org/2023.findings-acl.60): 18 | "Knowledge-F1 (K-F1) ... calculates the unigram overlap between the response and a knowledge snippet K, 19 | providing a verbatim measure of grounding to the input source. 20 | We propose K-F1++, a variant of K-F1, 21 | that captures only the novel information in the generated response and discounts any lexical alignment to the question: 22 | it calculates the unigram overlap between the response and K, 23 | after subtracting any tokens appearing in the question from the response." 24 | To use K-F1++, pass in the text to ignore to discount_text. 25 | """ 26 | # first create the tokenization function to use 27 | get_tokens = functools.partial(experiment.tokenize.get_tokens, lower=True, remove_nonalphanumeric_tokens=True) 28 | generation_tokens = set(get_tokens(generation)) 29 | if discount_text: 30 | discount_tokens = set(get_tokens(discount_text)) 31 | generation_tokens -= discount_tokens 32 | n_predicted_tokens = len(generation_tokens) 33 | if n_predicted_tokens == 0: 34 | raise ValueError("Expected generation to be non-empty.") 35 | f1_scores = [] 36 | for passage in passages: 37 | passage_tokens = set(get_tokens(passage)) 38 | if discount_text: 39 | passage_tokens -= discount_tokens 40 | n_ground_truth_tokens = len(passage_tokens) 41 | if n_ground_truth_tokens == 0: 42 | continue 43 | n_correct_tokens = len(passage_tokens & generation_tokens) 44 | precision = n_correct_tokens / n_predicted_tokens 45 | recall = n_correct_tokens / n_ground_truth_tokens 46 | if precision + recall == 0: 47 | f1 = 0 48 | else: 49 | f1 = 2 * (precision * recall) / (precision + recall) 50 | f1_scores.append(f1) 51 | if len(f1_scores) == 0: 52 | raise ValueError("No non-empty passages.") 53 | max_f1 = max(f1_scores) 54 | return max_f1 55 | 56 | 57 | @functools.cache 58 | def get_bertscore_metric_object() -> evaluate.EvaluationModule: 59 | """ 60 | See https://huggingface.co/spaces/evaluate-metric/bertscore 61 | """ 62 | return evaluate.load("bertscore") 63 | 64 | 65 | @functools.cache 66 | def get_bleurt_metric_object(checkpoint: str = "bleurt-20") -> evaluate.EvaluationModule: 67 | """See https://huggingface.co/spaces/evaluate-metric/bleurt 68 | 69 | According to https://github.com/google-research/bleurt, the current recommended checkpoint is BLEURT-20. 70 | TODO use that checkpoiont?? 71 | 72 | Args: 73 | checkpoint (str, optional): bleurt-base-512, bleurt-large-512, etc. Defaults to "bleurt-20". 74 | 75 | Returns: 76 | evaluate.EvaluationModule: The metric object 77 | """ 78 | return evaluate.load( 79 | "bleurt", 80 | checkpoint=checkpoint, 81 | module_type="metric", 82 | # download_config=datasets.DownloadConfig(use_etag=False), 83 | ) 84 | 85 | 86 | def compute_bertscore(passages: list[str], generation: str): 87 | bert_score = get_bertscore_metric_object() 88 | references = passages 89 | predictions = [generation] * len(references) 90 | score_dict = bert_score.compute(predictions=predictions, references=references, lang="en") 91 | return max(score_dict["f1"]) 92 | 93 | 94 | def compute_bleurt(passages: list[str], generation: str, compare_to_combined_passage: bool = False): 95 | bleurt = get_bleurt_metric_object() 96 | references = passages 97 | if compare_to_combined_passage: 98 | references += ["\n".join(passages)] 99 | predictions = [generation] * len(references) 100 | scores = bleurt.compute(predictions=predictions, references=references)["scores"] 101 | return max(scores) 102 | 103 | 104 | def compute_bleurt_batch(passages_list: list[list[str]], generation_list: list[str]): 105 | # TODO we can make this faster by implementing this 106 | assert len(passages_list) == len(generation_list) 107 | raise NotImplementedError("Batch scoring not yet implemented.") 108 | -------------------------------------------------------------------------------- /src/experiment/qualtrics.py: -------------------------------------------------------------------------------- 1 | import html 2 | import json 3 | import logging 4 | import math 5 | from datetime import datetime 6 | from pathlib import Path 7 | 8 | import numpy as np 9 | import pandas as pd 10 | 11 | OVERFLOW_RESPONSE = "No more responses to annotate" 12 | OVERFLOW_DOCUMENT = "n/a" 13 | OVERFLOW_QUERY = "No more student queries in this survey, just click through to the end." 14 | 15 | survey_id_key = "RoriSurveyId" 16 | response_keys = ["Response1Q", "Response2Q", "Response3Q"] 17 | query_text_key = "QueryTextQ" 18 | document_key = "DocumentQ" 19 | 20 | 21 | def chunker(seq, size): 22 | return (seq[pos : pos + size] for pos in range(0, len(seq), size)) 23 | 24 | 25 | def get_template(template_filepath: Path): 26 | with open(template_filepath) as infile: 27 | survey_text = infile.read() 28 | validate_template_survey_text(survey_text) 29 | return survey_text 30 | 31 | 32 | def validate_template_survey_text(survey_text, expected_survey_size: int = 15): 33 | assert json.loads(survey_text) 34 | 35 | # validate expected keys 36 | for key in response_keys + [query_text_key, document_key]: 37 | for i in range(1, expected_survey_size + 1): 38 | qkey = key + str(i) 39 | assert qkey in survey_text, qkey 40 | 41 | 42 | def convert_text(text, use_br=True): 43 | # not sure this replace is necessary, but I was having some issues with escaping 44 | text = text.replace("\\", "/") 45 | text = html.escape(text) 46 | if use_br: 47 | text = "

" + "


".join(text.split("\n")) + "

" 48 | else: 49 | text = "

" + "

".join(text.split("\n")) + "

" 50 | return text 51 | 52 | 53 | def create_surveys( 54 | df: pd.DataFrame, 55 | template_survey_text: str, 56 | survey_dir: Path, 57 | survey_size: int = 15, 58 | rng: np.random.Generator = None, 59 | ) -> pd.DataFrame: 60 | """Assumed columns: 61 | - generation 62 | - query 63 | - document 64 | """ 65 | if rng is None: 66 | rng = np.random.default_rng() 67 | rows = [] 68 | for query, group in df.groupby("query"): 69 | assert len(group) == 3, "Qualtrics survey hard-coded to accept 3 responses." 70 | assert group["generation"].nunique() == 3, "Generations/responses should be unique." 71 | for key in ["document"]: 72 | n_unique = group[key].nunique() 73 | if n_unique != 1: 74 | logging.warning(f"Expected 1 unique value in column {key}, found {n_unique}") 75 | # shuffle rows 76 | group = group.sample(frac=1, random_state=rng) 77 | # build new data structure 78 | responses = [] 79 | metas = [] 80 | for i in range(3): 81 | row = group.iloc[i] 82 | response = row.generation 83 | responses.append(response) 84 | meta = row.to_dict() 85 | del meta["document"] 86 | del meta["generation"] 87 | del meta["query"] 88 | metas.append(meta) 89 | row = [query, group.iloc[0]["document"], *responses, *metas] 90 | rows.append(row) 91 | # randomize row order 92 | rng.shuffle(rows) 93 | 94 | survey_dir.mkdir(exist_ok=True) 95 | for i, survey_rows in enumerate(chunker(rows, survey_size)): 96 | survey_id = f"s_{datetime.now().strftime('%Y%m%d')}_{i+1}/{math.ceil(len(rows) / survey_size)}" 97 | survey_text = create_survey(survey_id, survey_rows, template_survey_text, survey_size) 98 | survey_filepath = survey_dir / f"Rori_ranking_annotations_-_survey{i}.qsf" 99 | with open(survey_filepath, "w") as outfile: 100 | outfile.write(survey_text) 101 | for row in survey_rows: 102 | row.append(survey_id) 103 | 104 | # build survey_df from rows 105 | survey_df = pd.DataFrame( 106 | rows, 107 | columns=[ 108 | "query", 109 | "document", 110 | "response1", 111 | "response2", 112 | "response3", 113 | "response1_meta", 114 | "response2_meta", 115 | "response3_meta", 116 | "survey_id", 117 | ], 118 | ) 119 | return survey_df 120 | 121 | 122 | def create_survey(survey_id: str, rows: list, template_survey_text: str, survey_size: int = 15) -> str: 123 | survey_text = template_survey_text 124 | survey_text = survey_text.replace(survey_id_key, survey_id) 125 | for i in range(1, survey_size + 1): 126 | if i - 1 < len(rows): 127 | row = rows[i - 1] 128 | query, document, r1, r2, r3, _, _, _ = row 129 | else: 130 | query = OVERFLOW_QUERY 131 | document = OVERFLOW_DOCUMENT 132 | r1, r2, r3 = OVERFLOW_RESPONSE, OVERFLOW_RESPONSE, OVERFLOW_RESPONSE 133 | responses = [r1, r2, r3] 134 | for key, response in zip(response_keys, responses): 135 | qkey = key + str(i) 136 | text = convert_text(response, use_br=False) 137 | survey_text = survey_text.replace(qkey, text, 1) 138 | survey_text = survey_text.replace(query_text_key + str(i), convert_text(query), 1) 139 | survey_text = survey_text.replace(document_key + str(i), convert_text(document), 1) 140 | # verify that we've created valid JSON 141 | survey_text = json.dumps(json.loads(survey_text)) 142 | return survey_text 143 | -------------------------------------------------------------------------------- /src/experiment/tokenize.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import re 3 | 4 | from spacy.lang.en import English 5 | 6 | 7 | @functools.cache 8 | def get_spacy_english(): 9 | nlp = English() 10 | return nlp 11 | 12 | 13 | def get_tokens(string_to_tokenize: str, lower: bool = True, remove_nonalphanumeric_tokens: bool = False) -> list[str]: 14 | nlp = get_spacy_english() 15 | doc = nlp.tokenizer(string_to_tokenize) 16 | if lower: 17 | tokens = [t.text.lower() for t in doc] 18 | else: 19 | tokens = [t.text for t in doc] 20 | if remove_nonalphanumeric_tokens: 21 | tokens = [token for token in tokens if re.match("\\w+", token)] 22 | return tokens 23 | -------------------------------------------------------------------------------- /src/rag/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/src/rag/__init__.py -------------------------------------------------------------------------------- /src/rag/embedding_utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import numpy as np 4 | import openai 5 | import tiktoken 6 | 7 | EMBEDDING_DIM = 1536 8 | MAX_TOKENS_PER_REQUEST = 8191 9 | EMBEDDING_MODEL = "text-embedding-ada-002" 10 | 11 | 12 | def get_token_counts(text_list: list[str]) -> list[int]: 13 | """Given a list of texts, returns a list of the same length with the number of tokens from the EMBEDDING_MODEL tokenizer. 14 | 15 | Args: 16 | text_list (list[str]): Texts to tokenize. 17 | 18 | Returns: 19 | list[int]: Token counts corresponding to text_list. 20 | """ 21 | tokenizer = tiktoken.encoding_for_model(EMBEDDING_MODEL) 22 | token_counts = [] 23 | for string in text_list: 24 | token_counts.append(len(tokenizer.encode(string))) 25 | return token_counts 26 | 27 | 28 | def get_openai_embeddings( 29 | texts: list[str], 30 | embedding_model: str = EMBEDDING_MODEL, 31 | ) -> list[np.array]: 32 | """Given the list of texts, query the openai Embedding API, returning the embeddings in a list of numpy arrays. 33 | Note: calls through to a cahced function. 34 | 35 | Args: 36 | texts (list[str]): List of texts to embed. 37 | embedding_model (str, optional): Embedding model to use. Defaults to EMBEDDING_MODEL. 38 | 39 | Returns: 40 | list[np.array]: Embeddings, in the same order as the given texts. 41 | """ 42 | return get_openai_embeddings_cached(tuple(texts), embedding_model=embedding_model) 43 | 44 | 45 | @functools.lru_cache(maxsize=512, typed=True) 46 | def get_openai_embeddings_cached( 47 | texts: tuple[str], 48 | embedding_model: str = EMBEDDING_MODEL, 49 | ) -> list[np.array]: 50 | result = openai.Embedding.create(input=texts, engine=embedding_model) 51 | embedding_list = [np.array(d["embedding"]) for d in result.data] 52 | return embedding_list 53 | 54 | 55 | def batch_embed_texts( 56 | input_text_list: list[str], 57 | n_tokens_list: list[int], 58 | ) -> list[np.array]: 59 | """Embed the given texts, respecting the API max tokens limit given MAX_TOKENS_PER_REQUEST. 60 | 61 | Args: 62 | input_text_list (list[str]): Texts to embed. 63 | n_tokens_list (list[int]): As returned by `get_token_counts` 64 | 65 | Returns: 66 | list[np.array]: List of embeddings, stored in numpy arrays. 67 | """ 68 | curr_batch_token_count = 0 69 | texts = [] 70 | embedding_list = [] 71 | for text, n_tokens in zip(input_text_list, n_tokens_list): 72 | if curr_batch_token_count + n_tokens > MAX_TOKENS_PER_REQUEST: 73 | embedding_list.extend(get_openai_embeddings(texts, EMBEDDING_MODEL)) 74 | texts = [text] 75 | curr_batch_token_count = 0 76 | else: 77 | texts.append(text) 78 | curr_batch_token_count += n_tokens 79 | if len(texts) > 0: 80 | embedding_list.extend(get_openai_embeddings(texts, EMBEDDING_MODEL)) 81 | return embedding_list 82 | -------------------------------------------------------------------------------- /src/rag/gpf_utils.py: -------------------------------------------------------------------------------- 1 | def get_gpd_codes(lesson_code): 2 | """ 3 | The Global Proficiency Framework (GPF) uses GPD codes. 4 | This is a slight variant of GPD codes that includes the grade before the GPD code. 5 | 6 | Structure is: G.... 7 | 8 | e.g. G9.N5.1.3.1 has: 9 | grade 9 10 | domain N 11 | construct N5 12 | subconstruct N5.1 13 | skill N5.1.3 14 | index 1 15 | """ 16 | tokens = lesson_code.split(".") 17 | grade = int(tokens[0][1]) 18 | index = int(tokens[-1]) 19 | 20 | skill = ".".join(tokens[1:-1]) 21 | 22 | domain = tokens[1][0] 23 | construct = tokens[1] 24 | subconstruct = tokens[1] + "." + tokens[2] 25 | return grade, domain, construct, subconstruct, skill, index 26 | -------------------------------------------------------------------------------- /src/rag/logit_bias.py: -------------------------------------------------------------------------------- 1 | # Utilities associated with logit_bias 2 | # docs: https://platform.openai.com/docs/api-reference/chat/create#logit_bias 3 | # help doc: https://help.openai.com/en/articles/5247780-using-logit-bias-to-define-token-probability 4 | # logit_bias takes at most 300 tokens: https://aidungeon.medium.com/controlling-gpt-3-with-logit-bias-55866d593292 5 | import functools 6 | import importlib.resources 7 | import json 8 | import re 9 | import string as string_utils 10 | from collections import Counter 11 | 12 | import tiktoken 13 | 14 | from rag import resources 15 | 16 | 17 | @functools.cache 18 | def get_tokenizer(model_name: str = "gpt-3.5-turbo") -> tiktoken.Encoding: 19 | """Get the tokenizer. Cached. 20 | 21 | Args: 22 | model_name (str, optional): The model tokenizer to load. Defaults to "gpt-3.5-turbo". 23 | 24 | Returns: 25 | tiktoken.Encoding: The tiktoken/OpenAI tokenizer. 26 | """ 27 | tokenizer = tiktoken.encoding_for_model(model_name) 28 | return tokenizer 29 | 30 | 31 | @functools.cache 32 | def get_stopword_tokens(): 33 | """Cached version of `load_stopword_tokens`.""" 34 | return load_stopword_tokens() 35 | 36 | 37 | def load_stopword_tokens() -> set[int]: 38 | resource_filepath = importlib.resources.files(resources) / "dolma_stopwords.json" 39 | with resource_filepath.open("r") as infile: 40 | stopwords_dict = json.load(infile) 41 | return set(stopwords_dict["stopword_tokens"]) 42 | 43 | 44 | def create_stopword_token_set_from_word_list(word_list: list[str]) -> set[int]: 45 | """Create a set of stopword tokens from the given list of stop words. 46 | Used to create the default stopword resource loaded by `load_stopword_tokens`. 47 | 48 | Args: 49 | word_list (list[str]): List of words to include in the stopword set. 50 | 51 | Returns: 52 | set[int]: Set of stopword tokens. 53 | """ 54 | tokenizer = get_tokenizer() 55 | stopword_tokens = set() 56 | stopwords = ( 57 | word_list 58 | + list(map(str.lower, word_list)) 59 | + list(map(str.upper, word_list)) 60 | + list(map(str.capitalize, word_list)) 61 | + list(string_utils.whitespace) 62 | + list(string_utils.punctuation) 63 | ) 64 | for word in stopwords: 65 | for char in string_utils.whitespace + string_utils.punctuation: 66 | for string in [word, char + word, word + char]: 67 | tokens = tokenizer.encode(string) 68 | if len(tokens) == 1: 69 | stopword_tokens.add(tokens[0]) 70 | return stopword_tokens 71 | 72 | 73 | def get_nonstopword_tokens(text: str) -> list[int]: 74 | tokenizer = get_tokenizer() 75 | stopword_tokens = get_stopword_tokens() 76 | tokens = tokenizer.encode(text) 77 | tokens = [ 78 | token 79 | for token in tokens 80 | if token not in stopword_tokens and re.fullmatch("[ A-Za-z]+", tokenizer.decode([token])) 81 | ] 82 | return tokens 83 | 84 | 85 | def get_logit_bias( 86 | tokens: list[int], 87 | min_count: int = 2, 88 | n_tokens: int | None = None, 89 | max_tokens: int = 50, 90 | min_bias: float = 1.0, 91 | max_bias: float = 5.0, 92 | ) -> dict[int, float]: 93 | """Given a list of tokens, create a corresponding logit_bias dictionary. 94 | 95 | Roughly, identifies the most frequent max_tokens tokens, stopping at n_tokens if provided, 96 | that occur at least min_count times. Bias values are assigned based on frequency, in the range min_bias to max_bias. 97 | The most frequent token will always have a weight of max_bias in the resulting logit_bias. 98 | Bias defaults are generally inspired by this doc: https://help.openai.com/en/articles/5247780-using-logit-bias-to-define-token-probability 99 | 100 | Args: 101 | tokens (list[int]): The list of tokens, e.g. after having stopword tokens removed. 102 | min_count (int, optional): Defaults to 2. 103 | n_tokens (int | None, optional): Defaults to None. 104 | max_tokens (int, optional): Defaults to 50. 105 | min_bias (float, optional): Defaults to 1.0. 106 | max_bias (float, optional): Defaults to 5.0. 107 | 108 | Returns: 109 | dict[int, float]: The logit_bias dict that can be passed to the logit_bias parameter accepted by the OpenAI API. 110 | """ 111 | if len(tokens) == 0: 112 | return {} 113 | logit_bias = {} 114 | c = Counter(tokens).most_common(max_tokens) 115 | max_count = c[0][1] # count of most-frequently-occurring token 116 | if max_count >= min_count: 117 | for token, count in c: 118 | if count < min_count: 119 | continue 120 | bias = min_bias + (max_bias - min_bias) * (count / max_count) 121 | logit_bias[token] = bias 122 | if n_tokens is not None and len(logit_bias) >= n_tokens: 123 | break 124 | return logit_bias 125 | 126 | 127 | def get_logit_bias_from_slot( 128 | recent_slot_fill_dict: list[dict[str, str]], 129 | include: list[str] | None = None, 130 | exclude: list[str] = [], 131 | **kwargs, 132 | ) -> dict[int, float]: 133 | """Given texts that fill one or more slots, create an appropriate logit_bias. 134 | This is probably not a very reasonable way to do to this. 135 | 136 | Args: 137 | recent_slot_fill_dict (list[dict[str, str]]): See `prompt_utils.PromptManager`. 138 | include (list[str] | None, optional): Slot texts to consider. Defaults to None, meaning all slots are included. 139 | exclude (list[str], optional): Slot texts to ignore. Defaults to []. 140 | 141 | Returns: 142 | dict[int, float]: logit_bias 143 | """ 144 | texts = [] 145 | for slot_fill_dict in recent_slot_fill_dict: 146 | for key, value in slot_fill_dict.items(): 147 | if (include is None or key in include) and key not in exclude: 148 | texts.append(value) 149 | text = "\n".join(texts) 150 | tokens = get_nonstopword_tokens(text) 151 | return get_logit_bias(tokens, **kwargs) 152 | -------------------------------------------------------------------------------- /src/rag/misconceptions.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import importlib.resources 3 | import json 4 | from operator import itemgetter 5 | 6 | from rag import resources 7 | 8 | 9 | @functools.cache 10 | def get_misconception_list(): 11 | """Cached version of `load_misconception_list`.""" 12 | return load_misconception_list() 13 | 14 | 15 | def load_misconception_list() -> list[dict[str, str]]: 16 | """Misconceptions resource generated in the `MisconceptionData.ipynb` notebook. 17 | 18 | Returns: 19 | list[dict[str, str]]: List of misconceptions. 20 | """ 21 | resource_filepath = importlib.resources.files(resources) / "misconceptions.ndjson" 22 | with resource_filepath.open("r") as infile: 23 | misconception_list = json.load(infile) 24 | return misconception_list 25 | 26 | 27 | @functools.cache 28 | def get_misconceptions_string() -> str: 29 | """Intended for use in prompts, this is a semi-colon delimited list of common math misconceptions. 30 | 31 | Returns: 32 | str: List of misconceptions that is approximately 1000 tokens (see notebook). 33 | """ 34 | misconception_list = get_misconception_list() 35 | misconception_list.sort(key=itemgetter("Topic", "ID")) 36 | descriptions = [] 37 | for row in misconception_list: 38 | if not row["ID"].startswith("MaE"): 39 | continue 40 | description = row["Misconception"] 41 | description = description.split(".")[0] 42 | s = description.replace(";", ",").replace("\n", " ").strip() 43 | descriptions.append(s) 44 | misconception_string = "; ".join(descriptions) 45 | return misconception_string 46 | -------------------------------------------------------------------------------- /src/rag/prompt_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import re 4 | 5 | from rag import embedding_utils, retrieval_strategies 6 | 7 | VALID_ROLES: list[str] = ["user", "assistant", "system"] 8 | 9 | 10 | class PromptSelector: 11 | """PromptSelector provides utilities to enumerate and choose prompts. 12 | 13 | Prompts are stored in dictionaries in the `prompts` modules. 14 | """ 15 | 16 | def __init__(self, intro_prompt_dict: dict): 17 | self.intro_prompt_dict = intro_prompt_dict 18 | self.pretty_name_to_id_map = { 19 | t[1]["pretty_name"] if "pretty_name" in t[1] else f"Prompt {i}": t[0] 20 | for i, t in enumerate(self.intro_prompt_dict) 21 | } 22 | 23 | def get_intro_prompt_pretty_names(self): 24 | pretty_name_list = [] 25 | for i, prompt_info in enumerate(self.intro_prompt_dict.values()): 26 | pretty_name = prompt_info["pretty_name"] if "pretty_name" in prompt_info else f"Prompt {i}" 27 | pretty_name_list.append(pretty_name) 28 | return pretty_name_list 29 | 30 | def get_intro_prompt_message_lists(self) -> list[dict[str, str]]: 31 | message_lists = [] 32 | for prompt_info in self.intro_prompt_dict.values(): 33 | message_lists.append(prompt_info["messages"]) 34 | return message_lists 35 | 36 | def get_default_intro_prompt(self) -> dict[str]: 37 | return self.intro_prompt_dict[next(iter(self.intro_prompt_dict.keys()))] 38 | 39 | def convert_conversation_to_string(messages): 40 | conversation_string = "" 41 | for message in messages: 42 | conversation_string += message["role"].upper() + ":\n" 43 | conversation_string += message["content"] + "\n" 44 | return conversation_string 45 | 46 | def convert_string_to_conversation( 47 | conversation_string: str, 48 | ) -> list[dict[str, str]]: 49 | """Given a string representing a conversation, convert into the expected messages list format. 50 | 51 | Follows a pretty basic convention, defined in this implementation. 52 | 53 | Args: 54 | conversation_string (str): String representing a conversation. 55 | 56 | Returns: 57 | list[dict[str, str]]: List of messages, each with a "role" and "content". 58 | """ 59 | messages = [] 60 | message = { 61 | "content": "", 62 | } 63 | for line in conversation_string.split("\n"): 64 | possible_role = line[:-1].lower() 65 | if possible_role in VALID_ROLES: 66 | if "role" in message: 67 | message["content"] = message["content"].strip() 68 | messages.append(message) 69 | message = { 70 | "content": "", 71 | } 72 | message["role"] = possible_role 73 | else: 74 | message["content"] += line + "\n" 75 | if "role" in message: 76 | message["content"] = message["content"].strip() 77 | messages.append(message) 78 | return messages 79 | 80 | 81 | class PromptManager: 82 | """Stores prompts and generates message lists for passing to the OpenAI API.""" 83 | 84 | def __init__(self): 85 | self.intro_messages: list[dict[str, str]] = [] 86 | self.retrieval_strategy: retrieval_strategies.RetrievalStrategy = retrieval_strategies.NoRetrievalStrategy() 87 | self.stored_messages: list[dict[str, str]] = [] 88 | self.most_recent_slot_fill_dict: dict[str, str] = {} 89 | self.recent_slot_fill_dict: list[dict[str, str]] = [] 90 | 91 | def set_intro_messages(self, intro_messages: list[dict[str, str]]) -> PromptManager: 92 | self.intro_messages = intro_messages 93 | return self 94 | 95 | def set_retrieval_strategy( 96 | self, 97 | retrieval_strategy: retrieval_strategies.RetrievalStrategy, 98 | ) -> PromptManager: 99 | self.retrieval_strategy = retrieval_strategy 100 | return self 101 | 102 | def get_retrieval_strategy(self) -> retrieval_strategies.RetrievalStrategy: 103 | return self.retrieval_strategy 104 | 105 | def add_stored_message(self, message: dict[str, str]) -> PromptManager: 106 | self.stored_messages.append(message) 107 | return self 108 | 109 | def clear_stored_messages(self) -> PromptManager: 110 | self.stored_messages.clear() 111 | return self 112 | 113 | def build_query( 114 | self, 115 | user_query: str | None = None, 116 | previous_messages: list[dict[str, str]] | None = None, 117 | query_for_retrieval_context: str | None = None, 118 | ) -> list[dict[str, str]]: 119 | """Given a user_query (or the `intro_messages` set on this PromptManager), build a set of messages to pass to the OpenAI API. 120 | 121 | Args: 122 | user_query (str | None, optional): If provided, will construct a new user message from this user. Defaults to None. 123 | previous_messages (list[dict[str, str]] | None, optional): If provided, will continue a conversation. Defaults to None. 124 | query_for_retrieval_context (str | None, optional): If provided, this is used for any RetrievalStrategies that require querying. Defaults to None, meaning the user_query or the most recent user message will be used. 125 | 126 | Raises: 127 | KeyError: If the given RetrievalStrategy doesn't fill all the identified slots in the prompts. 128 | 129 | Returns: 130 | list[dict[str, str]]: List of messages, to pass to the OpenAI API. 131 | """ 132 | if previous_messages is None: 133 | previous_messages = self.stored_messages 134 | if len(previous_messages) == 0: 135 | # this is a new query 136 | messages = [message.copy() for message in self.intro_messages] 137 | self.stored_messages.extend(messages) 138 | else: 139 | # not a new query, 140 | # so include the previous messages as context 141 | messages = [message.copy() for message in previous_messages] 142 | if user_query is not None: 143 | user_message = { 144 | "role": "user", 145 | "content": user_query, 146 | } 147 | messages.append(user_message) 148 | self.stored_messages.append(user_message) 149 | 150 | should_remove_user_query_message = False 151 | if query_for_retrieval_context is None: 152 | query_for_retrieval_context = "" 153 | for message in messages[::-1]: 154 | expected_slots = PromptManager.identify_slots(message["content"]) 155 | if len(expected_slots) > 0: 156 | slot_fill_dict = self.retrieval_strategy.do_retrieval( 157 | expected_slots, 158 | query_for_retrieval_context, 159 | messages, 160 | ) 161 | self.most_recent_slot_fill_dict = slot_fill_dict 162 | self.recent_slot_fill_dict.append(slot_fill_dict) 163 | assert len(slot_fill_dict) == len( 164 | expected_slots, 165 | ), "Unexpected fill provided." 166 | if "user_query" in slot_fill_dict and user_query is not None: 167 | # special case: fill user_query slots with the current user_query 168 | slot_fill_dict["user_query"] = user_query 169 | should_remove_user_query_message = True 170 | try: 171 | message["content"] = message["content"].format(**slot_fill_dict) 172 | except KeyError: 173 | raise KeyError( 174 | f"Failed to fill {expected_slots} with {slot_fill_dict}.", 175 | ) 176 | else: 177 | self.recent_slot_fill_dict.append({}) 178 | if query_for_retrieval_context == "" and message["role"] == "user": 179 | # use as retrieval context the most recent user message 180 | # TODO rethink this, providing a more flexible way to specify the retrieval context 181 | query_for_retrieval_context = message["content"] 182 | self.recent_slot_fill_dict = self.recent_slot_fill_dict[::-1] 183 | if should_remove_user_query_message: 184 | self.stored_messages.pop() 185 | assert messages[-1]["content"] == user_query 186 | messages = messages[:-1] 187 | return messages 188 | 189 | def compute_stored_token_counts(self) -> int: 190 | token_counts = embedding_utils.get_token_counts( 191 | [message["content"] for message in self.stored_messages], 192 | ) 193 | total_token_count = sum(token_counts) 194 | return total_token_count 195 | 196 | def get_recent_slot_fill(self, slot_key: str) -> str | None: 197 | for slot_fill_dict in self.recent_slot_fill_dict: 198 | for key, value in slot_fill_dict.items(): 199 | if key == slot_key: 200 | return value 201 | return None 202 | 203 | def identify_slots(prompt_string: str) -> list[str]: 204 | """Uses a regex to identify missing slots in a prompt_string. 205 | 206 | More advanced slot formatting is not supported. 207 | 208 | Args: 209 | prompt_string (str): The prompt itself, with format-style slots to fill e.g. "This is a prompt with a slot: {slot_to_fill}" 210 | 211 | Returns: 212 | list[str]: List of identified slots. 213 | """ 214 | expected_slots = re.findall(r"{[^{} ]+}", prompt_string) 215 | return sorted({slot[1:-1] for slot in expected_slots}) 216 | -------------------------------------------------------------------------------- /src/rag/prompts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DigitalHarborFoundation/rag-for-math-qa/611fa7fdd71e7a85660ab9a13b0ecdbbbc280c0b/src/rag/prompts/__init__.py -------------------------------------------------------------------------------- /src/rag/prompts/hints.py: -------------------------------------------------------------------------------- 1 | intro_prompts = { 2 | "hint_sequence": { 3 | "pretty_name": "Hint sequence", 4 | "messages": [ 5 | { 6 | "role": "system", 7 | "content": """You are an expert mathematics tutor who gives useful hints for middle-school students. 8 | 9 | The following paragraphs are examples of content that may or not be relevant in helping the student write a hint. 10 | {rori_microlesson_texts} 11 | {openstax_subsection_texts}""", 12 | }, 13 | { 14 | "role": "user", 15 | "content": """I just received this math lesson: 16 | {lesson} 17 | 18 | Provide a hint for this math question: 19 | {question} 20 | 21 | I answered {incorrect_answer}, which is incorrect. Generate four hints for the correct answer {correct_answer} by following the steps below: 22 | 23 | FIRST: Tell me the goal of the problem. 24 | 25 | SECOND: Tell me what information I need to accomplish the goal. 26 | 27 | THIRD: Tell me what mathematical computation I need to do using the information I have to accomplish the goal. 28 | 29 | FOURTH: Tell me how to do the mathematical computation and show that it results in the correct answer "{correct_answer}".""", 30 | }, 31 | ], 32 | }, 33 | "slip_correction": { 34 | "pretty_name": "Slip correction", 35 | "messages": [ 36 | { 37 | "role": "system", 38 | "content": """You are an expert mathematics tutor who gives useful hints for middle-school students. 39 | 40 | The following paragraphs are examples of content that may or may not be relevant in helping the student write a hint. 41 | {rori_microlesson_texts} 42 | {openstax_subsection_texts}""", 43 | }, 44 | { 45 | "role": "user", 46 | "content": """I just received this math lesson: 47 | {lesson} 48 | 49 | Provide a hint for this math question: 50 | {question} 51 | 52 | The correct answer is {correct_answer}, but I answered {incorrect_answer}. 53 | I think I made a small slip-up. What did I do wrong? 54 | Your answer should be one sentence and start with "Remember to".""", 55 | }, 56 | ], 57 | }, 58 | "misconception": { 59 | "pretty_name": "Misconception-based hint", 60 | "messages": [ 61 | { 62 | "role": "system", 63 | "content": """You are an expert mathematics tutor who gives useful hints for middle-school students. 64 | 65 | The following paragraphs are examples of content that may or may not be relevant in helping the student write a hint. 66 | {rori_microlesson_texts} 67 | {openstax_subsection_texts}""", 68 | }, 69 | { 70 | "role": "user", 71 | "content": """I want you to give a hint for a student who just answered a maths question incorrectly. Explain that their incorrect answer might be due to a misconception. 72 | 73 | Here are some common misconceptions that lead middle-school math students to get an incorrect answer: 74 | {misconception_string} 75 | 76 | Relevant lesson: {lesson} 77 | Question: {question} 78 | Answer: {answer} 79 | Incorrect Answer: {incorrect_answer} 80 | 81 | Give a hint that identifies the possible misconception that led to the incorrect answer "{incorrect_answer}" rather than the correct answer "{answer}".""", 82 | }, 83 | ], 84 | }, 85 | "comparative_hint": { 86 | "pretty_name": "Comparative hint", 87 | "messages": [ 88 | { 89 | "role": "system", 90 | "content": """You are an expert mathematics tutor who gives useful hints for middle-school students. 91 | 92 | The following paragraphs are examples of content that may or not be relevant in helping the student write a hint. 93 | {rori_microlesson_texts} 94 | {openstax_subsection_texts}""", 95 | }, 96 | { 97 | "role": "user", 98 | "content": """Provide a hint for this math question: 99 | {question} 100 | 101 | I answered {incorrect_answer}. I know the correct answer is {correct_answer}, but I need a hint to understand why. 102 | 103 | Here's the relevant lesson for this problem: 104 | {lesson} 105 | 106 | FIRST: Repeat the worked example from the lesson. 107 | 108 | SECOND: Compare my question to the worked example, explaining how it is different. 109 | 110 | THIRD: Give me the steps to solve my question correctly, identifying the correct answer as {correct_answer} in the final step.""", 111 | }, 112 | ], 113 | }, 114 | } 115 | 116 | misconception_identification = { 117 | "creature_ai": { 118 | "pretty_name": "Nancy Otero's misconception identification prompt", 119 | "messages": [ 120 | { 121 | "role": "user", 122 | "content": """I’ll give you a spreadsheet with a list of MaEs. Each MaE has an ID, an explanation of the MaE, and 4 examples of the MaE. 123 | Then I'll show you a student incorrect answer to a math question. 124 | I want you to tell me how many of the {mae_count} MaEs you can identify in the answers, identify as many as you can. 125 | Please process the answer and tell me: 126 | If the answer is correct 127 | If the answer is not correct: how many MaEs can you identify? Which ones and why? 128 | Spreadsheet: 129 | {mae_spreadsheet_string} 130 | """, 131 | }, 132 | ], 133 | }, 134 | } 135 | -------------------------------------------------------------------------------- /src/rag/prompts/mathqa.py: -------------------------------------------------------------------------------- 1 | intro_prompts = { 2 | "general_math_qa_intro": { 3 | "pretty_name": "General middle-school math prompt", 4 | "messages": [ 5 | { 6 | "role": "system", 7 | "content": """You are going to act as a mathematics tutor for a 13 year old student who is in grade 8 or 9. 8 | This student lives in Ghana or Nigeria. 9 | You will be encouraging and factual. 10 | {rori_microlesson_texts} 11 | {openstax_subsection_texts} 12 | Prefer simple, short responses. 13 | If the student says something inappropriate or off topic you will say you can only focus on mathematics and ask them if they have any math-related follow-up questions. 14 | """, 15 | }, 16 | ], 17 | "retrieval_config": { # experimental configuration info, not yet implemented 18 | "rori_microlesson_texts": { 19 | "prefix": "Here is some lesson content that might be relevant:\n", 20 | }, 21 | "openstax_subsection_texts": { 22 | "prefix": "Here are some excerpts from a math textbook. If they are relevant to the question, feel free to use language or examples from these excerpts:\n", 23 | }, 24 | }, 25 | }, 26 | "retrieval_reliant_math_qa_intro": { 27 | "pretty_name": "Retrieval-reliant middle-school math prompt", 28 | "messages": [ 29 | { 30 | "role": "system", 31 | "content": """You are going to act as a mathematics tutor for a 13 year old student who is in grade 8 or 9. 32 | This student lives in Ghana or Nigeria. 33 | You will be encouraging and factual. 34 | 35 | Use examples and language from the section below to format your response: 36 | === 37 | {rori_microlesson_texts} 38 | {openstax_subsection_texts} 39 | === 40 | 41 | Prefer simple, short responses. 42 | If the student says something inappropriate or off topic you will say you can only focus on mathematics and ask them if they have any math-related follow-up questions. 43 | """, 44 | }, 45 | ], 46 | }, 47 | "instruct_qa": { 48 | # inspired by the instruct-qa QAPromptTemplate: https://github.com/McGill-NLP/instruct-qa/blob/main/instruct_qa/prompt/templates.py 49 | "pretty_name": "McGill's (very general) instruct-qa prompt", 50 | "messages": [ 51 | { 52 | "role": "user", 53 | "content": """Please answer the following question given the following passages: 54 | {rori_microlesson_texts} 55 | {openstax_subsection_texts} 56 | Question: {user_query} 57 | Answer: """, 58 | }, 59 | ], 60 | }, 61 | } 62 | -------------------------------------------------------------------------------- /src/rag/retrieval.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import collections.abc 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import scipy 9 | 10 | from rag import embedding_utils 11 | 12 | 13 | class RetrievalDb: 14 | """In-memory retrieval helper class. 15 | 16 | When creating new embeddings: 17 | ```python 18 | assert "text_col" in df.columns 19 | retrieval_db = RetrievalDb(Path("embedding_dir"), "db_name", "text_col", df) 20 | retrieval_db.create_embeddings() 21 | retrieval_db.save_df() 22 | ``` 23 | 24 | When loading existing embeddings: 25 | ```python 26 | retrieval_db = RetrievalDb(Path("embedding_dir"), "db_name", "text_col") 27 | # `load()` is called during construction if df is not provided 28 | assert "text_col" in retrieval_db.df.columns 29 | ``` 30 | """ 31 | 32 | def __init__( 33 | self, 34 | embedding_dir: Path, 35 | db_name: str, 36 | embed_col: str, 37 | df: pd.DataFrame | None = None, 38 | n_tokens_col: str = "n_tokens", 39 | ): 40 | self.embedding_dir = embedding_dir 41 | self.db_name = db_name 42 | 43 | self.df_filepath = self.embedding_dir / f"{self.db_name}_df.parquet" 44 | self.embedding_filepath = self.embedding_dir / f"{self.db_name}_embed.npy" 45 | 46 | self.embed_col = embed_col 47 | if df is None: 48 | self.load() 49 | else: 50 | self.df = df 51 | self.normalize_strings() 52 | assert self.embed_col in self.df.columns 53 | 54 | self.n_tokens_col = n_tokens_col 55 | if n_tokens_col not in self.df.columns: 56 | self.compute_token_counts() 57 | 58 | def normalize_strings(self): 59 | self.df[self.embed_col] = self.df[self.embed_col].map(normalize_text) 60 | 61 | def compute_token_counts(self): 62 | token_counts = embedding_utils.get_token_counts(self.df[self.embed_col]) 63 | self.df[self.n_tokens_col] = token_counts 64 | 65 | def create_embeddings(self): 66 | embedding_list = embedding_utils.batch_embed_texts( 67 | self.df[self.embed_col], 68 | self.df[self.n_tokens_col], 69 | ) 70 | self.embedding_mat = np.concatenate( 71 | [e.reshape(1, -1) for e in embedding_list], 72 | axis=0, 73 | ) 74 | np.save(self.embedding_filepath, self.embedding_mat) 75 | 76 | def save_df(self): 77 | self.df.to_parquet(self.df_filepath) 78 | 79 | def load(self): 80 | if not self.df_filepath.exists(): 81 | raise ValueError( 82 | f"Trying to load a dataframe from non-existent path: {self.df_filepath}", 83 | ) 84 | self.df = pd.read_parquet(self.df_filepath) 85 | self.embedding_mat = np.load(self.embedding_filepath) 86 | 87 | def compute_embedding_distances(self, query_embedding: np.array) -> np.array: 88 | if query_embedding.shape[0] != 1: 89 | query_embedding = query_embedding.reshape(1, -1) 90 | distances = scipy.spatial.distance.cdist( 91 | query_embedding, 92 | self.embedding_mat, 93 | metric="cosine", 94 | )[0] 95 | return distances 96 | 97 | def compute_string_distances(self, query_str: str) -> np.array: 98 | embedding_list = embedding_utils.get_openai_embeddings( 99 | [normalize_text(query_str)], 100 | ) 101 | query_embedding = embedding_list[0] 102 | return self.compute_embedding_distances(query_embedding) 103 | 104 | def iterate_query_embeddings( 105 | self, 106 | query_embedding_list: collections.abc.Iterable[np.array], 107 | ) -> collections.abc.Generator[np.array]: 108 | for query_embedding in query_embedding_list: 109 | yield self.compute_embedding_distances(query_embedding) 110 | 111 | def get_top_df(self, distances: np.array, k: int = 5) -> pd.DataFrame: 112 | sort_inds = np.argsort(distances) 113 | top_k_indices = sort_inds[:k] 114 | top_k_scores = distances[top_k_indices] 115 | assert top_k_indices.shape == top_k_scores.shape 116 | return self.df.iloc[top_k_indices] 117 | 118 | 119 | def get_distance_sort_indices(distances: np.array) -> np.array: 120 | return np.argsort(distances) 121 | 122 | 123 | def normalize_text(text: str) -> str: 124 | return text.replace("\n", " ").strip() 125 | 126 | 127 | class DbInfo: 128 | """Wrapper class with info about how retrieved texts should be incorporated in a prompt. 129 | 130 | See `prompt_utils.PromptManager`. 131 | """ 132 | 133 | def __init__( 134 | self, 135 | db: RetrievalDb, 136 | max_tokens: int = 1000, 137 | max_texts: int = 1000, 138 | prefix: str = "", 139 | suffix: str = "", 140 | join_string: str = "\n", 141 | use_parent_text: bool = False, 142 | parent_group_cols: list[str] = [], 143 | parent_sort_cols: list[str] = [], 144 | ): 145 | self.db = db 146 | self.max_tokens = max_tokens 147 | self.max_texts = max_texts 148 | self.prefix = prefix 149 | self.suffix = suffix 150 | self.join_string = join_string 151 | 152 | # configure parent retrieval 153 | self.use_parent_text = use_parent_text 154 | self.parent_join_string = "\n" 155 | self.parent_group_cols = parent_group_cols 156 | self.parent_sort_cols = parent_sort_cols 157 | 158 | def copy(self, **kwargs) -> DbInfo: 159 | """Create a copy of this DbInfo, overriding the keyword args with new values if provided. 160 | 161 | Returns: 162 | DbInfo: Newly instantiated copy. 163 | """ 164 | for expected_key in ["max_tokens", "prefix", "suffix"]: 165 | if expected_key not in kwargs: 166 | kwargs[expected_key] = getattr(self, expected_key) 167 | return DbInfo(self.db, **kwargs) 168 | 169 | def get_fill_string_from_distances(self, distances: np.array) -> str: 170 | """Given distances to the texts within the RetrievalDb, create an appropriate fill string. 171 | 172 | Args: 173 | distances (np.array): Distances, where closer texts in the RetrievalDb are more relevant. 174 | 175 | Returns: 176 | str: The string to include in the prompt. 177 | """ 178 | sort_inds = get_distance_sort_indices(distances) 179 | used_inds = set() 180 | texts = [] 181 | total_tokens = 0 182 | for ind in sort_inds: 183 | if ind in used_inds: 184 | continue 185 | if self.use_parent_text: 186 | token_budget = self.max_tokens - total_tokens 187 | text, n_tokens, new_used_inds = self.get_parent_text(ind, token_budget) 188 | used_inds.update(new_used_inds) 189 | else: 190 | text, n_tokens = self.get_single_text(ind) 191 | used_inds.add(ind) 192 | if total_tokens + n_tokens > self.max_tokens: 193 | break 194 | total_tokens += n_tokens 195 | texts.append(text) 196 | if len(texts) >= self.max_texts: 197 | break 198 | fill_string = self.prefix + self.join_string.join(texts) + self.suffix 199 | return fill_string 200 | 201 | def get_single_text(self, ind: int): 202 | """Given a index, return the text and corresponding number of tokens from the RetrievalDb. 203 | 204 | Args: 205 | ind (int): _description_ 206 | """ 207 | row = self.db.df.iloc[ind] 208 | text = row[self.db.embed_col] 209 | n_tokens = row[self.db.n_tokens_col] 210 | return text, n_tokens 211 | 212 | def get_parent_text(self, ind: int, token_budget: int): 213 | """ 214 | Intuition of "parent document" retriever is to retrieve for inclusion in a prompt the "parent" document, 215 | similar to including docs on either "side" as additional context. 216 | 217 | Read more: https://python.langchain.com/docs/modules/data_connection/retrievers/parent_document_retriever 218 | 219 | Args: 220 | ind (int): Most semantically relevant index to retrieve parents of. 221 | """ 222 | df = self.db.df 223 | row = df.iloc[ind] 224 | and_cond = np.ones(len(df), dtype=bool) 225 | for col in self.parent_group_cols: 226 | cond = df[col] == row[col] 227 | and_cond = np.logical_and(and_cond, cond) 228 | parent = df[and_cond] 229 | if self.parent_sort_cols is not None and len(self.parent_sort_cols) > 1: 230 | parent = parent.sort_values(by=self.parent_sort_cols) 231 | # include a variable amount of context based on the given token_budget 232 | # preference ranking implemented here: 233 | # - all docs 234 | # - up to token_budget docs from target_ind - 0 235 | total_tokens = parent[self.db.n_tokens_col].sum() 236 | if row[self.db.n_tokens_col] > token_budget: 237 | # simple case: NOTHING will fit in the token budget! 238 | return None 239 | elif total_tokens <= token_budget: 240 | # simple case: if all tokens in budget, no extra work 241 | rows = parent 242 | else: 243 | target_ind_val = df.index[ind] 244 | before = parent.loc[:target_ind_val] 245 | # see: https://stackoverflow.com/a/37872823 246 | cumulative_token_counts = before.loc[::-1, self.db.n_tokens_col].cumsum()[::-1] 247 | rows = before[cumulative_token_counts <= token_budget] 248 | assert len(rows) >= 1 249 | new_used_inds = {df.index.get_loc(new_ind) for new_ind in rows.index} 250 | assert ind in new_used_inds 251 | texts = rows[self.db.embed_col] 252 | n_tokens_rows = rows[self.db.n_tokens_col] 253 | text = self.parent_join_string.join(texts) 254 | # note this will underestimate the true number of tokens, due to whatever parent_join_string is 255 | n_tokens = n_tokens_rows.sum() 256 | return text, n_tokens, new_used_inds 257 | -------------------------------------------------------------------------------- /src/rag/retrieval_strategies.py: -------------------------------------------------------------------------------- 1 | from rag import retrieval 2 | 3 | 4 | class RetrievalStrategy: 5 | """General retrieval strategy interface.""" 6 | 7 | def do_retrieval( 8 | self, 9 | expected_slots: list[str], 10 | user_query: str, 11 | previous_messages: list[dict[str, str]] = [], 12 | ): 13 | raise ValueError("Not implemented.") 14 | 15 | 16 | class NoRetrievalStrategy(RetrievalStrategy): 17 | """Fill all expected_slots with the empty string.""" 18 | 19 | def do_retrieval( 20 | self, 21 | expected_slots: list[str], 22 | user_query: str, 23 | previous_messages: list[dict[str, str]] = [], 24 | ): 25 | return {expected_slot: "" for expected_slot in expected_slots} 26 | 27 | 28 | class StaticRetrievalStrategy(RetrievalStrategy): 29 | """Fill all expected_slots with a static string.""" 30 | 31 | def __init__(self, fill_string: str) -> None: 32 | super().__init__() 33 | self.fill_string = fill_string 34 | 35 | def do_retrieval( 36 | self, 37 | expected_slots: list[str], 38 | user_query: str, 39 | previous_messages: list[dict[str, str]] = [], 40 | ): 41 | return {expected_slot: self.fill_string for expected_slot in expected_slots} 42 | 43 | 44 | class EmbeddingRetrievalStrategy(RetrievalStrategy): 45 | """Fill all expected_slots with up to max_token texts from the retrieval_db. 46 | Uses the user_query to produce an embedding and compute RetrievalDb distances.""" 47 | 48 | def __init__(self, db: retrieval.RetrievalDb, max_tokens: int = 2000) -> None: 49 | super().__init__() 50 | self.db: retrieval.RetrievalDb = db 51 | self.max_tokens: int = max_tokens 52 | 53 | def do_retrieval( 54 | self, 55 | expected_slots: list[str], 56 | user_query: str, 57 | previous_messages: list[dict[str, str]] = [], 58 | ): 59 | distances = self.db.compute_string_distances(user_query) 60 | sort_inds = retrieval.get_distance_sort_indices(distances) 61 | texts = [] 62 | total_tokens = 0 63 | for ind in sort_inds: 64 | row = self.db.df.iloc[ind] 65 | n_tokens = row[self.db.n_tokens_col] 66 | if total_tokens + n_tokens > self.max_tokens: 67 | break 68 | total_tokens += n_tokens 69 | text = row[self.db.embed_col] 70 | texts.append(text) 71 | fill_string = "\n".join(texts) 72 | return {expected_slot: fill_string for expected_slot in expected_slots} 73 | 74 | 75 | class MappedEmbeddingRetrievalStrategy(RetrievalStrategy): 76 | """Fill all expected_slots based on the entries in slot_map. 77 | If asked to fill a slot not in slot_map, will use nonmatching_fill instead. 78 | slot_map can have either static strings or `retrieval.DbInfo`.""" 79 | 80 | def __init__( 81 | self, 82 | slot_map: dict[str, str | retrieval.DbInfo], 83 | nonmatching_fill: str = "", 84 | ) -> None: 85 | super().__init__() 86 | self.slot_map = slot_map 87 | self.nonmatching_fill = nonmatching_fill 88 | self._validate_slot_map() 89 | 90 | def _validate_slot_map(self): 91 | for key, value in self.slot_map.items(): 92 | if type(key) is not str: 93 | raise ValueError("Slot map keys must be strings.") 94 | if type(value) not in [str, retrieval.DbInfo]: 95 | raise ValueError("Unexpected type in slot map.") 96 | if type(value) is retrieval.DbInfo: 97 | if not hasattr(value.db, "compute_string_distances"): 98 | raise ValueError( 99 | "Expected a db with a compute_string_distances() method.", 100 | ) 101 | 102 | def update_map(self, slot_updates: dict): 103 | self.slot_map.update(slot_updates) 104 | self._validate_slot_map() 105 | 106 | def do_retrieval( 107 | self, 108 | expected_slots: list[str], 109 | user_query: str, 110 | previous_messages: list[dict[str, str]] = [], 111 | ): 112 | fill_string_map = {} 113 | for expected_slot in expected_slots: 114 | if expected_slot in self.slot_map: 115 | db_info = self.slot_map[expected_slot] 116 | if type(db_info) is str: 117 | fill_string = db_info 118 | else: 119 | distances = db_info.db.compute_string_distances(user_query) 120 | fill_string = db_info.get_fill_string_from_distances(distances) 121 | else: 122 | fill_string = self.nonmatching_fill 123 | fill_string_map[expected_slot] = fill_string 124 | return fill_string_map 125 | -------------------------------------------------------------------------------- /src/streamlit/Annotation_tool.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | 3 | st.markdown( 4 | """ # Annotation tool 5 | 6 | In-progress annotation tool. 7 | """, 8 | ) 9 | -------------------------------------------------------------------------------- /src/streamlit/pages/Relevance.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | 3 | st.markdown( 4 | """ # Relevance annotation 5 | 6 | In-progress page to annotate relevance. 7 | """, 8 | ) 9 | -------------------------------------------------------------------------------- /tests/unit/conftest.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pytest 4 | 5 | from rag import embedding_utils, retrieval 6 | 7 | 8 | def mock_get_openai_embeddings(input_text_list, *args, **kwargs): 9 | return [np.random.random(size=embedding_utils.EMBEDDING_DIM) for _ in input_text_list] 10 | 11 | 12 | @pytest.fixture 13 | def patch_get_openai_embeddings(monkeypatch): 14 | monkeypatch.setattr( 15 | "rag.embedding_utils.get_openai_embeddings", 16 | mock_get_openai_embeddings, 17 | ) 18 | 19 | 20 | @pytest.fixture 21 | def retrieval_db_path(tmp_path, patch_get_openai_embeddings): 22 | # Creates a retrieval database that can be used by multiple tests 23 | df = pd.DataFrame( 24 | [ 25 | { 26 | "categorical_var": "A", 27 | "group_var": 1, 28 | "text": "Test text 1.", 29 | }, 30 | { 31 | "categorical_var": "B", 32 | "group_var": 1, 33 | "text": "Test text 2.", 34 | }, 35 | { 36 | "categorical_var": "C", 37 | "group_var": 2, 38 | "text": "Test text 3.", 39 | }, 40 | ], 41 | ) 42 | db = retrieval.RetrievalDb(tmp_path, "conftestDb", "text", df) 43 | db.create_embeddings() 44 | db.save_df() 45 | return tmp_path 46 | 47 | 48 | @pytest.fixture 49 | def retrieval_db(retrieval_db_path) -> retrieval.RetrievalDb: 50 | db = retrieval.RetrievalDb(retrieval_db_path, "conftestDb", "text") 51 | return db 52 | -------------------------------------------------------------------------------- /tests/unit/rag/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | @pytest.fixture(autouse=True) 5 | def no_openai(monkeypatch): 6 | """Remove Embedding and ChatCompletion creations during all RAG unit tests. 7 | 8 | It's not actually clear to me that this works.""" 9 | monkeypatch.delattr("openai.Embedding.create") 10 | monkeypatch.delattr("openai.ChatCompletion.create") 11 | -------------------------------------------------------------------------------- /tests/unit/rag/test_embedding_utils.py: -------------------------------------------------------------------------------- 1 | from rag import embedding_utils 2 | 3 | 4 | def test_get_token_counts(): 5 | n_tokens_list = embedding_utils.get_token_counts(["test", "test "]) 6 | assert n_tokens_list == [1, 2] 7 | 8 | 9 | def test_get_openai_embeddings(patch_get_openai_embeddings): 10 | result = embedding_utils.get_openai_embeddings(["test"]) 11 | assert len(result) == 1 12 | assert result[0].shape[0] == embedding_utils.EMBEDDING_DIM 13 | 14 | 15 | def test_batch_embed_texts(patch_get_openai_embeddings): 16 | max_tokens = embedding_utils.MAX_TOKENS_PER_REQUEST 17 | input_text_list = ["test"] * (max_tokens + 1) 18 | embedding_list = embedding_utils.batch_embed_texts( 19 | input_text_list, 20 | embedding_utils.get_token_counts(input_text_list), 21 | ) 22 | assert len(embedding_list) == len(input_text_list) 23 | assert all(emb.shape[0] == embedding_utils.EMBEDDING_DIM for emb in embedding_list) 24 | -------------------------------------------------------------------------------- /tests/unit/rag/test_gpf_utils.py: -------------------------------------------------------------------------------- 1 | from rag import gpf_utils 2 | 3 | 4 | def test_get_gpd_codes(): 5 | grade, domain, construct, subconstruct, skill, index = gpf_utils.get_gpd_codes( 6 | "G9.N5.1.3.1", 7 | ) 8 | assert grade == 9 9 | assert domain == "N" 10 | assert construct == "N5" 11 | assert subconstruct == "N5.1" 12 | assert skill == "N5.1.3" 13 | assert index == 1 14 | -------------------------------------------------------------------------------- /tests/unit/rag/test_logit_bias.py: -------------------------------------------------------------------------------- 1 | from rag import logit_bias 2 | 3 | THE_TOKEN = 1820 # token for string "the" 4 | 5 | 6 | def test_get_tokenizer(): 7 | tokenizer = logit_bias.get_tokenizer() 8 | assert tokenizer.encode("the") == [1820] 9 | assert tokenizer.decode([THE_TOKEN]) == "the" 10 | 11 | 12 | def test_load_stopwords(): 13 | stopword_tokens = logit_bias.load_stopword_tokens() 14 | assert len(stopword_tokens) >= 11000 15 | assert THE_TOKEN in stopword_tokens, "The 'the' token (1820) should be in the stopwords." 16 | assert stopword_tokens == logit_bias.get_stopword_tokens() 17 | # test caching 18 | assert logit_bias.get_stopword_tokens() == logit_bias.get_stopword_tokens() 19 | 20 | 21 | def test_get_nonstopword_tokens(): 22 | # stopwords 23 | tokens = logit_bias.get_nonstopword_tokens("the was and were thus") 24 | assert len(tokens) == 0 25 | # non-stopwords 26 | tokens = logit_bias.get_nonstopword_tokens("verily osmogorp") 27 | assert len(tokens) >= 5 28 | 29 | 30 | def test_get_logit_bias(): 31 | tokens = logit_bias.get_nonstopword_tokens("verily verily verily") 32 | assert len(tokens) == 6 33 | logit_bias_dict = logit_bias.get_logit_bias(tokens, min_count=4) 34 | assert len(logit_bias_dict) == 0 35 | logit_bias_dict = logit_bias.get_logit_bias(tokens, min_count=3) 36 | assert len(logit_bias_dict) == 1 37 | assert any([token in logit_bias_dict for token in tokens]) 38 | 39 | 40 | def test_get_logit_bias_from_slot(): 41 | recent_slot_fill_dict = [ 42 | { 43 | "slot1": "verily verily verily", 44 | "slot2": "the", 45 | }, 46 | ] 47 | logit_bias_dict = logit_bias.get_logit_bias_from_slot(recent_slot_fill_dict) 48 | assert 1570 in logit_bias_dict 49 | # TODO make this test cover more cases and be more explanatory 50 | 51 | 52 | def test_create_stopword_token_set_from_word_list(): 53 | word_list = ["the"] 54 | stopword_tokens = logit_bias.create_stopword_token_set_from_word_list(word_list) 55 | assert len(stopword_tokens) > 1 56 | assert THE_TOKEN in stopword_tokens 57 | -------------------------------------------------------------------------------- /tests/unit/rag/test_misconceptions.py: -------------------------------------------------------------------------------- 1 | from rag import misconceptions 2 | 3 | 4 | def test_load_misconception_list(): 5 | misconception_list = misconceptions.load_misconception_list() 6 | assert len(misconception_list) >= 50 7 | # assert misconception_list == misconceptions.get_misconception_list() # I'm not sure why this inequality is false 8 | # test caching 9 | assert misconceptions.get_misconception_list() == misconceptions.get_misconception_list() 10 | 11 | 12 | def test_get_misconceptions_string(): 13 | misconception_string = misconceptions.get_misconceptions_string() 14 | assert len(misconception_string) >= 1000 15 | -------------------------------------------------------------------------------- /tests/unit/rag/test_prompt_utils.py: -------------------------------------------------------------------------------- 1 | from rag import prompt_utils, retrieval_strategies 2 | 3 | 4 | def test_conversion(): 5 | test_string = """SYSTEM: 6 | System prompt. 7 | Multiple lines. 8 | USER: 9 | User query.""" 10 | messages = prompt_utils.PromptSelector.convert_string_to_conversation(test_string) 11 | assert len(messages) == 2 12 | assert messages[0]["role"] == "system" 13 | assert messages[0]["content"] == "System prompt.\nMultiple lines.", messages[0] 14 | assert messages[1]["role"] == "user" 15 | assert messages[1]["content"] == "User query.", messages[1] 16 | 17 | 18 | def test_PromptSelector(): 19 | test_prompts = { 20 | "test_prompt_1": { 21 | "pretty_name": "Prompt 1", 22 | "messages": [ 23 | { 24 | "role": "system", 25 | "content": "System prompt 1.", 26 | }, 27 | ], 28 | }, 29 | "test_prompt_2": { 30 | "pretty_name": "Prompt 2", 31 | "messages": [ 32 | { 33 | "role": "system", 34 | "content": "System prompt 2.", 35 | }, 36 | ], 37 | }, 38 | } 39 | pm = prompt_utils.PromptSelector(test_prompts) 40 | assert pm.get_intro_prompt_pretty_names() == ["Prompt 1", "Prompt 2"] 41 | message_lists = pm.get_intro_prompt_message_lists() 42 | assert message_lists[0] == test_prompts["test_prompt_1"]["messages"] 43 | 44 | assert pm.get_default_intro_prompt()["pretty_name"] == "Prompt 1" 45 | 46 | 47 | def test_PromptManager(): 48 | pm = prompt_utils.PromptManager() 49 | 50 | # test basic query 51 | messages = pm.build_query("Test") 52 | assert len(messages) == 1 53 | assert messages[0]["content"] == "Test" 54 | assert len(pm.stored_messages) == 1 55 | assert pm.stored_messages[0]["content"] == "Test" 56 | pm.clear_stored_messages() 57 | 58 | # test conversation start with system message 59 | test_intro_messages = [ 60 | { 61 | "role": "system", 62 | "content": "System", 63 | }, 64 | ] 65 | messages = pm.set_intro_messages(test_intro_messages).build_query("User") 66 | assert len(messages) == 2 67 | assert messages[0]["role"] == "system", messages 68 | assert messages[0]["content"] == "System" 69 | assert messages[1]["role"] == "user" 70 | assert messages[1]["content"] == "User" 71 | pm.clear_stored_messages() 72 | 73 | # test conversation continuation 74 | previous_messages = messages 75 | previous_messages.append( 76 | { 77 | "role": "assistant", 78 | "content": "Assistant", 79 | }, 80 | ) 81 | messages = pm.set_intro_messages(test_intro_messages).build_query( 82 | "User2", 83 | previous_messages=previous_messages, 84 | ) 85 | assert len(messages) == 4 86 | assert messages[2]["content"] == "Assistant" 87 | assert messages[3]["content"] == "User2" 88 | 89 | 90 | def test_PromptManager_retrieval(): 91 | test_intro_messages = [ 92 | { 93 | "role": "system", 94 | "content": "Test {slot1} {slot2}", 95 | }, 96 | ] 97 | retrieval_strategy = retrieval_strategies.StaticRetrievalStrategy("Fill") 98 | pm = prompt_utils.PromptManager().set_intro_messages(test_intro_messages).set_retrieval_strategy(retrieval_strategy) 99 | messages = pm.build_query("User") 100 | assert len(messages) == 2 101 | assert messages[0]["content"] == "Test Fill Fill", messages[0] 102 | assert pm.intro_messages[0]["content"] == "Test {slot1} {slot2}" 103 | 104 | 105 | def test_identify_slots(): 106 | slots = prompt_utils.PromptManager.identify_slots("test {test1} {test2} {test3 }") 107 | assert slots == ["test1", "test2"] 108 | 109 | slots = prompt_utils.PromptManager.identify_slots("test {test1} {test1}") 110 | assert slots == ["test1"] 111 | 112 | 113 | def test_user_query_replacement(): 114 | test_intro_messages = [ 115 | { 116 | "role": "user", 117 | "content": "Question: {user_query}", 118 | }, 119 | ] 120 | messages = prompt_utils.PromptManager().set_intro_messages(test_intro_messages).build_query("Test") 121 | assert len(messages) == 1 122 | assert messages[0]["content"] == "Question: Test" 123 | 124 | # this behavior requires exactly the slot name "user_query" 125 | # compare to the below behavior, which produces 2 messages 126 | test_intro_messages = [ 127 | { 128 | "role": "user", 129 | "content": "Question: {other_slot}", 130 | }, 131 | ] 132 | messages = prompt_utils.PromptManager().set_intro_messages(test_intro_messages).build_query("Test") 133 | assert len(messages) == 2 134 | assert messages[0]["content"] == "Question: " 135 | assert messages[1]["content"] == "Test" 136 | -------------------------------------------------------------------------------- /tests/unit/rag/test_retrieval.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import pytest 6 | 7 | from rag import embedding_utils, retrieval 8 | 9 | 10 | def test_RetrievalDb(tmp_path, patch_get_openai_embeddings): 11 | df = pd.DataFrame( 12 | [ 13 | { 14 | "categorical_var": "A", 15 | "text": "Test text.", 16 | }, 17 | { 18 | "categorical_var": "B", 19 | "text": "Test\n\ntext.", 20 | }, 21 | ], 22 | ) 23 | db = retrieval.RetrievalDb(tmp_path, "testDb", "text", df) 24 | assert not db.embedding_filepath.exists() 25 | assert not db.df_filepath.exists() 26 | db.create_embeddings() 27 | assert db.embedding_filepath.exists() 28 | assert db.embedding_mat.shape == (len(df), embedding_utils.EMBEDDING_DIM) 29 | db.save_df() 30 | assert db.df_filepath.exists() 31 | 32 | db = retrieval.RetrievalDb(tmp_path, "testDb", "text") 33 | assert len(db.df) == len(df) 34 | 35 | distances = db.compute_string_distances("Test query.") 36 | assert len(distances) == len(df) 37 | top_df = db.get_top_df(distances, k=1) 38 | assert len(top_df) == 1 39 | 40 | # test non-existent db loading 41 | with pytest.raises(ValueError): 42 | retrieval.RetrievalDb(tmp_path, "testDb2", "text") 43 | 44 | 45 | def test_DbInfo(retrieval_db): 46 | db_info = retrieval.DbInfo(retrieval_db, max_tokens=1) 47 | assert db_info.max_tokens == 1 48 | db_info2 = db_info.copy() 49 | assert db_info2.max_tokens == 1 50 | db_info3 = db_info.copy(max_tokens=2) 51 | assert db_info3.max_tokens == 2 52 | 53 | 54 | def test_DbInfo_get_fill_string_from_distances(retrieval_db): 55 | test_prefix = "test_prefix" 56 | test_suffix = "test_suffix" 57 | db_info = retrieval.DbInfo( 58 | retrieval_db, 59 | prefix=test_prefix, 60 | suffix=test_suffix, 61 | max_texts=1, 62 | ) 63 | distances = np.array([0] + [1] * (len(db_info.db.df) - 1)) 64 | assert len(distances) == len(db_info.db.df) 65 | fill_string = db_info.get_fill_string_from_distances(distances) 66 | assert fill_string.startswith(test_prefix) 67 | assert fill_string.endswith(test_suffix) 68 | assert db_info.db.df[db_info.db.embed_col].iloc[0] in fill_string 69 | assert all( 70 | db_info.db.df[db_info.db.embed_col].iloc[1:].map(lambda t: t not in fill_string), 71 | ) 72 | 73 | 74 | def test_DbInfo_get_single_text(retrieval_db): 75 | db_info = retrieval.DbInfo(retrieval_db) 76 | text, n_tokens = db_info.get_single_text(0) 77 | assert text == db_info.db.df[db_info.db.embed_col].iloc[0] 78 | assert n_tokens > 0 79 | 80 | 81 | def test_DbInfo_get_parent_text(retrieval_db): 82 | db_info = retrieval.DbInfo( 83 | retrieval_db, 84 | max_texts=1, 85 | use_parent_text=True, 86 | parent_group_cols=["group_var"], 87 | parent_sort_cols=["categorical_var"], 88 | ) 89 | # first, verify that with a budget of 0 no texts are returned 90 | assert db_info.get_parent_text(0, 0) is None 91 | 92 | # second, verify normal case 93 | text, n_tokens, used_inds = db_info.get_parent_text(0, 1000) 94 | assert used_inds == {0, 1} 95 | expected_parent_rows = db_info.db.df[db_info.db.df["group_var"] == 1] 96 | expected_text = "\n".join(expected_parent_rows[db_info.db.embed_col]) 97 | assert text == expected_text, f"Parent text was {text}" 98 | assert n_tokens == expected_parent_rows[db_info.db.n_tokens_col].sum() 99 | 100 | # here, we retrieve the second entry in a parent text but we still get back the first and second together 101 | distances = np.array([1, 0, 1]) 102 | fill_string = db_info.get_fill_string_from_distances(distances) 103 | assert fill_string == expected_text 104 | 105 | # check for non-duplicate retrieval 106 | db_info.max_texts = 2 107 | distances = np.array([0, 0, 1]) 108 | fill_string = db_info.get_fill_string_from_distances(distances) 109 | assert len(re.findall(expected_text, fill_string)) == 1, f"Duplicate retrieval in {fill_string}" 110 | 111 | # third, we verify token budget backoff behavior 112 | # here, we give only enough budget to get the target text 113 | token_budget = db_info.db.df[db_info.db.n_tokens_col].iloc[0] 114 | text, n_tokens, used_inds = db_info.get_parent_text(0, token_budget) 115 | assert text == db_info.db.df[db_info.db.embed_col].iloc[0] 116 | assert n_tokens == token_budget 117 | assert used_inds == {0} 118 | 119 | # here, we give enough budget for two texts 120 | token_budget = db_info.db.df[db_info.db.n_tokens_col].iloc[0:1].sum() 121 | text, n_tokens, used_inds = db_info.get_parent_text(0, token_budget) 122 | assert len(used_inds - {0, 1}) == 0 123 | -------------------------------------------------------------------------------- /tests/unit/rag/test_retrieval_strategies.py: -------------------------------------------------------------------------------- 1 | from rag import retrieval, retrieval_strategies 2 | 3 | 4 | def test_NoRetrievalStrategy(): 5 | retriever = retrieval_strategies.NoRetrievalStrategy() 6 | assert retriever.do_retrieval(["test"], "user") == {"test": ""} 7 | 8 | 9 | def test_StaticRetrievalStrategy(): 10 | retriever = retrieval_strategies.StaticRetrievalStrategy("TestVal") 11 | assert retriever.do_retrieval(["test"], "user") == {"test": "TestVal"} 12 | 13 | 14 | def test_EmbeddingRetrievalStrategy(retrieval_db_path): 15 | db = retrieval.RetrievalDb(retrieval_db_path, "conftestDb", "text") 16 | retriever = retrieval_strategies.EmbeddingRetrievalStrategy(db, max_tokens=5) 17 | filled_slots = retriever.do_retrieval(["testSlot"], "testQuery") 18 | assert "testSlot" in filled_slots 19 | # note that the actual text retrieved here is random due to the way the retrieval embeddings are generated 20 | assert filled_slots["testSlot"].startswith("Test text"), "No retrieved text." 21 | 22 | # test max_tokens 23 | retriever = retrieval_strategies.EmbeddingRetrievalStrategy(db, max_tokens=0) 24 | filled_slots = retriever.do_retrieval(["testSlot"], "testQuery") 25 | assert filled_slots["testSlot"] == "" 26 | 27 | 28 | def test_MappedEmbeddingRetrievalStrategy(retrieval_db): 29 | slot_map = { 30 | "slot1": "fill1", 31 | "slot2": "fill2", 32 | } 33 | retriever = retrieval_strategies.MappedEmbeddingRetrievalStrategy( 34 | slot_map, 35 | nonmatching_fill="nomatch", 36 | ) 37 | filled_slots = retriever.do_retrieval(["slot1", "slot2"], "") 38 | assert slot_map == filled_slots 39 | filled_slots = retriever.do_retrieval(["slot1", "slot2", "slot3"], "") 40 | assert filled_slots["slot3"] == "nomatch" 41 | for slot, slot_fill in slot_map.items(): 42 | assert filled_slots[slot] == slot_fill 43 | 44 | slot_map = { 45 | "slot1": "fill1", 46 | "slot2": retrieval.DbInfo(retrieval_db), 47 | } 48 | retriever = retrieval_strategies.MappedEmbeddingRetrievalStrategy(slot_map) 49 | filled_slots = retriever.do_retrieval(["slot1", "slot2"], "") 50 | assert filled_slots["slot1"] == "fill1" 51 | assert filled_slots["slot2"].startswith("Test text") 52 | 53 | retriever.update_map({"slot2": "fill2"}) 54 | filled_slots = retriever.do_retrieval(["slot1", "slot2"], "") 55 | assert filled_slots["slot2"] == "fill2" 56 | -------------------------------------------------------------------------------- /tests/unit/test_auth.py: -------------------------------------------------------------------------------- 1 | from experiment import auth 2 | 3 | 4 | def test_passwd_check(): 5 | passphrase = "test" 6 | hashed_passphrase = auth.passwd_hash(passphrase) 7 | assert auth.passwd_check(hashed_passphrase, passphrase) 8 | -------------------------------------------------------------------------------- /tests/unit/test_completion_utils.py: -------------------------------------------------------------------------------- 1 | from experiment import completion_utils 2 | 3 | COMPLETION_ERROR_STRING = "completion_utils completion error" 4 | 5 | 6 | def mock_get_completion_error(*args, **kwargs): 7 | raise ValueError(COMPLETION_ERROR_STRING) 8 | 9 | 10 | def mock_get_completion_noerror(*args, **kwargs): 11 | return "Test completion" 12 | 13 | 14 | def test_get_completion_noraise(monkeypatch, caplog): 15 | # verify behavior when errors occur 16 | monkeypatch.setattr( 17 | "experiment.completion_utils.get_completion", 18 | mock_get_completion_error, 19 | ) 20 | assert completion_utils.get_completion_noraise([], sleep=0, sleep_time_between_attempts=0, max_attempts=2) is None 21 | for record in caplog.records: 22 | assert record.name == completion_utils.logger.name 23 | assert COMPLETION_ERROR_STRING in record.message 24 | assert ( 25 | len(caplog.records) == 3 26 | ), "Expected max_attempts error messages from attempts + 1 caught by get_completion_noraise." 27 | 28 | # verify no-error behavior 29 | monkeypatch.setattr( 30 | "experiment.completion_utils.get_completion", 31 | mock_get_completion_noerror, 32 | ) 33 | assert ( 34 | completion_utils.get_completion_noraise([], sleep=0, sleep_time_between_attempts=0, max_attempts=2) 35 | == "Test completion" 36 | ) 37 | -------------------------------------------------------------------------------- /tests/unit/test_generate.py: -------------------------------------------------------------------------------- 1 | import string 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | from experiment import generate 7 | 8 | 9 | def mock_get_completion(*args, **kwargs): 10 | choices = list(string.ascii_lowercase) 11 | return " ".join(np.random.choice(choices, size=50)) 12 | 13 | 14 | def mock_get_completion_error(*args, **kwargs): 15 | raise ValueError(f"Args: {args}") 16 | 17 | 18 | @pytest.fixture 19 | def patch_get_completion(monkeypatch): 20 | monkeypatch.setattr( 21 | "experiment.completion_utils.get_completion", 22 | mock_get_completion, 23 | ) 24 | 25 | 26 | def test_GenerationCorpus(tmp_path, patch_get_completion): 27 | corpus = generate.GenerationCorpus(tmp_path, "test_corpus1") 28 | assert len(corpus.generations) == 0 29 | messages = [ 30 | { 31 | "role": "user", 32 | "content": "Test query", 33 | }, 34 | ] 35 | metadata = {"test_key": "test_value"} 36 | assert corpus.generate(messages, metadata, sleep=0) 37 | assert not corpus.generate(messages, metadata, sleep=0) 38 | assert len(corpus.generations) == 1 39 | assert corpus.generations[0]["test_key"] == "test_value" 40 | assert corpus.generations[0]["generation"] is not None 41 | metadata = metadata.copy() # NOTE: this will fail without a copy()! 42 | metadata["test_key"] = "test_value2" 43 | assert not corpus.is_already_generated(messages, metadata) 44 | assert corpus.generate(messages, metadata, sleep=0) 45 | assert len(corpus.generations) == 2 46 | assert corpus.generations[-1]["test_key"] == "test_value2" 47 | 48 | corpus2 = generate.GenerationCorpus(tmp_path, "test_corpus1") 49 | assert len(corpus2.generations) == 2 50 | assert corpus2.generations[0]["test_key"] == "test_value" 51 | assert corpus2.generations[-1]["test_key"] == "test_value2" 52 | 53 | 54 | def test_GenerationCorpus_batch(tmp_path): 55 | corpus = generate.GenerationCorpus(tmp_path, "test_corpus") 56 | assert len(corpus.generations) == 0 57 | messages = [ 58 | { 59 | "role": "user", 60 | "content": "Test query", 61 | }, 62 | ] 63 | n_to_generate = 200 64 | metadata_list = [{"test_id": i, "messages": messages} for i in range(n_to_generate)] 65 | n_generated = corpus.batch_generate( 66 | metadata_list, 67 | n_processes=4, 68 | sleep=0, 69 | completion_func=mock_get_completion, 70 | ) 71 | assert n_generated == n_to_generate, f"Expected {n_generated} generations, but only produced {n_to_generate}" 72 | assert len(corpus.generations) == n_to_generate 73 | for generation in corpus.generations: 74 | assert generation["generation"] is not None 75 | 76 | assert ( 77 | corpus.batch_generate( 78 | metadata_list, 79 | n_processes=4, 80 | sleep=0, 81 | completion_func=mock_get_completion, 82 | fake_kwarg=0, # also test passing a kwarg 83 | ) 84 | == 0 85 | ) 86 | 87 | # test get_nonmatching_generations 88 | n_generations = len(corpus.generations) 89 | metadata_list = [dict(metadata) for metadata in metadata_list] 90 | popped_metadata = metadata_list.pop(0) 91 | nonmatching_generations = corpus.get_nonmatching_generations(metadata_list) 92 | assert len(nonmatching_generations) == 1 93 | assert nonmatching_generations[0] == popped_metadata 94 | assert popped_metadata in corpus.generations, "Unexpectedly removed from generations." 95 | assert n_generations == len(corpus.generations) 96 | assert len(corpus.get_nonmatching_generations(metadata_list, should_remove_nonmatching=True)) == 1 97 | assert ( 98 | len(corpus.generations) == n_generations - 1 99 | ), "Should be one fewer generation after removing non-matching generations." 100 | 101 | # test filter_generations() and overwite() 102 | n_generations = len(corpus.generations) 103 | assert corpus.filter_generations() == 0 104 | assert corpus.filter_generations(lambda _: False) == n_generations 105 | assert len(corpus.generations) == 0 106 | corpus.overwrite() 107 | 108 | 109 | def test_GenerationCorpus_batch_error(tmp_path): 110 | corpus = generate.GenerationCorpus(tmp_path, "test_corpus") 111 | assert len(corpus.generations) == 0 112 | messages = [ 113 | { 114 | "role": "user", 115 | "content": "Test query", 116 | }, 117 | ] 118 | n_to_generate = 2 119 | metadata_list = [{"test_id": i, "messages": messages} for i in range(n_to_generate)] 120 | with pytest.raises(ValueError): 121 | corpus.batch_generate( 122 | metadata_list, 123 | n_processes=1, 124 | sleep=0, 125 | completion_func=mock_get_completion_error, 126 | ) 127 | -------------------------------------------------------------------------------- /tests/unit/test_metrics.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from experiment import metrics 4 | 5 | 6 | def test_metric_objects(): 7 | assert metrics.get_bertscore_metric_object() is not None 8 | assert metrics.get_bleurt_metric_object() is not None 9 | 10 | 11 | def test_compute_bleurt(): 12 | passages = ["The alphabet is 26 letters long.", "Math is not so easy."] 13 | generation = "The English alphabet is 26 letters long." 14 | assert metrics.compute_bleurt(passages, generation) >= 0.3 15 | 16 | 17 | def test_compute_bertscore(): 18 | passages = ["The alphabet is 26 letters long.", "Math is not so easy."] 19 | generation = "The English alphabet is 26 letters long." 20 | assert metrics.compute_bertscore(passages, generation) >= 0.3 21 | 22 | 23 | def test_compute_macro_f1(): 24 | assert metrics.compute_macro_f1(["Test text", "distractor passage"], "test text") == 1 25 | assert metrics.compute_macro_f1(["Test"], "test text") == 2 / 3 26 | assert metrics.compute_macro_f1(["distractor", "distractor passage"], "test text") == 0 27 | with pytest.raises(ValueError): 28 | metrics.compute_macro_f1(["test"], "") 29 | with pytest.raises(ValueError): 30 | metrics.compute_macro_f1([""], "test") 31 | 32 | # K-F1++ 33 | assert ( 34 | metrics.compute_macro_f1( 35 | ["George Washington"], 36 | "Who was the first president? George Washington", 37 | discount_text="Who was the first president?", 38 | ) 39 | == 1 40 | ) 41 | -------------------------------------------------------------------------------- /tests/unit/test_qualtrics.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from datetime import datetime 3 | 4 | import pandas as pd 5 | import pytest 6 | 7 | from experiment import qualtrics 8 | 9 | 10 | @pytest.fixture(scope="module") 11 | def qualtrics_template_filepath(pytestconfig) -> pathlib.Path: 12 | template_filepath = pytestconfig.rootpath / "data" / "raw" / "qualtrics" / "Rori_ranking_annotations_-_template.qsf" 13 | if not template_filepath.exists(): 14 | pytest.skip(f"Qualtrics template at {template_filepath} does not exist.") 15 | return template_filepath 16 | 17 | 18 | def test_create_surveys(qualtrics_template_filepath, tmp_path): 19 | template_survey_text = qualtrics.get_template(qualtrics_template_filepath) 20 | assert template_survey_text is not None 21 | 22 | survey = qualtrics.create_survey("survey_id0", [], template_survey_text) 23 | assert qualtrics.OVERFLOW_RESPONSE in survey 24 | 25 | survey_dir = tmp_path / "test_surveys" 26 | df = pd.DataFrame( 27 | [ 28 | ["q1", "d1", "none", "g1", "a"], 29 | ["q1", "d1", "low", "g2", "a"], 30 | ["q1", "d1", "high", "g3", "a"], 31 | ], 32 | columns=["query", "document", "guidance", "generation", "metadata"], 33 | ) 34 | survey_df = qualtrics.create_surveys(df, template_survey_text, survey_dir) 35 | row = survey_df.iloc[0] 36 | assert row["query"] == "q1" 37 | assert row["document"] == "d1" 38 | assert set(row[["response1", "response2", "response3"]]) == {"g1", "g2", "g3"} 39 | assert row["survey_id"] == f"s_{datetime.now().strftime('%Y%m%d')}_1/1" 40 | -------------------------------------------------------------------------------- /tests/unit/test_tokenize.py: -------------------------------------------------------------------------------- 1 | from experiment import tokenize 2 | 3 | 4 | def test_get_tokens(): 5 | assert tokenize.get_tokens("Test text", lower=True) == ["test", "text"] 6 | assert tokenize.get_tokens("Test text", lower=False) == ["Test", "text"] 7 | 8 | assert tokenize.get_tokens("Test.", lower=False) == ["Test", "."] 9 | assert tokenize.get_tokens("Test.", lower=False, remove_nonalphanumeric_tokens=True) == ["Test"] 10 | --------------------------------------------------------------------------------