├── docs
└── images
│ ├── many_bugs.png
│ ├── zero_shot.png
│ ├── main_loops.png
│ ├── blank_test_tree.png
│ ├── first_suggestions.png
│ ├── topic_suggestions.png
│ └── manual_unconfirmed_inputs.png
├── adaptivetesting
├── resources
│ ├── favicon.png
│ └── main.js.LICENSE.txt
├── __init__.py
├── comm.py
├── utils
│ └── __init__.py
├── _model.py
├── embedders.py
├── _topic_model.py
├── _prompt_builder.py
├── _scorer.py
└── _server.py
├── client
├── .babelrc
├── src
│ ├── index.jsx
│ ├── total-value.jsx
│ ├── context-menu.jsx
│ ├── adatest.jsx
│ ├── bread-crum.jsx
│ ├── CommEvent.ts
│ ├── web-socket-comm.js
│ ├── jupyter-comm.js
│ ├── adatest.css
│ └── content-editable.jsx
├── README
├── tsconfig.json
├── package.json
├── webpack.config.js
└── dist
│ └── main.js.LICENSE.txt
├── test_trees
├── README.md
└── abstract_capabilities.csv
├── .gitignore
├── tests
├── simple_test_tree.csv
├── test_utils.py
├── test_test_tree.py
└── test_generators.py
├── development
├── launch.json
└── scripts
│ ├── build_wheel.py
│ └── install_from_wheel.py
├── LICENSE
├── .github
└── workflows
│ ├── python-app.yml
│ ├── python-wheel-build.yaml
│ └── codeql-analysis.yml
├── setup.py
├── SECURITY.md
├── README.md
└── notebooks
├── IMDB to Hotel Sentiment.ipynb
└── imdb_hotel_conversion.csv
/docs/images/many_bugs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/adaptive-testing/main/docs/images/many_bugs.png
--------------------------------------------------------------------------------
/docs/images/zero_shot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/adaptive-testing/main/docs/images/zero_shot.png
--------------------------------------------------------------------------------
/docs/images/main_loops.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/adaptive-testing/main/docs/images/main_loops.png
--------------------------------------------------------------------------------
/docs/images/blank_test_tree.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/adaptive-testing/main/docs/images/blank_test_tree.png
--------------------------------------------------------------------------------
/docs/images/first_suggestions.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/adaptive-testing/main/docs/images/first_suggestions.png
--------------------------------------------------------------------------------
/docs/images/topic_suggestions.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/adaptive-testing/main/docs/images/topic_suggestions.png
--------------------------------------------------------------------------------
/adaptivetesting/resources/favicon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/adaptive-testing/main/adaptivetesting/resources/favicon.png
--------------------------------------------------------------------------------
/docs/images/manual_unconfirmed_inputs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/adaptive-testing/main/docs/images/manual_unconfirmed_inputs.png
--------------------------------------------------------------------------------
/client/.babelrc:
--------------------------------------------------------------------------------
1 | {
2 | "presets": [
3 | "@babel/preset-env",
4 | "@babel/preset-react"
5 | ],
6 | "plugins": [
7 | "@babel/plugin-proposal-class-properties"
8 | ]
9 | }
--------------------------------------------------------------------------------
/client/src/index.jsx:
--------------------------------------------------------------------------------
1 | import React from 'react';
2 | import ReactDOM from 'react-dom';
3 | import AdaTest from './adatest'
4 |
5 |
6 | ReactDOM.render(
7 | ,
8 | document.getElementById('adatest_container_1')
9 | );
--------------------------------------------------------------------------------
/client/README:
--------------------------------------------------------------------------------
1 | To build the client run the following in the client directory:
2 |
3 | > npm install
4 |
5 | > npx webpack
6 |
7 | This will create the dist/build.js file.
8 |
9 | When doing dev use:
10 |
11 | npx webpack --watch
12 |
13 | to auto-rebuild the bundle when the JS code changes.
--------------------------------------------------------------------------------
/test_trees/README.md:
--------------------------------------------------------------------------------
1 | These sample test trees are provided without any warrenty to their accuracy. Because they can sometimes cover sensitive topics we encourage users to review and alter them as needed. The files do not represent the opinion of Microsoft and may not represent the opinions of any individual author. PRs are welcome!
2 |
--------------------------------------------------------------------------------
/client/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "outDir": "../adaptivetesting/resources/",
4 | "sourceMap": true,
5 | "noImplicitAny": false,
6 | "module": "commonjs",
7 | "esModuleInterop": true,
8 | "target": "es6",
9 | "jsx": "react",
10 | "baseUrl": ".",
11 | }
12 | }
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | gadfly.egg-info
2 | __pycache__
3 | *.pyc
4 | *.log
5 | .vscode
6 | node_modules
7 | .jekyll-cache
8 | .ipynb_checkpoints
9 | .sass-cache
10 | adaptivetesting.egg-info
11 | dist/gadfly-*
12 | local_scratch*
13 | wandb
14 | local_data
15 | /notebooks/TDD/data
16 | /notebooks/PPT
17 | /build
18 | /dist
19 | /test_trainer/runs
20 | /test_trees/local
21 | /adatest/utils/local
22 | /tests/local
23 |
--------------------------------------------------------------------------------
/tests/simple_test_tree.csv:
--------------------------------------------------------------------------------
1 | ,topic,input,output,label,labeler,description,model score,author
2 | 4f51624790ce4b03ae18680a695d0b67,,Test at top level,POSITIVE,Unknown,imputed,,,
3 | 17975b5cc7ce46cb858b4e5677368a90,/A,,,topic_marker,anonymous,,,
4 | d043a06e342a452494ca73d1b4ee7967,/A/B,,,topic_marker,anonymous,,,
5 | 79ffa3a899374d1eb6870929db5408b0,/A,"Test under A
6 | ",NEGATIVE,Unknown,imputed,,,
7 | 5e5764ab240940aba43cc5aa76cf5429,/A/C,,,topic_marker,anonymous,,,
8 | 4b256f1fb5c64e13bc73106d2146257e,/A/B,Test under B,NEGATIVE,Unknown,imputed,,,
9 |
--------------------------------------------------------------------------------
/adaptivetesting/__init__.py:
--------------------------------------------------------------------------------
1 | from ._test_tree import TestTree
2 | from ._test_tree_browser import TestTreeBrowser
3 | from ._scorer import Scorer, DummyScorer, ClassifierScorer, GeneratorScorer, RawScorer
4 | from ._server import serve
5 | from .embedders import _embed as embed
6 | from ._model import Model
7 | from ._topic_model import ChainTopicModel, StandardTopicModel
8 | from . import generators
9 |
10 | __version__ = '0.3.5'
11 |
12 | default_generators = {
13 | "abstract": TestTree(r"test_trees/abstract_capabilities.csv")
14 | }
15 | text_embedding_model = None
16 | image_embedding_model = None
--------------------------------------------------------------------------------
/development/launch.json:
--------------------------------------------------------------------------------
1 | {
2 | // Use IntelliSense to learn about possible attributes.
3 | // Hover to view descriptions of existing attributes.
4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5 | "version": "0.2.0",
6 | "configurations": [
7 | {
8 | "name": "Python: Current File",
9 | "type": "python",
10 | "request": "launch",
11 | "program": "${file}",
12 | "console": "integratedTerminal",
13 | "justMyCode": true
14 | },
15 | {
16 | "type": "pwa-chrome",
17 | "request": "launch",
18 | "name": "Launch Chrome against localhost",
19 | "url": "http://localhost:8080",
20 | "webRoot": "${workspaceFolder}/client"
21 | }
22 | ]
23 | }
24 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import adaptivetesting.utils as utils
4 |
5 |
6 | class TestIsSubTopic:
7 | @pytest.mark.parametrize(
8 | ["topic", "sub_topic"],
9 | [
10 | ("/A", "/A/B"),
11 | ("/A", "/A/B/C"),
12 | ("/A", "/A/B/C/"),
13 | ("/A ", "/A /B"),
14 | ("/A ", "/A /B "),
15 | ("/A ", "/A /B /C"),
16 | ],
17 | )
18 | def test_topic_is_subtopic(self, topic, sub_topic):
19 | assert utils.is_subtopic(topic, sub_topic)
20 |
21 | @pytest.mark.parametrize(
22 | ["topic", "not_sub_topic"], [("/A/B", "/A/C"), ("/A", "/AB")]
23 | )
24 | def test_topic_is_not_subtopic(self, topic, not_sub_topic):
25 | assert not utils.is_subtopic(topic, not_sub_topic)
26 |
27 | @pytest.mark.parametrize(
28 | ["topic"],
29 | # Extra ',' in tuples to resolve ambiguity between string and list of characters
30 | [("/A",), ("/A/B",), ("/A /B",), ("/A /B ",), ("/A/B/C",), ("/A /B /C",)],
31 | )
32 | def test_topic_own_subtopic(self, topic):
33 | assert utils.is_subtopic(topic, topic)
34 |
--------------------------------------------------------------------------------
/client/src/total-value.jsx:
--------------------------------------------------------------------------------
1 | import React from 'react';
2 | import autoBind from 'auto-bind';
3 | import { get, debounce } from 'lodash';
4 |
5 | export default class TotalValue extends React.Component {
6 | constructor(props) {
7 | super(props);
8 | autoBind(this);
9 |
10 | this.doStateUpdateDebounced = debounce(this.doStateUpdate, 100);
11 |
12 | this.pendingStateUpdate = {};
13 |
14 | // our starting state
15 | this.state = {
16 | // note that all the ids will also be properties of the state
17 | };
18 | }
19 |
20 | setSubtotal(id, subtotal) {
21 | this.pendingStateUpdate[id] = subtotal;
22 | this.doStateUpdateDebounced();
23 | }
24 |
25 | doStateUpdate() {
26 | this.setState(this.pendingStateUpdate);
27 | this.pendingStateUpdate = {};
28 | }
29 |
30 | render() {
31 | // we just sum up the current active subtotals
32 | let total = 0;
33 | for (let i in this.props.activeIds) {
34 | total += get(this.state, this.props.activeIds[i], 0);
35 | }
36 |
37 | return
38 | {total}
39 |
40 | }
41 | }
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/client/src/context-menu.jsx:
--------------------------------------------------------------------------------
1 | import React from 'react';
2 | import autoBind from 'auto-bind';
3 |
4 | export default class ContextMenu extends React.Component {
5 | static defaultProps = {
6 | top: 0,
7 | left: 0,
8 | open: false,
9 | onClick: () => undefined,
10 | rows: [
11 | "test"
12 | ]
13 | };
14 |
15 | constructor(props) {
16 | super(props);
17 | autoBind(this);
18 | }
19 |
20 | render() {
21 | return
22 |
23 |
24 | {this.props.rows && this.props.rows.map((row, index) => {
25 | return
this.handleRowClick(row, e)} className="adatest-hover-gray">{row}
26 | })}
27 |
28 |
29 | }
30 |
31 | handleBackgroundClick(e) {
32 | e.preventDefault();
33 | this.props.onClose();
34 | }
35 |
36 | handleRowClick(row, e) {
37 | e.preventDefault();
38 | this.props.onClick(row);
39 | }
40 | }
--------------------------------------------------------------------------------
/development/scripts/build_wheel.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import platform
4 | import subprocess
5 |
6 | _logger = logging.getLogger(__file__)
7 | logging.basicConfig(level=logging.INFO)
8 |
9 |
10 | def build_client():
11 | # Find our initial directory
12 | _logger.info("Starting build_client")
13 | _logger.info("Running npm install")
14 | # Spawning npm appears to work differently in PowerShell
15 | spawn_shell = platform.system() == "Windows"
16 | subprocess.run(
17 | ["npm", "install", "--loglevel", "verbose"],
18 | shell=spawn_shell,
19 | cwd="client",
20 | check=True,
21 | )
22 | _logger.info("Running npx webpack")
23 | subprocess.run(["npx", "webpack"], shell=spawn_shell, cwd="client", check=True)
24 | _logger.info("Ending build_client")
25 |
26 |
27 | def build_wheel():
28 | _logger.info("Starting build_wheel")
29 | subprocess.run(["python", "setup.py", "sdist", "bdist_wheel"], check=True)
30 | _logger.info("Ending build_wheel")
31 |
32 |
33 | def main():
34 | assert "setup.py" in os.listdir(), "Must be run from repo root"
35 | # Build the client
36 | build_client()
37 |
38 | # Build the wheel
39 | build_wheel()
40 |
41 | _logger.info("Completed")
42 |
43 |
44 | if __name__ == "__main__":
45 | main()
46 |
--------------------------------------------------------------------------------
/development/scripts/install_from_wheel.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import pathlib
4 | import subprocess
5 | import sys
6 |
7 | _logger = logging.getLogger(__file__)
8 | logging.basicConfig(level=logging.INFO)
9 |
10 |
11 | def build_argument_parser():
12 | desc = "Install Adaptive Testing from a wheel file"
13 |
14 | parser = argparse.ArgumentParser(description=desc)
15 | parser.add_argument(
16 | "--wheel-dir",
17 | help="Directory containing the AdaTest wheel",
18 | required=True,
19 | )
20 |
21 | return parser
22 |
23 |
24 | def main(argv):
25 | parser = build_argument_parser()
26 | args = parser.parse_args(argv)
27 |
28 | _logger.info("Finding wheel file")
29 | target_dir = pathlib.Path(args.wheel_dir)
30 | # Globbing works from Python, but not in Windows builds
31 | wheel_list = list(target_dir.glob("adaptivetesting*.whl"))
32 | assert len(wheel_list) == 1, f"Bad wheel_list: {wheel_list}"
33 | wheel_path = wheel_list[0].resolve()
34 | msg = f"Path to wheel: {wheel_path}"
35 | _logger.info(msg)
36 |
37 | _logger.info("Installing wheel")
38 | # Use this approach so that extras can be added
39 | adatest_spec = f"adaptivetesting[dev] @ {wheel_path.as_uri()}"
40 | subprocess.run(["pip", "install", f"{adatest_spec}"], check=True)
41 |
42 |
43 | if __name__ == "__main__":
44 | main(sys.argv[1:])
45 |
--------------------------------------------------------------------------------
/.github/workflows/python-app.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: Python application
5 |
6 | on:
7 | push:
8 | branches: [ main ]
9 | pull_request:
10 | branches: [ main ]
11 |
12 | jobs:
13 | build:
14 |
15 | runs-on: ${{ matrix.os }}
16 | strategy:
17 | matrix:
18 | os: [ubuntu-latest, macos-12, windows-latest]
19 | python-version: ['3.7', '3.8', '3.9', '3.10']
20 |
21 | steps:
22 | - uses: actions/checkout@v2
23 | - name: Set up Python ${{ matrix.python-version }} on ${{ matrix.os }}
24 | uses: actions/setup-python@v2
25 | with:
26 | python-version: ${{ matrix.python-version }}
27 | - name: Install dependencies
28 | run: |
29 | python -m pip install --upgrade pip
30 | pip install -e '.[dev]'
31 | - name: Lint with flake8
32 | run: |
33 | # stop the build if there are Python syntax errors or undefined names
34 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
35 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
36 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
37 | - name: Test with pytest
38 | run: |
39 | python -m pytest tests/
40 |
--------------------------------------------------------------------------------
/adaptivetesting/comm.py:
--------------------------------------------------------------------------------
1 | import time
2 | import json
3 | import logging
4 | log = logging.getLogger(__name__)
5 |
6 | class JupyterComm():
7 | def __init__(self, target_name, callback=None, mode="register"):
8 | from ipykernel.comm import Comm
9 |
10 | self.target_name = target_name
11 | self.callback = callback
12 | self.jcomm = None
13 | if mode == "register":
14 | def comm_opened(comm, open_msg):
15 | self.jcomm = comm
16 | self.jcomm.on_msg(self._fire_callback)
17 | get_ipython().kernel.comm_manager.register_target(self.target_name, comm_opened) # noqa: F821
18 | elif mode == "open":
19 | self.jcomm = Comm(target_name=target_name)
20 | self.jcomm.on_msg(self._fire_callback)
21 | else:
22 | raise Exception("Passed mode must be either 'open' or 'register'!")
23 |
24 | def _fire_callback(self, msg):
25 | self.callback(msg["content"]["data"])
26 |
27 | def send(self, data):
28 | for i in range(10):
29 | if self.jcomm is None:
30 | time.sleep(0.5)
31 | else:
32 | s = json.dumps(data)
33 | self.jcomm.send({"data": json.dumps(data)}) # we encode the JSON so iPython doesn't mess it up
34 | return
35 | raise Exception("The Jupyter comm channel was never opened from the other side, so not message can be sent!")
36 |
37 |
--------------------------------------------------------------------------------
/client/src/adatest.jsx:
--------------------------------------------------------------------------------
1 | import "./adatest.css";
2 |
3 | import React from 'react';
4 | import ReactDOM from 'react-dom';
5 | import { withRouter } from 'react-router-dom';
6 | import { BrowserRouter } from "react-router-dom";
7 | import { MemoryRouter } from 'react-router';
8 | import Browser from './browser'
9 |
10 | const BrowserWithRouter = withRouter(Browser);
11 |
12 | export default class AdaTest extends React.Component {
13 |
14 | constructor(props) {
15 | super(props);
16 | console.log("interfaceId", this.props.interfaceId)
17 | this.state = { enabled: true };
18 | window.adatest_root = this;
19 | }
20 | render() {
21 |
22 | const Router = this.props.environment === "web" ? BrowserRouter : MemoryRouter;
23 |
24 | return (
25 |
26 |
27 |
28 |
33 |
34 |
35 |
36 | );
37 | }
38 | }
39 |
40 | window.AdaTestReact = React
41 | window.AdaTestReactDOM = ReactDOM
42 | window.AdaTest = AdaTest
43 |
44 |
--------------------------------------------------------------------------------
/client/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "client",
3 | "version": "1.0.0",
4 | "description": "",
5 | "main": "index.js",
6 | "scripts": {
7 | "test": "echo \"Error: no test specified\" && exit 1",
8 | "build": "NODE_ENV=production webpack",
9 | "build:dev": "NODE_ENV=dev webpack"
10 | },
11 | "keywords": [],
12 | "author": "",
13 | "license": "ISC",
14 | "devDependencies": {
15 | "@babel/core": "^7.13.14",
16 | "@babel/preset-env": "^7.13.12",
17 | "@babel/preset-react": "^7.13.13",
18 | "@babel/preset-typescript": "^7.18.6",
19 | "@fortawesome/fontawesome-free": "^5.15.3",
20 | "@types/react": "^17.0.47",
21 | "@types/react-dom": "^18.0.5",
22 | "@types/react-router-dom": "^5.3.3",
23 | "babel": "^6.23.0",
24 | "babel-core": "^6.26.3",
25 | "babel-loader": "^8.2.2",
26 | "css-loader": "^5.2.0",
27 | "source-map-loader": "^4.0.0",
28 | "string-replace-loader": "^3.0.1",
29 | "style-loader": "^2.0.0",
30 | "ts-loader": "^9.3.1",
31 | "typescript": "^4.7.4",
32 | "webpack": "^5.30.0",
33 | "webpack-cli": "^4.6.0"
34 | },
35 | "dependencies": {
36 | "@fortawesome/fontawesome-svg-core": "^1.2.35",
37 | "@fortawesome/free-solid-svg-icons": "^5.15.3",
38 | "@fortawesome/react-fontawesome": "^0.1.14",
39 | "@material-ui/core": "^4.11.4",
40 | "@material-ui/icons": "^4.11.2",
41 | "@material-ui/lab": "^4.0.0-alpha.58",
42 | "auto-bind": "^4.0.0",
43 | "json5": "^2.2.0",
44 | "lodash": "^4.17.21",
45 | "react": "^17.0.2",
46 | "react-beautiful-dnd": "^13.1.0",
47 | "react-dom": "^17.0.2",
48 | "react-router-dom": "^5.2.0",
49 | "sanitize-html": "^2.3.3"
50 | }
51 | }
52 |
--------------------------------------------------------------------------------
/client/webpack.config.js:
--------------------------------------------------------------------------------
1 | const path = require('path');
2 |
3 | const isDevelopment = process.env.NODE_ENV === 'dev';
4 |
5 | module.exports = {
6 | entry: path.resolve(__dirname, './src/adatest.jsx'),
7 | devtool: isDevelopment ? 'eval-source-map' : false,
8 | module: {
9 | rules: [
10 | {
11 | test: /\.(js|jsx)$/,
12 | exclude: /node_modules/,
13 | use: ['babel-loader'],
14 | },
15 | {
16 | test: /\.ts(x?)$/,
17 | exclude: /node_modules/,
18 | use: [
19 | {
20 | loader: "ts-loader"
21 | }
22 | ]
23 | },
24 | {
25 | test: /\.css$/i,
26 | use: ["style-loader", "css-loader"],
27 | },
28 | {
29 | test: /\.(png|woff|woff2|eot|ttf|svg)$/,
30 | use: ['url-loader'],
31 | },
32 | { // this allows font-awesome to be used during development mode... (since we print to the page in a script tag)
33 | test: /\.js$/,
34 | loader: 'string-replace-loader',
35 | options: {
36 | search: '',
37 | replace: '_/script>',
38 | }
39 | },
40 | // https://github.com/webpack-contrib/source-map-loader
41 | {
42 | enforce: "pre",
43 | test: /\.js$/,
44 | loader: "source-map-loader"
45 | }
46 | ],
47 | },
48 | resolve: {
49 | extensions: ['*', '.js', '.jsx', '.ts', '.tsx'],
50 | },
51 | externals: {
52 | // 'react': 'React',
53 | // 'react-dom': 'ReactDOM'
54 | },
55 | output: {
56 | path: path.resolve(__dirname, '../adatest/resources'),
57 | filename: 'main.js',
58 | },
59 | mode: isDevelopment ? "development" : "production"
60 | };
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import codecs
4 | from setuptools import setup, find_packages
5 |
6 | here = os.path.abspath(os.path.dirname(__file__))
7 |
8 |
9 | def read(*parts):
10 | with codecs.open(os.path.join(here, *parts), "r") as fp:
11 | return fp.read()
12 |
13 |
14 | def find_version(*file_paths):
15 | version_file = read(*file_paths)
16 | version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M)
17 | if version_match:
18 | return version_match.group(1)
19 | raise RuntimeError("Unable to find version string.")
20 |
21 |
22 | setup(
23 | name="adaptivetesting",
24 | version=find_version("adaptivetesting", "__init__.py"),
25 | url="https://github.com/microsoft/adaptive-testing.git",
26 | author="Scott Lundberg and Marco Tulio Ribeiro",
27 | author_email="scott.lundberg@microsoft.com",
28 | description="Adaptively test and debug any natural language machine learning model.",
29 | packages=find_packages(exclude=["user_studies", "notebooks", "client"]),
30 | package_data={"adaptivetesting": ["resources/*"]},
31 | install_requires=[
32 | "aiohttp",
33 | "aiohttp_security",
34 | "aiohttp_session",
35 | "appdirs",
36 | "cryptography",
37 | "diskcache",
38 | "nest_asyncio",
39 | "numpy",
40 | "pandas",
41 | "profanity",
42 | "scikit-learn",
43 | "shap",
44 | ],
45 | extras_require={
46 | "dev": [
47 | "black",
48 | "flake8",
49 | "openai<1",
50 | "datasets",
51 | "transformers<4.26",
52 | "pytest",
53 | "pytest-mock",
54 | "torch",
55 | ]
56 | },
57 | )
58 |
--------------------------------------------------------------------------------
/client/dist/main.js.LICENSE.txt:
--------------------------------------------------------------------------------
1 | /*
2 | object-assign
3 | (c) Sindre Sorhus
4 | @license MIT
5 | */
6 |
7 | /*!
8 | * Font Awesome Free 5.15.3 by @fontawesome - https://fontawesome.com
9 | * License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License)
10 | */
11 |
12 | /*!
13 | * is-plain-object
14 | *
15 | * Copyright (c) 2014-2017, Jon Schlinkert.
16 | * Released under the MIT License.
17 | */
18 |
19 | /**
20 | * @license
21 | * Lodash
22 | * Copyright OpenJS Foundation and other contributors
23 | * Released under MIT license
24 | * Based on Underscore.js 1.8.3
25 | * Copyright Jeremy Ashkenas, DocumentCloud and Investigative Reporters & Editors
26 | */
27 |
28 | /** @license React v0.20.2
29 | * scheduler.production.min.js
30 | *
31 | * Copyright (c) Facebook, Inc. and its affiliates.
32 | *
33 | * This source code is licensed under the MIT license found in the
34 | * LICENSE file in the root directory of this source tree.
35 | */
36 |
37 | /** @license React v16.13.1
38 | * react-is.production.min.js
39 | *
40 | * Copyright (c) Facebook, Inc. and its affiliates.
41 | *
42 | * This source code is licensed under the MIT license found in the
43 | * LICENSE file in the root directory of this source tree.
44 | */
45 |
46 | /** @license React v17.0.2
47 | * react-dom.production.min.js
48 | *
49 | * Copyright (c) Facebook, Inc. and its affiliates.
50 | *
51 | * This source code is licensed under the MIT license found in the
52 | * LICENSE file in the root directory of this source tree.
53 | */
54 |
55 | /** @license React v17.0.2
56 | * react.production.min.js
57 | *
58 | * Copyright (c) Facebook, Inc. and its affiliates.
59 | *
60 | * This source code is licensed under the MIT license found in the
61 | * LICENSE file in the root directory of this source tree.
62 | */
63 |
--------------------------------------------------------------------------------
/adaptivetesting/resources/main.js.LICENSE.txt:
--------------------------------------------------------------------------------
1 | /*
2 | object-assign
3 | (c) Sindre Sorhus
4 | @license MIT
5 | */
6 |
7 | /*!
8 | * Font Awesome Free 5.15.3 by @fontawesome - https://fontawesome.com
9 | * License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License)
10 | */
11 |
12 | /*!
13 | * is-plain-object
14 | *
15 | * Copyright (c) 2014-2017, Jon Schlinkert.
16 | * Released under the MIT License.
17 | */
18 |
19 | /**
20 | * @license
21 | * Lodash
22 | * Copyright OpenJS Foundation and other contributors
23 | * Released under MIT license
24 | * Based on Underscore.js 1.8.3
25 | * Copyright Jeremy Ashkenas, DocumentCloud and Investigative Reporters & Editors
26 | */
27 |
28 | /** @license React v0.20.2
29 | * scheduler.production.min.js
30 | *
31 | * Copyright (c) Facebook, Inc. and its affiliates.
32 | *
33 | * This source code is licensed under the MIT license found in the
34 | * LICENSE file in the root directory of this source tree.
35 | */
36 |
37 | /** @license React v16.13.1
38 | * react-is.production.min.js
39 | *
40 | * Copyright (c) Facebook, Inc. and its affiliates.
41 | *
42 | * This source code is licensed under the MIT license found in the
43 | * LICENSE file in the root directory of this source tree.
44 | */
45 |
46 | /** @license React v17.0.2
47 | * react-dom.production.min.js
48 | *
49 | * Copyright (c) Facebook, Inc. and its affiliates.
50 | *
51 | * This source code is licensed under the MIT license found in the
52 | * LICENSE file in the root directory of this source tree.
53 | */
54 |
55 | /** @license React v17.0.2
56 | * react.production.min.js
57 | *
58 | * Copyright (c) Facebook, Inc. and its affiliates.
59 | *
60 | * This source code is licensed under the MIT license found in the
61 | * LICENSE file in the root directory of this source tree.
62 | */
63 |
--------------------------------------------------------------------------------
/client/src/bread-crum.jsx:
--------------------------------------------------------------------------------
1 | import React from 'react';
2 | import autoBind from 'auto-bind';
3 |
4 | export default class BreadCrum extends React.Component {
5 | constructor(props) {
6 | super(props);
7 | autoBind(this);
8 |
9 | this.state = {
10 | dropHighlighted: 0
11 | };
12 | }
13 |
14 | render() {
15 | // console.log("br", this.props.name, this.props.name === "")
16 | return
19 | {this.props.name === "" ? this.props.defaultName : decodeURIComponent(this.props.name)}
20 |
21 | }
22 |
23 | onClick(e) {
24 | e.preventDefault();
25 | e.stopPropagation();
26 | if (this.props.onClick) {
27 | if (this.props.name === "") {
28 | this.props.onClick(this.props.topic);
29 | } else {
30 | this.props.onClick(this.props.topic + "/" + this.props.name);
31 | }
32 | }
33 | }
34 |
35 | onDragOver(e) {
36 | e.preventDefault();
37 | e.stopPropagation();
38 | }
39 |
40 | onDragEnter(e) {
41 | e.preventDefault();
42 | e.stopPropagation();
43 | this.setState({dropHighlighted: this.state.dropHighlighted + 1});
44 | }
45 |
46 | onDragLeave(e) {
47 | e.preventDefault();
48 | e.stopPropagation();
49 | this.setState({dropHighlighted: this.state.dropHighlighted - 1});
50 | }
51 |
52 | onDrop(e) {
53 | const id = e.dataTransfer.getData("id");
54 | this.setState({dropHighlighted: 0});
55 | if (this.props.onDrop) {
56 | let suffix = "";
57 | if (id.includes("/")) {
58 | suffix = "/" + id.split("/").pop();
59 | }
60 | this.props.onDrop(id, this.props.topic + (this.props.name === "" ? "" : "/" + this.props.name) + suffix);
61 | }
62 | }
63 | }
--------------------------------------------------------------------------------
/adaptivetesting/utils/__init__.py:
--------------------------------------------------------------------------------
1 | import re
2 | import urllib
3 | import io
4 | import shap
5 |
6 | import numpy as np
7 |
8 |
9 | def parse_test_type(test_type):
10 | part_names = ["text1", "value1", "text2", "value2", "text3", "value3", "text4"]
11 | parts = re.split(r"(\{\}|\[\])", test_type)
12 | part_values = ["" for _ in range(7)]
13 | for i, part in enumerate(parts):
14 | part_values[i] = part
15 | return {name: value for name, value in zip(part_names, part_values)}
16 |
17 |
18 | # https://codereview.stackexchange.com/questions/253198/improved-isinstance-for-ipython
19 | def isinstance_ipython(obj, ref_class):
20 | def _class_name(obj):
21 | name = getattr(obj, "__qualname__", getattr(obj, "__name__", ""))
22 | return (getattr(obj, "__module__", "") + "." + name).lstrip(".")
23 |
24 | return isinstance(obj, ref_class) or _class_name(type(obj)) == _class_name(
25 | ref_class
26 | )
27 |
28 |
29 | _images_cache = {}
30 |
31 |
32 | def get_image(url):
33 | if url not in _images_cache:
34 | try:
35 | _images_cache[url] = _download_image(url)
36 | except urllib.error.URLError:
37 | _images_cache[url] = get_image(
38 | "https://upload.wikimedia.org/wikipedia/commons/d/d1/Image_not_available.png"
39 | )
40 |
41 | return _images_cache[url]
42 |
43 |
44 | def _download_image(url):
45 | import PIL
46 |
47 | urllib_request = urllib.request.Request(
48 | url,
49 | data=None,
50 | headers={
51 | "User-Agent": "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:72.0) Gecko/20100101 Firefox/72.0"
52 | },
53 | )
54 | with urllib.request.urlopen(urllib_request, timeout=10) as r:
55 | img_stream = io.BytesIO(r.read())
56 | return PIL.Image.open(img_stream)
57 |
58 |
59 | def is_subtopic(topic, candidate):
60 | # Returns true if candidate is a subtopic of topic
61 | # Both arguments are strings, which look like UNIX paths
62 | # Return is boolean
63 | # return True if re.search(r"^%s(/|$)" % re.escape(topic), candidate) else False
64 | if len(topic) == len(candidate):
65 | return topic == candidate
66 | else:
67 | return candidate.startswith(topic) and candidate[len(topic)] == "/"
68 |
69 |
70 | def convert_float(s):
71 | try:
72 | f = float(s)
73 | except ValueError:
74 | f = np.nan
75 | return f
76 |
--------------------------------------------------------------------------------
/.github/workflows/python-wheel-build.yaml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: Build Python Wheel
5 |
6 | on:
7 | push:
8 | branches: [ main ]
9 | pull_request:
10 | branches: [ main ]
11 |
12 |
13 | env:
14 | # Select the Artifact to be used in the 'test' job
15 | blessed-wheel-artifact: wheel-ubuntu-latest-3.10
16 |
17 |
18 | jobs:
19 | build:
20 | runs-on: ${{ matrix.os }}
21 | strategy:
22 | matrix:
23 | os: [ubuntu-latest, windows-latest]
24 | python-version: ['3.10']
25 |
26 | steps:
27 | - uses: actions/checkout@v2
28 | - name: Set up Python ${{ matrix.python-version }} on ${{ matrix.os }}
29 | uses: actions/setup-python@v2
30 | with:
31 | python-version: ${{ matrix.python-version }}
32 | - name: Update pip and setuptools
33 | run: |
34 | pip install --upgrade pip
35 | pip install --upgrade setuptools wheel
36 | - name: Check node version
37 | run: npm --version
38 | - name: Run build_wheel script
39 | run: python development/scripts/build_wheel.py
40 | - name: Save wheel artifact
41 | uses: actions/upload-artifact@v3
42 | with:
43 | # Name must match blessed-wheel-artifact above
44 | name: wheel-${{ matrix.os }}-${{ matrix.python-version }}
45 | path: dist/*
46 |
47 | test:
48 | needs: build
49 | runs-on: ${{ matrix.os }}
50 | strategy:
51 | matrix:
52 | os: [macos-12-latest, ubuntu-latest, windows-latest]
53 | python-version: ['3.8', '3.9', '3.10']
54 |
55 | steps:
56 | - uses: actions/checkout@v2
57 | - name: Set up Python ${{ matrix.python-version }} on ${{ matrix.os }}
58 | uses: actions/setup-python@v2
59 | with:
60 | python-version: ${{ matrix.python-version }}
61 | - name: Remove AdaTest source directory
62 | uses: JesseTG/rm@v1.0.3
63 | with:
64 | path: adaptivetesting
65 | - name: Download wheel artifact
66 | uses: actions/download-artifact@v3
67 | with:
68 | name: ${{ env.blessed-wheel-artifact }}
69 | - name: Install wheel from file
70 | run: python development/scripts/install_from_wheel.py --wheel-dir=.
71 | - name: Test with pytest
72 | run: |
73 | python -m pytest tests/
--------------------------------------------------------------------------------
/tests/test_test_tree.py:
--------------------------------------------------------------------------------
1 | from operator import truediv
2 | import os
3 | import pathlib
4 | import tempfile
5 |
6 | import numpy as np
7 |
8 | import adaptivetesting
9 |
10 |
11 | def test_simple_init():
12 | tree = adaptivetesting.TestTree()
13 | assert len(tree) == 0
14 |
15 |
16 | def test_simple_init_with_file():
17 | tree = adaptivetesting.TestTree("temp_test_tree.csv")
18 | assert len(tree) == 0
19 |
20 |
21 | def test_simple_init_with_list():
22 | tree = adaptivetesting.TestTree(["The food was nice!", "The location is excellent."])
23 | assert len(tree) == 3
24 | assert tree.columns.to_list() == [
25 | "topic",
26 | "input",
27 | "output",
28 | "label",
29 | "labeler",
30 | "description",
31 | ]
32 | assert "The food was nice!" in tree["input"].to_list()
33 | assert "The location is excellent." in tree["input"].to_list()
34 | outputs = np.unique(tree["output"].to_list(), return_counts=True)
35 | assert outputs[0][0] == ""
36 | assert outputs[1][0] == 1
37 | assert outputs[0][1] == "[no output]"
38 | assert outputs[1][1] == 2
39 |
40 |
41 | def test_to_csv():
42 | tree = adaptivetesting.TestTree(
43 | [
44 | {
45 | "topic": "",
46 | "type": "{} should output {}",
47 | "input": "This is good",
48 | "output": "NEGATIVE",
49 | "label": "fail",
50 | }
51 | ]
52 | )
53 | with tempfile.TemporaryDirectory() as td:
54 | target_file = os.path.join(td, "adaptivetesting_out.csv")
55 | tree.to_csv(target_file)
56 | assert os.path.exists(target_file)
57 |
58 |
59 | def test_has_subtopic_or_tests():
60 | curr_file = pathlib.Path(__file__)
61 | curr_dir = curr_file.parent
62 | input_csv = curr_dir / "simple_test_tree.csv"
63 | assert input_csv.exists()
64 | tree = adaptivetesting.TestTree(str(input_csv))
65 | # The top level topic appears to be an empty string, which is odd
66 | assert tree.topic_has_subtopics("") == True
67 | assert tree.topic_has_direct_tests("") == True
68 | assert tree.topic_has_direct_tests("/A") == True
69 | assert tree.topic_has_subtopics("/A") == True
70 | assert tree.topic_has_direct_tests("/A/B") == True
71 | assert tree.topic_has_subtopics("/A/B") == False
72 | assert tree.topic_has_direct_tests("/A/C") == False
73 | assert tree.topic_has_subtopics("/A/C") == False
74 |
--------------------------------------------------------------------------------
/.github/workflows/codeql-analysis.yml:
--------------------------------------------------------------------------------
1 | # For most projects, this workflow file will not need changing; you simply need
2 | # to commit it to your repository.
3 | #
4 | # You may wish to alter this file to override the set of languages analyzed,
5 | # or to provide custom queries or build logic.
6 | #
7 | # ******** NOTE ********
8 | # We have attempted to detect the languages in your repository. Please check
9 | # the `language` matrix defined below to confirm you have the correct set of
10 | # supported CodeQL languages.
11 | #
12 | name: "CodeQL"
13 |
14 | on:
15 | push:
16 | branches: [ main ]
17 | pull_request:
18 | # The branches below must be a subset of the branches above
19 | branches: [ main ]
20 | schedule:
21 | - cron: '29 12 * * 6'
22 |
23 | jobs:
24 | analyze:
25 | name: Analyze
26 | runs-on: ubuntu-latest
27 | permissions:
28 | actions: read
29 | contents: read
30 | security-events: write
31 |
32 | strategy:
33 | fail-fast: false
34 | matrix:
35 | language: [ 'python' ]
36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ]
37 | # Learn more about CodeQL language support at https://git.io/codeql-language-support
38 |
39 | steps:
40 | - name: Checkout repository
41 | uses: actions/checkout@v2
42 |
43 | # Initializes the CodeQL tools for scanning.
44 | - name: Initialize CodeQL
45 | uses: github/codeql-action/init@v1
46 | with:
47 | languages: ${{ matrix.language }}
48 | # If you wish to specify custom queries, you can do so here or in a config file.
49 | # By default, queries listed here will override any specified in a config file.
50 | # Prefix the list here with "+" to use these queries and those in the config file.
51 | # queries: ./path/to/local/query, your-org/your-repo/queries@main
52 |
53 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
54 | # If this step fails, then you should remove it and run the build manually (see below)
55 | - name: Autobuild
56 | uses: github/codeql-action/autobuild@v1
57 |
58 | # ℹ️ Command-line programs to run using the OS shell.
59 | # 📚 https://git.io/JvXDl
60 |
61 | # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines
62 | # and modify them (or add more) to build your code if your project
63 | # uses a compiled language
64 |
65 | #- run: |
66 | # make bootstrap
67 | # make release
68 |
69 | - name: Perform CodeQL Analysis
70 | uses: github/codeql-action/analyze@v1
71 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).
40 |
41 |
--------------------------------------------------------------------------------
/client/src/CommEvent.ts:
--------------------------------------------------------------------------------
1 |
2 | // The base class for all communication data over websocket
3 | export class CommEvent {
4 | // Using snake case because these objects will be parsed in Python backend
5 | readonly event_id: string;
6 | constructor(event_id: string, data?: object) {
7 | this.event_id = event_id;
8 | if (data) {
9 | for (const k of Object.keys(data)) {
10 | this[k] = data[k];
11 | }
12 | }
13 | }
14 | }
15 |
16 |
17 | export function finishTopicDescription(topic_marker_id: string, description: string) {
18 | return new CommEvent("change_description", {"topic_marker_id": topic_marker_id, "description": description});
19 | }
20 |
21 |
22 | export function redraw() {
23 | return new CommEvent("redraw");
24 | }
25 |
26 |
27 | export function generateSuggestions(data: object) {
28 | return new CommEvent("generate_suggestions", data);
29 | }
30 |
31 |
32 | export function clearSuggestions() {
33 | return new CommEvent("clear_suggestions")
34 | }
35 |
36 |
37 | export function setFirstModel(model: string) {
38 | return new CommEvent("set_first_model", { "model": model });
39 | }
40 |
41 |
42 | export function changeGenerator(generator: string) {
43 | return new CommEvent("change_generator", { "generator": generator });
44 | }
45 |
46 |
47 | export function changeMode(mode: string) {
48 | return new CommEvent("change_mode", {"mode": mode})
49 | }
50 |
51 |
52 | export function addTopic() {
53 | return new CommEvent("add_new_topic")
54 | }
55 |
56 |
57 | export function addTest() {
58 | return new CommEvent("add_new_test");
59 | }
60 |
61 |
62 | export function changeFilter(filter_text: string) {
63 | return new CommEvent("change_filter", {"filter_text": filter_text});
64 | }
65 |
66 |
67 | export function changeTopic(topic: string) {
68 | return new CommEvent("change_topic", {"topic": topic});
69 | }
70 |
71 |
72 | export function moveTest(test_ids: string[] | string, topic: string) {
73 | if (!Array.isArray(test_ids)) {
74 | test_ids = [test_ids]
75 | }
76 | return new CommEvent("move_test", { "test_ids": test_ids, "topic": topic });
77 | }
78 |
79 |
80 | export function deleteTest(test_ids: string[] | string) {
81 | if (!Array.isArray(test_ids)) {
82 | test_ids = [test_ids]
83 | }
84 | return new CommEvent("delete_test", { "test_ids": test_ids });
85 | }
86 |
87 |
88 | export function changeLabel(test_id: string, label: string, labeler: string) {
89 | return new CommEvent("change_label", { "test_ids": [test_id], "label": label, "labeler": labeler });
90 | }
91 |
92 |
93 | export function changeInput(test_id: string, input: string) {
94 | return new CommEvent("change_input", { "test_ids": [test_id], "input": input });
95 | }
96 |
97 |
98 | export function changeOutput(test_id: string, output: string) {
99 | return new CommEvent("change_output", { "test_ids": [test_id], "output": output });
100 | }
101 |
--------------------------------------------------------------------------------
/client/src/web-socket-comm.js:
--------------------------------------------------------------------------------
1 | import JSON5 from 'json5';
2 | import autoBind from 'auto-bind';
3 | import { defer, debounce } from 'lodash';
4 |
5 | export default class WebSocketComm {
6 | constructor(interfaceId, websocketServer, onopen) {
7 | autoBind(this);
8 | this.interfaceId = interfaceId;
9 | this.websocketServer = websocketServer;
10 | this.callbackMap = {};
11 | this.data = {};
12 | this.pendingData = {};
13 | this.onopen = onopen;
14 | this.reconnectDelay = 100;
15 |
16 | this.debouncedSendPendingData500 = debounce(this.sendPendingData, 500);
17 | this.debouncedSendPendingData1000 = debounce(this.sendPendingData, 1000);
18 |
19 | this.connect();
20 | }
21 |
22 | send(keys, data) {
23 | this.addPendingData(keys, data);
24 | this.sendPendingData();
25 | }
26 |
27 | sendEvent(commEvent) {
28 | for (const k of Object.keys(commEvent)) {
29 | this.addPendingData(k, commEvent[k]);
30 | }
31 | this.sendPendingData();
32 | }
33 |
34 | debouncedSendEvent500(commEvent) {
35 | for (const k of Object.keys(commEvent)) {
36 | this.addPendingData(k, commEvent[k]);
37 | }
38 | this.debouncedSendPendingData500();
39 | }
40 |
41 | debouncedSend500(keys, data) {
42 | this.addPendingData(keys, data);
43 | this.debouncedSendPendingData500();
44 | }
45 |
46 | debouncedSend1000(keys, data) {
47 | this.addPendingData(keys, data);
48 | this.debouncedSendPendingData1000();
49 | }
50 |
51 | addPendingData(keys, data) {
52 | // console.log("addPendingData", keys, data);
53 | if (!Array.isArray(keys)) keys = [keys];
54 | for (const i in keys) {
55 | const k = keys[i];
56 | this.pendingData[k] = data;
57 | this.data[k] = Object.assign(this.data[k] || {}, data); // pretend it has already changed in our data cache
58 | }
59 | }
60 |
61 | connect() {
62 | let wsUri = (window.location.protocol=='https:' ? 'wss://' : 'ws://') + (this.websocketServer.startsWith("/") ? window.location.host : "") + this.websocketServer;
63 | this.wcomm = new WebSocket(wsUri);
64 | this.wcomm.onopen = this.onopen;
65 | this.wcomm.onmessage = this.updateData;
66 | this.wcomm.onerror = this.onError;
67 | this.wcomm.onclose = this.onClose;
68 | }
69 |
70 | updateData(e) {
71 | // console.log("updateData", e)
72 | let data = JSON5.parse(e.data);
73 | console.log("updateData", data)
74 | for (const k in data) {
75 | // console.log("data[k]", data[k])
76 | this.data[k] = Object.assign(this.data[k] || {}, data[k]);
77 | if (k in this.callbackMap) {
78 | this.callbackMap[k](data[k]);
79 | }
80 | }
81 | }
82 |
83 | onError(e) {
84 | console.log("Websocket error", e);
85 | }
86 |
87 | onClose(e) {
88 | console.log('Socket is closed. Reconnect will be attempted...', e.reason);
89 | setTimeout(this.connect, this.reconnectDelay);
90 | this.reconnectDelay += 1000;
91 | }
92 |
93 | subscribe(key, callback) {
94 | this.callbackMap[key] = callback;
95 | defer(_ => this.callbackMap[key](this.data[key]));
96 | }
97 |
98 | sendPendingData() {
99 | console.log("sending", this.pendingData);
100 | this.wcomm.send(JSON.stringify(this.pendingData));
101 | this.pendingData = {};
102 | }
103 | }
--------------------------------------------------------------------------------
/adaptivetesting/_model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import shap
3 |
4 |
5 | class Model():
6 | """ This wraps models used in Adaptive Testing so that have a consistent interface.
7 |
8 | This should eventually just be the Model class from SHAP, but we keep a simple version here for now
9 | so we can easily update it during initial development.
10 | """
11 |
12 | def __new__(cls, model, *args, **kwargs):
13 | """ If we are wrapping a model that is already a Model, we just return it.
14 | """
15 | if shap.utils.safe_isinstance(model, "adaptivetesting.Model") or shap.utils.safe_isinstance(model, "shap.models.Model"):
16 | return model
17 | else:
18 | return super().__new__(cls)
19 |
20 | def __init__(self, model, output_names=None, **kwargs):
21 | """ Build a new model by wrapping the given model object.
22 |
23 | Parameters
24 | ----------
25 | model : object
26 | The model to wrap. This can be a plain python function that accepts a list of strings and returns either
27 | a vector of probabilities or another string. It can also be a transformers pipeline object (we try to wrap
28 | common model types transparently).
29 |
30 | output_names : list of str, optional
31 | The names of the outputs of the model. If not given, we try to infer them from the model.
32 | """
33 |
34 | # finish early if we are wrapping an object that is already a Model
35 | if shap.utils.safe_isinstance(model, "adaptivetesting.Model") or shap.utils.safe_isinstance(model, "shap.models.Model"):
36 | if output_names is not None:
37 | self.output_names = output_names
38 | assert len(kwargs) == 0
39 | return
40 |
41 | # get outputs names from the model if it has them and we don't
42 | if output_names is None and hasattr(model, "output_names"):
43 | output_names = model.output_names
44 |
45 | # If we are in the base class we check to see if we should rebuild the model as a specialized subclass
46 | if self.__class__ is Model:
47 |
48 | # wrap transformer pipeline objects for convenience
49 | if shap.utils.safe_isinstance(model, "transformers.pipelines.text_classification.TextClassificationPipeline"):
50 | self.__class__ = shap.models.TransformersPipeline
51 | shap.models.TransformersPipeline.__init__(self, model, **kwargs)
52 | if output_names is not None: # Override output names if user supplied
53 | self.output_names = output_names
54 |
55 | elif shap.utils.safe_isinstance(model, "transformers.pipelines.text_generation.TextGenerationPipeline"):
56 | self.__class__ = TransformersTextGenerationPipeline
57 | TransformersTextGenerationPipeline.__init__(self, model, **kwargs)
58 |
59 | else:
60 | self.inner_model = model
61 | self.output_names = output_names
62 |
63 | def __call__(self, *args, **kwargs):
64 | return np.array(self.inner_model(*args, **kwargs))
65 |
66 |
67 | class TransformersTextGenerationPipeline(Model):
68 | """ This wraps the transformer text generation pipeline object to match the Model API.
69 |
70 | TODO: move this to SHAP.
71 | """
72 | def __init__(self, pipeline):
73 | self._inner_model = pipeline
74 | self.output_names = None
75 |
76 | def __call__(self, strings):
77 | inner_out = self._inner_model(strings)
78 | out = []
79 | for s, data in zip(strings, inner_out):
80 | out.append(data[0]["generated_text"][len(s):]) # remove the input text from the output
81 | return out
82 |
--------------------------------------------------------------------------------
/client/src/jupyter-comm.js:
--------------------------------------------------------------------------------
1 | import JSON5 from 'json5';
2 | import autoBind from 'auto-bind';
3 | import { defer, debounce } from 'lodash';
4 |
5 | export default class JupyterComm {
6 | constructor(interfaceId, onopen) {
7 | autoBind(this);
8 | this.interfaceId = interfaceId;
9 | this.callbackMap = {};
10 | this.data = {};
11 | this.pendingData = {};
12 | this.jcomm = new InnerJupyterComm('adatest_interface_target_'+this.interfaceId, this.updateData);
13 |
14 | this.debouncedSendPendingData500 = debounce(this.sendPendingData, 500);
15 | this.debouncedSendPendingData1000 = debounce(this.sendPendingData, 1000);
16 | if (onopen) {
17 | defer(onopen);
18 | }
19 | }
20 |
21 | send(keys, data) {
22 | this.addPendingData(keys, data);
23 | this.sendPendingData();
24 | }
25 |
26 | sendEvent(commEvent) {
27 | for (const k of Object.keys(commEvent)) {
28 | this.addPendingData(k, commEvent[k]);
29 | }
30 | this.sendPendingData();
31 | }
32 |
33 | debouncedSendEvent500(commEvent) {
34 | for (const k of Object.keys(commEvent)) {
35 | this.addPendingData(k, commEvent[k]);
36 | }
37 | this.debouncedSendPendingData500();
38 | }
39 |
40 | debouncedSend500(keys, data) {
41 | this.addPendingData(keys, data);
42 | this.debouncedSendPendingData500();
43 | }
44 |
45 | debouncedSend1000(keys, data) {
46 | this.addPendingData(keys, data);
47 | this.debouncedSendPendingData1000();
48 | }
49 |
50 | addPendingData(keys, data) {
51 |
52 | // console.log("addPendingData", keys, data);
53 | if (!Array.isArray(keys)) keys = [keys];
54 | for (const i in keys) this.pendingData[keys[i]] = data;
55 | }
56 |
57 | updateData(data) {
58 | data = JSON5.parse(data["data"]) // data from Jupyter is wrapped so we get to do our own JSON encoding
59 | console.log("updateData", data)
60 |
61 | // save the data locally
62 | for (const k in data) {
63 | this.data[k] = data[k];
64 | }
65 |
66 | // call all the registered callbacks
67 | for (const k in data) {
68 | if (k in this.callbackMap) {
69 | this.callbackMap[k](this.data[k]);
70 | }
71 | }
72 | }
73 |
74 | subscribe(key, callback) {
75 | this.callbackMap[key] = callback;
76 | defer(_ => this.callbackMap[key](this.data[key]));
77 | }
78 |
79 | sendPendingData() {
80 | console.log("sending", this.pendingData);
81 | this.jcomm.send_data(this.pendingData);
82 | this.pendingData = {};
83 | }
84 | }
85 |
86 | class InnerJupyterComm {
87 | constructor(target_name, callback, mode="open") {
88 | this._fire_callback = this._fire_callback.bind(this);
89 | this._register = this._register.bind(this)
90 |
91 | this.jcomm = undefined;
92 | this.callback = callback;
93 |
94 | // https://jupyter-notebook.readthedocs.io/en/stable/comms.html
95 | if (mode === "register") {
96 | Jupyter.notebook.kernel.comm_manager.register_target(target_name, this._register);
97 | } else {
98 | this.jcomm = Jupyter.notebook.kernel.comm_manager.new_comm(target_name);
99 | this.jcomm.on_msg(this._fire_callback);
100 | }
101 | }
102 |
103 | send_data(data) {
104 | if (this.jcomm !== undefined) {
105 | this.jcomm.send(data);
106 | } else {
107 | console.error("Jupyter comm module not yet loaded! So we can't send the message.")
108 | }
109 | }
110 |
111 | _register(jcomm, msg) {
112 | this.jcomm = jcomm;
113 | this.jcomm.on_msg(this._fire_callback);
114 | }
115 |
116 | _fire_callback(msg) {
117 | this.callback(msg.content.data)
118 | }
119 | }
120 |
121 | // const comm = JupyterComm();
122 |
123 | // // Jupyter.notebook.kernel.comm_manager.register_target('gadfly_comm_target',
124 | // // function(jcomm, msg) {
125 | // // // comm is the frontend comm instance
126 | // // // msg is the comm_open message, which can carry data
127 |
128 | // // comm.jcomm = jcomm
129 |
130 | // // // Register handlers for later messages:
131 | // // inner_comm.on_msg(function(msg) { console.log("MSGG", msg); });
132 | // // inner_comm.on_close(function(msg) { console.log("MSGdG", msg); });
133 | // // comm.send({'foo': 0});
134 | // // }
135 | // // );
136 |
137 | // export default comm;
--------------------------------------------------------------------------------
/tests/test_generators.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import pytest
3 |
4 | from transformers import pipeline
5 |
6 | from adaptivetesting import generators
7 |
8 | TRANSFORMER_PIPELINE_MODELS = ["EleutherAI/gpt-neo-125M"]
9 |
10 |
11 | class TestTransformers:
12 | @pytest.mark.parametrize("model_name", TRANSFORMER_PIPELINE_MODELS)
13 | def test_smoke(self, model_name):
14 | hf_model = pipeline("text-generation", model=model_name)
15 | target = generators.Transformers(hf_model.model, hf_model.tokenizer)
16 |
17 | prompts = [
18 | ("id A", "", "Great hotel"),
19 | ("id B", "", "Bathroom too small"),
20 | ]
21 |
22 | desired_result_count = 2
23 |
24 | results = target(
25 | prompts=prompts,
26 | topic="",
27 | mode="tests",
28 | topic_description="",
29 | num_samples=desired_result_count,
30 | )
31 | assert results is not None
32 | assert len(results) == desired_result_count
33 | for item in results:
34 | assert isinstance(item, str)
35 |
36 | @pytest.mark.parametrize("model_name", TRANSFORMER_PIPELINE_MODELS)
37 | def test_with_topics(self, model_name):
38 | hf_model = pipeline("text-generation", model=model_name)
39 | target = generators.Transformers(hf_model.model, hf_model.tokenizer)
40 |
41 | prompts = [
42 | ("id A", "some string", "Great hotel"),
43 | ("id B", "some_string", "Bathroom too small"),
44 | ]
45 |
46 | desired_result_count = 2
47 |
48 | results = target(
49 | prompts=prompts,
50 | topic="",
51 | mode="tests",
52 | topic_description="",
53 | num_samples=desired_result_count,
54 | )
55 | assert results is not None
56 | assert len(results) == desired_result_count
57 | for item in results:
58 | assert isinstance(item, str)
59 |
60 |
61 | GENERATOR_PIPELINE_MODELS = [
62 | "facebook/opt-125m",
63 | "facebook/opt-350m",
64 | "EleutherAI/gpt-neo-125M",
65 | "gpt2",
66 | "bigscience/bloom-560m",
67 | ]
68 |
69 |
70 | class TestPipelines:
71 | @pytest.mark.parametrize(
72 | "model_name",
73 | GENERATOR_PIPELINE_MODELS,
74 | )
75 | def test_smoke(self, model_name):
76 | hf_pipeline = pipeline("text-generation", model=model_name)
77 | target = generators.Pipelines(hf_pipeline)
78 |
79 | prompts = [
80 | ("id A", "", "Great hotel"),
81 | ("id B", "", "Bathroom too small"),
82 | ]
83 |
84 | desired_result_count = 2
85 |
86 | results = target(
87 | prompts=prompts,
88 | topic="some topic",
89 | mode="tests",
90 | topic_description="",
91 | num_samples=desired_result_count,
92 | )
93 | assert results is not None
94 | assert len(results) == desired_result_count
95 | for item in results:
96 | assert isinstance(item, str)
97 |
98 | @pytest.mark.parametrize(
99 | "model_name",
100 | GENERATOR_PIPELINE_MODELS,
101 | )
102 | def test_with_topics(self, model_name):
103 | hf_pipeline = pipeline("text-generation", model=model_name)
104 | target = generators.Pipelines(hf_pipeline)
105 |
106 | prompts = [
107 | ("id A", "some_string", "Great hotel"),
108 | ("id B", "some_string", "Bathroom too small"),
109 | ]
110 |
111 | desired_result_count = 2
112 |
113 | results = target(
114 | prompts=prompts,
115 | topic="some topic",
116 | mode="tests",
117 | topic_description="",
118 | num_samples=desired_result_count,
119 | )
120 | assert results is not None
121 | assert len(results) == desired_result_count
122 | for item in results:
123 | assert isinstance(item, str)
124 |
125 |
126 | class TestOpenAI:
127 | def test_smoke(self, mocker):
128 | OPENAI_API_KEY = "Not for you, CredScan"
129 |
130 | openai_completion = mocker.patch("openai.Completion", autospec=True)
131 | patched_response = {"choices": [{"text": "Ret 1"}, {"text": "Ret 2"}]}
132 | openai_completion.create.return_value = patched_response
133 |
134 | target = generators.OpenAI("curie", api_key=OPENAI_API_KEY)
135 |
136 | prompts = [
137 | ("id A", "", "Great hotel"),
138 | ("id B", "", "Bathroom too small"),
139 | ]
140 |
141 | desired_result_count = 2
142 |
143 | results = target(
144 | prompts=prompts,
145 | topic="",
146 | topic_description="",
147 | mode="tests",
148 | num_samples=desired_result_count,
149 | scorer=None,
150 | )
151 | assert results is not None
152 | assert len(results) == desired_result_count
153 | assert "Ret 1" in results
154 | assert "Ret 2" in results
155 | openai_completion.create.assert_called_with(
156 | model="curie",
157 | prompt=['"Great hotel"\n"Bathroom too small"\n"'],
158 | user="adaptivetesting",
159 | max_tokens=100,
160 | temperature=1.0,
161 | top_p=0.95,
162 | n=2,
163 | stop='"',
164 | )
165 |
--------------------------------------------------------------------------------
/client/src/adatest.css:
--------------------------------------------------------------------------------
1 | .adatest-button-container {
2 | display: block;
3 | width: 100%;
4 | padding: 3px;
5 | padding-top: 1px;
6 | padding-bottom: 1px;
7 | font-size: 13px;
8 | }
9 | .adatest-button-children {
10 | display: flex;
11 | vertical-align: top;
12 | width: 100%;
13 | }
14 |
15 | .adatest-button-box {
16 | padding: 2px;
17 | padding-left: 5px;
18 | padding-right: 5px;
19 | display: flex;
20 | cursor: pointer;
21 | text-align: center;
22 | width: 100%;
23 | margin-bottom: 3px;
24 | border: 1px solid #dddddd;
25 | }
26 |
27 | .adatest-explore-toggle {
28 | float: right;
29 | color: #666666;
30 | margin-top: 4px;
31 | margin-left: 4px;
32 | width: 30px;
33 | }
34 |
35 | .adatest-add-topic {
36 | color: #bbbbbb;
37 | margin-top: 4px;
38 | margin-left: 7px;
39 | margin-right: 7px;
40 | cursor: pointer;
41 | }
42 | .adatest-add-topic-wrapper {
43 | margin-top: 4px;
44 | margin-left: auto;
45 | margin-right: auto;
46 | }
47 |
48 | .adatest-button-label {
49 | flex: 1;
50 | display: inline-block;
51 | text-align: center;
52 | user-select: none;
53 | }
54 |
55 | .adatest-topic-label {
56 | width: 300px;
57 | border: 0px;
58 | border-bottom: 1px solid #999999;
59 | text-align: left;
60 | margin-top: 10px;
61 | margin-bottom: 10px;
62 | outline: 0px solid transparent;
63 | background: transparent;
64 | }
65 | .adatest-topic-label:focus {
66 | outline: 0px solid transparent;
67 | }
68 |
69 | .adatest-children-frame {
70 | padding: 0px;
71 | margin-top: 10px;
72 | /* padding-bottom: 10px; */
73 | border-radius: 7px 7px 7px 7px;
74 | border: 1px solid rgb(216, 222, 228);
75 | }
76 |
77 | .adatest-row-child {
78 | display: flex;
79 | padding: 0px;
80 | padding-left: 5px;
81 | padding-right: 0px;
82 | min-height: 30px;
83 | text-align: left;
84 | border-radius: 0px;/*10px 10px 10px 10px;*/
85 | outline: none;
86 | padding-top: 5px;
87 | padding-bottom: 5px;
88 | }
89 | .adatest-row-hidden {
90 | text-decoration: line-through;
91 | }
92 |
93 | .adatest-row-score-plot-box {
94 | float: right;
95 | height: 30px;
96 | flex: 0 0 150px;
97 | margin-right: 10px;
98 | padding-top: 0px;
99 | align-self: center;
100 | }
101 |
102 | .adatest-row-add-button {
103 | margin-top: 8px;
104 | flex: 0 0 20px;
105 | display: inline-block;
106 | margin-left: 10px;
107 | }
108 |
109 | .adatest-top-add-button {
110 | margin-top: 8px;
111 | flex: 0 0 20px;
112 | display: inline-block;
113 | margin-left: 10px;
114 | }
115 |
116 | [contenteditable] {
117 | -webkit-user-select: text;
118 | user-select: text;
119 | }
120 |
121 | /* [contenteditable]:focus {
122 | outline-width: 0px;
123 | } */
124 |
125 | .adatest-plain-select {
126 | margin-left: 4px;
127 | appearance: none;
128 | border: none;
129 | background: none;
130 | -webkit-appearance: none;
131 | font-size: 13px;
132 | font-family: Helvetica Neue, Helvetica, Arial, sans-serif;
133 | margin: 0px;
134 | }
135 | .adatest-plain-select:focus {
136 | outline: 0px solid transparent;
137 | }
138 |
139 | .adatest-hover-opacity {
140 | opacity: 0.1;
141 | }
142 | .adatest-hover-opacity:hover {
143 | opacity: 0.6;
144 | }
145 |
146 | .adatest-scroll-wrap::-webkit-scrollbar {
147 | display: none;
148 | }
149 |
150 | .adatest-row-hide-button {
151 | opacity: 0;
152 | margin-top: 8px;
153 | /* line-height: 30px; */
154 | flex: 0 0 15px;
155 | display: inline-block;
156 | margin-left: 10px;
157 | }
158 | .adatest-row-hide-button:hover {
159 | opacity: 0.6
160 | }
161 | .adatest-row-hide-hovering {
162 | opacity: 0.1;
163 | }
164 | .adatest-row-hide-hidden {
165 | opacity: 0.6;
166 | }
167 |
168 | .adatest-row-score-text-box {
169 | float: right;
170 | text-align: right;
171 | flex: 0 0 50px;
172 | line-height: 30px;
173 | height: 30px;
174 | color: #999999;
175 | padding-right: 5px;
176 | font-size: 12px;
177 | overflow: hidden;
178 | box-sizing: border-box;
179 | }
180 |
181 | .adatest-main-table {
182 | margin-left: "auto";
183 | margin-right: "auto";
184 | background: #ffffff;
185 | width: 100%;
186 | }
187 |
188 | .adatest-main-table tbody tr:nth-child(odd) {
189 | background: #ffffff;
190 | }
191 |
192 | .adatest-main-table tbody tr:hover {
193 | background: #ffffff;
194 | }
195 |
196 |
197 | .adatest-output-text:focus {
198 | outline: 0px solid transparent;
199 | }
200 |
201 | .adatest-row-selected {
202 | background: #00000008;
203 | /* border-left: 3px solid #00000099; */
204 | /* padding-left: 2px; */
205 | }
206 |
207 | .adatest-select-width-calculation {
208 | position: absolute;
209 | visibility: hidden;
210 | height: auto;
211 | width: auto;
212 | white-space: nowrap;
213 | font-size: 13px;
214 | font-family: Helvetica Neue, Helvetica, Arial, sans-serif;
215 | white-space: pre;
216 | }
217 |
218 | .adatest-row {
219 | display: flex;
220 | margin-top: 3px;
221 | }
222 | .adatest-row-input {
223 | flex: 1;
224 | display: flex;
225 | justify-content: right;
226 | align-items: center;
227 | }
228 | .adatest-row-editing {
229 | border-bottom: 0px dashed rgb(0, 130, 248);
230 | }
231 | .adatest-row-editing:focus {
232 | outline: 0px solid transparent;
233 | }
234 | .adatest-row-drop-highlighted {
235 | background: #dddddd;
236 | }
237 | .adatest-row-hover-highlighted {
238 | background: #f0f0f0;
239 | }
240 | .adatest-row-dragging {
241 | background: #dddddd;
242 | }
243 |
244 |
245 | .adatest-crum-selected {
246 | background: #dddddd;
247 | }
248 |
249 | .adatest-editable:focus {
250 | outline: 0px solid transparent;
251 | }
252 |
253 | .adatest-suggestions-box {
254 | position: relative;
255 | border-radius: 7px 7px 7px 7px;
256 | background: rgb(246, 248, 250);
257 | text-align: center;
258 | padding: 0px;
259 | padding-top: 10px;
260 | margin-top: 10px;
261 | padding-bottom: 0px;
262 | border: 1px solid rgb(216, 222, 228);
263 | }
264 | .adatest-suggestions-box-after {
265 | content: 'werwerwe';
266 | width: 100%;
267 | height: 22px;
268 | position: absolute;
269 | bottom: 0px;
270 | background: linear-gradient(rgba(246, 248, 250, 0) 0px, rgb(246, 248, 250, 1) 18px);
271 | pointer-events: none;
272 | border-radius: 0px 0px 7px 7px;
273 | }
274 | .adatest-drop-highlighted {
275 | background: #dddddd;
276 | }
277 |
278 | .adatest-hover-gray:hover {
279 | background: #dddddd;
280 | }
--------------------------------------------------------------------------------
/adaptivetesting/embedders.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import adaptivetesting
3 | from sklearn.preprocessing import normalize
4 | import appdirs
5 | import diskcache
6 | _embedding_memory_cache = {}
7 | _embedding_file_cache = diskcache.Cache(appdirs.user_cache_dir("adaptivetesting") + "/embeddings.diskcache")
8 |
9 | def _embed(strings, normalize=True):
10 |
11 | # find which strings are not in the cache
12 | new_text_strings = []
13 | new_image_urls = []
14 | text_prefix = _text_embedding_model().name # TODO: need to figure out how to do the same for image embedding, but only when needed
15 | for s in strings:
16 | if s.startswith("__IMAGE="):
17 | prefixed_s = s
18 | else:
19 | prefixed_s = text_prefix + s
20 | if prefixed_s not in _embedding_memory_cache:
21 | if prefixed_s not in _embedding_file_cache:
22 | if s.startswith("__IMAGE="):
23 | new_image_urls.append(s)
24 | else:
25 | new_text_strings.append(s)
26 | _embedding_memory_cache[prefixed_s] = None # so we don't embed the same string twice
27 | else:
28 | _embedding_memory_cache[prefixed_s] = _embedding_file_cache[prefixed_s]
29 |
30 | # embed the new text strings
31 | if len(new_text_strings) > 0:
32 | new_embeds = _text_embedding_model()(new_text_strings)
33 | for i,s in enumerate(new_text_strings):
34 | prefixed_s = text_prefix + s
35 | if normalize:
36 | _embedding_memory_cache[prefixed_s] = new_embeds[i] / np.linalg.norm(new_embeds[i])
37 | else:
38 | _embedding_memory_cache[prefixed_s] = new_embeds[i]
39 | _embedding_file_cache[prefixed_s] = _embedding_memory_cache[prefixed_s]
40 |
41 | # embed the new image urls
42 | if len(new_image_urls) > 0:
43 | new_embeds = _image_embedding_model()([url[8:] for url in new_image_urls])
44 | for i,s in enumerate(new_image_urls):
45 | if normalize:
46 | _embedding_memory_cache[s] = new_embeds[i] / np.linalg.norm(new_embeds[i])
47 | else:
48 | _embedding_memory_cache[s] = new_embeds[i]
49 | _embedding_file_cache[s] = _embedding_memory_cache[s]
50 |
51 | return [_embedding_memory_cache[s if s.startswith("__IMAGE=") else text_prefix + s] for s in strings]
52 |
53 | def _text_embedding_model():
54 | """ Get the text embedding model.
55 |
56 | Much of this code block is from the sentence_transformers documentation.
57 | """
58 | if adaptivetesting.text_embedding_model is None:
59 |
60 | # # get the modules we need to compute embeddings
61 | # import torch
62 | # import transformers
63 |
64 | # # Mean Pooling - Take attention mask into account for correct averaging
65 | # def mean_pooling(model_output, attention_mask):
66 | # token_embeddings = model_output[0] # First element of model_output contains all token embeddings
67 | # input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
68 | # return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
69 |
70 | # # Load model from HuggingFace Hub
71 | # tokenizer = transformers.AutoTokenizer.from_pretrained('sentence-transformers/stsb-roberta-base-v2')
72 | # model = transformers.AutoModel.from_pretrained('sentence-transformers/stsb-roberta-base-v2')
73 |
74 | # # Tokenize sentences
75 | # def embed_model(sentences):
76 | # encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
77 |
78 | # # Compute token embeddings
79 | # with torch.no_grad():
80 | # model_output = model(**encoded_input)
81 |
82 | # # Perform pooling. In this case, max pooling.
83 | # return mean_pooling(model_output, encoded_input['attention_mask']).cpu().numpy()
84 |
85 | adaptivetesting.text_embedding_model = TransformersTextEmbedding()
86 |
87 | return adaptivetesting.text_embedding_model
88 |
89 | def _image_embedding_model():
90 | if adaptivetesting.image_embedding_model is None:
91 | import clip # pylint: disable=import-outside-toplevel
92 | import torch
93 |
94 | model, preprocess = clip.load("ViT-L/14", device="cpu", jit=True)
95 |
96 | def embed_model(urls):
97 | with torch.no_grad():
98 | out = []
99 | for url in urls:
100 | image = adaptivetesting.utils.get_image(url)
101 | image_emb = model.encode_image(preprocess(image).unsqueeze(0).to("cpu"))
102 | image_emb /= image_emb.norm(dim=-1, keepdim=True)
103 | image_emb = image_emb.cpu().detach().numpy().astype("float32")[0]
104 | out.append(image_emb)
105 | return np.vstack(out)
106 |
107 | adaptivetesting.image_embedding_model = embed_model
108 |
109 | return adaptivetesting.image_embedding_model
110 |
111 | def cos_sim(a, b):
112 | """ Cosine distance between two vectors.
113 | """
114 | return normalize(a, axis=1) @ normalize(b, axis=1).T
115 |
116 | class TransformersTextEmbedding():
117 | def __init__(self, model="sentence-transformers/stsb-roberta-base-v2"):
118 | import transformers
119 | self.tokenizer = transformers.AutoTokenizer.from_pretrained(model)
120 | self.model = transformers.AutoModel.from_pretrained(model)
121 | self.model_name = model
122 | self.name = "adaptivetesting.embedders.TransformersTextEmbedding(" + self.model_name + "):"
123 |
124 | def __call__(self, strings):
125 | import torch
126 |
127 | encoded_input = self.tokenizer(strings, padding=True, truncation=True, return_tensors='pt')
128 |
129 | # Compute token embeddings
130 | with torch.no_grad():
131 | model_output = self.model(**encoded_input)
132 |
133 | # Perform mean pooling
134 | token_embeddings = model_output[0] # First element of model_output contains all token embeddings
135 | input_mask_expanded = encoded_input['attention_mask'].unsqueeze(-1).expand(token_embeddings.size()).float()
136 | embeds = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
137 | return embeds.cpu().numpy()
138 |
139 | class OpenAITextEmbedding():
140 | def __init__(self, model="text-similarity-babbage-001", api_key=None, replace_newlines=True):
141 | import openai
142 | self.model = model
143 | if api_key is not None:
144 | openai.api_key = api_key
145 | self.replace_newlines = replace_newlines
146 | self.model_name = model
147 | self.name = "adaptivetesting.embedders.OpenAITextEmbedding(" + self.model_name + "):"
148 |
149 | def __call__(self, strings):
150 | import openai
151 |
152 | if len(strings) == 0:
153 | return np.array([])
154 |
155 | # clean the strings for OpenAI
156 | cleaned_strings = []
157 | for s in strings:
158 | if s == "":
159 | s = " " # because OpenAI doesn't like empty strings
160 | elif self.replace_newlines:
161 | s = s.replace("\n", " ") # OpenAI recommends this for things that are not code
162 | cleaned_strings.append(s)
163 |
164 | # call the OpenAI API to complete the prompts
165 | response = openai.Embedding.create(
166 | input=cleaned_strings, model=self.model, user="adatest"
167 | )
168 |
169 | return np.vstack([e["embedding"] for e in response["data"]])
170 |
--------------------------------------------------------------------------------
/client/src/content-editable.jsx:
--------------------------------------------------------------------------------
1 | import React from 'react';
2 | import autoBind from 'auto-bind';
3 | import sanitizeHtml from 'sanitize-html';
4 | import { defer } from 'lodash';
5 |
6 | export default class ContentEditable extends React.Component {
7 | static defaultProps = {
8 | editable: true,
9 | defaultText: "",
10 | finishOnReturn: false
11 | };
12 |
13 | constructor(props) {
14 | super(props);
15 | autoBind(this);
16 | this.lastText = null;
17 |
18 | this.divRef = {};
19 | window["cedit_"+this.props.id] = this;
20 | }
21 |
22 | render() {
23 | //console.log("this.props.text", this.props.text)
24 | const emptyContent = this.props.text === undefined || this.props.text.length === 0;
25 | this.lastEditable = this.props.editable;
26 | if (this.lastText === null) this.lastText = this.props.text;
27 | return this.divRef = el}
29 | id={this.props.id}
30 | style={{opacity: emptyContent ? "0.3" : "1", display: "inline", overflowWrap: "anywhere", whiteSpace: "pre-wrap"}}
31 | onFocus={this.onFocus}
32 | onInput={this.handleInput}
33 | onKeyPress={this.handleKeyPress}
34 | onKeyDown={this.handleKeyDown}
35 | onBlur={this.onBlur}
36 | onDoubleClick={this.handleDoubleClick}
37 | onDragStart={this.stopDrag}
38 | onClick={this.onClick}
39 | contentEditable={this.props.editable}
40 | className="adatest-editable"
41 | dangerouslySetInnerHTML={{__html: sanitizeHtml(emptyContent ? this.props.defaultText : this.props.text)}}
42 | tabIndex="0"
43 | >
44 | }
45 |
46 | stopDrag(e) {
47 | console.log("stopDrag")
48 | e.preventDefault();
49 | return false;
50 | }
51 |
52 | handleDoubleClick(e) {
53 | const range = getMouseEventCaretRange(e);
54 | console.log("handleDoubleClick", range, e)
55 | }
56 |
57 | focus() {
58 |
59 | // we blur without triggering an action so that we can refocus
60 | // this is important to get the cursor to come back sometimes
61 | this.skipBlurAction = true;
62 | this.divRef.blur();
63 | this.skipBlurAction = false;
64 |
65 | this.divRef.focus();
66 | }
67 |
68 | blur() {
69 | this.divRef.blur();
70 | }
71 |
72 | onFocus(e) {
73 | console.log("onFocus in ContentEditable", this.props.text);
74 |
75 | // if (!this.props.editing) return;
76 |
77 | if (this.props.text !== this.props.defaultText && this.divRef.textContent === this.props.defaultText) {
78 | e.preventDefault();
79 | e.stopPropagation();
80 | this.divRef.textContent = "";
81 | if (this.props.onClick) this.props.onClick(e); // why we need this is crazy to me, seems like setting inner text kills the click event
82 | // defer(() => this.focus());
83 | console.log("clear!!", this.props.editable)
84 | defer(() => this.focus());
85 | }
86 | }
87 |
88 | onClick(e) {
89 | // console.log("onClick in ContentEditable", this.props.onClick)
90 | if (this.props.onClick) {
91 | e.preventDefault();
92 | e.stopPropagation();
93 | this.props.onClick(e);
94 | }
95 | e.stopPropagation();
96 | }
97 |
98 | getValue() {
99 | const text = this.divRef.textContent;
100 | if (text === this.props.defaultText) return "";
101 | else return text;
102 | }
103 |
104 | shouldComponentUpdate(nextProps) {
105 | return nextProps.text !== this.divRef.textContent && (nextProps.text != "" || this.divRef.textContent != this.props.defaultText) || nextProps.editable != this.lastEditable;
106 | }
107 |
108 | componentDidUpdate() {
109 | this.componentDidUpdateOrMount(false);
110 | }
111 |
112 | componentDidMount() {
113 | this.componentDidUpdateOrMount(true);
114 | }
115 |
116 | componentDidUpdateOrMount(mount) {
117 | // console.log("ContentEditable componentDidUpdateOrMount", mount, this.props.text, this.props.editable);
118 | if (this.props.text !== this.divRef.textContent) {
119 | if (this.props.text !== undefined && this.props.text !== null && (this.props.text.length > 0 || this.divRef.textContent !== this.props.defaultText)) {
120 | this.divRef.textContent = this.props.text;
121 | } else {
122 | if (mount) this.divRef.textContent = this.props.defaultText;
123 | }
124 | }
125 | if (this.props.text && (this.props.text.startsWith("New topic") || this.props.text === "New test") && this.props.editable) { // hacky but works for now
126 | // console.log("HACK!", this.props.text)
127 | this.divRef.focus();
128 | selectElement(this.divRef);
129 | // document.execCommand('selectAll', false, null);
130 | }
131 | }
132 |
133 | handleInput(e, finishing) {
134 | console.log("handleInput", finishing, this.divRef.textContent)
135 | const text = this.divRef.textContent;
136 | if (this.props.onInput && text !== this.lastText) {
137 | this.props.onInput(text);
138 | this.lastText = text;
139 | }
140 |
141 | if (finishing && this.props.onFinish) {
142 | this.props.onFinish(text);
143 | }
144 |
145 | if (text === this.props.defaultText) this.divRef.style.opacity = 0.3;
146 | else this.divRef.style.opacity = 1.0;
147 | }
148 |
149 | onBlur(e) {
150 | console.log("onBlur in ContentEditable", this.divRef.textContent, this.skipBlurAction)
151 | if (this.skipBlurAction) return;
152 | // if (this.divRef.textContent.length === this.props.defaultText) {
153 | // this.divRef.textContent = "";
154 | // }
155 | this.handleInput(e, true);
156 | if (this.divRef.textContent.length === 0) {
157 | this.divRef.textContent = this.props.defaultText;
158 | this.divRef.style.opacity = 0.3;
159 | }
160 | }
161 |
162 | handleKeyPress(e) {
163 |
164 | console.log("handleKeyPress", e.charCode, this.props.finishOnReturn)
165 | e.stopPropagation();
166 | if (e.charCode == 13 && this.props.finishOnReturn) {
167 | e.preventDefault();
168 |
169 | this.handleInput(e, true);
170 | }
171 | }
172 |
173 | handleKeyDown(e) {
174 | console.log("handleKeyDown", e.charCode, this.props.finishOnReturn)
175 | // only let the enter/return key go through
176 | if (e.charCode != 13 || !this.props.finishOnReturn) e.stopPropagation();
177 | }
178 | }
179 |
180 | function selectElement(element){
181 | var doc = document;
182 | console.log(this, element);
183 | if (doc.body.createTextRange) {
184 | var range = document.body.createTextRange();
185 | range.moveToElementText(element);
186 | range.select();
187 | } else if (window.getSelection) {
188 | var selection = window.getSelection();
189 | var range = document.createRange();
190 | range.selectNodeContents(element);
191 | selection.removeAllRanges();
192 | selection.addRange(range);
193 | }
194 | }
195 |
196 | function setCaret(el, pos) {
197 | var range = document.createRange();
198 | var sel = window.getSelection();
199 |
200 | range.setStart(el, pos)
201 | range.collapse(true)
202 |
203 | sel.removeAllRanges()
204 | sel.addRange(range)
205 | }
206 | document.setCaret = setCaret;
207 |
208 | function findParentWithClass(el, className) {
209 | const orig_el = el;
210 | while (el && !el.className.includes(className)) {
211 | el = el.parentElement;
212 | }
213 | return el ? el : orig_el;
214 | }
215 |
216 | function getMouseEventCaretRange(evt) {
217 | var range, x = evt.clientX, y = evt.clientY;
218 |
219 | // Try the simple IE way first
220 | if (document.body.createTextRange) {
221 | range = document.body.createTextRange();
222 | range.moveToPoint(x, y);
223 | }
224 |
225 | else if (typeof document.createRange != "undefined") {
226 | // Try Mozilla's rangeOffset and rangeParent properties,
227 | // which are exactly what we want
228 | if (typeof evt.rangeParent != "undefined") {
229 | range = document.createRange();
230 | range.setStart(evt.rangeParent, evt.rangeOffset);
231 | range.collapse(true);
232 | }
233 |
234 | // Try the standards-based way next
235 | else if (document.caretPositionFromPoint) {
236 | var pos = document.caretPositionFromPoint(x, y);
237 | range = document.createRange();
238 | range.setStart(pos.offsetNode, pos.offset);
239 | range.collapse(true);
240 | }
241 |
242 | // Next, the WebKit way
243 | else if (document.caretRangeFromPoint) {
244 | range = document.caretRangeFromPoint(x, y);
245 | }
246 | }
247 |
248 | return range;
249 | }
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # adaptive-testing
2 | adaptive-testing uses language models against themselves to build suites of unit tests. It is an interative (and fun!) process between a user and a language model that results in a tree of unit tests specifically adapted to the model you are testing. Fixing any failed tests with fine-tuning then leads to an iterative debugging process similar to traditional software development. See paper for details.
3 |
4 |
5 | 
6 | Note, adaptive-testing is currently a beta release so please share any issues you encounter.
7 |
8 |
9 | ## Install
10 |
11 | ```
12 | pip install adatest
13 | ```
14 |
15 | ## Sentiment analysis example
16 |
17 | adaptive-testing can test any NLP model you can call with a python function, here we will test a basic open source sentiment analysis model. Since adaptive-testing relies on a generative language model to help you create tests, you need to specify what generative model it will use, here we use GPT-3 from OpenAI or GPT-Neo locally. Tests are organized into a test tree that follows the DataFrame API and is organized like a file system, here we create a new empty tree, but you can also start with a previous test tree that targets a similar task. The core adaptive-testing loop starts when you call the `.adapt()` method on a test tree passing the model(s) you want to test and the backend generator you want to use. The code for all this is below:
18 |
19 | ```python
20 | import transformers
21 | import adaptivetesting
22 |
23 | # create a HuggingFace sentiment analysis model
24 | classifier = transformers.pipeline("sentiment-analysis", return_all_scores=True)
25 |
26 | # specify the backend generator used to help you write tests
27 | generator = adaptivetesting.generators.OpenAI('curie', api_key=OPENAI_API_KEY)
28 |
29 | # ...or you can use an open source generator
30 | #neo = transformers.pipeline('text-generation', model="EleutherAI/gpt-neo-125M")
31 | #generator = adaptivetesting.generators.Transformers(neo.model, neo.tokenizer)
32 |
33 | # create a new test tree
34 | tests = adaptivetesting.TestTree("hotel_reviews.csv")
35 |
36 | # adapt the tests to our model to launch a notebook-based testing interface
37 | # (wrap with adaptivetesting.serve to launch a standalone server)
38 | tests.adapt(classifier, generator, auto_save=True)
39 | ```
40 |
41 |
42 |
43 |
44 |
45 | Once we have launched a test tree browser, we can use the interface to create new topics and tests. Here we create the topic "/Clear positives/Location" to test how well this model classifies clearly positive statements about a hotel's location. We then add a few starting examples of what we want to see in this topic (clearly positive statements about hotel location):
46 |
47 |
48 |
49 |
50 |
51 | Each test consists of a model input, a model output, a pass/fail label, and a score for the current target model. The input text should fall within the scope of the current topic, which here means it is a clearly positive statement about hotel locations. The output text is what the target model we are testing generated (or it can be manually specified, in which case it turns light grey to show it does not reflect the current model behavior). The label is a pass/fail indicator that denotes if the model output is correct with respect to the aspect being tested in the current topic, in our case the model was correct for all the inputs we entered. The model score represents if the testing model passes or fails and how confident the model is when producing the current output.
52 |
53 | Note that in the above figure all the label indicators are hollow, this means that we have not yet labeled these examples, and adaptive-testing is just guessing that they are correct. They are all correct so can click the checkmarks to confirm and label all these examples. By confirming we teach adaptive-testing more about what we want this topic to test, so it becomes better at predicting future labels, and hence automating the testing process. Once we label these examples we can then click "Suggestions" and adaptive-testing will attempt to write new in-topic examples for us, labeling them and sorting then by score so we can see the most likely failures at the top of the list.
54 |
55 |
56 |
57 |
58 |
59 | Starting at the top of the list we can confirm or change the label for each suggestion and so add them to the current topic (like marking "very convientent for walking" -> "POSITIVE" as correct model behavior), while we reject (or just ignore) examples that don't belong in the current topic (like "Second visit" which is not about a hotel's location). After we have added some new suggestions to the current topic (we normally only bother to look at the top few suggestions) we can repeat the process by clicking "Suggestions" again. Repeating the process a few times allows adaptive-testing to learn from our feedback and hill-climb towards generating better and better suggestions (ones that are more likely to be on-topic and reveal model failures). Doing this for a few rounds reveals lots of bugs in the model related to positive hotel location statements.
60 |
61 |
62 |
63 |
64 |
65 | Once we have testing the location aspect enough we can repeat the process to test a new aspect of model behavior, for example comments about hotel swimming pools or gyms. The space of possible concepts for hotel reviews is large, so to help explore it adaptive-testing can suggest new topics once we have a few examples:
66 |
67 |
68 |
69 |
70 |
71 | After we accept some of these new topic suggestions we can open them and fill them out without ever even writing seed examples. adaptive-testing can suggest new tests inside an empty topic by just using examples other topics and the current topic's name.
72 |
73 |
74 |
75 |
76 |
77 | This is just a short example of how to find bugs in a sentiment analysis model, but the same process can be applied to any NLP model (even ones that generate free form text). Test trees can be adapted to new models and shared with others collaboratively (they are just CSV files). Once you have enough bugs you can fine tune your model against a mixture of your test tree and the original training data to fix all the bugs in the test tree while retaining performance on your original training data (we will share a full demo notebook of this soon).
78 |
79 |
80 | ## Citation
81 |
82 | If you find adaptive-testing or test trees useful in your work feel free to cite our ACL paper:
83 | [Adaptive Testing and Debugging of NLP Models](https://aclanthology.org/2022.acl-long.230) (Ribeiro & Lundberg, ACL 2022)
84 |
85 | ## Contributing
86 |
87 | This project welcomes contributions and suggestions. Most contributions require you to agree to a
88 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
89 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
90 |
91 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide
92 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
93 | provided by the bot. You will only need to do this once across all repos using our CLA.
94 |
95 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
96 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
97 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
98 |
99 | ## Trademarks
100 |
101 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
102 | trademarks or logos is subject to and must follow
103 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
104 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
105 | Any use of third-party trademarks or logos are subject to those third-party's policies.
106 |
--------------------------------------------------------------------------------
/adaptivetesting/_topic_model.py:
--------------------------------------------------------------------------------
1 | import sklearn
2 | import numpy as np
3 | from sklearn import multioutput
4 | from sklearn import preprocessing
5 | from sklearn.linear_model import RidgeClassifierCV
6 | from sklearn.linear_model import LogisticRegression
7 | from sklearn.svm import LinearSVC
8 | import adaptivetesting
9 | import re
10 |
11 | class ConstantModel():
12 | def __init__(self, probability):
13 | self.probability = probability
14 | def predict_prob(self, embeddings):
15 | if not hasattr(embeddings[0], "__len__"):
16 | return self.probability
17 | else:
18 | return [self.probability] * len(embeddings)
19 |
20 | class CVModel():
21 | def __init__(self, embeddings, labels):
22 | self.inner_model = RidgeClassifierCV(class_weight={"pass": 1, "fail": 1})
23 | self.inner_model.fit(embeddings, labels)
24 |
25 | def predict_prob(self, embeddings):
26 | assert len(self.inner_model.classes_) == 2
27 | d = self.inner_model.decision_function(embeddings)
28 | probs = np.exp(d) / (np.exp(d) + np.exp(-d))
29 |
30 | return probs
31 |
32 | class OutputNearestNeighborLabelModel():
33 | def __init__(self, embeddings, labels):
34 | embeddings[:,:embeddings.shape[1]//2] = 0 # zero out the embedding for the input value so we only depend on the output
35 | self.model = sklearn.neighbors.KNeighborsClassifier(1)
36 | self.model.fit(embeddings, labels)
37 | def predict(self, embeddings):
38 | embeddings[:,:embeddings.shape[1]//2] = 0
39 | return self.model.predict(embeddings)
40 |
41 | class TopicLabelingModel:
42 | def __init__(self, topic, test_tree):
43 | self.topic = topic
44 | self.test_tree = test_tree
45 |
46 | # mask out entries that do not have a pass/fail label
47 | valid_mask = ~((test_tree["labeler"] == "imputed") | (test_tree["label"] == "topic_marker") | (test_tree["label"] == "off_topic"))
48 |
49 | # try and select samples from the current topic
50 | topic_mask = (test_tree["topic"] == topic) & valid_mask
51 |
52 | # if we didn't find enough samples then expand to include subtopics
53 | if topic_mask.sum() <= 1:
54 | topic_mask = test_tree["topic"].str.startswith(topic) & valid_mask
55 |
56 | # if we still didn't find enough samples then expand to include parent topics
57 | parts = topic.split("/")
58 | for i in range(len(parts), 0, -1):
59 | prefix = "/".join(parts[:i+1])
60 | if topic_mask.sum() <= 1:
61 | topic_mask = test_tree["topic"].str.startswith(prefix) & valid_mask
62 | else:
63 | break
64 |
65 | # get our features and labels for fitting a model
66 | strings = list(test_tree["input"][topic_mask]) + list(test_tree["output"][topic_mask])
67 | labels = list(test_tree["label"][topic_mask])
68 | unrolled_embeds = adaptivetesting.embed(strings)
69 | embeddings = np.hstack([unrolled_embeds[:len(labels)], unrolled_embeds[len(labels):]])
70 |
71 | # empty test tree
72 | if len(labels) == 0:
73 | self.model = ConstantModel(0.0)
74 |
75 | # constant label topic
76 | elif len(set(labels)) == 1:
77 | self.model = ConstantModel(0.0 if labels[0] == "pass" else 1.0)
78 |
79 | # enough samples to fit a model
80 | else:
81 |
82 | # we are in a highly overparametrized situation, so we use a linear SVC to get "max-margin" based generalization
83 | # TODO: SML: It seems to me that the SVC seems to do very well as long as there are no "errors" in the data labels. But it will
84 | # do very poorly if there are errors in the data labels since it will fit them exactly. Perhaps we can help this by
85 | # ensembling several SVCs together each trained on a different bootstrap sample? This might add the roubustness (against label mismatches)
86 | # that is lacking with hard-margin SVC fitting (it is also motivated a bit by the connections between SGD and hard-margin SVC fitting, and that
87 | # in practice SGD works on subsamples of the data so it should be less sensitive to label misspecification).
88 | # self.model = LinearSVC()
89 |
90 | # self.model = LogisticRegression(penalty='l2', random_state=0, C=1.0, solver='lbfgs', max_iter=1000)
91 |
92 | # This seemed to be reasonably well calibrated on simple tests, so we use it instead of SVC
93 | self.model = CVModel(embeddings, labels)
94 |
95 | # # add the missing predict_proba method to the base model
96 | # def predict_proba(self, X):
97 | # if len(self.classes_) == 1:
98 | # return np.ones((len(X), 1))
99 | # d = self.decision_function(X)
100 | # if len(self.classes_) == 2:
101 | # probs = np.exp(d) / (np.exp(d) + np.exp(-d))
102 | # return np.array([1 - probs, probs]).T
103 | # probs = np.exp(d).T / np.sum(np.exp(d), axis=1)
104 | # return probs.T
105 | # self.model.predict_proba = predict_proba.__get__(self.model, self.model.__class__)
106 |
107 | # self.model.fit(embeddings, labels)
108 |
109 | def __call__(self, input, output):
110 | embeddings = np.hstack(adaptivetesting.embed([input, output]))
111 | if not hasattr(embeddings[0], "__len__"):
112 | return self.model.predict_prob([embeddings])[0]
113 | return self.model.predict_prob(embeddings)
114 |
115 | class TopicMembershipModel:
116 | """ A model that predicts if a given test fits in a given topic.
117 |
118 | Note that this model only depends on the inputs not the output values for a test.
119 | """
120 | def __init__(self, topic, test_tree):
121 | self.topic = topic
122 | self.test_tree = test_tree
123 |
124 | # mask out entries that do not have a topic membership label
125 | valid_mask = ~((test_tree["labeler"] == "imputed") | (test_tree["label"] == "topic_marker"))
126 |
127 | # try and select samples from the current topic
128 | topic_mask = (test_tree["topic"] == topic) & valid_mask
129 |
130 | # if we didn't find enough samples then expand to include subtopics
131 | if topic_mask.sum() <= 1:
132 | topic_mask = test_tree["topic"].str.startswith(topic) & valid_mask
133 |
134 | # if we still didn't find enough samples then expand to include parent topics
135 | parts = topic.split("/")
136 | for i in range(len(parts), 0, -1):
137 | prefix = "/".join(parts[:i+1])
138 | if topic_mask.sum() <= 1:
139 | topic_mask = test_tree["topic"].str.startswith(prefix) & valid_mask
140 | else:
141 | break
142 |
143 | # get our features and labels for fitting a model
144 | strings = list(test_tree["input"][topic_mask])
145 | labels = [l if l == "off_topic" else "on_topic" for l in test_tree["label"][topic_mask]]
146 | embeddings = np.array(adaptivetesting.embed(strings))
147 |
148 | # empty test tree (default to on-topic)
149 | if len(labels) == 0:
150 | self.model = ConstantModel(1.0)
151 |
152 | # constant label topic
153 | elif len(set(labels)) == 1:
154 | self.model = ConstantModel(0.0 if labels[0] == "off_topic" else 1.0)
155 |
156 | # enough samples to fit a model
157 | else:
158 |
159 | # we are in a highly overparametrized situation, so we use a linear SVC to get "max-margin" based generalization
160 | self.model = CVModel()
161 | self.model.fit(embeddings, labels)
162 |
163 | def __call__(self, input):
164 | embeddings = adaptivetesting.embed([input])[0]
165 | if not hasattr(embeddings[0], "__len__"):
166 | return "on_topic" if self.model.predict_prob([embeddings])[0] > 0.5 else "off_topic"
167 | return ["on_topic" if v > 0.5 else "off_topic" for v in self.model.predict_prob(embeddings)]
168 |
169 | class ChainTopicModel:
170 | def __init__(self, model=None):
171 | if model is None:
172 | self.base_model = RidgeClassifierCV()
173 | else:
174 | self.base_model = model
175 | def fit(self, X, y):
176 | topics = y
177 | max_levels = max([len(x.split('>')) for x in topics])
178 | self.model = sklearn.multioutput.ClassifierChain(self.base_model, order=list(range(max_levels)))
179 | y = [list(map(str.strip, x.split('>'))) for x in topics]
180 | y = np.array([x + ['-'] * (max_levels - len(x)) for x in y])
181 | self.encoders = [preprocessing.LabelEncoder() for _ in range(max_levels)]
182 | self.possible_topics = set()
183 | for x in topics:
184 | self.possible_topics.add(x)
185 | a = x.split(' > ')
186 | for i in range(1, len(a)):
187 | self.possible_topics.add(' > '.join(a[:i]))
188 |
189 | self.classes_ = list(self.possible_topics)
190 | new_y = np.zeros(y.shape)
191 | for i in range(y.shape[1]):
192 | self.encoders[i].fit(y[:, i])
193 | new_y[:, i] = self.encoders[i].transform(y[:, i])
194 | self.model.fit(X, new_y)
195 | def predict(self, X):
196 | y = self.model.predict(X)
197 | ret = []
198 | for i in range(y.shape[1]):
199 | ret.append(self.encoders[i].inverse_transform(y[:, i].astype(int)))
200 | y = np.array(ret).T
201 | ret = []
202 | for x in y:
203 | x = [z for z in x if z != '-']
204 | a = ' > '.join(x)
205 | while a not in self.possible_topics:
206 | x = x[:-1]
207 | a = ' > '.join(x)
208 | ret.append(a)
209 | return np.array(ret)
210 |
211 | def predict_proba(self, X):
212 | # This is just a fake function for now, puts 1 in the predicted class and 0 elsewhere
213 | y = self.predict(X)
214 | ret = np.zeros((len(X), len(self.classes_)))
215 | for i, r in enumerate(y):
216 | ret[i, self.classes_.index(r)] = 1
217 | return ret
218 |
219 | class StandardTopicModel:
220 | def __init__(self, threshold=0.5):
221 | self.model= sklearn.linear_model.RidgeClassifierCV()
222 | self.threshold=threshold
223 | # add the missing predict_proba method to RidgeClassifierCV
224 | def predict_proba(self, X):
225 | if len(self.classes_) == 1:
226 | return np.ones((len(X), 1))
227 | d = self.decision_function(X)
228 | if len(self.classes_) == 2:
229 | probs = np.exp(d) / (np.exp(d) + np.exp(-d))
230 | return np.array([1 - probs, probs]).T
231 | probs = np.exp(d).T / np.sum(np.exp(d), axis=1)
232 | return probs.T
233 | self.model.predict_proba = predict_proba.__get__(self.model, self.model.__class__)
234 | def fit(self, X, y):
235 | self.model.fit(X, y)
236 | def predict_proba(self, X):
237 | return self.model.predict_proba(X)
238 | def predict(self, X):
239 | if self.threshold is None:
240 | return self.model.predict(X)
241 | pps = self.model.predict_proba(X)
242 | zero_index = list(self.model.classes_).index('Not problematic')
243 | ret = []
244 | for p in pps:
245 | if p[zero_index] >= self.threshold:
246 | ret.append(self.model.classes_[zero_index])
247 | continue
248 | else:
249 | best = np.argsort(p)
250 | if best[-1] == zero_index:
251 | best = best[:-1]
252 | ret.append(self.model.classes_[best[-1]])
253 | return np.array(ret)
254 | # return self.model.predict(X)
255 |
--------------------------------------------------------------------------------
/adaptivetesting/_prompt_builder.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import logging
3 | import re
4 | import urllib.parse
5 | import adaptivetesting
6 | from .embedders import cos_sim
7 | from .utils import is_subtopic
8 | log = logging.getLogger(__name__)
9 |
10 |
11 | class PromptBuilder():
12 | """ A class to build prompts for the model.
13 | """
14 |
15 | def __init__(self, prompt_size=7, slot_randomization=0.25, score_randomization=0.05, skip_randomization=0.25, prompt_diversity=True,
16 | subtopic_diversity=True):
17 | """ Initialize the prompt builder.
18 |
19 | Parameters
20 | ----------
21 | prompt_size : int
22 | The number of test slots to include in the prompt.
23 |
24 | slot_randomization : float
25 | The proportion of slots to make fully random (within the current topic).
26 |
27 | score_randomization : float
28 | The standard deviation of an additive Gaussian randomization factor for the scores.
29 |
30 | skip_randomization : float
31 | The proportion of times we skip over top ranking tests when building the prompt.
32 |
33 | prompt_diversity : bool
34 | Whether to include a diversity term when selecting tests for the prompt. This diversity term is based
35 | on the embeddings of each test.
36 |
37 | subtopic_diversity : bool
38 | If true, we will try and pick tests from a diverse set of subtopics of the current topic (if we are
39 | using subtopic tests and not direct child tests).
40 | """
41 |
42 | assert skip_randomization < 0.99, "skip_randomization must be less than 1, otherwise everything will always be skipped!"
43 |
44 | self.prompt_size = prompt_size
45 | self.slot_randomization = slot_randomization
46 | self.score_randomization = score_randomization # TODO: make this scale according to the stddev of the top 7 entries?
47 | self.skip_randomization = skip_randomization
48 | self.prompt_diversity = prompt_diversity
49 | self.subtopic_diversity = subtopic_diversity
50 |
51 | def __call__(self, test_tree, topic, score_column, repetitions=1, filter="", suggest_topics=False, working_set_size=100, embeddings=None):
52 | """ This builds a prompt for GPT3 that elicits useful input examples.
53 |
54 | Parameters
55 | ----------
56 | test_tree : adaptivetesting.TestTree
57 | The test tree to generate prompts from.
58 |
59 | topic : str
60 | The topic to build a prompt for.
61 |
62 | score_column : str
63 | The column to use for scoring the tests.
64 |
65 | repetitions : int
66 | The number of times to repeat the prompt generation process. This is how many prompots we will return.
67 |
68 | filter : str
69 | A filter to apply to the test set before selecting tests to build the prompt.
70 |
71 | suggest_topics : bool
72 | If true, we will create a prompt filled with topic names instead of a list of tests.
73 |
74 | working_set_size : int
75 | How many top tests to consider when doing the full iterative scoring process. Larger values may take longer.
76 | Note that this has no effect as long as we never go more than working_set_size tests deep during the prompt
77 | item selection process.
78 |
79 | embeddings : dict
80 | A dictionary of embeddings to use for the prompt. This is used to compute the prompt_diversity.
81 | """
82 |
83 | ids = np.array(test_tree.index)
84 |
85 | # return early for an empty test tree
86 | if len(ids) == 0:
87 | return [[] for _ in range(repetitions)]
88 |
89 | # we compute each test's distance from current topic, where distance is measured
90 | # by the length of the topic prefix shared between the test and the current topic
91 | topic_scaling = np.ones(test_tree.shape[0])
92 | topic_parts = topic.split("/")
93 | for i in range(1, len(topic_parts)):
94 | prefix = "/".join(topic_parts[:i+1])
95 | if suggest_topics:
96 | prefix += "/"
97 | topic_scaling *= 1 + 99 * np.array([v.startswith(prefix) for v in test_tree["topic"]])
98 |
99 | # promote direct children over subtopic descendants and filter for topics vs tests
100 | if suggest_topics:
101 | topic_scaling *= 1 + 99 * np.array([v.rsplit('/', 1)[0] == topic for v in test_tree["topic"]])
102 | topic_scaling *= np.array(test_tree["label"] == "topic_marker")
103 | else:
104 | topic_scaling *= 1 + 99 * np.array([v == topic for v in test_tree["topic"]])
105 | topic_scaling *= np.array(test_tree["label"] != "topic_marker")
106 | topic_scaling *= np.array(["__suggestions__" not in t for t in test_tree["topic"]])
107 |
108 | # return early if we have nothing to build a prompt with
109 | if np.sum(topic_scaling) == 0:
110 | return [[] for _ in range(repetitions)]
111 |
112 | topic_scaling /= np.max(topic_scaling)
113 |
114 | # hide rows that don't match the filter
115 | hidden_scaling = np.ones(len(ids))
116 | if filter != "":
117 | filter_compiled = re.compile(filter)
118 | for i,k in enumerate(ids):
119 | test = test_tree.loc[k]
120 | if filter_compiled.search(test.test_type) is not None:
121 | continue
122 | if hasattr(test, "input") and filter_compiled.search(test.input) is not None:
123 | continue
124 | if hasattr(test, "output") and filter_compiled.search(test.output) is not None:
125 | continue
126 | hidden_scaling[i] = 0.0
127 |
128 | # filter down to a single test type (chosen to match the top scoring test)
129 | if suggest_topics:
130 | # scores currently do not influence topic suggestions
131 | # TODO: can we score topics and topic suggestions?
132 | scores = np.ones(len(ids))
133 | else:
134 | # compute a positive single value score for each test
135 | scores = np.array([score_max(test_tree.loc[k, score_column], test_tree.loc[k, "label"]) for k in ids])
136 |
137 | # filter down to just top rows we will use during the iterative scoring process
138 | rank_vals = scores * topic_scaling * hidden_scaling
139 | top_inds = np.argsort(-rank_vals)[:working_set_size]
140 | ids = ids[top_inds]
141 | topic_scaling = topic_scaling[top_inds]
142 | hidden_scaling = hidden_scaling[top_inds]
143 | scores = scores[top_inds] * 1.0
144 |
145 | # build a list of randomized prompts
146 | prompts = []
147 | for _ in range(repetitions):
148 |
149 | # store tmp versions of things we update during the iteration
150 | scores_curr = scores.copy()
151 | topic_scaling_curr = topic_scaling.copy()
152 |
153 | # score randomization
154 | scores_curr += self.score_randomization * np.random.rand(len(ids))
155 |
156 | # sim_avoidance is a vector that marks which items (and items related through similarities)
157 | # should be avoided (ranked lower for prompt selection)
158 | if self.prompt_diversity:
159 | sim_avoidance = np.zeros(len(ids))
160 | if suggest_topics:
161 | embeddings_arr = np.vstack(adaptivetesting.embed(
162 | [urllib.parse.unquote(test_tree.loc[id, "topic"].split("/")[-1]) for id in ids]
163 | ))
164 | else:
165 | embeddings_arr = np.hstack([
166 | np.vstack(adaptivetesting.embed([test_tree.loc[id, "input"] for id in ids])),
167 | np.vstack(adaptivetesting.embed([test_tree.loc[id, "output"] for id in ids]))
168 | ])
169 | similarities = cos_sim(embeddings_arr, embeddings_arr)
170 | hard_avoidance = np.zeros(len(ids))
171 | diversity = np.ones(len(ids))
172 |
173 | # compute how many greedy and how many random positions we will have
174 | num_random = max(0, min(np.random.binomial(self.prompt_size, self.slot_randomization), len(ids) - self.prompt_size))
175 | num_greedy = max(0, min(self.prompt_size - num_random, len(ids) - num_random))
176 |
177 | # iteratively select prompt items
178 | prompt_ids = []
179 | outside_topics_used = np.ones(len(ids))
180 | while len(prompt_ids) < num_greedy + num_random:
181 |
182 | # once we get to the random part of the process we scramble the scores
183 | if len(prompt_ids) == num_greedy:
184 | scores_curr = 1 + np.random.rand(len(ids))*0.1
185 |
186 | # find the next bext index
187 | if self.prompt_diversity:
188 | diversity = 1 - (similarities * sim_avoidance).max(1)
189 | rank_vals = scores_curr * topic_scaling_curr * diversity * (1 - hard_avoidance) * hidden_scaling * outside_topics_used
190 |
191 | if np.nanmax(rank_vals) <= 0 and len(prompt_ids) > 0: # stop if we have run out of the current subtree
192 | break
193 |
194 | new_ind = np.nanargmax(rank_vals)
195 | skip_rand = np.random.rand()
196 |
197 | # make it unlikely we will choose the same outside topic twice
198 | new_ind_topic = test_tree.loc[ids[new_ind], "topic"]
199 | if not is_subtopic(topic, new_ind_topic):
200 | outside_topics_used *= 1 - 0.9 * np.array([test_tree.loc[id, "topic"] == new_ind_topic for id in ids])
201 |
202 | # add or skip this item
203 | if skip_rand >= self.skip_randomization:
204 | prompt_ids.append(ids[new_ind])
205 | avoidance_level = 1
206 | else:
207 | avoidance_level = 1 - 0.1
208 |
209 | # avoid this IO pair as we select the next pairs
210 | hard_avoidance[new_ind] = avoidance_level
211 | if self.prompt_diversity:
212 | sim_avoidance[new_ind] = avoidance_level
213 |
214 | # lower the weight of the subtopic we just picked from
215 | if self.subtopic_diversity:
216 | new_topic = test_tree.loc[ids[new_ind], "topic"]
217 | if topic != new_topic and is_subtopic(topic, new_topic):
218 | subtopic = topic + "/" + new_topic[(len(topic)+1):].split("/")[0]
219 | subtopic_scaling = np.array([0.001 if is_subtopic(subtopic, test_tree.loc[k, "topic"]) else 1 for k in ids])
220 | topic_scaling_curr *= subtopic_scaling
221 |
222 | # create the prompt as a list of tuples
223 | prompt = []
224 | for k in reversed(prompt_ids):
225 | row = test_tree.loc[k]
226 | if suggest_topics:
227 | if row["topic"] == "":
228 | continue # we can't use the root to help suggest topic names
229 | parents,child = row["topic"].rsplit("/", 1)
230 | prompt.append((k, parents, urllib.parse.unquote(child)))
231 | else:
232 | prompt.append((k, row["topic"], row["input"]))
233 | prompts.append(prompt)
234 |
235 | return prompts
236 |
237 | def score_max(s, label):
238 | if s == "" or s is None:
239 | return 1 if label == "fail" else 0
240 | elif isinstance(s, str):
241 | return np.max([convert_float(v) for v in s.split("|")])
242 | elif np.isnan(s):
243 | return 1 if label == "fail" else 0
244 | else:
245 | return np.max(s)
246 |
247 | def convert_float(s):
248 | try:
249 | f = float(s)
250 | except ValueError:
251 | f = np.nan
252 | return f
--------------------------------------------------------------------------------
/test_trees/abstract_capabilities.csv:
--------------------------------------------------------------------------------
1 | ,topic,input,output,label,labeler,description,model score
2 | b72f54bb56ec458c87138faf4ef64f4f,/Semantic Role Labeling,,,topic_marker,adatest_default,,
3 | 3c506e4cbcc14d87ac0c6e3169beeb0f,/Logic,,,topic_marker,adatest_default,,
4 | 2cb3c973d2b446cfb1f4a4a7f40f7253,/Named Entity Recognition,,,topic_marker,adatest_default,,
5 | 9037cb78df8e4b09bab972c5661b735d,/Vocabulary+POS,,,topic_marker,adatest_default,,
6 | 480258ecc6654efb879d29726fc05507,/Fairness,,,topic_marker,adatest_default,,
7 | 87c8b63256d0486fb306b9b13d042ad6,/Negation,,,topic_marker,adatest_default,,
8 | 2c2cccc2ae714d77a33ec72a93339739,/Coreference,,,topic_marker,adatest_default,,
9 | 3eb405c6ab6a4eeb95c57f00c2f8c542,/Fairness/Race,,,topic_marker,adatest_default,,
10 | 6287a53a733e454796a05efd939e0903,/Fairness/Gender and sexuality,,,topic_marker,adatest_default,,
11 | cd722a3a511a4b2cbb211eab3435a968,/Fairness/Gender and sexuality/LGBTQ,,,topic_marker,adatest_default,,
12 | 8f16c69e8fe3490b8be6c9fcafb8dc2f,/Fairness/Politics,,,topic_marker,adatest_default,,
13 | fe40a6e72696481aaf6e4f575868c8b2,/Robustness/Adding or removing irrelevant punctuation,,,topic_marker,adatest_default,,
14 | 683233e699c14d54a6801acdca3b08d3,/Robustness/Spelling errors,,,topic_marker,adatest_default,,
15 | c12be5d95344420a8bf191f52c45d5ce,/Robustness/Adding a random URL or @ mentions,,,topic_marker,adatest_default,,
16 | 045f4ad7e36a4f1694f9260a6d8f32ab,/Fairness/Religion,,,topic_marker,adatest_default,,
17 | 9ce174f1fc544297becdeef49cc061db,/Fairness/Disabilities,,,topic_marker,adatest_default,,
18 | e6a7264a9ce1475091d8fa841a3b758d,/Fairness/Age,,,topic_marker,adatest_default,,
19 | 022942349c854e678b4623f37fdd6cd6,/Logic/Symmetry,,,topic_marker,adatest_default,,
20 | eda7126ae4c2400ea2b22135b19fd5f5,/Logic/Implications,,,topic_marker,adatest_default,,
21 | 5b2621ecc57f4052bc5218049faf825c,/Taxonomy/Synonyms,,,topic_marker,adatest_default,,
22 | 5fe557f4efda4c18956dbf72729e69b2,/Taxonomy/Antonyms,,,topic_marker,adatest_default,,
23 | 45784b5867f24d70bae80cbf60100bc5,/Robustness/Contractions,,,topic_marker,adatest_default,,
24 | 7d00b5c1b43046d28ecfed25860a83b7,/Fairness/Religion/Christianity,,,topic_marker,adatest_default,,
25 | 6f192e1dd87c4441a9121fba0da8534f,/Fairness/Religion/Islam,,,topic_marker,adatest_default,,
26 | 7b3addf57c3c46249e819b624909033f,/Fairness/Religion/Hinduism,,,topic_marker,adatest_default,,
27 | cc309caf440a49c1a4636808ce811e58,/Fairness/Religion/Buddhism,,,topic_marker,adatest_default,,
28 | bc1e8ab4711c4bbfb26876e95bd2ee95,/Fairness/Religion/Judaism,,,topic_marker,adatest_default,,
29 | 14187c4c7c9243a193758455bbaa7776,/Fairness/Religion/Shinto,,,topic_marker,adatest_default,,
30 | 9a12e7319fc3417d85e837daef62449b,/Fairness/Religion/Taoism,,,topic_marker,adatest_default,,
31 | 382da467f7b04ea296f81fdd897e45a0,/Fairness/Religion/Jainism,,,topic_marker,adatest_default,,
32 | 6f1298d713be4fdc9c1f668c31b38b50,/Fairness/Religion/Atheism,,,topic_marker,adatest_default,,
33 | 58c32b204151411f94078d89d4516ea5,/Fairness/Religion/Sikhism,,,topic_marker,adatest_default,,
34 | 936b858cc6684230987ee8d3e12ee8c1,/Fairness/Religion/Spiritism,,,topic_marker,adatest_default,,
35 | 5f1457585d3743a68d0e891ff5d43e51,/Fairness/Religion/Bahai Faith,,,topic_marker,adatest_default,,
36 | c66b61b6c0e9478fbae3b3883d9c6efb,/Fairness/Age/Young,,,topic_marker,adatest_default,,
37 | b88fa13816c7485f8006bd90948dce0a,/Fairness/Age/Old,,,topic_marker,adatest_default,,
38 | 3d97bfa9f086421a8f30683f4e00b653,/Fairness/Disabilities/Autism spectrum disorders,,,topic_marker,adatest_default,,
39 | 0a3a12c82c53428d8b17aa33f6b893c0,/Fairness/Disabilities/Deaf and hard of hearing,,,topic_marker,adatest_default,,
40 | 1464613188b74cae9a64f72bafe9e7e6,/Fairness/Disabilities/Visual impairments,,,topic_marker,adatest_default,,
41 | 3b9c6615468f4699b589f0e0d590b388,/Fairness/Disabilities/Learning disabilities,,,topic_marker,adatest_default,,
42 | 2f3d38e7a0c247e8b1854b78dfb9de70,/Fairness/Disabilities/Mobility disabilities,,,topic_marker,adatest_default,,
43 | 2b69cb088cb64c0dbbf56e67722ce535,/Fairness/Race/East Asian,,,topic_marker,adatest_default,,
44 | bb62f5f3b53447e282f7fbbb77c014e7,/Fairness/Race/European,,,topic_marker,adatest_default,,
45 | eac2219f50954542a5a53d53e5390247,/Fairness/Race/Indigenous American,,,topic_marker,adatest_default,,
46 | 8e3ebca3b259483c837d52e51a9e724c,/Fairness/Race/Sub-Saharan African,,,topic_marker,adatest_default,,
47 | 2312dee91409400f86a75f6992a148f7,/Fairness/Race/North African,,,topic_marker,adatest_default,,
48 | 72dda66a4ea944249ca55f3d502b60f2,/Fairness/Race/West Asian,,,topic_marker,adatest_default,,
49 | 93715bd783e640ca8be8d32d9ddb9ace,/Fairness/Race/Melanesian,,,topic_marker,adatest_default,,
50 | 734d00bbdd664182b96115eeda255c00,/Fairness/Race/Central & South Asian,,,topic_marker,adatest_default,,
51 | 1c856fe4b2f048eab2a9b0302c7c735f,/Fairness/Race/European/Northwestern European,,,topic_marker,adatest_default,,
52 | 6fb01ad5445f42e2b19d390da4e922a8,/Fairness/Race/European/Jewish,,,topic_marker,adatest_default,,
53 | 30d9d929aecd411d86b278c1958e1662,/Fairness/Race/European/Eastern European,,,topic_marker,adatest_default,,
54 | 84d0247422a24e529360a24fdc95944e,/Fairness/Race/European/Southern European,,,topic_marker,adatest_default,,
55 | 05ce2308a4974026ac1ea52be07fbea3,"/Fairness/Race/Arab, Egyptian & Levantine",,,topic_marker,adatest_default,,
56 | b64f34df1c40453196ed5874f8a56a43,/Fairness/Race/European/Northwestern European/British & Irish,,,topic_marker,adatest_default,,
57 | c33650d6aea543c1813e04d8a914f423,/Fairness/Race/European/Northwestern European/Scandinavian,,,topic_marker,adatest_default,,
58 | 990d9443b2d943fb91ed8c5e959e7ba5,/Fairness/Race/European/Northwestern European/French & German,,,topic_marker,adatest_default,,
59 | fa65eedda36b4e0882ff9649b46f30a8,/Fairness/Race/European/Northwestern European/Finnish,,,topic_marker,adatest_default,,
60 | bc1b76925b474d2e9bafda65a0c1e246,/Fairness/Race/European/Southern European/Greek & Balkan,,,topic_marker,adatest_default,,
61 | 2440037191764f3bb1d76601f1560296,/Fairness/Race/European/Southern European/Italian,,,topic_marker,adatest_default,,
62 | e9c73fe0f30e41c6bc489b82e66fec87,/Fairness/Race/European/Southern European/Sardinian,,,topic_marker,adatest_default,,
63 | 872478a8298e439b96e17ff072c330fe,/Fairness/Race/European/Southern European/Spanish & Portuguese,,,topic_marker,adatest_default,,
64 | 18400d3dbec24d5b9524a8f4bd52dd3e,/Fairness/Ethnicity,,,topic_marker,adatest_default,,
65 | eaefb247d21f4d1580d489e045ee3666,/Fairness/Ethnicity/Black or African American,,,topic_marker,adatest_default,,
66 | dc5d4e7ee670443cb6401b629a068c50,/Fairness/Ethnicity/White,,,topic_marker,adatest_default,,
67 | 00b80e0768c94a25acf2cea654b068a8,/Fairness/Ethnicity/Hispanic,,,topic_marker,adatest_default,,
68 | 4d1cfc27dfbf45569e53ed54a1db73ca,/Fairness/Ethnicity/Asian,,,topic_marker,adatest_default,,
69 | 46fe783a83c647058396ea7d5e723f31,/Fairness/Ethnicity/Arab,,,topic_marker,adatest_default,,
70 | ae02ab0695cc474db66b4c1ec3e26cac,/Fairness/Ethnicity/Jewish,,,topic_marker,adatest_default,,
71 | fea6dba1333d4ba8820135c85b13ea7f,/Fairness/Ethnicity/Black or African American/Black or African American,,,topic_marker,adatest_default,,
72 | 86bd9c19b41647519618f94157e57e56,/Fairness/Ethnicity/Nationality,,,topic_marker,adatest_default,,
73 | e6d782ca2f964c1787dfabb070ce9e01,/Fairness/Ethnicity/Nationality/American,,,topic_marker,adatest_default,,
74 | c553bb653f1d42afa0c9a2ae9bc9137b,/Fairness/Ethnicity/Nationality/Chinese,,,topic_marker,adatest_default,,
75 | 0b2e5161d82d40bf8588a896c0a88333,/Fairness/Ethnicity/Nationality/British,,,topic_marker,adatest_default,,
76 | 5c688fa9edd34e998b1a646e5cc92d1e,/Fairness/Ethnicity/Nationality/Irish,,,topic_marker,adatest_default,,
77 | b09935df44a24257bcbf18180ec015b9,/Fairness/Ethnicity/Nationality/Israeli,,,topic_marker,adatest_default,,
78 | 2a8fc16f6f544fd59dd82cca9fac89cb,/Fairness/Ethnicity/Nationality/Egyptian,,,topic_marker,adatest_default,,
79 | 001d29ac38a846eaaeb4b670842d4685,/Fairness/Ethnicity/Nationality/Mexican,,,topic_marker,adatest_default,,
80 | f7fca049c61948fbbbd952616104d59b,/Fairness/Ethnicity/Nationality/Canadian,,,topic_marker,adatest_default,,
81 | 6f0ebdf0d13a4eddb45808ed1c831f0a,/Fairness/Ethnicity/Nationality/Brazilian,,,topic_marker,adatest_default,,
82 | 5aef5b9ecde2464d8bf6e54f34db55b8,/Fairness/Ethnicity/Nationality/Indonesian,,,topic_marker,adatest_default,,
83 | 6062d425b1f745bf97a6331d59925e0a,/Fairness/Ethnicity/Nationality/Russian,,,topic_marker,adatest_default,,
84 | fdb8fe1da4db47e1b1df4696644df6c9,/Fairness/Ethnicity/Nationality/Indian,,,topic_marker,adatest_default,,
85 | c8a4d9a49ca94ff194238d7df9040bb2,/Fairness/Ethnicity/Nationality/Pakistani,,,topic_marker,adatest_default,,
86 | 4fc69a320a60438e936f450f234c18d1,/Fairness/Ethnicity/Nationality/Nigerian,,,topic_marker,adatest_default,,
87 | 65c9db7b6d13432da1cc6f236b7340c5,/Fairness/Ethnicity/Nationality/Bangladeshi,,,topic_marker,adatest_default,,
88 | 4b51ae3043c145de8919dced27724fa4,/Fairness/Ethnicity/Nationality/Japanese,,,topic_marker,adatest_default,,
89 | 39952472ab7e45908dbba6e13c131403,/Fairness/Ethnicity/Nationality/Ethiopian,,,topic_marker,adatest_default,,
90 | 6fcc5a84123a42239a9abbb47251c2b9,/Fairness/Ethnicity/Nationality/Filipino,,,topic_marker,adatest_default,,
91 | 01450c14171d44559113adc39ede0d3f,/Fairness/Ethnicity/Nationality/Vietnamese,,,topic_marker,adatest_default,,
92 | d6f8f0aa6c2a49459af593b2ac34320b,/Fairness/Ethnicity/Nationality/Congolese,,,topic_marker,adatest_default,,
93 | 47c0c7dd5e924010914836602f1c02b6,/Fairness/Ethnicity/Nationality/Iranian,,,topic_marker,adatest_default,,
94 | 79872f5db4ee45b8bde7a54c2f113d58,/Fairness/Ethnicity/Nationality/Turkish,,,topic_marker,adatest_default,,
95 | eb229bf76b624ada96a4f7ecb913016d,/Fairness/Ethnicity/Nationality/German,,,topic_marker,adatest_default,,
96 | 8e2ce2153c4a4ef4bd1e3947c0064fcf,/Fairness/Ethnicity/Nationality/French,,,topic_marker,adatest_default,,
97 | 07affcfaa3f045aebed2861e6df89234,/Fairness/Ethnicity/Nationality/Italian,,,topic_marker,adatest_default,,
98 | b41f2aefe1ca47cebb7955c90b315474,/Fairness/Ethnicity/Nationality/Spanish,,,topic_marker,adatest_default,,
99 | ef25682376c44ba3822eeddf03fcb082,/Fairness/Ethnicity/Nationality/Korean,,,topic_marker,adatest_default,,
100 | a1b7b5aa36444e839987b3c98a6b3940,/Fairness/Ethnicity/Nationality/Australian,,,topic_marker,adatest_default,,
101 | 1f792e47ee404f9182f97e76fda15419,/Literary Devices/Analogies,,,topic_marker,adatest_default,,
102 | c2ab86dd60434b78b33058a9a44b3cd7,/Literary Devices/Analogies,,,topic_marker,adatest_default,,
103 | ef2f42885c98445dbbb2be0fca994d6c,/Literary Devices/Analogies,,,topic_marker,adatest_default,,
104 | cfb77d784eac4cf5abeb11f919e4c64f,/Literary Devices,,,topic_marker,adatest_default,,
105 | 2782002c34c345d6a112665c02e5e9e2,/Literary Devices/Metaphor,,,topic_marker,adatest_default,,
106 | 2d921ac3158748eeb443eaa5880ee5f4,/Literary Devices/Irony+Sarcasm,,,topic_marker,adatest_default,,
107 | e5bdc973808f44b090bc358433976648,/Literary Devices/Hyperbole,,,topic_marker,adatest_default,,
108 | 9465029dfb0844b4b28222633e4ee9da,/Literary Devices/Allusions,,,topic_marker,adatest_default,,
109 | 2a6d5758bf7e487a8b08f40d06e60315,/Literary Devices/Personification,,,topic_marker,adatest_default,,
110 | 97317d629a29480cbcf6a77d9d78d679,/Literary Devices/Imagery,,,topic_marker,adatest_default,,
111 | 5d2ebde64275493da3592b711a53f16d,/Literary Devices/Simile,,,topic_marker,adatest_default,,
112 | 93cd0029f5ff41628474aa46007f21e8,/Literary Devices/Symbolism,,,topic_marker,adatest_default,,
113 | 61cd12099c67425aa2a2a1873a473eb7,/Literary Devices/Puns,,,topic_marker,adatest_default,,
114 | e1bc2b66791c4eb08e2f01b167550110,/Fairness,,,topic_marker,adatest_default,,
115 | fe0f67d4116d4e37940cfd927b59c373,/Fairness,,,topic_marker,adatest_default,,
116 | 28202017462248398be8300204f1cb4d,/Fairness,,,topic_marker,adatest_default,,
117 | fe71895657474d6f9f01602d6ccf3573,/Fairness/Occupation,,,topic_marker,adatest_default,,
118 | ffa7673d58aa454aa92bd5905b7be832,/Fairness/Immigration,,,topic_marker,adatest_default,,
119 | ada79da98a3c4c8eac7dd2c8748dde59,/Vocabulary+POS/Important verbs,,,topic_marker,adatest_default,,
120 | 0b43004d2bf74213b38ef2713233fbc4,/Vocabulary+POS/Important adjectives,,,topic_marker,adatest_default,,
121 | ac574e7cdfdc464c87d86c815cfc9660,/Vocabulary+POS/Important nouns,,,topic_marker,adatest_default,,
122 | c0f21344b01147f986a51d89e47f4fc8,/Vocabulary+POS/Important modifiers,,,topic_marker,adatest_default,,
123 | ca1476723e414a6a9a047cc13c34faf3,/Vocabulary+POS/Irrelevant word changes,,,topic_marker,adatest_default,,
124 | 24c93ae4c4234dce911125b894e260de,/Named Entity Recognition/Person names,,,topic_marker,adatest_default,,
125 | 3dcc45a2c680451db7c00a0c664b79e5,/Named Entity Recognition/Location names,,,topic_marker,adatest_default,,
126 | 704d58451d354a9bb9edf6557383a3d2,/Named Entity Recognition/Person names/Person names or changes that matter for prediction,,,topic_marker,adatest_default,,
127 | ee49e950701742e180ead9fad1d44434,/Named Entity Recognition/Person names/Person names or changes that matter for prediction,,,topic_marker,adatest_default,,
128 | c6a6e0ff746043769e652350de03fb15,/Named Entity Recognition/Person names/Person names or changes that do not matter for prediction,,,topic_marker,adatest_default,,
129 | c94fcc5ab74445f9889db304ca233832,/Named Entity Recognition/Location names/Location names or changes that matter for prediction,,,topic_marker,adatest_default,,
130 | 8954a752d9cc40cd84101fa18786e1d9,/Named Entity Recognition/Location names/Location names or changes that do not matter for prediction,,,topic_marker,adatest_default,,
131 | d3c9c6401ada46b7b9ba0f34785d1fb2,/Named Entity Recognition/Numbers,,,topic_marker,adatest_default,,
132 | e0406c39fc4b4e1cb9d8b92e860001a4,/Named Entity Recognition/Numbers/Numbers or changes that matter for prediction,,,topic_marker,adatest_default,,
133 | 38c621ea185f4bd299d7a32e0bbd0ff5,/Named Entity Recognition/Numbers/Numbers or changes that do not matter for prediction,,,topic_marker,adatest_default,,
134 | 4091e479855a4609ae2469f46a7098c9,/Fairness/Gender and sexuality/Male to Female or vice versa,,,topic_marker,adatest_default,,
135 | 7ef2dd29dc4a46b6a0afc444fa09493c,/Semantic Role Labeling/Agent - Object swap,,,topic_marker,adatest_default,,
136 | d8de9181c54644bc9c8808546a1f3787,/Semantic Role Labeling/Active - passive swap,,,topic_marker,adatest_default,,
137 | 6370cb75a68f4d588e505832fb9e75aa,/Logic/Conjunctions,,,topic_marker,adatest_default,,
138 | 6284e4f2d27a43409c4181ef2505cb08,/Logic/Disjunctions,,,topic_marker,adatest_default,,
139 | d54eb69301924cf1966e302b08f01f6e,/Taxonomy/Subtypes,,,topic_marker,adatest_default,,
140 |
--------------------------------------------------------------------------------
/adaptivetesting/_scorer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import re
3 | import logging
4 | import uuid
5 | import itertools
6 | import shap
7 | from ._model import Model
8 | import adaptivetesting
9 |
10 | log = logging.getLogger(__name__)
11 |
12 | class Scorer():
13 | def __new__(cls, model, *args, **kwargs):
14 | """ If we are wrapping an object that is already a Scorer, we just return it.
15 | """
16 | if shap.utils.safe_isinstance(model, "adaptivetesting.Scorer"):
17 | return model
18 | else:
19 | return super().__new__(cls)
20 |
21 | def __init__(self, model):
22 | """ Auto detect the model type and subclass to the right scorer object.
23 | """
24 |
25 | # ensure we have a model of type Model
26 | if shap.utils.safe_isinstance(getattr(self, "model", None), "adaptivetesting.Model") or shap.utils.safe_isinstance(getattr(self, "model", None), "shap.models.Model"):
27 | pass
28 | elif shap.utils.safe_isinstance(model, "adaptivetesting.Model") or shap.utils.safe_isinstance(model, "shap.models.Model"):
29 | self.model = model
30 | else:
31 | self.model = Model(model)
32 |
33 | # If we are in the base class we need to pick the right specialized subclass to become
34 | if self.__class__ is Scorer:
35 |
36 | # finish early if we are wrapping an object that is already a Scorer (__new__ will have already done the work)
37 | if shap.utils.safe_isinstance(model, "adaptivetesting.Scorer"):
38 | return
39 |
40 | # see if we are scoring a generator or a classifier
41 | out = self.model(["string 1", "string 2"])
42 | if isinstance(out[0], str):
43 | self.__class__ = GeneratorScorer
44 | GeneratorScorer.__init__(self, model)
45 | else:
46 | self.__class__ = ClassifierScorer
47 | ClassifierScorer.__init__(self, model)
48 |
49 |
50 | class DummyScorer(Scorer):
51 | def __init__(self):
52 | self._id = uuid.uuid4().hex
53 | def __call__(self, tests):
54 | out = []
55 | for k, test in tests.iterrows():
56 | try:
57 | score = float(test.value2)
58 | except:
59 | score = np.nan
60 | out.append(score)
61 | return np.array(out)
62 |
63 | class ClassifierScorer(Scorer):
64 | """ Wraps a text classification model and defines a callable scorer that returns a score value for any input/output pair.
65 |
66 | Positive scores indicate test failures, positive scores indicate tests that pass. For example if we wrap
67 | a text sentiment classifer the `scorer(TestTree([("this is great!", "should be", "POSITIVE")]))` will return
68 | a large positive value indicating that the model is very likely to correctly produce that output when given
69 | that input.
70 | """
71 |
72 | def __init__(self, model, top_probs=20, output_names=None):
73 | """ Create a new scorer given a model that returns a probability vector for each input string.
74 |
75 | Parameters:
76 | -----------
77 | model : callable
78 | A model that is callable with a single argument (which is a list of strings) and returns a matrix of outputs.
79 |
80 | top_probs : int
81 | The number of top output probabilities to consider when scoring tests. This is used to reduce the number of
82 | input/output pairs that are passed to the local topic labeling model (and so save compute).
83 |
84 | output_names : list of strings
85 | A list of strings that correspond to the outputs of the model. If None, model.output_names is used.
86 | """
87 | super().__init__(model)
88 |
89 | # extract output names from the model if they are not provided directly
90 | if output_names is None and getattr(self, "output_names", None) is None:
91 | self.output_names = self.model.output_names
92 | elif output_names is not None:
93 | self.output_names = output_names
94 | elif not hasattr(self, "output_names"):
95 | self.output_names = None
96 |
97 | if not callable(self.output_names):
98 | self._output_name_to_index = {v: i for i, v in enumerate(self.output_names)}
99 | self.top_probs = top_probs
100 |
101 | def __call__(self, tests, eval_ids):
102 | """ Compute the scores (and model outputs) for the tests matching the given ids.
103 |
104 | Parameters
105 | ----------
106 | tests : TestTree
107 | A test tree for scoring. Note this should be the full test tree since it defines the local topic label
108 | models used for scoring.
109 |
110 | eval_ids : list of strings
111 | The ids of the tests to score.
112 | """
113 |
114 | # expand templates in the test tree
115 | eval_inputs = []
116 | eval_inds = []
117 | for i, id in enumerate(eval_ids):
118 | test = tests.loc[id]
119 | template_expansions = expand_template(test.input)
120 | for expansion in template_expansions:
121 | eval_inputs.append(expansion)
122 | eval_inds.append(i)
123 |
124 | # run the model
125 | try:
126 | model_out = self.model(eval_inputs)
127 | except Exception as e:
128 | model_out = np.zeros((len(eval_inputs), len(self.model.output_names))) * np.nan # TODO: remove this hack after the user study
129 | log.error(e)
130 | log.error(eval_inputs)
131 | log.error("The model threw an exception when evaluating inputs! We are patching this disaster with np.nan for the sake of the user study!")
132 |
133 | # compute the output strings and probabilites for each output in template form
134 | out_strings = [[] for _ in range(len(eval_ids))]
135 | out_probs = [[] for _ in range(len(eval_ids))]
136 | i = 0
137 | while i < len(model_out):
138 | out_strings[eval_inds[i]].append(self.model.output_names[np.argmax(model_out[i])])
139 | out_probs[eval_inds[i]].append(model_out[i])
140 | i += 1
141 | for i in set(eval_inds):
142 | out_strings[i] = "|".join(out_strings[i]) # template outputs are joined by |
143 | out_probs[i] = np.column_stack(out_probs[i]) # the probability of a set of items is the prob of the min item
144 |
145 | # compute the embeddings as a batch (this fills a cache we will use when scoring below)
146 | adaptivetesting.embed(list(tests.loc[eval_ids, "input"]))
147 |
148 | # score all the tests
149 | scores = []
150 | outputs = []
151 | for i, ind in enumerate(eval_inds):
152 | outputs.append(out_strings[ind])
153 | scores.append(self._score_test(tests, eval_ids[ind], out_probs[ind], self.top_probs))
154 |
155 | return outputs,scores
156 |
157 | def _score_test(self, tests, id, probs, top_probs):
158 | test = tests.loc[id]
159 | total_fail_prob = 0
160 | total_pass_prob = 0
161 |
162 | # if this is not a templated test
163 | if probs.shape[1] == 1:
164 | inds = np.argsort(probs[:,0])[::-1]
165 | for ind in inds[:top_probs]:
166 |
167 | # Scott: we could use any manually given labels when possible, but then that would make the score depend on the label
168 | # and so we would either need to save the full output of the model or recompute every time
169 | # if self.model.output_names[ind] == test["output"] and test["labeler"] != "imputed":
170 | # label = test["label"]
171 |
172 | # we use the local topic model to predict the label
173 | fail_prob = tests.topic_labeling_model(test.topic)(test.input, self.model.output_names[ind])
174 | total_fail_prob += probs[ind, 0] * fail_prob
175 | total_pass_prob += probs[ind, 0] * (1 - fail_prob)
176 |
177 | if not (total_fail_prob + total_pass_prob > 0):
178 | return np.nan
179 | else:
180 | return total_fail_prob / (total_pass_prob + total_fail_prob)
181 | else:
182 | raise NotImplementedError("TODO: implement classifer scoring for templated tests")
183 |
184 | class GeneratorScorer(Scorer):
185 | """ Wraps a text generation model as a callable scorer that can be applied to a test tree.
186 | """
187 |
188 | def __init__(self, model):
189 | """ Create a new scorer for a generative text model.
190 |
191 | Parameters:
192 | -----------
193 | model : callable
194 | A model that is callable with a single argument (which is a list of strings) and returns a list of strings.
195 | """
196 | super().__init__(model)
197 |
198 | # we don't want to re-init a class if init has alrady been done (this can happen when Scorer(maybe_scorer) is called)
199 | if hasattr(self, "_id"):
200 | return # already initialized
201 |
202 | def __call__(self, tests, eval_ids):
203 | """ Score a set of tests.
204 |
205 | Parameters
206 | ----------
207 | tests : TestTree or DataFrame
208 | A dataframe of tests.
209 |
210 | eval_ids : list of strings
211 | The evaluation IDs to use.
212 | """
213 |
214 | # determine which rows we need to evaluate
215 | eval_inputs = []
216 | eval_inds = []
217 | for i, id in enumerate(eval_ids):
218 | template_expansions = expand_template(tests.loc[id, "input"])
219 | for expansion in template_expansions:
220 | eval_inputs.append(expansion)
221 | eval_inds.append(i)
222 |
223 | # run the model on the rows we need to evaluate
224 | try:
225 | model_out = self.model(eval_inputs)
226 | except Exception as e:
227 | model_out = [""] * len(eval_inputs) # TODO: remove this hack after the user study
228 | log.error(e)
229 | log.error(eval_inputs)
230 | log.error("The model threw an exception when evaluating inputs! We are patching this disaster with np.nan for the sake of the user study!")
231 |
232 | # compute the output strings for each output
233 | out_strings = [[] for _ in range(len(eval_ids))]
234 | i = 0
235 | while i < len(model_out):
236 | out_strings[eval_inds[i]].append(str(model_out[i]))
237 | i += 1
238 | for i in set(eval_inds):
239 | out_strings[i] = "|".join(out_strings[i]) # template outputs are joined by |
240 |
241 | scores = []
242 | outputs = []
243 | for i, ind in enumerate(eval_inds):
244 | outputs.append(out_strings[ind])
245 | scores.append(self._score_test(tests, eval_ids[ind], out_strings[ind]))
246 |
247 | return outputs,scores
248 |
249 | def _score_test(self, tests, id, output):
250 | test = tests.loc[id]
251 |
252 | fail_prob = tests.topic_labeling_model(test.topic)(test.input, output)
253 |
254 | return fail_prob
255 |
256 | class RawScorer(Scorer):
257 | """ Wraps a model that directly outputs a score each input as a callable scorer.
258 |
259 | The score from the model should be in the range [0,1] with higher scores indicating failures
260 | (or just more interesting behavior).
261 | """
262 |
263 | def __init__(self, model):
264 | """ Create a new scorer given a model that returns a bounded real value for each input string.
265 |
266 | Parameters:
267 | -----------
268 | model : callable
269 | A model that is callable with a single argument (which is a list of strings) and returns a vector of score in the range [0,1].
270 | """
271 | super().__init__(model)
272 |
273 | def __call__(self, tests, eval_ids):
274 | """ Compute the scores (and model outputs) for the tests matching the given ids.
275 |
276 | Parameters
277 | ----------
278 | tests : TestTree
279 | A test tree for scoring. Note this should be the full test tree since it defines the local topic label
280 | models used for scoring.
281 |
282 | eval_ids : list of strings
283 | The ids of the tests to score.
284 | """
285 |
286 | # expand templates in the test tree
287 | eval_inputs = []
288 | eval_inds = []
289 | for i, id in enumerate(eval_ids):
290 | test = tests.loc[id]
291 | template_expansions = expand_template(test.input)
292 | for expansion in template_expansions:
293 | eval_inputs.append(expansion)
294 | eval_inds.append(i)
295 |
296 | # run the model
297 | try:
298 | model_out = self.model(eval_inputs)
299 | except Exception as e:
300 | model_out = np.zeros(len(eval_inputs)) * np.nan # TODO: remove this hack after the user study
301 | log.error(e)
302 | log.error(eval_inputs)
303 | log.error("The model threw an exception when evaluating inputs! We are patching this disaster with np.nan for the sake of the user study!")
304 |
305 | # compute the output strings and scores for each output in template form
306 | out_strings = [[] for _ in range(len(eval_ids))]
307 | out_scores = [[] for _ in range(len(eval_ids))]
308 | i = 0
309 | while i < len(model_out):
310 | out_strings[eval_inds[i]].append(str(np.round(model_out[i], 6))) # convert float to string with precision of 6
311 | out_scores[eval_inds[i]].append(model_out[i])
312 | i += 1
313 | for i in eval_inds:
314 | out_strings[i] = "|".join(out_strings[i]) # template outputs are joined by |
315 | out_scores[i] = np.max(out_scores[i]) # the score of a set of items is the score of the max item
316 |
317 | # score all the tests
318 | scores = []
319 | outputs = []
320 | for i, ind in enumerate(eval_inds):
321 | outputs.append(out_strings[ind])
322 | scores.append(out_scores[ind])
323 |
324 | return outputs,scores
325 |
326 | def expand_template(s, keep_braces=False):
327 | """ Expand a template string into a list of strings.
328 | """
329 | # parts = []
330 | # for s in strings:
331 | matches = re.findall("{[^}]*}", s)
332 | s = re.sub("{[^}]*}", "{}", s)
333 | template_groups = [str(m)[1:-1].split("|") for m in matches]
334 | try:
335 | if keep_braces:
336 | return [s.format(*['{{{p}}}' for p in parts]) for parts in itertools.product(*template_groups)]
337 | else:
338 | return [s.format(*parts) for parts in itertools.product(*template_groups)]
339 | except ValueError:
340 | return [s] # we return the template not filled in if it is invalid
341 |
342 | def clean_template(s):
343 | """ This removes duplicate template entries.
344 | """
345 | matches = re.findall("{[^}]*}", s)
346 | s = re.sub("{[^}]*}", "{}", s)
347 | template_groups = [str(m)[1:-1].split("|") for m in matches]
348 | clean_groups = ["{"+"|".join(list({v: None for v in g}.keys()))+"}" for g in template_groups]
349 | try:
350 | return s.format(*clean_groups)
351 | except ValueError:
352 | return s # we return the template not cleaned in if it is invalid
353 |
--------------------------------------------------------------------------------
/adaptivetesting/_server.py:
--------------------------------------------------------------------------------
1 | import json
2 | import functools
3 | import uuid
4 | import pathlib
5 | import logging
6 | import os
7 |
8 |
9 | import asyncio
10 | import nest_asyncio
11 | nest_asyncio.apply()
12 |
13 | import aiohttp
14 | from aiohttp import web
15 | import aiohttp_session
16 | import aiohttp_session.cookie_storage
17 | import aiohttp_security
18 | #from aiohttp_session import SimpleCookieStorage, session_middleware
19 | from aiohttp_security import check_permission, \
20 | is_anonymous, remember, forget, \
21 | setup as setup_security, SessionIdentityPolicy
22 | from aiohttp_security.abc import AbstractAuthorizationPolicy
23 | import cryptography.fernet
24 | from . import TestTree
25 | import functools
26 |
27 | log = logging.getLogger(__name__)
28 |
29 |
30 | def serve(test_tree_browsers, host="localhost", port=8080, static_dir=None, authenticate=lambda user, password: True,
31 | authorize=lambda user,location: True, auth_duration=60 * 60 * 8, ssl_crt=None, ssl_key=None):
32 | """ Serves the interface at the given host and port.
33 | """
34 | log.debug(f"serve(test_tree_browsers={test_tree_browsers})")
35 |
36 | if isinstance(test_tree_browsers, TestTree):
37 | raise Exception("You cannot serve a TestTree directly! You need to call it with a scorer like test_tree(scorer).")
38 |
39 | if isinstance(authenticate, dict):
40 | auth_dict = authenticate
41 | def check_pass(user, password):
42 | return auth_dict.get(user, object()) == password
43 | authenticate = check_pass
44 |
45 | loop = asyncio.get_event_loop()
46 |
47 | if not hasattr(test_tree_browsers, "interface_event") and callable(test_tree_browsers):
48 | test_tree_browsers = functools.lru_cache(maxsize=None)(test_tree_browsers)
49 |
50 | id = uuid.uuid4().hex
51 |
52 | async def send_ws_data(ws, str_data):
53 | await ws.send_str(str_data)
54 |
55 | async def topic_handler(request):
56 | log.debug(f"topic_handler({request})")
57 | logged_in = not await aiohttp_security.is_anonymous(request)
58 |
59 | if not logged_in:
60 | user = request.rel_url.query.get("user", "anonymous")
61 | if authenticate(user, None):
62 | redirect_response = web.HTTPFound(str(request.rel_url))
63 | await remember(request, redirect_response, user)
64 | return redirect_response
65 | else:
66 | raise web.HTTPFound(f'/_login?user={user}&sendback={str(request.rel_url)}')
67 | else:
68 | user = await aiohttp_security.authorized_userid(request)
69 | if hasattr(test_tree_browsers, "interface_event"):
70 | prefix = ""
71 | test_tree_browser = test_tree_browsers
72 | test_tree_name = 'fake'
73 | else:
74 | test_tree_name = request.match_info["test_tree"]
75 | prefix = "/" + test_tree_name
76 | if callable(test_tree_browsers):
77 | test_tree_browser = test_tree_browsers(test_tree_name)
78 | else:
79 | test_tree_browser = test_tree_browsers.get(test_tree_name, None)
80 |
81 | # make sure we found the given test
82 | if not hasattr(test_tree_browser, "interface_event"):
83 | log.debug(f"The test tree we found was not valid: {test_tree_browsers}")
84 | raise web.HTTPNotFound()
85 | test_tree_browser.user = user
86 | test_tree_browser.name = test_tree_name
87 |
88 | interface_html = f"""
89 |
90 |
91 | Adaptive Testing
92 |
93 |
94 | {test_tree_browser._repr_html_(prefix=prefix, environment="web", websocket_server=prefix+"/_ws")}
95 |
96 |
97 | """
98 |
99 | return web.Response(text=interface_html, content_type="text/html")
100 |
101 | async def static_handler(request):
102 | logged_in = not await aiohttp_security.is_anonymous(request)
103 | log.debug(f"static_handler({request})")
104 | if not logged_in:
105 | user = request.rel_url.query.get("user", "anonymous")
106 | if authenticate(user, None):
107 | redirect_response = web.HTTPFound(str(request.rel_url))
108 | await remember(request, redirect_response, user)
109 | return redirect_response
110 | else:
111 | raise web.HTTPFound(f'/_login?user={user}&sendback={str(request.rel_url)}')
112 | else:
113 | if request.raw_path == "/favicon.ico":
114 | file_path = pathlib.Path(__file__).parent.absolute()
115 | return web.FileResponse(file_path / "resources" / "favicon.png" )
116 | elif "file_path" in request.match_info:
117 | file_path = os.path.join(static_dir, *request.match_info["file_path"].replace("..", "").split("/"))
118 | return web.FileResponse(file_path)
119 | else:
120 | raise web.HTTPNotFound()
121 | # with open(file_path) as f:
122 | # file_data = f.read()
123 | # if file_path.endswith(".png"):
124 | # content_type = "image/png"
125 | # elif file_path.endswith(".gif"):
126 | # content_type = "image/gif"
127 | # elif file_path.endswith(".jpg") or file_path.endswith(".jpeg"):
128 | # content_type = "image/jpeg"
129 | # elif file_path.endswith(".js"):
130 | # content_type = "application/javascript"
131 | # else:
132 | # content_type = "text/html"
133 | # return web.Response(text=file_data, content_type=content_type)
134 |
135 | async def login_handler(request):
136 | sendback = request.rel_url.query.get("sendback", "/")
137 | user = request.rel_url.query.get("user", "")
138 | return web.Response(text=f"""
139 |
140 |
141 | Adaptive Testing Login
142 |
143 |
144 |
245 |
246 |
247 |
248 | """, content_type="text/html")
249 |
250 | async def auth_handler(request):
251 | post_params = await request.post()
252 | sendback = post_params.get("sendback", "/")
253 | user = post_params.get('user', None)
254 | password = post_params.get('password', None)
255 |
256 | if authenticate(user if user is not None else "anonymous", password):
257 | redirect_response = web.HTTPFound(sendback)
258 | await remember(request, redirect_response, user if user is not None else "anonymous")
259 | return redirect_response
260 | else:
261 | raise web.HTTPFound(f"/_login?{'user='+user+'&' if user is not None else ''}sendback={sendback}")
262 |
263 | async def websocket_handler(request):
264 | ws = web.WebSocketResponse()
265 | await ws.prepare(request)
266 |
267 | # build a WebSocket comm object
268 | class WebSocketComm():
269 | pass
270 | def ws_send(data):
271 | loop.run_until_complete(send_ws_data(ws, json.dumps(data)))
272 | comm = WebSocketComm()
273 | comm.send = ws_send
274 |
275 | if hasattr(test_tree_browsers, "_repr_html_"):
276 | test_tree_browser = test_tree_browsers
277 | else:
278 | test_tree_name = request.match_info["test_tree"]
279 | if callable(test_tree_browsers):
280 | test_tree_browser = test_tree_browsers(test_tree_name)
281 | else:
282 | test_tree_browser = test_tree_browsers[test_tree_name]
283 | test_tree_browser.comm = comm
284 |
285 | async for msg in ws:
286 | if msg.type == aiohttp.WSMsgType.TEXT:
287 | if msg.data == 'close':
288 | log.debug(f"Closing WebSocket for user '{getattr(test_tree_browser, 'user', None)}' for test tree '{getattr(test_tree_browser, 'name', None)}'!")
289 | await ws.close()
290 | else:
291 | data = json.loads(msg.data)
292 | log.info(f"WebSocket message from user '{getattr(test_tree_browser, 'user', None)}' for test tree '{getattr(test_tree_browser, 'name', None)}' is {data}")
293 | test_tree_browser.interface_event(data)
294 | elif msg.type == aiohttp.WSMsgType.ERROR:
295 | print('WebSocket connection closed with exception %s' % ws.exception())
296 |
297 | return ws
298 |
299 | async def make_app():
300 | middleware = aiohttp_session.session_middleware(aiohttp_session.cookie_storage.EncryptedCookieStorage(
301 | cryptography.fernet.Fernet.generate_key().decode(), max_age=auth_duration
302 | ))
303 | # middleware = aiohttp_session.session_middleware(aiohttp_session.SimpleCookieStorage())
304 | app = web.Application(middlewares=[middleware])
305 |
306 | app.add_routes([
307 | web.get('/_ws', websocket_handler),
308 | web.get('/_login', login_handler),
309 | web.post('/_auth', auth_handler),
310 | web.get('/favicon.ico', static_handler)
311 | ])
312 |
313 | if static_dir is not None:
314 | app.add_routes([web.static('/_static', static_dir)])
315 | if hasattr(test_tree_browsers, "_repr_html_"):
316 | app.add_routes([
317 | web.get('/{topic_path:.*}', topic_handler),
318 | web.get('/_ws', websocket_handler)
319 | ])
320 | else:
321 | if static_dir is not None:
322 | app.add_routes([web.get('/{test_tree}/_static/{file_path:.*}', static_handler)])
323 | app.add_routes([
324 | web.get('/{test_tree}/_ws', websocket_handler),
325 | web.get('/{test_tree}', topic_handler),
326 | web.get('/{test_tree}/{topic_path:.*}', topic_handler)
327 | ])
328 |
329 | policy = SessionIdentityPolicy()
330 | setup_security(app, policy, AdaptiveTestingPolicy())
331 |
332 | return app
333 |
334 | state = {
335 | "site": None,
336 | "runner": None
337 | }
338 | async def start_server(state, host, port, ssl_crt, ssl_key):
339 |
340 | if ssl_crt is not None:
341 | import ssl
342 | ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
343 | ssl_context.load_cert_chain(ssl_crt, ssl_key)
344 | else:
345 | ssl_context = None
346 |
347 | app = await make_app()
348 | state["runner"] = aiohttp.web.AppRunner(app)
349 | await state["runner"].setup()
350 | state["site"] = web.TCPSite(state["runner"], host, port, ssl_context=ssl_context)
351 | await state["site"].start()
352 | print(f"Server started at http://{host}:{port}")
353 |
354 | async def stop_server(state):
355 | await state["site"].stop()
356 | await state["runner"].shutdown()
357 |
358 |
359 | # aiohttp.web.run_app(make_app(), port=port)
360 | loop.run_until_complete(start_server(state, host=host, port=port, ssl_crt=ssl_crt, ssl_key=ssl_key))
361 |
362 | try:
363 | loop.run_forever()
364 | finally:
365 | loop.run_until_complete(stop_server(state))
366 |
367 | class AdaptiveTestingPolicy(AbstractAuthorizationPolicy):
368 | async def authorized_userid(self, identity):
369 | """Retrieve authorized user id.
370 | Return the user_id of the user identified by the identity
371 | or 'None' if no user exists related to the identity.
372 | """
373 | return identity
374 |
375 | async def permits(self, identity, permission, context=None):
376 | """Check user permissions.
377 | Return True if the identity is allowed the permission
378 | in the current context, else return False.
379 | """
380 | return identity == 'jack' and permission in ('listen',)
381 |
--------------------------------------------------------------------------------
/notebooks/IMDB to Hotel Sentiment.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "412d3c6b",
6 | "metadata": {},
7 | "source": [
8 | "# IMDB to Hotel Sentiment\n",
9 | "\n",
10 | "In this notebook we will take a sentiment analysis model trained on IMDB reviews, and fine tune it to analyse tweets about hotels. We will use Adaptive Testing to help us generate a suitable test suite."
11 | ]
12 | },
13 | {
14 | "cell_type": "markdown",
15 | "id": "2e9f4de2",
16 | "metadata": {},
17 | "source": [
18 | "## Seeding the PRNG\n",
19 | "\n",
20 | "Before we do anything else, we first seed the PRNG, to ensure that we have reproducible results:"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": null,
26 | "id": "c907dfd0",
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "import torch\n",
31 | "torch.manual_seed(1012351)"
32 | ]
33 | },
34 | {
35 | "cell_type": "markdown",
36 | "id": "52054118",
37 | "metadata": {},
38 | "source": [
39 | "## The Base Model\n",
40 | "\n",
41 | "We will use the [`aychang/roberta-base-imdb` from Hugging Face](https://huggingface.co/aychang/roberta-base-imdb) as our base model. This is a binary model which has been trained on a collection of IMDB reviews. First, we load the model itself:"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": null,
47 | "id": "8a1c4c6e",
48 | "metadata": {},
49 | "outputs": [],
50 | "source": [
51 | "from transformers import AutoModelForSequenceClassification, AutoTokenizer\n",
52 | "from transformers import pipeline\n",
53 | "\n",
54 | "base_model_name = \"aychang/roberta-base-imdb\"\n",
55 | "\n",
56 | "model = AutoModelForSequenceClassification.from_pretrained(base_model_name,num_labels=2)\n",
57 | "tokenizer = AutoTokenizer.from_pretrained(base_model_name)\n",
58 | "\n",
59 | "original_pipeline = pipeline(\"sentiment-analysis\",\n",
60 | " model=model,\n",
61 | " tokenizer=tokenizer,\n",
62 | " top_k=2)"
63 | ]
64 | },
65 | {
66 | "cell_type": "markdown",
67 | "id": "90e5e15b",
68 | "metadata": {},
69 | "source": [
70 | "Now, let's try a few sentences:"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": null,
76 | "id": "8960e590",
77 | "metadata": {},
78 | "outputs": [],
79 | "source": [
80 | "sample_strings = [\n",
81 | " \"Great cinematography but a poor movie overall\",\n",
82 | " \"Snappy dialogue makes for enjoyable entertainment\",\n",
83 | " \"Located on a busy street with much traffic\"\n",
84 | "]\n",
85 | "\n",
86 | "for s in sample_strings:\n",
87 | " print(s, \"\\n\", original_pipeline(s), \"\\n\")"
88 | ]
89 | },
90 | {
91 | "cell_type": "markdown",
92 | "id": "d83c2a06",
93 | "metadata": {},
94 | "source": [
95 | "We can see that the two statements about movies are well classified, but the one about the final one about the hotel is not."
96 | ]
97 | },
98 | {
99 | "cell_type": "markdown",
100 | "id": "4d90d0d8",
101 | "metadata": {},
102 | "source": [
103 | "## Using Adaptive Testing\n",
104 | "\n",
105 | "AdaptiveTesting is a tool to help create training/test suites for language models. The basic workflow is:\n",
106 | "\n",
107 | "1. User provides some sample input\n",
108 | "1. User flags whether the model output is correct or not\n",
109 | "1. AdaptiveTesting uses a second language model to generate more inputs from those already provided\n",
110 | "1. User decides which of the AdaptiveTesting proposed inputs to incorporate (and whether the model provided a correct response)\n",
111 | "\n",
112 | "Iterating through this process a few times can generate a lot of tests quite quickly."
113 | ]
114 | },
115 | {
116 | "cell_type": "code",
117 | "execution_count": null,
118 | "id": "e48bbbef",
119 | "metadata": {},
120 | "outputs": [],
121 | "source": [
122 | "import adaptivetesting"
123 | ]
124 | },
125 | {
126 | "cell_type": "markdown",
127 | "id": "139763f8",
128 | "metadata": {},
129 | "source": [
130 | "For our generator, we use OpenAI's GPT-3 model. For this, we need to read the access key in from a file:"
131 | ]
132 | },
133 | {
134 | "cell_type": "code",
135 | "execution_count": null,
136 | "id": "a6293838",
137 | "metadata": {},
138 | "outputs": [],
139 | "source": [
140 | "import os\n",
141 | "with open(os.path.expanduser('~/.openai_api_key'), 'r') as file:\n",
142 | " OPENAI_API_KEY = file.read().replace('\\n', '')"
143 | ]
144 | },
145 | {
146 | "cell_type": "markdown",
147 | "id": "70f74aec",
148 | "metadata": {},
149 | "source": [
150 | "First, we create the generator object which AdaptiveTesting will use to suggest more tests which are similar to the ones we provide:"
151 | ]
152 | },
153 | {
154 | "cell_type": "code",
155 | "execution_count": null,
156 | "id": "6e310578",
157 | "metadata": {},
158 | "outputs": [],
159 | "source": [
160 | "generator = adaptivetesting.generators.OpenAI('curie', api_key=OPENAI_API_KEY)"
161 | ]
162 | },
163 | {
164 | "cell_type": "markdown",
165 | "id": "56095496",
166 | "metadata": {},
167 | "source": [
168 | "Now we create the test tree. We will load a set of tests which we have already started work on, to make the process faster:"
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": null,
174 | "id": "1bc4e9ff",
175 | "metadata": {},
176 | "outputs": [],
177 | "source": [
178 | "tests = adaptivetesting.TestTree(\"imdb_hotel_conversion.csv\")"
179 | ]
180 | },
181 | {
182 | "cell_type": "markdown",
183 | "id": "58006371",
184 | "metadata": {},
185 | "source": [
186 | "And fire up the AdaptiveTesting interface:"
187 | ]
188 | },
189 | {
190 | "cell_type": "code",
191 | "execution_count": null,
192 | "id": "056dbba0",
193 | "metadata": {},
194 | "outputs": [],
195 | "source": [
196 | "tests.adapt(original_pipeline, generator, auto_save=True, recompute_scores=True)"
197 | ]
198 | },
199 | {
200 | "cell_type": "markdown",
201 | "id": "17fa752a",
202 | "metadata": {},
203 | "source": [
204 | "With a set of samples composed, we need to use them to finetune the model. To begin this process, load the CSV file we've created into a DataFrame and drop the portions we don't need:"
205 | ]
206 | },
207 | {
208 | "cell_type": "code",
209 | "execution_count": null,
210 | "id": "32a790c1",
211 | "metadata": {},
212 | "outputs": [],
213 | "source": [
214 | "import pandas as pd\n",
215 | "\n",
216 | "def load_adatest_data(csv_file: str):\n",
217 | " tmp = pd.read_csv(csv_file)\n",
218 | " \n",
219 | " # Drop topic marker rows\n",
220 | " tmp2 = tmp[tmp['label'] != 'topic_marker']\n",
221 | " # Drop suggestion rows\n",
222 | " tmp3 = tmp2[tmp2['topic'] != 'suggestion']\n",
223 | " \n",
224 | " # Remove columns we don't need\n",
225 | " tmp4 = tmp3.drop(labels=['labeler', 'description', 'author', 'Unnamed: 0'], axis=1)\n",
226 | " \n",
227 | " # Rename columns\n",
228 | " tmp5 = tmp4.rename(mapper={'input': 'sentence', 'label': 'model_is_correct'}, axis=1)\n",
229 | " \n",
230 | " # Remove any spurious rows\n",
231 | " tmp6 = tmp5[tmp5['topic'].notna()]\n",
232 | " \n",
233 | " # Don't need to track original rows\n",
234 | " tmp7 = tmp6.reset_index(drop=True)\n",
235 | " \n",
236 | " return tmp7\n",
237 | "\n",
238 | "\n",
239 | "test_data = load_adatest_data('imdb_hotel_conversion.csv')\n",
240 | "display(test_data)"
241 | ]
242 | },
243 | {
244 | "cell_type": "markdown",
245 | "id": "e1777931",
246 | "metadata": {},
247 | "source": [
248 | "Next, we need to get the actual labels corresponding to each sentence. For this we need to combine the column which contains the output of our model and the column containing our manual labelling of whether the model was correct or incorrect."
249 | ]
250 | },
251 | {
252 | "cell_type": "code",
253 | "execution_count": null,
254 | "id": "e3e8a3d6",
255 | "metadata": {},
256 | "outputs": [],
257 | "source": [
258 | "def generate_label(row):\n",
259 | " # The model output is either 'pos' or 'neg'\n",
260 | " model_result = row['output']\n",
261 | " # Return based on whether the model response was marked correct or incorrect\n",
262 | " if row['model_is_correct'] == 'pass':\n",
263 | " return model_result\n",
264 | " else:\n",
265 | " if model_result == 'pos':\n",
266 | " return 'neg'\n",
267 | " else:\n",
268 | " return 'pos'\n",
269 | " \n",
270 | "# Apply this to the data\n",
271 | "test_data['label'] = test_data.apply(generate_label, axis=1)\n",
272 | "test_data"
273 | ]
274 | },
275 | {
276 | "cell_type": "markdown",
277 | "id": "4031033a",
278 | "metadata": {},
279 | "source": [
280 | "We can call the pipeline directly on the sentences we have generated, and make sure that we get the same results as the one stored by Adaptive Testing:"
281 | ]
282 | },
283 | {
284 | "cell_type": "code",
285 | "execution_count": null,
286 | "id": "649de984",
287 | "metadata": {},
288 | "outputs": [],
289 | "source": [
290 | "import numpy as np\n",
291 | "\n",
292 | "def get_label(label_probabilities):\n",
293 | " # The pipeline returns all of the label probabilities\n",
294 | " # We need to extract the largest\n",
295 | " max_score = 0\n",
296 | " label = None\n",
297 | " for l in label_probabilities:\n",
298 | " if l['score'] > max_score:\n",
299 | " max_score = l['score']\n",
300 | " label = l['label']\n",
301 | " return label\n",
302 | "\n",
303 | "y_pred = [get_label(x) for x in original_pipeline(test_data.sentence.to_list())]\n",
304 | "\n",
305 | "\n",
306 | "test_data['my_y_pred'] = y_pred\n",
307 | "assert np.array_equal(test_data['my_y_pred'], test_data['output'])\n",
308 | "\n",
309 | "display(test_data)"
310 | ]
311 | },
312 | {
313 | "cell_type": "markdown",
314 | "id": "385dbdff",
315 | "metadata": {},
316 | "source": [
317 | "We can also evaluate our chosen metric, and check that the accuracy score matches what we expect from the summary at the top level of the Adaptive Testing widget:"
318 | ]
319 | },
320 | {
321 | "cell_type": "code",
322 | "execution_count": null,
323 | "id": "cb2e1f62",
324 | "metadata": {},
325 | "outputs": [],
326 | "source": [
327 | "from datasets import load_metric\n",
328 | "\n",
329 | "metric_name = 'accuracy'\n",
330 | "\n",
331 | "metric = load_metric(metric_name)\n",
332 | "\n",
333 | "def label_to_int(l: str) -> int:\n",
334 | " # Use the mapping provided by the model\n",
335 | " return model.config.label2id[l]\n",
336 | "\n",
337 | "metric.compute(predictions=test_data['my_y_pred'].apply(label_to_int), references=test_data['label'].apply(label_to_int))"
338 | ]
339 | },
340 | {
341 | "cell_type": "markdown",
342 | "id": "7d8348a9",
343 | "metadata": {},
344 | "source": [
345 | "There is one final tweak to make to our data prior to finetuning the model: the Hugging Face `Trainer`s do not use the human-friendly labels, but the corresponding integer ids. So use the mapping provided by the model to convert the 'label' column:"
346 | ]
347 | },
348 | {
349 | "cell_type": "code",
350 | "execution_count": null,
351 | "id": "ac67a09a",
352 | "metadata": {},
353 | "outputs": [],
354 | "source": [
355 | "test_data['label'] = test_data['label'].apply(label_to_int)\n",
356 | "print(test_data.dtypes)"
357 | ]
358 | },
359 | {
360 | "cell_type": "markdown",
361 | "id": "ff6a412c",
362 | "metadata": {},
363 | "source": [
364 | "Now, we can split our dataset into training and test sets. We stratify based on the 'topic' column, to ensure that we have samples from all of the various topics we have generated:"
365 | ]
366 | },
367 | {
368 | "cell_type": "code",
369 | "execution_count": null,
370 | "id": "20d7fbf6",
371 | "metadata": {},
372 | "outputs": [],
373 | "source": [
374 | "from sklearn.model_selection import train_test_split\n",
375 | "\n",
376 | "train_df, test_df = train_test_split(test_data, stratify=test_data['topic'], test_size=0.3)"
377 | ]
378 | },
379 | {
380 | "cell_type": "markdown",
381 | "id": "4c7b887f",
382 | "metadata": {},
383 | "source": [
384 | "Convert our DataFrames into Hugging Face `Dataset`s:"
385 | ]
386 | },
387 | {
388 | "cell_type": "code",
389 | "execution_count": null,
390 | "id": "2aa80d00",
391 | "metadata": {},
392 | "outputs": [],
393 | "source": [
394 | "from datasets import Dataset\n",
395 | "\n",
396 | "train_ds = Dataset.from_pandas(df = train_df)\n",
397 | "test_ds = Dataset.from_pandas(df = test_df)\n",
398 | "train_ds"
399 | ]
400 | },
401 | {
402 | "cell_type": "markdown",
403 | "id": "b5197b33",
404 | "metadata": {},
405 | "source": [
406 | "Encode our datasets:"
407 | ]
408 | },
409 | {
410 | "cell_type": "code",
411 | "execution_count": null,
412 | "id": "fffb55b6",
413 | "metadata": {},
414 | "outputs": [],
415 | "source": [
416 | "def preprocess_function(examples):\n",
417 | " result = tokenizer(examples[\"sentence\"],\n",
418 | " add_special_tokens = True,\n",
419 | " truncation = True,\n",
420 | " padding = \"max_length\",\n",
421 | " return_attention_mask = True\n",
422 | " )\n",
423 | " return result\n",
424 | "\n",
425 | "train_encoded = train_ds.map(preprocess_function, batched=True)\n",
426 | "test_encoded = test_ds.map(preprocess_function, batched=True)\n",
427 | "\n",
428 | "drop_cols = ['topic', '__index_level_0__','model_is_correct', 'model score', 'my_y_pred', 'output']\n",
429 | "\n",
430 | "train_encoded = train_encoded.remove_columns(drop_cols)\n",
431 | "test_encoded = test_encoded.remove_columns(drop_cols)"
432 | ]
433 | },
434 | {
435 | "cell_type": "markdown",
436 | "id": "a76e81ed",
437 | "metadata": {},
438 | "source": [
439 | "Configure a new training run:"
440 | ]
441 | },
442 | {
443 | "cell_type": "code",
444 | "execution_count": null,
445 | "id": "f26489b2",
446 | "metadata": {},
447 | "outputs": [],
448 | "source": [
449 | "from transformers import TrainingArguments\n",
450 | "\n",
451 | "batch_size = 4\n",
452 | "\n",
453 | "args_ft = TrainingArguments(\n",
454 | " f\"hotel_fine_tuned\",\n",
455 | " evaluation_strategy = \"epoch\",\n",
456 | " save_strategy = \"epoch\",\n",
457 | " learning_rate=2e-5,\n",
458 | " per_device_train_batch_size=batch_size,\n",
459 | " per_device_eval_batch_size=batch_size,\n",
460 | " num_train_epochs=5,\n",
461 | " weight_decay=0.01,\n",
462 | " load_best_model_at_end=True,\n",
463 | " metric_for_best_model=metric_name,\n",
464 | " push_to_hub=False,\n",
465 | ")"
466 | ]
467 | },
468 | {
469 | "cell_type": "markdown",
470 | "id": "baf346f1",
471 | "metadata": {},
472 | "source": [
473 | "Now, load a fresh copy of the model for fine tuning. This will allow us to compare the two models side-by-side:"
474 | ]
475 | },
476 | {
477 | "cell_type": "code",
478 | "execution_count": null,
479 | "id": "fc8ddbb9",
480 | "metadata": {},
481 | "outputs": [],
482 | "source": [
483 | "ft_model = AutoModelForSequenceClassification.from_pretrained(base_model_name,num_labels=2)"
484 | ]
485 | },
486 | {
487 | "cell_type": "markdown",
488 | "id": "dca9df6b",
489 | "metadata": {},
490 | "source": [
491 | "Create our new `Trainer` object, using the model we've just loaded. We pass in our new datasets for training and evaluation:"
492 | ]
493 | },
494 | {
495 | "cell_type": "code",
496 | "execution_count": null,
497 | "id": "4d1e8e93",
498 | "metadata": {},
499 | "outputs": [],
500 | "source": [
501 | "from transformers import Trainer\n",
502 | "\n",
503 | "def compute_metrics(eval_pred):\n",
504 | " predictions, labels = eval_pred\n",
505 | " # Predictions are probabilities, so the actual answer is the index with the highest probability\n",
506 | " predictions = np.argmax(predictions, axis=1)\n",
507 | " return metric.compute(predictions=predictions, references=labels)\n",
508 | "\n",
509 | "trainer_ft = Trainer(\n",
510 | " ft_model,\n",
511 | " args_ft,\n",
512 | " train_dataset=train_encoded,\n",
513 | " eval_dataset=test_encoded,\n",
514 | " tokenizer=tokenizer,\n",
515 | " compute_metrics=compute_metrics\n",
516 | ")"
517 | ]
518 | },
519 | {
520 | "cell_type": "markdown",
521 | "id": "f104c611",
522 | "metadata": {},
523 | "source": [
524 | "Now, we can run the training. On a CPU, this may take a few minutes (large values of 'few' may be experienced):"
525 | ]
526 | },
527 | {
528 | "cell_type": "code",
529 | "execution_count": null,
530 | "id": "4a276759",
531 | "metadata": {
532 | "scrolled": true
533 | },
534 | "outputs": [],
535 | "source": [
536 | "trainer_ft.train()"
537 | ]
538 | },
539 | {
540 | "cell_type": "code",
541 | "execution_count": null,
542 | "id": "86ebf95d",
543 | "metadata": {},
544 | "outputs": [],
545 | "source": [
546 | "trainer_ft.evaluate()"
547 | ]
548 | },
549 | {
550 | "cell_type": "markdown",
551 | "id": "58956445",
552 | "metadata": {},
553 | "source": [
554 | "## Assessing the Fine-Tuned Model\n",
555 | "\n",
556 | "Now that we have fine-tuned the model with some examples which talk about hotels, we can see if it performs better. First, we put the new model into a scoring pipeline:"
557 | ]
558 | },
559 | {
560 | "cell_type": "code",
561 | "execution_count": null,
562 | "id": "2e1d2687",
563 | "metadata": {},
564 | "outputs": [],
565 | "source": [
566 | "ft_pipeline = pipeline(\"sentiment-analysis\",\n",
567 | " model=trainer_ft.model.to('cpu'),\n",
568 | " tokenizer=tokenizer,\n",
569 | " top_k=2)"
570 | ]
571 | },
572 | {
573 | "cell_type": "markdown",
574 | "id": "f39b78c5",
575 | "metadata": {},
576 | "source": [
577 | "We can re-run the initial samples we tried above:"
578 | ]
579 | },
580 | {
581 | "cell_type": "code",
582 | "execution_count": null,
583 | "id": "7586109f",
584 | "metadata": {},
585 | "outputs": [],
586 | "source": [
587 | "for s in sample_strings:\n",
588 | " print(s, \"\\n\", ft_pipeline(s), \"\\n\")"
589 | ]
590 | },
591 | {
592 | "cell_type": "markdown",
593 | "id": "53455b62",
594 | "metadata": {},
595 | "source": [
596 | "The sentences about movies are still well classified, but the final one about a hotel has the correct prediction now.\n",
597 | "\n",
598 | "For a more systematic comparison, we can run our `test_df` through both pipelines:"
599 | ]
600 | },
601 | {
602 | "cell_type": "code",
603 | "execution_count": null,
604 | "id": "28332c5e",
605 | "metadata": {},
606 | "outputs": [],
607 | "source": [
608 | "def get_label(label_probabilities):\n",
609 | " # The pipeline returns all of the label probabilities\n",
610 | " # We need to extract the largest\n",
611 | " max_score = 0\n",
612 | " label = None\n",
613 | " for l in label_probabilities:\n",
614 | " if l['score'] > max_score:\n",
615 | " max_score = l['score']\n",
616 | " label = l['label']\n",
617 | " # Convert back to the id\n",
618 | " return ft_model.config.label2id[label]\n",
619 | "\n",
620 | "y_pred_orig = [get_label(x) for x in original_pipeline(test_df.sentence.to_list())]\n",
621 | "y_pred_ft = [get_label(x) for x in ft_pipeline(test_df.sentence.to_list())]\n",
622 | "\n",
623 | "print(\"Original : \", metric.compute(predictions=y_pred_orig, references=test_df.label))\n",
624 | "print(\"Fine Tuned: \", metric.compute(predictions=y_pred_ft, references=test_df.label))"
625 | ]
626 | },
627 | {
628 | "cell_type": "markdown",
629 | "id": "6c7c536a",
630 | "metadata": {},
631 | "source": [
632 | "We see a noticeable improvement in accuracy."
633 | ]
634 | },
635 | {
636 | "cell_type": "code",
637 | "execution_count": null,
638 | "id": "01399407",
639 | "metadata": {},
640 | "outputs": [],
641 | "source": []
642 | }
643 | ],
644 | "metadata": {
645 | "kernelspec": {
646 | "display_name": "Python 3 (ipykernel)",
647 | "language": "python",
648 | "name": "python3"
649 | },
650 | "language_info": {
651 | "codemirror_mode": {
652 | "name": "ipython",
653 | "version": 3
654 | },
655 | "file_extension": ".py",
656 | "mimetype": "text/x-python",
657 | "name": "python",
658 | "nbconvert_exporter": "python",
659 | "pygments_lexer": "ipython3",
660 | "version": "3.9.12"
661 | }
662 | },
663 | "nbformat": 4,
664 | "nbformat_minor": 5
665 | }
666 |
--------------------------------------------------------------------------------
/notebooks/imdb_hotel_conversion.csv:
--------------------------------------------------------------------------------
1 | ,topic,input,output,label,labeler,description,author,model score
2 | 00f2dfd40977438f9942793768959e8a,,,,topic_marker,imputed,,,
3 | 02076434c5fc4a389f127e542c479c64,/Rooms,Sheets didn't seem clean,neg,pass,anonymous,,anonymous,0.03536648847074937
4 | 029c3b23ed7c491bb0f5b35687dd3f8c,/Location,Far from the best beaches,pos,fail,anonymous,,anonymous,0.8619214039468347
5 | 04189a46ab1642789a5f59fb5a758e29,/Price,Expensive for what you get,pos,pass,anonymous,,,0.0
6 | 04f087d371604398b32205eab29cd587,/Location,Close to nowhere,neg,pass,anonymous,,,0.007999936759466829
7 | 061c756434f84e0b9d4dbad828496837,/Location,Just off the bus line,pos,pass,anonymous,,anonymous,0.0
8 | 069449b9da3b4481827417686c244033,/Price,High prices for average services and less convinient location,neg,pass,anonymous,,anonymous,0.0
9 | 06b259ccddbe4769b050fb855dd245dc,/Rooms,Only had a couple pillows,pos,fail,anonymous,,anonymous,0.9068814921475123
10 | 06ffa0d789944320a18d1c91a9567315,/Facilities/Restaurant,Don't eat here,neg,pass,anonymous,,anonymous,0.002532984066390833
11 | 0adc101c0dc549a089ed5733a2893dba,/Service,Service at the hotel is quite slow,neg,pass,anonymous,,anonymous,0.0
12 | 0ba15c3386e34777bfb4906394f15b82,/Price,One star service for a four star price,pos,fail,anonymous,,,1.0
13 | 0e74d4372f484ef089c75e8f82446226,/Location,Bus line runs close by,pos,pass,anonymous,,anonymous,0.0
14 | 10555369f4254142a2d42df7558063d2,/Facilities/Restaurant,"My lasagne was excellent, I could eat it every day!",pos,pass,anonymous,,anonymous,0.0
15 | 123de4a1154a4a47b3f897230b66c3f1,/Service,Any issues are taken care of quickly,pos,pass,anonymous,,anonymous,0.0
16 | 15ded90cab9e453792a97a4e739a644e,/Location,No walkability for people with mobility issues,pos,fail,anonymous,,anonymous,1.0
17 | 1604abdbfbda497e8de46f111bf4e506,/Facilities/Pool,Pool cleaning was horrible,neg,pass,anonymous,,anonymous,0.0015588112163477117
18 | 18c454fa3d444526ad71a95fe5a45ef5,/Price,Low price but high standard of service,pos,pass,anonymous,,anonymous,0.0
19 | 1a256503b0364a40808ef76b10248c30,/Service,Front desk staff were attentive,pos,pass,anonymous,,anonymous,0.0
20 | 1b96d669ba1546dba79174644e9d741f,/Facilities/Restaurant,Restaurant was located in the basement,neg,pass,anonymous,,anonymous,0.0
21 | 1c5020a50e5747cb90f4d08259c09772,/Price,,,topic_marker,anonymous,,,
22 | 1cdedd7004ea4a34a9623c91aae382f0,/Facilities/Restaurant,Waiters were friendly and attentive,pos,pass,anonymous,,anonymous,0.0
23 | 1da0369fd6bb4afb958caf57ac4ceaf5,/Price,Extraordinarily expensive for what you get,neg,pass,anonymous,,anonymous,0.0
24 | 1dc58e624e364c2b82e7f4b20ca06676,/Facilities/Restaurant,Bread did not come at the end of meal,neg,pass,anonymous,,anonymous,0.02117403731632515
25 | 1e27ca77e6bf46beb7282fc50ed9dba0,/Rooms,Only 2 bath towels,pos,fail,anonymous,,anonymous,0.5659768154018169
26 | 1ed5d0e0e01248c4a197e65ee461db7b,/Location,Hotel dropped into an anonymous business park,neg,pass,anonymous,,,0.0038315118022913173
27 | 22259ce0a5df47a5a956c49cbc8755a0,/Service,Always willing to do whatever it takes to make your stay pleasant,pos,pass,anonymous,,anonymous,0.0
28 | 2245eaa2609f40009a5a98471fbec625,/Facilities,,,topic_marker,anonymous,,,
29 | 2346ed4286764ff089c1917cf9da2ddd,/Price,Prices are a little on the steep side,pos,fail,anonymous,,anonymous,1.0
30 | 244f7853fd1142bb9427c36a666b91f9,/Facilities/Restaurant,Food quantity was small,neg,pass,anonymous,,anonymous,0.010578740506381622
31 | 257400db509e49c1b6dd9e287685140e,/Facilities/Restaurant,No toilets in restaurant,pos,fail,anonymous,,anonymous,0.8834572233315061
32 | 27a04b3d63254577b1112ae04145f935,/Facilities/Restaurant,No free refills,pos,fail,anonymous,,anonymous,0.9858127695421115
33 | 287e05c7919345bba29ce27de5b80299,/Facilities/Restaurant,Got food poisoning,neg,pass,anonymous,,anonymous,0.0
34 | 2db7e0e996ac4b8da81fa7bbb286914c,/Location,An oasis of calm during tourist season,pos,pass,anonymous,,anonymous,0.0
35 | 2e3005f8a8c24cbd96401adb7f321458,/Facilities/Restaurant,Breakfast needs a major overhaul,neg,pass,anonymous,,anonymous,0.0
36 | 2f5b87f286b04736bf061c3a7e438911,/Rooms,"No views of the ocean at all, only of a parking lot.",neg,pass,anonymous,,anonymous,0.0055049279431199595
37 | 34998856113841b78e12a3f454ff4420,/Price,Rigged to overcharge,neg,pass,anonymous,,anonymous,0.0
38 | 357caede646a4c5ea1b59da52e807e12,/Location,Slightly isolated from the area's attractions,pos,fail,anonymous,,anonymous,0.9509071111679077
39 | 387d9bf5e9ac47f49503fe8aaa10f2f9,/Rooms,Our air conditioner wasn't working,neg,pass,anonymous,,anonymous,0.0
40 | 38af115065e9475a82ba455e010667e7,/Price,Crazy expensive for what you get,neg,pass,anonymous,,anonymous,0.0
41 | 3ab16b419ab54eda9616b20631da00a2,/Facilities/Restaurant,Get here as early as possible or you'll be waiting for a long time,pos,fail,anonymous,,anonymous,0.9929495269304643
42 | 3b0a1573d9154d938928d7ce3ddff043,/Price,The price seemed reasonable,neg,fail,anonymous,,anonymous,0.16721585891294058
43 | 3b63631e191f473f852a9276e3e7bcde,/Location,Too far from some attractions,neg,pass,anonymous,,anonymous,0.2603970766067505
44 | 3ca8925223ff49969a35dd682a2abdb6,/Facilities/Restaurant,Seafood is so much better elsewhere,pos,fail,anonymous,,anonymous,0.9628090240245476
45 | 40bdeeb227c44fb99e19466df0617dd5,/Rooms,Room was noisy at night,pos,fail,anonymous,,anonymous,0.9222849882329996
46 | 4231463634b84ccf918ae5e3269a32e5,/Location,Can walk to art district,pos,pass,anonymous,,anonymous,0.0
47 | 44dc6a3678654196aa2212a943f64ec1,/Rooms,We could just lie in bed and watch the ships in the harbor,neg,fail,anonymous,,,1.0
48 | 4819f234c2c74680a145f8a4617aaf04,/Facilities/Pool,Water terribly cold,neg,pass,anonymous,,,0.0018953193006728897
49 | 4877e7bd4af94768bd881652c7876066,/Price,Good budget option,pos,pass,anonymous,,anonymous,0.0
50 | 4c272643ac1f4a62a0ccca592273d51c,/Service,Staff very attentive,pos,pass,anonymous,,,0.0
51 | 4cb11859b69a48908e0ca2e02c17ddc2,/Service,Checkout very quick,pos,pass,anonymous,,anonymous,0.0
52 | 4d0c722297bc40dba00809efeff3ddf2,/Rooms,"Room was too hot, could not sleep",neg,pass,anonymous,,anonymous,0.020949860030593202
53 | 4dd929eba2094491a64af7a2ce5d7316,/Rooms,TV didn't work in the room,neg,pass,anonymous,,,0.005870745712330646
54 | 514cf379d06440ccadb04ecf97509779,/Location,Quiet and secluded,pos,pass,anonymous,,anonymous,0.0
55 | 514de23d0393402c81c36a54e60f29ed,/Location,Neighborhood is intimidating,pos,fail,anonymous,,anonymous,0.9916251948979808
56 | 5170081ead85441e8a6887d5b0f62599,/Price,Not bad for the price paid,neg,fail,anonymous,,anonymous,1.0
57 | 54206f4acfc0498dac42279eba4fb303,/Rooms,Clean room,pos,pass,anonymous,,anonymous,0.0
58 | 5584df6c0efe4298ba96be700dce9228,/Rooms,Bed in need of replacement,pos,fail,anonymous,,anonymous,0.9335646420665538
59 | 57617f8ff6f449b69307c77a7bbbf2cb,/Rooms,No bath mats for the shower,neg,pass,anonymous,,anonymous,0.01952743961762565
60 | 57a816da1e594e47a057898522897884,/Service,Happy with how quickly front desk answered the telephone,pos,pass,anonymous,,anonymous,0.0
61 | 57e569d649fc4bd08dfdc2a1fbaa2f6e,/Location,No trains or buses nearby,pos,fail,anonymous,,anonymous,0.9322171141787088
62 | 58f95cfa3e7e4c28b7dcf9a1e1df99ed,/Location,,,topic_marker,anonymous,,,
63 | 5cfefaac1bf1406c968067fe3a912b1d,/Service,"I had a problem with a handle on the toilet...he just replaced it, no big deal",neg,fail,anonymous,,anonymous,1.0
64 | 5d892e8dce8646978c9e3ff5501390b6,/Service,Next day laundry service was very good,pos,pass,anonymous,,anonymous,0.0
65 | 5fd2f4ce35254c049dc40398b2248d18,/Facilities/Restaurant,Breakfast service very efficient,pos,pass,anonymous,,,0.0
66 | 627696d7081c4f56955dc8041858724d,/Service,Laundry service shrank my clothes,neg,pass,anonymous,,,0.0
67 | 62f8f606ac574e66977b800db82cea34,/Price,Great value for money,pos,pass,anonymous,,,0.0
68 | 64027ae207a14db0aaab9be6a9ecea1e,/Location,Tucked away on a quiet street,neg,fail,anonymous,,,0.022233733700274178
69 | 68bd829209644bc9a84dd7bd6f5d5b01,/Facilities/Pool,,,topic_marker,anonymous,,,
70 | 69de51585b4b45a48b6669292ffb3fd6,/Rooms,Room was small but cozy,pos,pass,anonymous,,,0.0
71 | 6ae7e0a652374f8d895d581b66d7085e,/Rooms,Bathroom light didn't work,neg,pass,anonymous,,,0.00546946100179207
72 | 6c3453f4c1e747d2b58e22bbcf1f16e7,/Rooms,No hot water in the shower,neg,pass,anonymous,,,0.01566485443944781
73 | 6d586a3fff17484ea71ee23ecc04582c,/Price,Overpriced given the poor facilities,neg,pass,anonymous,,,0.0
74 | 708fb5c25b7b4b03963710a8e3f34929,/Location,Close to parks and water,pos,pass,anonymous,,anonymous,0.0
75 | 72ba2281ba504b65a2ef931291a611e7,/Facilities/Pool,Pool was almost totally blocked by rubbish,neg,pass,anonymous,,anonymous,0.005074477409384375
76 | 754cdb6c7f294e1483f807e2646466d6,/Service,Service at front desk needs improvement,neg,pass,anonymous,,anonymous,0.0
77 | 75f3a65f995f43f599417ca8c20dfc64,/Facilities/Restaurant,There is a better place to eat a short distance away,pos,fail,anonymous,,anonymous,0.8710957429139538
78 | 769f96e52819496f83c8496c39b92f6a,/Price,Pricey for what you get,neg,pass,anonymous,,anonymous,0.0
79 | 76eb16ffe9c7438bb1da78e680833931,/Location,Surrounding neighborhood is unsafe,neg,pass,anonymous,,anonymous,0.06573726834944035
80 | 77895cede75c46c29f209dfaefcc08aa,/Facilities/Pool,Pool is not working,neg,pass,anonymous,,anonymous,0.008743519935493917
81 | 77a53430419e48c198f7c200a3bdc41b,/Service,The staff would try and make you smile,neg,fail,anonymous,,anonymous,1.0
82 | 78049dd0d393426cbae76ddbb187007f,/Facilities/Restaurant,Don't eat the ciabatta bread,pos,fail,anonymous,,anonymous,0.6495006680488586
83 | 783dd0a357cc4066ac0a662e435c58f5,/Service,Not always enough being done for guests in the lobby,neg,pass,anonymous,,anonymous,0.0
84 | 786bcd6d9b594e0b84ed70c6da5f676c,/Rooms,Desk chair conspicuous by its absence,neg,pass,anonymous,,,0.01946825884876856
85 | 7a9551af176e4af1865515b32f5397bc,/Price,Not cheap enough for this location,neg,pass,anonymous,,anonymous,0.0
86 | 7bc8465e009e4983a1d8fbf50842ffe6,/Facilities/Pool,No towels,pos,fail,anonymous,,,0.956467116010669
87 | 7e24228f5504453b8646f9dd7cc73c79,/Location,Awful neighbourhood for walking,pos,fail,anonymous,,anonymous,0.9462354253697054
88 | 81419edffe684a6391deb9af1d87b2c0,/Price,The most expensive hotel on the strip,pos,fail,anonymous,,anonymous,1.0
89 | 825ea717350c40ff90a8d303b2c3259f,/Location,Too far from beach,neg,pass,anonymous,,anonymous,0.03152497975248768
90 | 8275a1368f3f41e8b02b55ce39e8187e,/Facilities/Restaurant,Food is poor,neg,pass,anonymous,,anonymous,0.05149824440842651
91 | 8279b202475b40bd90d323f0ffdfcac2,/Rooms,Plenty of hot water,neg,fail,anonymous,,anonymous,0.09558087017650185
92 | 853c9fd8147c46f3b579effe6f4520c2,/Facilities/Restaurant,All the meals were a bit over seasoned,pos,fail,anonymous,,anonymous,0.9088237095843423
93 | 86e6fc6fe5d14c8691a9e8158712bb5f,/Service,Overall everyone is willing to help out,pos,pass,anonymous,,anonymous,0.0
94 | 877463f68a7f460c978b0ee39959e8e9,/Price,Your money is not well spent,neg,pass,anonymous,,anonymous,0.0
95 | 891b56253c1a439ab7bbafce933f64b7,/Rooms,View from the balcony was of a parking lot,pos,fail,anonymous,,anonymous,0.9438569021719737
96 | 895f996aeff14dcdbda01e0770b08691,/Rooms,Couldn't see any ocean,pos,fail,anonymous,,anonymous,0.9530809174163879
97 | 8b57269f3e8441a7bbbe4c99cd81cd37,/Facilities/Restaurant,Best nachos in town,pos,pass,anonymous,,anonymous,0.0
98 | 8d40b4fae3cd4779bf40266c0f94f240,/Location,Buses and trains do not stop nearby,pos,fail,anonymous,,anonymous,0.9761177196263865
99 | 908ad023d80f410eac1075ca63cf11a6,/Location,Walkability not great,pos,fail,anonymous,,anonymous,0.5081075581344234
100 | 922d0ee4e1c74a25920ecb627f3dd63c,/Service,"Took a long time to get our keys, until nearly midnight",pos,fail,anonymous,,anonymous,1.0
101 | 939ac59a858b45aa913058de2ecd16d2,/Rooms,Was not satisfied with the condition of the room or the room itself,pos,fail,anonymous,,anonymous,0.9646280156359222
102 | 9621a20b718747b790a7526dfd38cfbd,/Service,"Accommodating staff, courteous",pos,pass,anonymous,,anonymous,0.0
103 | 9632ec45e7b04be6bcff15e52f52e599,/Rooms,Bathroom had been cleaned just prior to our arrival,pos,pass,anonymous,,anonymous,0.0
104 | 978373b2df0f42538c13492cf0d6471a,/Facilities/Restaurant,Sauces a bit too hot,pos,fail,anonymous,,anonymous,0.9444596030643642
105 | 97af1b8a4a2948898130c48753328b54,/Rooms,View was different from promotional photos,pos,fail,anonymous,,anonymous,0.9764968971125088
106 | 9840372f7900424baff1e9258e58e3d7,/Rooms,AC wasn't working well,neg,pass,anonymous,,anonymous,0.020115314715187866
107 | 9a2ff1b8978846ae9f295d4846ec12b8,/Rooms,Didn't have hot water,neg,pass,anonymous,,anonymous,0.01998141284973178
108 | 9b1fef0c3b2d4a839cc5fb0b3d854861,/Price,Rather overpriced for what you get,pos,fail,anonymous,,anonymous,0.0
109 | 9ba4cb824eef42ab878e8fc4ce572973,/Facilities/Pool,Pool service didn't include towels,neg,pass,anonymous,,,0.20532752267528293
110 | 9d5879a6b52f4d41add360507657519a,/Location,Horrendous traffic noise,neg,pass,anonymous,,,0.007439990382418224
111 | 9dfd09a7496f436dafb651947727d51e,/Facilities/Restaurant,Waited 10 minutes after seating for our waitress,pos,fail,anonymous,,anonymous,0.9440716364373023
112 | a1a82990388a4e61b3fd72db79ab5870,/Location,Can walk to downtown/transit,pos,pass,anonymous,,anonymous,0.0
113 | a25a7218116347f1b075c57e7043a373,/Facilities/Pool,Pool closed for repair,pos,fail,anonymous,,,0.9739708322546833
114 | a2c4512803b843a5adc5d4c572dc4912,/Facilities/Pool,"When we were there, the pool was closed",pos,fail,anonymous,,anonymous,0.9828934881076902
115 | a2d1674f57e040dead2d64eb2fbf9916,/Facilities/Restaurant,Like variety on buffet,pos,pass,anonymous,,anonymous,0.0
116 | a52c80339cd64af882265082ae3a42ae,/Service,I was not happy with the service,neg,pass,anonymous,,anonymous,0.0
117 | a6b3ce110f474bada036aa13ccf41548,/Service,,,topic_marker,anonymous,,,
118 | a76d528c740b40019f3f69d1ed1a6b15,/Rooms,Ocean view from our room,pos,pass,anonymous,,anonymous,0.0
119 | a77390fe811d45a187e1d04004df4350,/Facilities/Restaurant,Restaurant wasn't large but served excellent food,pos,pass,anonymous,,,0.0
120 | a9027c10b08d4bc6a865ddd21e4c480d,/Price,Cheap and very quiet,pos,pass,anonymous,,anonymous,0.0
121 | a92ccc2284254033af6974e63383edea,/Price,Very expensive for what you get,neg,pass,anonymous,,anonymous,0.0
122 | a933bcbd4c2542baafaa8475961d38e8,/Rooms,Rooms were dated and small,pos,fail,anonymous,,anonymous,0.9489458026560805
123 | ab3b7acf65b24cd1afa1729a773c4110,/Rooms,The room had no clock,neg,pass,anonymous,,anonymous,0.01801225306431987
124 | abfb9c8a7a364bfdbc142a6f2ac7cdac,/Facilities/Restaurant,Food was unremarkable,neg,pass,anonymous,,,0.0
125 | ac5ebdc88ed245919a6f3364340ef448,/Rooms,Cramped bathrooms,neg,pass,anonymous,,anonymous,0.003501463001233201
126 | adeffd2e8b79435f8aaf02494294c0c0,/Facilities/Restaurant,Late dining was an added bonus,pos,pass,anonymous,,anonymous,0.0
127 | af8f94e9389e49568bbac469c6d540a9,/Price,Fair price for a swanky strip hotel,neg,fail,anonymous,,anonymous,1.0
128 | afd78d1c77d243d2ba54c3686ac2ff3e,/Price,Fairly priced,pos,pass,anonymous,,,0.0
129 | aff3e0264a8d448c836daa4a109419a7,/Location,Hard to get to from tourist attractions,pos,fail,anonymous,,anonymous,0.9668508984189681
130 | b09f27d62f0f45cd851464b44aeefe2f,/Rooms,The smell of mold or mildew was heavy,neg,pass,anonymous,,anonymous,0.03170379541096625
131 | b0ad37a7942e4d0989833f90903f9d6d,/Location,Close to the beach,pos,pass,anonymous,,anonymous,0.0
132 | b0b16ba1a58c4c288a667aef6ccc76b7,/Location,Very few restaurants and shops,neg,pass,anonymous,,anonymous,0.03229419574689485
133 | b1aab53a4f2143b2894c09f1dc9fb82e,/Rooms,Bed was very comfy,pos,pass,anonymous,,anonymous,0.0
134 | b21f1062d7564a76b4a251189da8a1dd,/Facilities/Pool,Two hot tubs!,pos,pass,anonymous,,,0.0
135 | b63f3ac15004474eb013c3d4134e1e01,/Service,They were so willing to help us with everything,pos,pass,anonymous,,anonymous,0.0
136 | b8b872e399ae44c4baa73de896162065,/Service,Concierge got us a reservation at a restaurant,pos,pass,anonymous,,,0.0
137 | b8e1a50475604ee390feab4bfe14424a,/Service,Had been promised early checkin but were turned away,pos,fail,anonymous,,,1.0
138 | b9f5539ec780496aa72e3777d03fbaa7,/Rooms,Our room had an ocean view,pos,pass,anonymous,,,0.0
139 | ba328c365f3943488ecf109c455a5a60,/Location,Close to all the sights,pos,pass,anonymous,,anonymous,0.0
140 | bab0178fc46a427899447113248122c9,/Facilities/Pool,No lane swimming,pos,fail,anonymous,,,0.9611776900817911
141 | bac7bb1c7ffa41b485ca71e09aae6998,/Location,Away from the madness of downtown,pos,pass,anonymous,,anonymous,0.0
142 | bc8637a35f6a43ecab2e9f1e2a6b9517,/Facilities/Restaurant,"Food was terrible, should close down",neg,pass,anonymous,,anonymous,0.013274874064298664
143 | bc931528b0bf4c3dbb258a8a38d0cdc6,/Rooms,Bathroom fan broke,neg,pass,anonymous,,anonymous,0.043131239390885534
144 | be290ab5098c4ce5bee1abace7f6bc7b,/Facilities/Pool,Pool was closed for the season,pos,fail,anonymous,,anonymous,0.9801850118077637
145 | bf49f08e930448d9ad00343919b31886,/Price,This place is more expensive than it appears,pos,fail,anonymous,,anonymous,1.0
146 | bfb2e0d1194c4625898339bd1b2f24e9,/Rooms,The air conditioner was not working,neg,pass,anonymous,,anonymous,0.11590849003673025
147 | c22f5c0b880d40889e599f54ca38928a,/Rooms,Room was very small and dark,neg,pass,anonymous,,anonymous,0.023819851068293596
148 | c30ad576150f4343a129c9bd4872460e,/Facilities/Restaurant,Food portion was mini,pos,fail,anonymous,,anonymous,0.9234152849055209
149 | c58b3df66491447a994fb6eaca312640,/Rooms,Wardrobe lacked coathangers,neg,pass,anonymous,,,0.004106684871886362
150 | c5ced663935e4d14a0ea3c57e1ae343e,/Service,I asked for a late check-out but was told no problem,pos,pass,anonymous,,anonymous,0.0
151 | c6cbf8a3e9754d09b391ef1e6c9978f1,/Facilities/Restaurant,Drinks reasonably priced,pos,pass,anonymous,,anonymous,0.0
152 | c896f64736b54f539606a011b40cffdd,/Rooms,Problem with key card in the elevator,neg,pass,anonymous,,anonymous,0.4970642328262329
153 | c99d0927d84645178879e949ff37787c,/Service,Friendly concierge,pos,pass,anonymous,,anonymous,0.0
154 | ca06223f280541ac9a4e9c27cb7ff986,/Location,Road noise is intense,pos,fail,anonymous,,anonymous,0.9398255207305095
155 | ca8622d782ac41f195a4bb3efd4d4e2c,/Facilities/Restaurant,Dining room did not seem well maintained,neg,pass,anonymous,,anonymous,0.0
156 | caf8976f3f34409cae22dcf6361e2634,/Location,Isolated from the downtown,pos,fail,anonymous,,anonymous,0.5898931866947897
157 | ccf1a41f509148afb4b78b014b2a1848,/Rooms,Rooms have not been kept in good condition,pos,fail,anonymous,,anonymous,0.8474004485875167
158 | cd40428fd68d4db8b4a8a0bf547b943e,/Facilities/Restaurant,,,topic_marker,anonymous,,,
159 | cd92ed58725d4658be9c2d0b0cfb1671,/Price,Expensive but was worth it,pos,pass,anonymous,,anonymous,0.0
160 | ce1cc7a936574950bf822b54ae68b7ff,/Rooms,Plenty of space,neg,fail,anonymous,,anonymous,1.0
161 | ce5ac35f85e5461f82d6ef5d1cf8a653,/Facilities/Restaurant,They have milkshakes. Try them,pos,pass,anonymous,,anonymous,0.0
162 | cf2fc3107ba849628ca3ff4c44f10d7c,/Rooms,Where are the towels?,neg,pass,anonymous,,anonymous,0.039039845402230385
163 | d02af05847554e829fd8c2a54d25a682,/Facilities/Restaurant,Went to restaurant and were promptly seated,pos,pass,anonymous,,anonymous,0.0
164 | d38aa871131c411094fb7eff4329eaf2,/Price,Unreasonably high price,neg,pass,anonymous,,anonymous,0.0
165 | d3aee530ba1c497090a85682d4eedecc,/Rooms,Room was very small and hot and window was broken,neg,pass,anonymous,,anonymous,0.1914782162671484
166 | d438771ca25e4b789a166f42311a8778,/Service,Booked our son's ski lesson and was not told that was an extra fee,pos,fail,anonymous,,anonymous,1.0
167 | d55325e8dc5d48be92c94dd76ba92b0d,/Rooms,Windows need new screens,pos,fail,anonymous,,anonymous,0.9633868193121728
168 | d6fa4ecef1554b108c121c367cea2c09,/Rooms,Shower didn't drain well,neg,pass,anonymous,,anonymous,0.02177738050122612
169 | d7e456cfc5944b108bfd56e6a28c5fc4,/Rooms,The only view from the room was of the front parking lot,neg,pass,anonymous,,anonymous,0.005614943169105575
170 | d88da78b2a304af79d7451acac7fbb13,/Facilities/Restaurant,Baked potatoes taste off,neg,pass,anonymous,,anonymous,0.0027625570047995476
171 | d8e1dc7d1baa4ca696c611a9f15506f3,/Location,Near light rail,pos,pass,anonymous,,anonymous,0.0
172 | d94d5eb7522746cfa8b2a80aa55d38b8,/Service,Staff was friendly,pos,pass,anonymous,,anonymous,0.0
173 | df8ce65ef12847e58ea7052b1ecc4e2a,/Service,Housekeeping quick and efficient,pos,pass,anonymous,,,0.0
174 | e15d2b8a684c429d86d192c3194ae153,/Facilities/Pool,Lockers were broken,pos,fail,anonymous,,,0.9748728620309092
175 | e2765b3385b7450fa7196361c189bd2f,/Service,Receptionists very attentive,pos,pass,anonymous,,anonymous,0.0
176 | e33ec9dbd07b4abda6ff020c1e34fe18,/Facilities/Restaurant,There are better places to eat nearby,pos,fail,anonymous,,anonymous,0.9084235104448649
177 | e36343e0829c4c8a9564662b9b82c2df,/Price,Rigged to rip off,neg,pass,anonymous,,anonymous,0.0
178 | e58c6478b5e546deaf9c94f95c85f7f8,/Service,Hotel staff will always act like they don't understand what you are saying,neg,pass,anonymous,,anonymous,0.0
179 | e66ed9065fcd4820bd6f8765eeca2565,/Price,Fairly priced for what you get,pos,pass,anonymous,,anonymous,0.0
180 | e75cea1ce3274d3cb4e709cd45c62fa3,/Service,Every single person working there made me feel like I'm their only priority,neg,fail,anonymous,,anonymous,0.9832501604566285
181 | e7a6d4bfe0a845c0b8e2f01ab10cf8a8,/Rooms,Our room had a refrigerator and was very clean,pos,pass,anonymous,,anonymous,0.0
182 | e8ef75b760bc445585d4f998909425c3,/Service,Tried to reach a concierge twice and never responded to each of our requests.,neg,pass,anonymous,,anonymous,0.0
183 | e9b85bb224354a7fb311f68bd458be07,/Price,Kind of pricey,neg,pass,anonymous,,anonymous,0.0
184 | ece3562bf84044ed9cb56d27566766f6,/Location,"Nice, quiet neighborhood",pos,pass,anonymous,,anonymous,0.0
185 | ed88f3124b224e61a7286dbb26d694e8,/Price,Quite expensive,neg,pass,anonymous,,anonymous,0.0
186 | f2011cf43eb642f6a503c4ed612bb767,/Facilities/Restaurant,Served my meal in less than 15 minutes,neg,fail,anonymous,,anonymous,1.0
187 | f29cfe9996ad40c8957df6208abeda39,/Price,A bit pricey for what it is,pos,fail,anonymous,,anonymous,0.9759607533207704
188 | f30ef67be728417a956e6438bd1ff375,/Service,Clothes came back neatly folded from laundry,pos,pass,anonymous,,,0.0
189 | f3d016316bfa4d20bfb93e971c8c021c,/Price,Four-star amenities on a one-star budget,pos,pass,anonymous,,anonymous,0.0
190 | f44b30d794014d5f964c3908f5eb63be,/Location,Easy to get to transit,pos,pass,anonymous,,,0.0
191 | f7cd947b83754f7788cf3b6eee47a1f5,/Price,Very expensive for the quality,pos,fail,anonymous,,anonymous,0.976195779895292
192 | f9d57eb623564826bdc697f7dd43be77,/Price,Fair value for money,pos,pass,anonymous,,anonymous,0.0
193 | fa4d7f67a3f14fa79a380afa5764ef80,/Price,Cheap for what you get,neg,fail,anonymous,,anonymous,0.014309784957217955
194 | fb9866889c414e2593158a1ddaa0e610,/Price,Costs have crept up a lot since last time I was here,pos,fail,anonymous,,anonymous,1.0
195 | fc6f00cf6dea4459811bb3d9ba84cec4,/Rooms,,,topic_marker,anonymous,,,
196 | fedd0976a40b49ad84fad703f6d70837,/Rooms,The toilet kept flooding,pos,fail,anonymous,,anonymous,0.945428799227947
197 | fef922cf2a62466584f4582da641ca2c,/Rooms,No door to bathroom!,pos,fail,anonymous,,anonymous,0.5445216298103333
198 | ff38155ab75945caac5c9e08eba4e1c4,/Rooms,Rooms were clean,pos,pass,anonymous,,anonymous,0.0
199 | fff299cde43e42ab8f7ba0a64c7f1a7d,,New test,pos,pass,imputed,,anonymous,0.0
200 |
--------------------------------------------------------------------------------