├── docs └── images │ ├── many_bugs.png │ ├── zero_shot.png │ ├── main_loops.png │ ├── blank_test_tree.png │ ├── first_suggestions.png │ ├── topic_suggestions.png │ └── manual_unconfirmed_inputs.png ├── adaptivetesting ├── resources │ ├── favicon.png │ └── main.js.LICENSE.txt ├── __init__.py ├── comm.py ├── utils │ └── __init__.py ├── _model.py ├── embedders.py ├── _topic_model.py ├── _prompt_builder.py ├── _scorer.py └── _server.py ├── client ├── .babelrc ├── src │ ├── index.jsx │ ├── total-value.jsx │ ├── context-menu.jsx │ ├── adatest.jsx │ ├── bread-crum.jsx │ ├── CommEvent.ts │ ├── web-socket-comm.js │ ├── jupyter-comm.js │ ├── adatest.css │ └── content-editable.jsx ├── README ├── tsconfig.json ├── package.json ├── webpack.config.js └── dist │ └── main.js.LICENSE.txt ├── test_trees ├── README.md └── abstract_capabilities.csv ├── .gitignore ├── tests ├── simple_test_tree.csv ├── test_utils.py ├── test_test_tree.py └── test_generators.py ├── development ├── launch.json └── scripts │ ├── build_wheel.py │ └── install_from_wheel.py ├── LICENSE ├── .github └── workflows │ ├── python-app.yml │ ├── python-wheel-build.yaml │ └── codeql-analysis.yml ├── setup.py ├── SECURITY.md ├── README.md └── notebooks ├── IMDB to Hotel Sentiment.ipynb └── imdb_hotel_conversion.csv /docs/images/many_bugs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/adaptive-testing/main/docs/images/many_bugs.png -------------------------------------------------------------------------------- /docs/images/zero_shot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/adaptive-testing/main/docs/images/zero_shot.png -------------------------------------------------------------------------------- /docs/images/main_loops.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/adaptive-testing/main/docs/images/main_loops.png -------------------------------------------------------------------------------- /docs/images/blank_test_tree.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/adaptive-testing/main/docs/images/blank_test_tree.png -------------------------------------------------------------------------------- /docs/images/first_suggestions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/adaptive-testing/main/docs/images/first_suggestions.png -------------------------------------------------------------------------------- /docs/images/topic_suggestions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/adaptive-testing/main/docs/images/topic_suggestions.png -------------------------------------------------------------------------------- /adaptivetesting/resources/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/adaptive-testing/main/adaptivetesting/resources/favicon.png -------------------------------------------------------------------------------- /docs/images/manual_unconfirmed_inputs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/adaptive-testing/main/docs/images/manual_unconfirmed_inputs.png -------------------------------------------------------------------------------- /client/.babelrc: -------------------------------------------------------------------------------- 1 | { 2 | "presets": [ 3 | "@babel/preset-env", 4 | "@babel/preset-react" 5 | ], 6 | "plugins": [ 7 | "@babel/plugin-proposal-class-properties" 8 | ] 9 | } -------------------------------------------------------------------------------- /client/src/index.jsx: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import ReactDOM from 'react-dom'; 3 | import AdaTest from './adatest' 4 | 5 | 6 | ReactDOM.render( 7 | , 8 | document.getElementById('adatest_container_1') 9 | ); -------------------------------------------------------------------------------- /client/README: -------------------------------------------------------------------------------- 1 | To build the client run the following in the client directory: 2 | 3 | > npm install 4 | 5 | > npx webpack 6 | 7 | This will create the dist/build.js file. 8 | 9 | When doing dev use: 10 | 11 | npx webpack --watch 12 | 13 | to auto-rebuild the bundle when the JS code changes. -------------------------------------------------------------------------------- /test_trees/README.md: -------------------------------------------------------------------------------- 1 | These sample test trees are provided without any warrenty to their accuracy. Because they can sometimes cover sensitive topics we encourage users to review and alter them as needed. The files do not represent the opinion of Microsoft and may not represent the opinions of any individual author. PRs are welcome! 2 | -------------------------------------------------------------------------------- /client/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "outDir": "../adaptivetesting/resources/", 4 | "sourceMap": true, 5 | "noImplicitAny": false, 6 | "module": "commonjs", 7 | "esModuleInterop": true, 8 | "target": "es6", 9 | "jsx": "react", 10 | "baseUrl": ".", 11 | } 12 | } -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | gadfly.egg-info 2 | __pycache__ 3 | *.pyc 4 | *.log 5 | .vscode 6 | node_modules 7 | .jekyll-cache 8 | .ipynb_checkpoints 9 | .sass-cache 10 | adaptivetesting.egg-info 11 | dist/gadfly-* 12 | local_scratch* 13 | wandb 14 | local_data 15 | /notebooks/TDD/data 16 | /notebooks/PPT 17 | /build 18 | /dist 19 | /test_trainer/runs 20 | /test_trees/local 21 | /adatest/utils/local 22 | /tests/local 23 | -------------------------------------------------------------------------------- /tests/simple_test_tree.csv: -------------------------------------------------------------------------------- 1 | ,topic,input,output,label,labeler,description,model score,author 2 | 4f51624790ce4b03ae18680a695d0b67,,Test at top level,POSITIVE,Unknown,imputed,,, 3 | 17975b5cc7ce46cb858b4e5677368a90,/A,,,topic_marker,anonymous,,, 4 | d043a06e342a452494ca73d1b4ee7967,/A/B,,,topic_marker,anonymous,,, 5 | 79ffa3a899374d1eb6870929db5408b0,/A,"Test under A 6 | ",NEGATIVE,Unknown,imputed,,, 7 | 5e5764ab240940aba43cc5aa76cf5429,/A/C,,,topic_marker,anonymous,,, 8 | 4b256f1fb5c64e13bc73106d2146257e,/A/B,Test under B,NEGATIVE,Unknown,imputed,,, 9 | -------------------------------------------------------------------------------- /adaptivetesting/__init__.py: -------------------------------------------------------------------------------- 1 | from ._test_tree import TestTree 2 | from ._test_tree_browser import TestTreeBrowser 3 | from ._scorer import Scorer, DummyScorer, ClassifierScorer, GeneratorScorer, RawScorer 4 | from ._server import serve 5 | from .embedders import _embed as embed 6 | from ._model import Model 7 | from ._topic_model import ChainTopicModel, StandardTopicModel 8 | from . import generators 9 | 10 | __version__ = '0.3.5' 11 | 12 | default_generators = { 13 | "abstract": TestTree(r"test_trees/abstract_capabilities.csv") 14 | } 15 | text_embedding_model = None 16 | image_embedding_model = None -------------------------------------------------------------------------------- /development/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: Current File", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "${file}", 12 | "console": "integratedTerminal", 13 | "justMyCode": true 14 | }, 15 | { 16 | "type": "pwa-chrome", 17 | "request": "launch", 18 | "name": "Launch Chrome against localhost", 19 | "url": "http://localhost:8080", 20 | "webRoot": "${workspaceFolder}/client" 21 | } 22 | ] 23 | } 24 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import adaptivetesting.utils as utils 4 | 5 | 6 | class TestIsSubTopic: 7 | @pytest.mark.parametrize( 8 | ["topic", "sub_topic"], 9 | [ 10 | ("/A", "/A/B"), 11 | ("/A", "/A/B/C"), 12 | ("/A", "/A/B/C/"), 13 | ("/A ", "/A /B"), 14 | ("/A ", "/A /B "), 15 | ("/A ", "/A /B /C"), 16 | ], 17 | ) 18 | def test_topic_is_subtopic(self, topic, sub_topic): 19 | assert utils.is_subtopic(topic, sub_topic) 20 | 21 | @pytest.mark.parametrize( 22 | ["topic", "not_sub_topic"], [("/A/B", "/A/C"), ("/A", "/AB")] 23 | ) 24 | def test_topic_is_not_subtopic(self, topic, not_sub_topic): 25 | assert not utils.is_subtopic(topic, not_sub_topic) 26 | 27 | @pytest.mark.parametrize( 28 | ["topic"], 29 | # Extra ',' in tuples to resolve ambiguity between string and list of characters 30 | [("/A",), ("/A/B",), ("/A /B",), ("/A /B ",), ("/A/B/C",), ("/A /B /C",)], 31 | ) 32 | def test_topic_own_subtopic(self, topic): 33 | assert utils.is_subtopic(topic, topic) 34 | -------------------------------------------------------------------------------- /client/src/total-value.jsx: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import autoBind from 'auto-bind'; 3 | import { get, debounce } from 'lodash'; 4 | 5 | export default class TotalValue extends React.Component { 6 | constructor(props) { 7 | super(props); 8 | autoBind(this); 9 | 10 | this.doStateUpdateDebounced = debounce(this.doStateUpdate, 100); 11 | 12 | this.pendingStateUpdate = {}; 13 | 14 | // our starting state 15 | this.state = { 16 | // note that all the ids will also be properties of the state 17 | }; 18 | } 19 | 20 | setSubtotal(id, subtotal) { 21 | this.pendingStateUpdate[id] = subtotal; 22 | this.doStateUpdateDebounced(); 23 | } 24 | 25 | doStateUpdate() { 26 | this.setState(this.pendingStateUpdate); 27 | this.pendingStateUpdate = {}; 28 | } 29 | 30 | render() { 31 | // we just sum up the current active subtotals 32 | let total = 0; 33 | for (let i in this.props.activeIds) { 34 | total += get(this.state, this.props.activeIds[i], 0); 35 | } 36 | 37 | return 38 | {total} 39 | 40 | } 41 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /client/src/context-menu.jsx: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import autoBind from 'auto-bind'; 3 | 4 | export default class ContextMenu extends React.Component { 5 | static defaultProps = { 6 | top: 0, 7 | left: 0, 8 | open: false, 9 | onClick: () => undefined, 10 | rows: [ 11 | "test" 12 | ] 13 | }; 14 | 15 | constructor(props) { 16 | super(props); 17 | autoBind(this); 18 | } 19 | 20 | render() { 21 | return
22 |
23 |
24 | {this.props.rows && this.props.rows.map((row, index) => { 25 | return
this.handleRowClick(row, e)} className="adatest-hover-gray">{row}
26 | })} 27 |
28 |
29 | } 30 | 31 | handleBackgroundClick(e) { 32 | e.preventDefault(); 33 | this.props.onClose(); 34 | } 35 | 36 | handleRowClick(row, e) { 37 | e.preventDefault(); 38 | this.props.onClick(row); 39 | } 40 | } -------------------------------------------------------------------------------- /development/scripts/build_wheel.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import platform 4 | import subprocess 5 | 6 | _logger = logging.getLogger(__file__) 7 | logging.basicConfig(level=logging.INFO) 8 | 9 | 10 | def build_client(): 11 | # Find our initial directory 12 | _logger.info("Starting build_client") 13 | _logger.info("Running npm install") 14 | # Spawning npm appears to work differently in PowerShell 15 | spawn_shell = platform.system() == "Windows" 16 | subprocess.run( 17 | ["npm", "install", "--loglevel", "verbose"], 18 | shell=spawn_shell, 19 | cwd="client", 20 | check=True, 21 | ) 22 | _logger.info("Running npx webpack") 23 | subprocess.run(["npx", "webpack"], shell=spawn_shell, cwd="client", check=True) 24 | _logger.info("Ending build_client") 25 | 26 | 27 | def build_wheel(): 28 | _logger.info("Starting build_wheel") 29 | subprocess.run(["python", "setup.py", "sdist", "bdist_wheel"], check=True) 30 | _logger.info("Ending build_wheel") 31 | 32 | 33 | def main(): 34 | assert "setup.py" in os.listdir(), "Must be run from repo root" 35 | # Build the client 36 | build_client() 37 | 38 | # Build the wheel 39 | build_wheel() 40 | 41 | _logger.info("Completed") 42 | 43 | 44 | if __name__ == "__main__": 45 | main() 46 | -------------------------------------------------------------------------------- /development/scripts/install_from_wheel.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import pathlib 4 | import subprocess 5 | import sys 6 | 7 | _logger = logging.getLogger(__file__) 8 | logging.basicConfig(level=logging.INFO) 9 | 10 | 11 | def build_argument_parser(): 12 | desc = "Install Adaptive Testing from a wheel file" 13 | 14 | parser = argparse.ArgumentParser(description=desc) 15 | parser.add_argument( 16 | "--wheel-dir", 17 | help="Directory containing the AdaTest wheel", 18 | required=True, 19 | ) 20 | 21 | return parser 22 | 23 | 24 | def main(argv): 25 | parser = build_argument_parser() 26 | args = parser.parse_args(argv) 27 | 28 | _logger.info("Finding wheel file") 29 | target_dir = pathlib.Path(args.wheel_dir) 30 | # Globbing works from Python, but not in Windows builds 31 | wheel_list = list(target_dir.glob("adaptivetesting*.whl")) 32 | assert len(wheel_list) == 1, f"Bad wheel_list: {wheel_list}" 33 | wheel_path = wheel_list[0].resolve() 34 | msg = f"Path to wheel: {wheel_path}" 35 | _logger.info(msg) 36 | 37 | _logger.info("Installing wheel") 38 | # Use this approach so that extras can be added 39 | adatest_spec = f"adaptivetesting[dev] @ {wheel_path.as_uri()}" 40 | subprocess.run(["pip", "install", f"{adatest_spec}"], check=True) 41 | 42 | 43 | if __name__ == "__main__": 44 | main(sys.argv[1:]) 45 | -------------------------------------------------------------------------------- /.github/workflows/python-app.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python application 5 | 6 | on: 7 | push: 8 | branches: [ main ] 9 | pull_request: 10 | branches: [ main ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ${{ matrix.os }} 16 | strategy: 17 | matrix: 18 | os: [ubuntu-latest, macos-12, windows-latest] 19 | python-version: ['3.7', '3.8', '3.9', '3.10'] 20 | 21 | steps: 22 | - uses: actions/checkout@v2 23 | - name: Set up Python ${{ matrix.python-version }} on ${{ matrix.os }} 24 | uses: actions/setup-python@v2 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | pip install -e '.[dev]' 31 | - name: Lint with flake8 32 | run: | 33 | # stop the build if there are Python syntax errors or undefined names 34 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 35 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 36 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 37 | - name: Test with pytest 38 | run: | 39 | python -m pytest tests/ 40 | -------------------------------------------------------------------------------- /adaptivetesting/comm.py: -------------------------------------------------------------------------------- 1 | import time 2 | import json 3 | import logging 4 | log = logging.getLogger(__name__) 5 | 6 | class JupyterComm(): 7 | def __init__(self, target_name, callback=None, mode="register"): 8 | from ipykernel.comm import Comm 9 | 10 | self.target_name = target_name 11 | self.callback = callback 12 | self.jcomm = None 13 | if mode == "register": 14 | def comm_opened(comm, open_msg): 15 | self.jcomm = comm 16 | self.jcomm.on_msg(self._fire_callback) 17 | get_ipython().kernel.comm_manager.register_target(self.target_name, comm_opened) # noqa: F821 18 | elif mode == "open": 19 | self.jcomm = Comm(target_name=target_name) 20 | self.jcomm.on_msg(self._fire_callback) 21 | else: 22 | raise Exception("Passed mode must be either 'open' or 'register'!") 23 | 24 | def _fire_callback(self, msg): 25 | self.callback(msg["content"]["data"]) 26 | 27 | def send(self, data): 28 | for i in range(10): 29 | if self.jcomm is None: 30 | time.sleep(0.5) 31 | else: 32 | s = json.dumps(data) 33 | self.jcomm.send({"data": json.dumps(data)}) # we encode the JSON so iPython doesn't mess it up 34 | return 35 | raise Exception("The Jupyter comm channel was never opened from the other side, so not message can be sent!") 36 | 37 | -------------------------------------------------------------------------------- /client/src/adatest.jsx: -------------------------------------------------------------------------------- 1 | import "./adatest.css"; 2 | 3 | import React from 'react'; 4 | import ReactDOM from 'react-dom'; 5 | import { withRouter } from 'react-router-dom'; 6 | import { BrowserRouter } from "react-router-dom"; 7 | import { MemoryRouter } from 'react-router'; 8 | import Browser from './browser' 9 | 10 | const BrowserWithRouter = withRouter(Browser); 11 | 12 | export default class AdaTest extends React.Component { 13 | 14 | constructor(props) { 15 | super(props); 16 | console.log("interfaceId", this.props.interfaceId) 17 | this.state = { enabled: true }; 18 | window.adatest_root = this; 19 | } 20 | render() { 21 | 22 | const Router = this.props.environment === "web" ? BrowserRouter : MemoryRouter; 23 | 24 | return ( 25 |
26 |
27 | 28 | 33 | 34 |
35 |
36 | ); 37 | } 38 | } 39 | 40 | window.AdaTestReact = React 41 | window.AdaTestReactDOM = ReactDOM 42 | window.AdaTest = AdaTest 43 | 44 | -------------------------------------------------------------------------------- /client/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "client", 3 | "version": "1.0.0", 4 | "description": "", 5 | "main": "index.js", 6 | "scripts": { 7 | "test": "echo \"Error: no test specified\" && exit 1", 8 | "build": "NODE_ENV=production webpack", 9 | "build:dev": "NODE_ENV=dev webpack" 10 | }, 11 | "keywords": [], 12 | "author": "", 13 | "license": "ISC", 14 | "devDependencies": { 15 | "@babel/core": "^7.13.14", 16 | "@babel/preset-env": "^7.13.12", 17 | "@babel/preset-react": "^7.13.13", 18 | "@babel/preset-typescript": "^7.18.6", 19 | "@fortawesome/fontawesome-free": "^5.15.3", 20 | "@types/react": "^17.0.47", 21 | "@types/react-dom": "^18.0.5", 22 | "@types/react-router-dom": "^5.3.3", 23 | "babel": "^6.23.0", 24 | "babel-core": "^6.26.3", 25 | "babel-loader": "^8.2.2", 26 | "css-loader": "^5.2.0", 27 | "source-map-loader": "^4.0.0", 28 | "string-replace-loader": "^3.0.1", 29 | "style-loader": "^2.0.0", 30 | "ts-loader": "^9.3.1", 31 | "typescript": "^4.7.4", 32 | "webpack": "^5.30.0", 33 | "webpack-cli": "^4.6.0" 34 | }, 35 | "dependencies": { 36 | "@fortawesome/fontawesome-svg-core": "^1.2.35", 37 | "@fortawesome/free-solid-svg-icons": "^5.15.3", 38 | "@fortawesome/react-fontawesome": "^0.1.14", 39 | "@material-ui/core": "^4.11.4", 40 | "@material-ui/icons": "^4.11.2", 41 | "@material-ui/lab": "^4.0.0-alpha.58", 42 | "auto-bind": "^4.0.0", 43 | "json5": "^2.2.0", 44 | "lodash": "^4.17.21", 45 | "react": "^17.0.2", 46 | "react-beautiful-dnd": "^13.1.0", 47 | "react-dom": "^17.0.2", 48 | "react-router-dom": "^5.2.0", 49 | "sanitize-html": "^2.3.3" 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /client/webpack.config.js: -------------------------------------------------------------------------------- 1 | const path = require('path'); 2 | 3 | const isDevelopment = process.env.NODE_ENV === 'dev'; 4 | 5 | module.exports = { 6 | entry: path.resolve(__dirname, './src/adatest.jsx'), 7 | devtool: isDevelopment ? 'eval-source-map' : false, 8 | module: { 9 | rules: [ 10 | { 11 | test: /\.(js|jsx)$/, 12 | exclude: /node_modules/, 13 | use: ['babel-loader'], 14 | }, 15 | { 16 | test: /\.ts(x?)$/, 17 | exclude: /node_modules/, 18 | use: [ 19 | { 20 | loader: "ts-loader" 21 | } 22 | ] 23 | }, 24 | { 25 | test: /\.css$/i, 26 | use: ["style-loader", "css-loader"], 27 | }, 28 | { 29 | test: /\.(png|woff|woff2|eot|ttf|svg)$/, 30 | use: ['url-loader'], 31 | }, 32 | { // this allows font-awesome to be used during development mode... (since we print to the page in a script tag) 33 | test: /\.js$/, 34 | loader: 'string-replace-loader', 35 | options: { 36 | search: '', 37 | replace: '_/script>', 38 | } 39 | }, 40 | // https://github.com/webpack-contrib/source-map-loader 41 | { 42 | enforce: "pre", 43 | test: /\.js$/, 44 | loader: "source-map-loader" 45 | } 46 | ], 47 | }, 48 | resolve: { 49 | extensions: ['*', '.js', '.jsx', '.ts', '.tsx'], 50 | }, 51 | externals: { 52 | // 'react': 'React', 53 | // 'react-dom': 'ReactDOM' 54 | }, 55 | output: { 56 | path: path.resolve(__dirname, '../adatest/resources'), 57 | filename: 'main.js', 58 | }, 59 | mode: isDevelopment ? "development" : "production" 60 | }; -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import codecs 4 | from setuptools import setup, find_packages 5 | 6 | here = os.path.abspath(os.path.dirname(__file__)) 7 | 8 | 9 | def read(*parts): 10 | with codecs.open(os.path.join(here, *parts), "r") as fp: 11 | return fp.read() 12 | 13 | 14 | def find_version(*file_paths): 15 | version_file = read(*file_paths) 16 | version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M) 17 | if version_match: 18 | return version_match.group(1) 19 | raise RuntimeError("Unable to find version string.") 20 | 21 | 22 | setup( 23 | name="adaptivetesting", 24 | version=find_version("adaptivetesting", "__init__.py"), 25 | url="https://github.com/microsoft/adaptive-testing.git", 26 | author="Scott Lundberg and Marco Tulio Ribeiro", 27 | author_email="scott.lundberg@microsoft.com", 28 | description="Adaptively test and debug any natural language machine learning model.", 29 | packages=find_packages(exclude=["user_studies", "notebooks", "client"]), 30 | package_data={"adaptivetesting": ["resources/*"]}, 31 | install_requires=[ 32 | "aiohttp", 33 | "aiohttp_security", 34 | "aiohttp_session", 35 | "appdirs", 36 | "cryptography", 37 | "diskcache", 38 | "nest_asyncio", 39 | "numpy", 40 | "pandas", 41 | "profanity", 42 | "scikit-learn", 43 | "shap", 44 | ], 45 | extras_require={ 46 | "dev": [ 47 | "black", 48 | "flake8", 49 | "openai<1", 50 | "datasets", 51 | "transformers<4.26", 52 | "pytest", 53 | "pytest-mock", 54 | "torch", 55 | ] 56 | }, 57 | ) 58 | -------------------------------------------------------------------------------- /client/dist/main.js.LICENSE.txt: -------------------------------------------------------------------------------- 1 | /* 2 | object-assign 3 | (c) Sindre Sorhus 4 | @license MIT 5 | */ 6 | 7 | /*! 8 | * Font Awesome Free 5.15.3 by @fontawesome - https://fontawesome.com 9 | * License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) 10 | */ 11 | 12 | /*! 13 | * is-plain-object 14 | * 15 | * Copyright (c) 2014-2017, Jon Schlinkert. 16 | * Released under the MIT License. 17 | */ 18 | 19 | /** 20 | * @license 21 | * Lodash 22 | * Copyright OpenJS Foundation and other contributors 23 | * Released under MIT license 24 | * Based on Underscore.js 1.8.3 25 | * Copyright Jeremy Ashkenas, DocumentCloud and Investigative Reporters & Editors 26 | */ 27 | 28 | /** @license React v0.20.2 29 | * scheduler.production.min.js 30 | * 31 | * Copyright (c) Facebook, Inc. and its affiliates. 32 | * 33 | * This source code is licensed under the MIT license found in the 34 | * LICENSE file in the root directory of this source tree. 35 | */ 36 | 37 | /** @license React v16.13.1 38 | * react-is.production.min.js 39 | * 40 | * Copyright (c) Facebook, Inc. and its affiliates. 41 | * 42 | * This source code is licensed under the MIT license found in the 43 | * LICENSE file in the root directory of this source tree. 44 | */ 45 | 46 | /** @license React v17.0.2 47 | * react-dom.production.min.js 48 | * 49 | * Copyright (c) Facebook, Inc. and its affiliates. 50 | * 51 | * This source code is licensed under the MIT license found in the 52 | * LICENSE file in the root directory of this source tree. 53 | */ 54 | 55 | /** @license React v17.0.2 56 | * react.production.min.js 57 | * 58 | * Copyright (c) Facebook, Inc. and its affiliates. 59 | * 60 | * This source code is licensed under the MIT license found in the 61 | * LICENSE file in the root directory of this source tree. 62 | */ 63 | -------------------------------------------------------------------------------- /adaptivetesting/resources/main.js.LICENSE.txt: -------------------------------------------------------------------------------- 1 | /* 2 | object-assign 3 | (c) Sindre Sorhus 4 | @license MIT 5 | */ 6 | 7 | /*! 8 | * Font Awesome Free 5.15.3 by @fontawesome - https://fontawesome.com 9 | * License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) 10 | */ 11 | 12 | /*! 13 | * is-plain-object 14 | * 15 | * Copyright (c) 2014-2017, Jon Schlinkert. 16 | * Released under the MIT License. 17 | */ 18 | 19 | /** 20 | * @license 21 | * Lodash 22 | * Copyright OpenJS Foundation and other contributors 23 | * Released under MIT license 24 | * Based on Underscore.js 1.8.3 25 | * Copyright Jeremy Ashkenas, DocumentCloud and Investigative Reporters & Editors 26 | */ 27 | 28 | /** @license React v0.20.2 29 | * scheduler.production.min.js 30 | * 31 | * Copyright (c) Facebook, Inc. and its affiliates. 32 | * 33 | * This source code is licensed under the MIT license found in the 34 | * LICENSE file in the root directory of this source tree. 35 | */ 36 | 37 | /** @license React v16.13.1 38 | * react-is.production.min.js 39 | * 40 | * Copyright (c) Facebook, Inc. and its affiliates. 41 | * 42 | * This source code is licensed under the MIT license found in the 43 | * LICENSE file in the root directory of this source tree. 44 | */ 45 | 46 | /** @license React v17.0.2 47 | * react-dom.production.min.js 48 | * 49 | * Copyright (c) Facebook, Inc. and its affiliates. 50 | * 51 | * This source code is licensed under the MIT license found in the 52 | * LICENSE file in the root directory of this source tree. 53 | */ 54 | 55 | /** @license React v17.0.2 56 | * react.production.min.js 57 | * 58 | * Copyright (c) Facebook, Inc. and its affiliates. 59 | * 60 | * This source code is licensed under the MIT license found in the 61 | * LICENSE file in the root directory of this source tree. 62 | */ 63 | -------------------------------------------------------------------------------- /client/src/bread-crum.jsx: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import autoBind from 'auto-bind'; 3 | 4 | export default class BreadCrum extends React.Component { 5 | constructor(props) { 6 | super(props); 7 | autoBind(this); 8 | 9 | this.state = { 10 | dropHighlighted: 0 11 | }; 12 | } 13 | 14 | render() { 15 | // console.log("br", this.props.name, this.props.name === "") 16 | return
19 | {this.props.name === "" ? this.props.defaultName : decodeURIComponent(this.props.name)} 20 |
21 | } 22 | 23 | onClick(e) { 24 | e.preventDefault(); 25 | e.stopPropagation(); 26 | if (this.props.onClick) { 27 | if (this.props.name === "") { 28 | this.props.onClick(this.props.topic); 29 | } else { 30 | this.props.onClick(this.props.topic + "/" + this.props.name); 31 | } 32 | } 33 | } 34 | 35 | onDragOver(e) { 36 | e.preventDefault(); 37 | e.stopPropagation(); 38 | } 39 | 40 | onDragEnter(e) { 41 | e.preventDefault(); 42 | e.stopPropagation(); 43 | this.setState({dropHighlighted: this.state.dropHighlighted + 1}); 44 | } 45 | 46 | onDragLeave(e) { 47 | e.preventDefault(); 48 | e.stopPropagation(); 49 | this.setState({dropHighlighted: this.state.dropHighlighted - 1}); 50 | } 51 | 52 | onDrop(e) { 53 | const id = e.dataTransfer.getData("id"); 54 | this.setState({dropHighlighted: 0}); 55 | if (this.props.onDrop) { 56 | let suffix = ""; 57 | if (id.includes("/")) { 58 | suffix = "/" + id.split("/").pop(); 59 | } 60 | this.props.onDrop(id, this.props.topic + (this.props.name === "" ? "" : "/" + this.props.name) + suffix); 61 | } 62 | } 63 | } -------------------------------------------------------------------------------- /adaptivetesting/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import re 2 | import urllib 3 | import io 4 | import shap 5 | 6 | import numpy as np 7 | 8 | 9 | def parse_test_type(test_type): 10 | part_names = ["text1", "value1", "text2", "value2", "text3", "value3", "text4"] 11 | parts = re.split(r"(\{\}|\[\])", test_type) 12 | part_values = ["" for _ in range(7)] 13 | for i, part in enumerate(parts): 14 | part_values[i] = part 15 | return {name: value for name, value in zip(part_names, part_values)} 16 | 17 | 18 | # https://codereview.stackexchange.com/questions/253198/improved-isinstance-for-ipython 19 | def isinstance_ipython(obj, ref_class): 20 | def _class_name(obj): 21 | name = getattr(obj, "__qualname__", getattr(obj, "__name__", "")) 22 | return (getattr(obj, "__module__", "") + "." + name).lstrip(".") 23 | 24 | return isinstance(obj, ref_class) or _class_name(type(obj)) == _class_name( 25 | ref_class 26 | ) 27 | 28 | 29 | _images_cache = {} 30 | 31 | 32 | def get_image(url): 33 | if url not in _images_cache: 34 | try: 35 | _images_cache[url] = _download_image(url) 36 | except urllib.error.URLError: 37 | _images_cache[url] = get_image( 38 | "https://upload.wikimedia.org/wikipedia/commons/d/d1/Image_not_available.png" 39 | ) 40 | 41 | return _images_cache[url] 42 | 43 | 44 | def _download_image(url): 45 | import PIL 46 | 47 | urllib_request = urllib.request.Request( 48 | url, 49 | data=None, 50 | headers={ 51 | "User-Agent": "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:72.0) Gecko/20100101 Firefox/72.0" 52 | }, 53 | ) 54 | with urllib.request.urlopen(urllib_request, timeout=10) as r: 55 | img_stream = io.BytesIO(r.read()) 56 | return PIL.Image.open(img_stream) 57 | 58 | 59 | def is_subtopic(topic, candidate): 60 | # Returns true if candidate is a subtopic of topic 61 | # Both arguments are strings, which look like UNIX paths 62 | # Return is boolean 63 | # return True if re.search(r"^%s(/|$)" % re.escape(topic), candidate) else False 64 | if len(topic) == len(candidate): 65 | return topic == candidate 66 | else: 67 | return candidate.startswith(topic) and candidate[len(topic)] == "/" 68 | 69 | 70 | def convert_float(s): 71 | try: 72 | f = float(s) 73 | except ValueError: 74 | f = np.nan 75 | return f 76 | -------------------------------------------------------------------------------- /.github/workflows/python-wheel-build.yaml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Build Python Wheel 5 | 6 | on: 7 | push: 8 | branches: [ main ] 9 | pull_request: 10 | branches: [ main ] 11 | 12 | 13 | env: 14 | # Select the Artifact to be used in the 'test' job 15 | blessed-wheel-artifact: wheel-ubuntu-latest-3.10 16 | 17 | 18 | jobs: 19 | build: 20 | runs-on: ${{ matrix.os }} 21 | strategy: 22 | matrix: 23 | os: [ubuntu-latest, windows-latest] 24 | python-version: ['3.10'] 25 | 26 | steps: 27 | - uses: actions/checkout@v2 28 | - name: Set up Python ${{ matrix.python-version }} on ${{ matrix.os }} 29 | uses: actions/setup-python@v2 30 | with: 31 | python-version: ${{ matrix.python-version }} 32 | - name: Update pip and setuptools 33 | run: | 34 | pip install --upgrade pip 35 | pip install --upgrade setuptools wheel 36 | - name: Check node version 37 | run: npm --version 38 | - name: Run build_wheel script 39 | run: python development/scripts/build_wheel.py 40 | - name: Save wheel artifact 41 | uses: actions/upload-artifact@v3 42 | with: 43 | # Name must match blessed-wheel-artifact above 44 | name: wheel-${{ matrix.os }}-${{ matrix.python-version }} 45 | path: dist/* 46 | 47 | test: 48 | needs: build 49 | runs-on: ${{ matrix.os }} 50 | strategy: 51 | matrix: 52 | os: [macos-12-latest, ubuntu-latest, windows-latest] 53 | python-version: ['3.8', '3.9', '3.10'] 54 | 55 | steps: 56 | - uses: actions/checkout@v2 57 | - name: Set up Python ${{ matrix.python-version }} on ${{ matrix.os }} 58 | uses: actions/setup-python@v2 59 | with: 60 | python-version: ${{ matrix.python-version }} 61 | - name: Remove AdaTest source directory 62 | uses: JesseTG/rm@v1.0.3 63 | with: 64 | path: adaptivetesting 65 | - name: Download wheel artifact 66 | uses: actions/download-artifact@v3 67 | with: 68 | name: ${{ env.blessed-wheel-artifact }} 69 | - name: Install wheel from file 70 | run: python development/scripts/install_from_wheel.py --wheel-dir=. 71 | - name: Test with pytest 72 | run: | 73 | python -m pytest tests/ -------------------------------------------------------------------------------- /tests/test_test_tree.py: -------------------------------------------------------------------------------- 1 | from operator import truediv 2 | import os 3 | import pathlib 4 | import tempfile 5 | 6 | import numpy as np 7 | 8 | import adaptivetesting 9 | 10 | 11 | def test_simple_init(): 12 | tree = adaptivetesting.TestTree() 13 | assert len(tree) == 0 14 | 15 | 16 | def test_simple_init_with_file(): 17 | tree = adaptivetesting.TestTree("temp_test_tree.csv") 18 | assert len(tree) == 0 19 | 20 | 21 | def test_simple_init_with_list(): 22 | tree = adaptivetesting.TestTree(["The food was nice!", "The location is excellent."]) 23 | assert len(tree) == 3 24 | assert tree.columns.to_list() == [ 25 | "topic", 26 | "input", 27 | "output", 28 | "label", 29 | "labeler", 30 | "description", 31 | ] 32 | assert "The food was nice!" in tree["input"].to_list() 33 | assert "The location is excellent." in tree["input"].to_list() 34 | outputs = np.unique(tree["output"].to_list(), return_counts=True) 35 | assert outputs[0][0] == "" 36 | assert outputs[1][0] == 1 37 | assert outputs[0][1] == "[no output]" 38 | assert outputs[1][1] == 2 39 | 40 | 41 | def test_to_csv(): 42 | tree = adaptivetesting.TestTree( 43 | [ 44 | { 45 | "topic": "", 46 | "type": "{} should output {}", 47 | "input": "This is good", 48 | "output": "NEGATIVE", 49 | "label": "fail", 50 | } 51 | ] 52 | ) 53 | with tempfile.TemporaryDirectory() as td: 54 | target_file = os.path.join(td, "adaptivetesting_out.csv") 55 | tree.to_csv(target_file) 56 | assert os.path.exists(target_file) 57 | 58 | 59 | def test_has_subtopic_or_tests(): 60 | curr_file = pathlib.Path(__file__) 61 | curr_dir = curr_file.parent 62 | input_csv = curr_dir / "simple_test_tree.csv" 63 | assert input_csv.exists() 64 | tree = adaptivetesting.TestTree(str(input_csv)) 65 | # The top level topic appears to be an empty string, which is odd 66 | assert tree.topic_has_subtopics("") == True 67 | assert tree.topic_has_direct_tests("") == True 68 | assert tree.topic_has_direct_tests("/A") == True 69 | assert tree.topic_has_subtopics("/A") == True 70 | assert tree.topic_has_direct_tests("/A/B") == True 71 | assert tree.topic_has_subtopics("/A/B") == False 72 | assert tree.topic_has_direct_tests("/A/C") == False 73 | assert tree.topic_has_subtopics("/A/C") == False 74 | -------------------------------------------------------------------------------- /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ main ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ main ] 20 | schedule: 21 | - cron: '29 12 * * 6' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | permissions: 28 | actions: read 29 | contents: read 30 | security-events: write 31 | 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | language: [ 'python' ] 36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] 37 | # Learn more about CodeQL language support at https://git.io/codeql-language-support 38 | 39 | steps: 40 | - name: Checkout repository 41 | uses: actions/checkout@v2 42 | 43 | # Initializes the CodeQL tools for scanning. 44 | - name: Initialize CodeQL 45 | uses: github/codeql-action/init@v1 46 | with: 47 | languages: ${{ matrix.language }} 48 | # If you wish to specify custom queries, you can do so here or in a config file. 49 | # By default, queries listed here will override any specified in a config file. 50 | # Prefix the list here with "+" to use these queries and those in the config file. 51 | # queries: ./path/to/local/query, your-org/your-repo/queries@main 52 | 53 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 54 | # If this step fails, then you should remove it and run the build manually (see below) 55 | - name: Autobuild 56 | uses: github/codeql-action/autobuild@v1 57 | 58 | # ℹ️ Command-line programs to run using the OS shell. 59 | # 📚 https://git.io/JvXDl 60 | 61 | # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines 62 | # and modify them (or add more) to build your code if your project 63 | # uses a compiled language 64 | 65 | #- run: | 66 | # make bootstrap 67 | # make release 68 | 69 | - name: Perform CodeQL Analysis 70 | uses: github/codeql-action/analyze@v1 71 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). 40 | 41 | -------------------------------------------------------------------------------- /client/src/CommEvent.ts: -------------------------------------------------------------------------------- 1 | 2 | // The base class for all communication data over websocket 3 | export class CommEvent { 4 | // Using snake case because these objects will be parsed in Python backend 5 | readonly event_id: string; 6 | constructor(event_id: string, data?: object) { 7 | this.event_id = event_id; 8 | if (data) { 9 | for (const k of Object.keys(data)) { 10 | this[k] = data[k]; 11 | } 12 | } 13 | } 14 | } 15 | 16 | 17 | export function finishTopicDescription(topic_marker_id: string, description: string) { 18 | return new CommEvent("change_description", {"topic_marker_id": topic_marker_id, "description": description}); 19 | } 20 | 21 | 22 | export function redraw() { 23 | return new CommEvent("redraw"); 24 | } 25 | 26 | 27 | export function generateSuggestions(data: object) { 28 | return new CommEvent("generate_suggestions", data); 29 | } 30 | 31 | 32 | export function clearSuggestions() { 33 | return new CommEvent("clear_suggestions") 34 | } 35 | 36 | 37 | export function setFirstModel(model: string) { 38 | return new CommEvent("set_first_model", { "model": model }); 39 | } 40 | 41 | 42 | export function changeGenerator(generator: string) { 43 | return new CommEvent("change_generator", { "generator": generator }); 44 | } 45 | 46 | 47 | export function changeMode(mode: string) { 48 | return new CommEvent("change_mode", {"mode": mode}) 49 | } 50 | 51 | 52 | export function addTopic() { 53 | return new CommEvent("add_new_topic") 54 | } 55 | 56 | 57 | export function addTest() { 58 | return new CommEvent("add_new_test"); 59 | } 60 | 61 | 62 | export function changeFilter(filter_text: string) { 63 | return new CommEvent("change_filter", {"filter_text": filter_text}); 64 | } 65 | 66 | 67 | export function changeTopic(topic: string) { 68 | return new CommEvent("change_topic", {"topic": topic}); 69 | } 70 | 71 | 72 | export function moveTest(test_ids: string[] | string, topic: string) { 73 | if (!Array.isArray(test_ids)) { 74 | test_ids = [test_ids] 75 | } 76 | return new CommEvent("move_test", { "test_ids": test_ids, "topic": topic }); 77 | } 78 | 79 | 80 | export function deleteTest(test_ids: string[] | string) { 81 | if (!Array.isArray(test_ids)) { 82 | test_ids = [test_ids] 83 | } 84 | return new CommEvent("delete_test", { "test_ids": test_ids }); 85 | } 86 | 87 | 88 | export function changeLabel(test_id: string, label: string, labeler: string) { 89 | return new CommEvent("change_label", { "test_ids": [test_id], "label": label, "labeler": labeler }); 90 | } 91 | 92 | 93 | export function changeInput(test_id: string, input: string) { 94 | return new CommEvent("change_input", { "test_ids": [test_id], "input": input }); 95 | } 96 | 97 | 98 | export function changeOutput(test_id: string, output: string) { 99 | return new CommEvent("change_output", { "test_ids": [test_id], "output": output }); 100 | } 101 | -------------------------------------------------------------------------------- /client/src/web-socket-comm.js: -------------------------------------------------------------------------------- 1 | import JSON5 from 'json5'; 2 | import autoBind from 'auto-bind'; 3 | import { defer, debounce } from 'lodash'; 4 | 5 | export default class WebSocketComm { 6 | constructor(interfaceId, websocketServer, onopen) { 7 | autoBind(this); 8 | this.interfaceId = interfaceId; 9 | this.websocketServer = websocketServer; 10 | this.callbackMap = {}; 11 | this.data = {}; 12 | this.pendingData = {}; 13 | this.onopen = onopen; 14 | this.reconnectDelay = 100; 15 | 16 | this.debouncedSendPendingData500 = debounce(this.sendPendingData, 500); 17 | this.debouncedSendPendingData1000 = debounce(this.sendPendingData, 1000); 18 | 19 | this.connect(); 20 | } 21 | 22 | send(keys, data) { 23 | this.addPendingData(keys, data); 24 | this.sendPendingData(); 25 | } 26 | 27 | sendEvent(commEvent) { 28 | for (const k of Object.keys(commEvent)) { 29 | this.addPendingData(k, commEvent[k]); 30 | } 31 | this.sendPendingData(); 32 | } 33 | 34 | debouncedSendEvent500(commEvent) { 35 | for (const k of Object.keys(commEvent)) { 36 | this.addPendingData(k, commEvent[k]); 37 | } 38 | this.debouncedSendPendingData500(); 39 | } 40 | 41 | debouncedSend500(keys, data) { 42 | this.addPendingData(keys, data); 43 | this.debouncedSendPendingData500(); 44 | } 45 | 46 | debouncedSend1000(keys, data) { 47 | this.addPendingData(keys, data); 48 | this.debouncedSendPendingData1000(); 49 | } 50 | 51 | addPendingData(keys, data) { 52 | // console.log("addPendingData", keys, data); 53 | if (!Array.isArray(keys)) keys = [keys]; 54 | for (const i in keys) { 55 | const k = keys[i]; 56 | this.pendingData[k] = data; 57 | this.data[k] = Object.assign(this.data[k] || {}, data); // pretend it has already changed in our data cache 58 | } 59 | } 60 | 61 | connect() { 62 | let wsUri = (window.location.protocol=='https:' ? 'wss://' : 'ws://') + (this.websocketServer.startsWith("/") ? window.location.host : "") + this.websocketServer; 63 | this.wcomm = new WebSocket(wsUri); 64 | this.wcomm.onopen = this.onopen; 65 | this.wcomm.onmessage = this.updateData; 66 | this.wcomm.onerror = this.onError; 67 | this.wcomm.onclose = this.onClose; 68 | } 69 | 70 | updateData(e) { 71 | // console.log("updateData", e) 72 | let data = JSON5.parse(e.data); 73 | console.log("updateData", data) 74 | for (const k in data) { 75 | // console.log("data[k]", data[k]) 76 | this.data[k] = Object.assign(this.data[k] || {}, data[k]); 77 | if (k in this.callbackMap) { 78 | this.callbackMap[k](data[k]); 79 | } 80 | } 81 | } 82 | 83 | onError(e) { 84 | console.log("Websocket error", e); 85 | } 86 | 87 | onClose(e) { 88 | console.log('Socket is closed. Reconnect will be attempted...', e.reason); 89 | setTimeout(this.connect, this.reconnectDelay); 90 | this.reconnectDelay += 1000; 91 | } 92 | 93 | subscribe(key, callback) { 94 | this.callbackMap[key] = callback; 95 | defer(_ => this.callbackMap[key](this.data[key])); 96 | } 97 | 98 | sendPendingData() { 99 | console.log("sending", this.pendingData); 100 | this.wcomm.send(JSON.stringify(this.pendingData)); 101 | this.pendingData = {}; 102 | } 103 | } -------------------------------------------------------------------------------- /adaptivetesting/_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import shap 3 | 4 | 5 | class Model(): 6 | """ This wraps models used in Adaptive Testing so that have a consistent interface. 7 | 8 | This should eventually just be the Model class from SHAP, but we keep a simple version here for now 9 | so we can easily update it during initial development. 10 | """ 11 | 12 | def __new__(cls, model, *args, **kwargs): 13 | """ If we are wrapping a model that is already a Model, we just return it. 14 | """ 15 | if shap.utils.safe_isinstance(model, "adaptivetesting.Model") or shap.utils.safe_isinstance(model, "shap.models.Model"): 16 | return model 17 | else: 18 | return super().__new__(cls) 19 | 20 | def __init__(self, model, output_names=None, **kwargs): 21 | """ Build a new model by wrapping the given model object. 22 | 23 | Parameters 24 | ---------- 25 | model : object 26 | The model to wrap. This can be a plain python function that accepts a list of strings and returns either 27 | a vector of probabilities or another string. It can also be a transformers pipeline object (we try to wrap 28 | common model types transparently). 29 | 30 | output_names : list of str, optional 31 | The names of the outputs of the model. If not given, we try to infer them from the model. 32 | """ 33 | 34 | # finish early if we are wrapping an object that is already a Model 35 | if shap.utils.safe_isinstance(model, "adaptivetesting.Model") or shap.utils.safe_isinstance(model, "shap.models.Model"): 36 | if output_names is not None: 37 | self.output_names = output_names 38 | assert len(kwargs) == 0 39 | return 40 | 41 | # get outputs names from the model if it has them and we don't 42 | if output_names is None and hasattr(model, "output_names"): 43 | output_names = model.output_names 44 | 45 | # If we are in the base class we check to see if we should rebuild the model as a specialized subclass 46 | if self.__class__ is Model: 47 | 48 | # wrap transformer pipeline objects for convenience 49 | if shap.utils.safe_isinstance(model, "transformers.pipelines.text_classification.TextClassificationPipeline"): 50 | self.__class__ = shap.models.TransformersPipeline 51 | shap.models.TransformersPipeline.__init__(self, model, **kwargs) 52 | if output_names is not None: # Override output names if user supplied 53 | self.output_names = output_names 54 | 55 | elif shap.utils.safe_isinstance(model, "transformers.pipelines.text_generation.TextGenerationPipeline"): 56 | self.__class__ = TransformersTextGenerationPipeline 57 | TransformersTextGenerationPipeline.__init__(self, model, **kwargs) 58 | 59 | else: 60 | self.inner_model = model 61 | self.output_names = output_names 62 | 63 | def __call__(self, *args, **kwargs): 64 | return np.array(self.inner_model(*args, **kwargs)) 65 | 66 | 67 | class TransformersTextGenerationPipeline(Model): 68 | """ This wraps the transformer text generation pipeline object to match the Model API. 69 | 70 | TODO: move this to SHAP. 71 | """ 72 | def __init__(self, pipeline): 73 | self._inner_model = pipeline 74 | self.output_names = None 75 | 76 | def __call__(self, strings): 77 | inner_out = self._inner_model(strings) 78 | out = [] 79 | for s, data in zip(strings, inner_out): 80 | out.append(data[0]["generated_text"][len(s):]) # remove the input text from the output 81 | return out 82 | -------------------------------------------------------------------------------- /client/src/jupyter-comm.js: -------------------------------------------------------------------------------- 1 | import JSON5 from 'json5'; 2 | import autoBind from 'auto-bind'; 3 | import { defer, debounce } from 'lodash'; 4 | 5 | export default class JupyterComm { 6 | constructor(interfaceId, onopen) { 7 | autoBind(this); 8 | this.interfaceId = interfaceId; 9 | this.callbackMap = {}; 10 | this.data = {}; 11 | this.pendingData = {}; 12 | this.jcomm = new InnerJupyterComm('adatest_interface_target_'+this.interfaceId, this.updateData); 13 | 14 | this.debouncedSendPendingData500 = debounce(this.sendPendingData, 500); 15 | this.debouncedSendPendingData1000 = debounce(this.sendPendingData, 1000); 16 | if (onopen) { 17 | defer(onopen); 18 | } 19 | } 20 | 21 | send(keys, data) { 22 | this.addPendingData(keys, data); 23 | this.sendPendingData(); 24 | } 25 | 26 | sendEvent(commEvent) { 27 | for (const k of Object.keys(commEvent)) { 28 | this.addPendingData(k, commEvent[k]); 29 | } 30 | this.sendPendingData(); 31 | } 32 | 33 | debouncedSendEvent500(commEvent) { 34 | for (const k of Object.keys(commEvent)) { 35 | this.addPendingData(k, commEvent[k]); 36 | } 37 | this.debouncedSendPendingData500(); 38 | } 39 | 40 | debouncedSend500(keys, data) { 41 | this.addPendingData(keys, data); 42 | this.debouncedSendPendingData500(); 43 | } 44 | 45 | debouncedSend1000(keys, data) { 46 | this.addPendingData(keys, data); 47 | this.debouncedSendPendingData1000(); 48 | } 49 | 50 | addPendingData(keys, data) { 51 | 52 | // console.log("addPendingData", keys, data); 53 | if (!Array.isArray(keys)) keys = [keys]; 54 | for (const i in keys) this.pendingData[keys[i]] = data; 55 | } 56 | 57 | updateData(data) { 58 | data = JSON5.parse(data["data"]) // data from Jupyter is wrapped so we get to do our own JSON encoding 59 | console.log("updateData", data) 60 | 61 | // save the data locally 62 | for (const k in data) { 63 | this.data[k] = data[k]; 64 | } 65 | 66 | // call all the registered callbacks 67 | for (const k in data) { 68 | if (k in this.callbackMap) { 69 | this.callbackMap[k](this.data[k]); 70 | } 71 | } 72 | } 73 | 74 | subscribe(key, callback) { 75 | this.callbackMap[key] = callback; 76 | defer(_ => this.callbackMap[key](this.data[key])); 77 | } 78 | 79 | sendPendingData() { 80 | console.log("sending", this.pendingData); 81 | this.jcomm.send_data(this.pendingData); 82 | this.pendingData = {}; 83 | } 84 | } 85 | 86 | class InnerJupyterComm { 87 | constructor(target_name, callback, mode="open") { 88 | this._fire_callback = this._fire_callback.bind(this); 89 | this._register = this._register.bind(this) 90 | 91 | this.jcomm = undefined; 92 | this.callback = callback; 93 | 94 | // https://jupyter-notebook.readthedocs.io/en/stable/comms.html 95 | if (mode === "register") { 96 | Jupyter.notebook.kernel.comm_manager.register_target(target_name, this._register); 97 | } else { 98 | this.jcomm = Jupyter.notebook.kernel.comm_manager.new_comm(target_name); 99 | this.jcomm.on_msg(this._fire_callback); 100 | } 101 | } 102 | 103 | send_data(data) { 104 | if (this.jcomm !== undefined) { 105 | this.jcomm.send(data); 106 | } else { 107 | console.error("Jupyter comm module not yet loaded! So we can't send the message.") 108 | } 109 | } 110 | 111 | _register(jcomm, msg) { 112 | this.jcomm = jcomm; 113 | this.jcomm.on_msg(this._fire_callback); 114 | } 115 | 116 | _fire_callback(msg) { 117 | this.callback(msg.content.data) 118 | } 119 | } 120 | 121 | // const comm = JupyterComm(); 122 | 123 | // // Jupyter.notebook.kernel.comm_manager.register_target('gadfly_comm_target', 124 | // // function(jcomm, msg) { 125 | // // // comm is the frontend comm instance 126 | // // // msg is the comm_open message, which can carry data 127 | 128 | // // comm.jcomm = jcomm 129 | 130 | // // // Register handlers for later messages: 131 | // // inner_comm.on_msg(function(msg) { console.log("MSGG", msg); }); 132 | // // inner_comm.on_close(function(msg) { console.log("MSGdG", msg); }); 133 | // // comm.send({'foo': 0}); 134 | // // } 135 | // // ); 136 | 137 | // export default comm; -------------------------------------------------------------------------------- /tests/test_generators.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pytest 3 | 4 | from transformers import pipeline 5 | 6 | from adaptivetesting import generators 7 | 8 | TRANSFORMER_PIPELINE_MODELS = ["EleutherAI/gpt-neo-125M"] 9 | 10 | 11 | class TestTransformers: 12 | @pytest.mark.parametrize("model_name", TRANSFORMER_PIPELINE_MODELS) 13 | def test_smoke(self, model_name): 14 | hf_model = pipeline("text-generation", model=model_name) 15 | target = generators.Transformers(hf_model.model, hf_model.tokenizer) 16 | 17 | prompts = [ 18 | ("id A", "", "Great hotel"), 19 | ("id B", "", "Bathroom too small"), 20 | ] 21 | 22 | desired_result_count = 2 23 | 24 | results = target( 25 | prompts=prompts, 26 | topic="", 27 | mode="tests", 28 | topic_description="", 29 | num_samples=desired_result_count, 30 | ) 31 | assert results is not None 32 | assert len(results) == desired_result_count 33 | for item in results: 34 | assert isinstance(item, str) 35 | 36 | @pytest.mark.parametrize("model_name", TRANSFORMER_PIPELINE_MODELS) 37 | def test_with_topics(self, model_name): 38 | hf_model = pipeline("text-generation", model=model_name) 39 | target = generators.Transformers(hf_model.model, hf_model.tokenizer) 40 | 41 | prompts = [ 42 | ("id A", "some string", "Great hotel"), 43 | ("id B", "some_string", "Bathroom too small"), 44 | ] 45 | 46 | desired_result_count = 2 47 | 48 | results = target( 49 | prompts=prompts, 50 | topic="", 51 | mode="tests", 52 | topic_description="", 53 | num_samples=desired_result_count, 54 | ) 55 | assert results is not None 56 | assert len(results) == desired_result_count 57 | for item in results: 58 | assert isinstance(item, str) 59 | 60 | 61 | GENERATOR_PIPELINE_MODELS = [ 62 | "facebook/opt-125m", 63 | "facebook/opt-350m", 64 | "EleutherAI/gpt-neo-125M", 65 | "gpt2", 66 | "bigscience/bloom-560m", 67 | ] 68 | 69 | 70 | class TestPipelines: 71 | @pytest.mark.parametrize( 72 | "model_name", 73 | GENERATOR_PIPELINE_MODELS, 74 | ) 75 | def test_smoke(self, model_name): 76 | hf_pipeline = pipeline("text-generation", model=model_name) 77 | target = generators.Pipelines(hf_pipeline) 78 | 79 | prompts = [ 80 | ("id A", "", "Great hotel"), 81 | ("id B", "", "Bathroom too small"), 82 | ] 83 | 84 | desired_result_count = 2 85 | 86 | results = target( 87 | prompts=prompts, 88 | topic="some topic", 89 | mode="tests", 90 | topic_description="", 91 | num_samples=desired_result_count, 92 | ) 93 | assert results is not None 94 | assert len(results) == desired_result_count 95 | for item in results: 96 | assert isinstance(item, str) 97 | 98 | @pytest.mark.parametrize( 99 | "model_name", 100 | GENERATOR_PIPELINE_MODELS, 101 | ) 102 | def test_with_topics(self, model_name): 103 | hf_pipeline = pipeline("text-generation", model=model_name) 104 | target = generators.Pipelines(hf_pipeline) 105 | 106 | prompts = [ 107 | ("id A", "some_string", "Great hotel"), 108 | ("id B", "some_string", "Bathroom too small"), 109 | ] 110 | 111 | desired_result_count = 2 112 | 113 | results = target( 114 | prompts=prompts, 115 | topic="some topic", 116 | mode="tests", 117 | topic_description="", 118 | num_samples=desired_result_count, 119 | ) 120 | assert results is not None 121 | assert len(results) == desired_result_count 122 | for item in results: 123 | assert isinstance(item, str) 124 | 125 | 126 | class TestOpenAI: 127 | def test_smoke(self, mocker): 128 | OPENAI_API_KEY = "Not for you, CredScan" 129 | 130 | openai_completion = mocker.patch("openai.Completion", autospec=True) 131 | patched_response = {"choices": [{"text": "Ret 1"}, {"text": "Ret 2"}]} 132 | openai_completion.create.return_value = patched_response 133 | 134 | target = generators.OpenAI("curie", api_key=OPENAI_API_KEY) 135 | 136 | prompts = [ 137 | ("id A", "", "Great hotel"), 138 | ("id B", "", "Bathroom too small"), 139 | ] 140 | 141 | desired_result_count = 2 142 | 143 | results = target( 144 | prompts=prompts, 145 | topic="", 146 | topic_description="", 147 | mode="tests", 148 | num_samples=desired_result_count, 149 | scorer=None, 150 | ) 151 | assert results is not None 152 | assert len(results) == desired_result_count 153 | assert "Ret 1" in results 154 | assert "Ret 2" in results 155 | openai_completion.create.assert_called_with( 156 | model="curie", 157 | prompt=['"Great hotel"\n"Bathroom too small"\n"'], 158 | user="adaptivetesting", 159 | max_tokens=100, 160 | temperature=1.0, 161 | top_p=0.95, 162 | n=2, 163 | stop='"', 164 | ) 165 | -------------------------------------------------------------------------------- /client/src/adatest.css: -------------------------------------------------------------------------------- 1 | .adatest-button-container { 2 | display: block; 3 | width: 100%; 4 | padding: 3px; 5 | padding-top: 1px; 6 | padding-bottom: 1px; 7 | font-size: 13px; 8 | } 9 | .adatest-button-children { 10 | display: flex; 11 | vertical-align: top; 12 | width: 100%; 13 | } 14 | 15 | .adatest-button-box { 16 | padding: 2px; 17 | padding-left: 5px; 18 | padding-right: 5px; 19 | display: flex; 20 | cursor: pointer; 21 | text-align: center; 22 | width: 100%; 23 | margin-bottom: 3px; 24 | border: 1px solid #dddddd; 25 | } 26 | 27 | .adatest-explore-toggle { 28 | float: right; 29 | color: #666666; 30 | margin-top: 4px; 31 | margin-left: 4px; 32 | width: 30px; 33 | } 34 | 35 | .adatest-add-topic { 36 | color: #bbbbbb; 37 | margin-top: 4px; 38 | margin-left: 7px; 39 | margin-right: 7px; 40 | cursor: pointer; 41 | } 42 | .adatest-add-topic-wrapper { 43 | margin-top: 4px; 44 | margin-left: auto; 45 | margin-right: auto; 46 | } 47 | 48 | .adatest-button-label { 49 | flex: 1; 50 | display: inline-block; 51 | text-align: center; 52 | user-select: none; 53 | } 54 | 55 | .adatest-topic-label { 56 | width: 300px; 57 | border: 0px; 58 | border-bottom: 1px solid #999999; 59 | text-align: left; 60 | margin-top: 10px; 61 | margin-bottom: 10px; 62 | outline: 0px solid transparent; 63 | background: transparent; 64 | } 65 | .adatest-topic-label:focus { 66 | outline: 0px solid transparent; 67 | } 68 | 69 | .adatest-children-frame { 70 | padding: 0px; 71 | margin-top: 10px; 72 | /* padding-bottom: 10px; */ 73 | border-radius: 7px 7px 7px 7px; 74 | border: 1px solid rgb(216, 222, 228); 75 | } 76 | 77 | .adatest-row-child { 78 | display: flex; 79 | padding: 0px; 80 | padding-left: 5px; 81 | padding-right: 0px; 82 | min-height: 30px; 83 | text-align: left; 84 | border-radius: 0px;/*10px 10px 10px 10px;*/ 85 | outline: none; 86 | padding-top: 5px; 87 | padding-bottom: 5px; 88 | } 89 | .adatest-row-hidden { 90 | text-decoration: line-through; 91 | } 92 | 93 | .adatest-row-score-plot-box { 94 | float: right; 95 | height: 30px; 96 | flex: 0 0 150px; 97 | margin-right: 10px; 98 | padding-top: 0px; 99 | align-self: center; 100 | } 101 | 102 | .adatest-row-add-button { 103 | margin-top: 8px; 104 | flex: 0 0 20px; 105 | display: inline-block; 106 | margin-left: 10px; 107 | } 108 | 109 | .adatest-top-add-button { 110 | margin-top: 8px; 111 | flex: 0 0 20px; 112 | display: inline-block; 113 | margin-left: 10px; 114 | } 115 | 116 | [contenteditable] { 117 | -webkit-user-select: text; 118 | user-select: text; 119 | } 120 | 121 | /* [contenteditable]:focus { 122 | outline-width: 0px; 123 | } */ 124 | 125 | .adatest-plain-select { 126 | margin-left: 4px; 127 | appearance: none; 128 | border: none; 129 | background: none; 130 | -webkit-appearance: none; 131 | font-size: 13px; 132 | font-family: Helvetica Neue, Helvetica, Arial, sans-serif; 133 | margin: 0px; 134 | } 135 | .adatest-plain-select:focus { 136 | outline: 0px solid transparent; 137 | } 138 | 139 | .adatest-hover-opacity { 140 | opacity: 0.1; 141 | } 142 | .adatest-hover-opacity:hover { 143 | opacity: 0.6; 144 | } 145 | 146 | .adatest-scroll-wrap::-webkit-scrollbar { 147 | display: none; 148 | } 149 | 150 | .adatest-row-hide-button { 151 | opacity: 0; 152 | margin-top: 8px; 153 | /* line-height: 30px; */ 154 | flex: 0 0 15px; 155 | display: inline-block; 156 | margin-left: 10px; 157 | } 158 | .adatest-row-hide-button:hover { 159 | opacity: 0.6 160 | } 161 | .adatest-row-hide-hovering { 162 | opacity: 0.1; 163 | } 164 | .adatest-row-hide-hidden { 165 | opacity: 0.6; 166 | } 167 | 168 | .adatest-row-score-text-box { 169 | float: right; 170 | text-align: right; 171 | flex: 0 0 50px; 172 | line-height: 30px; 173 | height: 30px; 174 | color: #999999; 175 | padding-right: 5px; 176 | font-size: 12px; 177 | overflow: hidden; 178 | box-sizing: border-box; 179 | } 180 | 181 | .adatest-main-table { 182 | margin-left: "auto"; 183 | margin-right: "auto"; 184 | background: #ffffff; 185 | width: 100%; 186 | } 187 | 188 | .adatest-main-table tbody tr:nth-child(odd) { 189 | background: #ffffff; 190 | } 191 | 192 | .adatest-main-table tbody tr:hover { 193 | background: #ffffff; 194 | } 195 | 196 | 197 | .adatest-output-text:focus { 198 | outline: 0px solid transparent; 199 | } 200 | 201 | .adatest-row-selected { 202 | background: #00000008; 203 | /* border-left: 3px solid #00000099; */ 204 | /* padding-left: 2px; */ 205 | } 206 | 207 | .adatest-select-width-calculation { 208 | position: absolute; 209 | visibility: hidden; 210 | height: auto; 211 | width: auto; 212 | white-space: nowrap; 213 | font-size: 13px; 214 | font-family: Helvetica Neue, Helvetica, Arial, sans-serif; 215 | white-space: pre; 216 | } 217 | 218 | .adatest-row { 219 | display: flex; 220 | margin-top: 3px; 221 | } 222 | .adatest-row-input { 223 | flex: 1; 224 | display: flex; 225 | justify-content: right; 226 | align-items: center; 227 | } 228 | .adatest-row-editing { 229 | border-bottom: 0px dashed rgb(0, 130, 248); 230 | } 231 | .adatest-row-editing:focus { 232 | outline: 0px solid transparent; 233 | } 234 | .adatest-row-drop-highlighted { 235 | background: #dddddd; 236 | } 237 | .adatest-row-hover-highlighted { 238 | background: #f0f0f0; 239 | } 240 | .adatest-row-dragging { 241 | background: #dddddd; 242 | } 243 | 244 | 245 | .adatest-crum-selected { 246 | background: #dddddd; 247 | } 248 | 249 | .adatest-editable:focus { 250 | outline: 0px solid transparent; 251 | } 252 | 253 | .adatest-suggestions-box { 254 | position: relative; 255 | border-radius: 7px 7px 7px 7px; 256 | background: rgb(246, 248, 250); 257 | text-align: center; 258 | padding: 0px; 259 | padding-top: 10px; 260 | margin-top: 10px; 261 | padding-bottom: 0px; 262 | border: 1px solid rgb(216, 222, 228); 263 | } 264 | .adatest-suggestions-box-after { 265 | content: 'werwerwe'; 266 | width: 100%; 267 | height: 22px; 268 | position: absolute; 269 | bottom: 0px; 270 | background: linear-gradient(rgba(246, 248, 250, 0) 0px, rgb(246, 248, 250, 1) 18px); 271 | pointer-events: none; 272 | border-radius: 0px 0px 7px 7px; 273 | } 274 | .adatest-drop-highlighted { 275 | background: #dddddd; 276 | } 277 | 278 | .adatest-hover-gray:hover { 279 | background: #dddddd; 280 | } -------------------------------------------------------------------------------- /adaptivetesting/embedders.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import adaptivetesting 3 | from sklearn.preprocessing import normalize 4 | import appdirs 5 | import diskcache 6 | _embedding_memory_cache = {} 7 | _embedding_file_cache = diskcache.Cache(appdirs.user_cache_dir("adaptivetesting") + "/embeddings.diskcache") 8 | 9 | def _embed(strings, normalize=True): 10 | 11 | # find which strings are not in the cache 12 | new_text_strings = [] 13 | new_image_urls = [] 14 | text_prefix = _text_embedding_model().name # TODO: need to figure out how to do the same for image embedding, but only when needed 15 | for s in strings: 16 | if s.startswith("__IMAGE="): 17 | prefixed_s = s 18 | else: 19 | prefixed_s = text_prefix + s 20 | if prefixed_s not in _embedding_memory_cache: 21 | if prefixed_s not in _embedding_file_cache: 22 | if s.startswith("__IMAGE="): 23 | new_image_urls.append(s) 24 | else: 25 | new_text_strings.append(s) 26 | _embedding_memory_cache[prefixed_s] = None # so we don't embed the same string twice 27 | else: 28 | _embedding_memory_cache[prefixed_s] = _embedding_file_cache[prefixed_s] 29 | 30 | # embed the new text strings 31 | if len(new_text_strings) > 0: 32 | new_embeds = _text_embedding_model()(new_text_strings) 33 | for i,s in enumerate(new_text_strings): 34 | prefixed_s = text_prefix + s 35 | if normalize: 36 | _embedding_memory_cache[prefixed_s] = new_embeds[i] / np.linalg.norm(new_embeds[i]) 37 | else: 38 | _embedding_memory_cache[prefixed_s] = new_embeds[i] 39 | _embedding_file_cache[prefixed_s] = _embedding_memory_cache[prefixed_s] 40 | 41 | # embed the new image urls 42 | if len(new_image_urls) > 0: 43 | new_embeds = _image_embedding_model()([url[8:] for url in new_image_urls]) 44 | for i,s in enumerate(new_image_urls): 45 | if normalize: 46 | _embedding_memory_cache[s] = new_embeds[i] / np.linalg.norm(new_embeds[i]) 47 | else: 48 | _embedding_memory_cache[s] = new_embeds[i] 49 | _embedding_file_cache[s] = _embedding_memory_cache[s] 50 | 51 | return [_embedding_memory_cache[s if s.startswith("__IMAGE=") else text_prefix + s] for s in strings] 52 | 53 | def _text_embedding_model(): 54 | """ Get the text embedding model. 55 | 56 | Much of this code block is from the sentence_transformers documentation. 57 | """ 58 | if adaptivetesting.text_embedding_model is None: 59 | 60 | # # get the modules we need to compute embeddings 61 | # import torch 62 | # import transformers 63 | 64 | # # Mean Pooling - Take attention mask into account for correct averaging 65 | # def mean_pooling(model_output, attention_mask): 66 | # token_embeddings = model_output[0] # First element of model_output contains all token embeddings 67 | # input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 68 | # return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) 69 | 70 | # # Load model from HuggingFace Hub 71 | # tokenizer = transformers.AutoTokenizer.from_pretrained('sentence-transformers/stsb-roberta-base-v2') 72 | # model = transformers.AutoModel.from_pretrained('sentence-transformers/stsb-roberta-base-v2') 73 | 74 | # # Tokenize sentences 75 | # def embed_model(sentences): 76 | # encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt') 77 | 78 | # # Compute token embeddings 79 | # with torch.no_grad(): 80 | # model_output = model(**encoded_input) 81 | 82 | # # Perform pooling. In this case, max pooling. 83 | # return mean_pooling(model_output, encoded_input['attention_mask']).cpu().numpy() 84 | 85 | adaptivetesting.text_embedding_model = TransformersTextEmbedding() 86 | 87 | return adaptivetesting.text_embedding_model 88 | 89 | def _image_embedding_model(): 90 | if adaptivetesting.image_embedding_model is None: 91 | import clip # pylint: disable=import-outside-toplevel 92 | import torch 93 | 94 | model, preprocess = clip.load("ViT-L/14", device="cpu", jit=True) 95 | 96 | def embed_model(urls): 97 | with torch.no_grad(): 98 | out = [] 99 | for url in urls: 100 | image = adaptivetesting.utils.get_image(url) 101 | image_emb = model.encode_image(preprocess(image).unsqueeze(0).to("cpu")) 102 | image_emb /= image_emb.norm(dim=-1, keepdim=True) 103 | image_emb = image_emb.cpu().detach().numpy().astype("float32")[0] 104 | out.append(image_emb) 105 | return np.vstack(out) 106 | 107 | adaptivetesting.image_embedding_model = embed_model 108 | 109 | return adaptivetesting.image_embedding_model 110 | 111 | def cos_sim(a, b): 112 | """ Cosine distance between two vectors. 113 | """ 114 | return normalize(a, axis=1) @ normalize(b, axis=1).T 115 | 116 | class TransformersTextEmbedding(): 117 | def __init__(self, model="sentence-transformers/stsb-roberta-base-v2"): 118 | import transformers 119 | self.tokenizer = transformers.AutoTokenizer.from_pretrained(model) 120 | self.model = transformers.AutoModel.from_pretrained(model) 121 | self.model_name = model 122 | self.name = "adaptivetesting.embedders.TransformersTextEmbedding(" + self.model_name + "):" 123 | 124 | def __call__(self, strings): 125 | import torch 126 | 127 | encoded_input = self.tokenizer(strings, padding=True, truncation=True, return_tensors='pt') 128 | 129 | # Compute token embeddings 130 | with torch.no_grad(): 131 | model_output = self.model(**encoded_input) 132 | 133 | # Perform mean pooling 134 | token_embeddings = model_output[0] # First element of model_output contains all token embeddings 135 | input_mask_expanded = encoded_input['attention_mask'].unsqueeze(-1).expand(token_embeddings.size()).float() 136 | embeds = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) 137 | return embeds.cpu().numpy() 138 | 139 | class OpenAITextEmbedding(): 140 | def __init__(self, model="text-similarity-babbage-001", api_key=None, replace_newlines=True): 141 | import openai 142 | self.model = model 143 | if api_key is not None: 144 | openai.api_key = api_key 145 | self.replace_newlines = replace_newlines 146 | self.model_name = model 147 | self.name = "adaptivetesting.embedders.OpenAITextEmbedding(" + self.model_name + "):" 148 | 149 | def __call__(self, strings): 150 | import openai 151 | 152 | if len(strings) == 0: 153 | return np.array([]) 154 | 155 | # clean the strings for OpenAI 156 | cleaned_strings = [] 157 | for s in strings: 158 | if s == "": 159 | s = " " # because OpenAI doesn't like empty strings 160 | elif self.replace_newlines: 161 | s = s.replace("\n", " ") # OpenAI recommends this for things that are not code 162 | cleaned_strings.append(s) 163 | 164 | # call the OpenAI API to complete the prompts 165 | response = openai.Embedding.create( 166 | input=cleaned_strings, model=self.model, user="adatest" 167 | ) 168 | 169 | return np.vstack([e["embedding"] for e in response["data"]]) 170 | -------------------------------------------------------------------------------- /client/src/content-editable.jsx: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import autoBind from 'auto-bind'; 3 | import sanitizeHtml from 'sanitize-html'; 4 | import { defer } from 'lodash'; 5 | 6 | export default class ContentEditable extends React.Component { 7 | static defaultProps = { 8 | editable: true, 9 | defaultText: "", 10 | finishOnReturn: false 11 | }; 12 | 13 | constructor(props) { 14 | super(props); 15 | autoBind(this); 16 | this.lastText = null; 17 | 18 | this.divRef = {}; 19 | window["cedit_"+this.props.id] = this; 20 | } 21 | 22 | render() { 23 | //console.log("this.props.text", this.props.text) 24 | const emptyContent = this.props.text === undefined || this.props.text.length === 0; 25 | this.lastEditable = this.props.editable; 26 | if (this.lastText === null) this.lastText = this.props.text; 27 | return
this.divRef = el} 29 | id={this.props.id} 30 | style={{opacity: emptyContent ? "0.3" : "1", display: "inline", overflowWrap: "anywhere", whiteSpace: "pre-wrap"}} 31 | onFocus={this.onFocus} 32 | onInput={this.handleInput} 33 | onKeyPress={this.handleKeyPress} 34 | onKeyDown={this.handleKeyDown} 35 | onBlur={this.onBlur} 36 | onDoubleClick={this.handleDoubleClick} 37 | onDragStart={this.stopDrag} 38 | onClick={this.onClick} 39 | contentEditable={this.props.editable} 40 | className="adatest-editable" 41 | dangerouslySetInnerHTML={{__html: sanitizeHtml(emptyContent ? this.props.defaultText : this.props.text)}} 42 | tabIndex="0" 43 | >
44 | } 45 | 46 | stopDrag(e) { 47 | console.log("stopDrag") 48 | e.preventDefault(); 49 | return false; 50 | } 51 | 52 | handleDoubleClick(e) { 53 | const range = getMouseEventCaretRange(e); 54 | console.log("handleDoubleClick", range, e) 55 | } 56 | 57 | focus() { 58 | 59 | // we blur without triggering an action so that we can refocus 60 | // this is important to get the cursor to come back sometimes 61 | this.skipBlurAction = true; 62 | this.divRef.blur(); 63 | this.skipBlurAction = false; 64 | 65 | this.divRef.focus(); 66 | } 67 | 68 | blur() { 69 | this.divRef.blur(); 70 | } 71 | 72 | onFocus(e) { 73 | console.log("onFocus in ContentEditable", this.props.text); 74 | 75 | // if (!this.props.editing) return; 76 | 77 | if (this.props.text !== this.props.defaultText && this.divRef.textContent === this.props.defaultText) { 78 | e.preventDefault(); 79 | e.stopPropagation(); 80 | this.divRef.textContent = ""; 81 | if (this.props.onClick) this.props.onClick(e); // why we need this is crazy to me, seems like setting inner text kills the click event 82 | // defer(() => this.focus()); 83 | console.log("clear!!", this.props.editable) 84 | defer(() => this.focus()); 85 | } 86 | } 87 | 88 | onClick(e) { 89 | // console.log("onClick in ContentEditable", this.props.onClick) 90 | if (this.props.onClick) { 91 | e.preventDefault(); 92 | e.stopPropagation(); 93 | this.props.onClick(e); 94 | } 95 | e.stopPropagation(); 96 | } 97 | 98 | getValue() { 99 | const text = this.divRef.textContent; 100 | if (text === this.props.defaultText) return ""; 101 | else return text; 102 | } 103 | 104 | shouldComponentUpdate(nextProps) { 105 | return nextProps.text !== this.divRef.textContent && (nextProps.text != "" || this.divRef.textContent != this.props.defaultText) || nextProps.editable != this.lastEditable; 106 | } 107 | 108 | componentDidUpdate() { 109 | this.componentDidUpdateOrMount(false); 110 | } 111 | 112 | componentDidMount() { 113 | this.componentDidUpdateOrMount(true); 114 | } 115 | 116 | componentDidUpdateOrMount(mount) { 117 | // console.log("ContentEditable componentDidUpdateOrMount", mount, this.props.text, this.props.editable); 118 | if (this.props.text !== this.divRef.textContent) { 119 | if (this.props.text !== undefined && this.props.text !== null && (this.props.text.length > 0 || this.divRef.textContent !== this.props.defaultText)) { 120 | this.divRef.textContent = this.props.text; 121 | } else { 122 | if (mount) this.divRef.textContent = this.props.defaultText; 123 | } 124 | } 125 | if (this.props.text && (this.props.text.startsWith("New topic") || this.props.text === "New test") && this.props.editable) { // hacky but works for now 126 | // console.log("HACK!", this.props.text) 127 | this.divRef.focus(); 128 | selectElement(this.divRef); 129 | // document.execCommand('selectAll', false, null); 130 | } 131 | } 132 | 133 | handleInput(e, finishing) { 134 | console.log("handleInput", finishing, this.divRef.textContent) 135 | const text = this.divRef.textContent; 136 | if (this.props.onInput && text !== this.lastText) { 137 | this.props.onInput(text); 138 | this.lastText = text; 139 | } 140 | 141 | if (finishing && this.props.onFinish) { 142 | this.props.onFinish(text); 143 | } 144 | 145 | if (text === this.props.defaultText) this.divRef.style.opacity = 0.3; 146 | else this.divRef.style.opacity = 1.0; 147 | } 148 | 149 | onBlur(e) { 150 | console.log("onBlur in ContentEditable", this.divRef.textContent, this.skipBlurAction) 151 | if (this.skipBlurAction) return; 152 | // if (this.divRef.textContent.length === this.props.defaultText) { 153 | // this.divRef.textContent = ""; 154 | // } 155 | this.handleInput(e, true); 156 | if (this.divRef.textContent.length === 0) { 157 | this.divRef.textContent = this.props.defaultText; 158 | this.divRef.style.opacity = 0.3; 159 | } 160 | } 161 | 162 | handleKeyPress(e) { 163 | 164 | console.log("handleKeyPress", e.charCode, this.props.finishOnReturn) 165 | e.stopPropagation(); 166 | if (e.charCode == 13 && this.props.finishOnReturn) { 167 | e.preventDefault(); 168 | 169 | this.handleInput(e, true); 170 | } 171 | } 172 | 173 | handleKeyDown(e) { 174 | console.log("handleKeyDown", e.charCode, this.props.finishOnReturn) 175 | // only let the enter/return key go through 176 | if (e.charCode != 13 || !this.props.finishOnReturn) e.stopPropagation(); 177 | } 178 | } 179 | 180 | function selectElement(element){ 181 | var doc = document; 182 | console.log(this, element); 183 | if (doc.body.createTextRange) { 184 | var range = document.body.createTextRange(); 185 | range.moveToElementText(element); 186 | range.select(); 187 | } else if (window.getSelection) { 188 | var selection = window.getSelection(); 189 | var range = document.createRange(); 190 | range.selectNodeContents(element); 191 | selection.removeAllRanges(); 192 | selection.addRange(range); 193 | } 194 | } 195 | 196 | function setCaret(el, pos) { 197 | var range = document.createRange(); 198 | var sel = window.getSelection(); 199 | 200 | range.setStart(el, pos) 201 | range.collapse(true) 202 | 203 | sel.removeAllRanges() 204 | sel.addRange(range) 205 | } 206 | document.setCaret = setCaret; 207 | 208 | function findParentWithClass(el, className) { 209 | const orig_el = el; 210 | while (el && !el.className.includes(className)) { 211 | el = el.parentElement; 212 | } 213 | return el ? el : orig_el; 214 | } 215 | 216 | function getMouseEventCaretRange(evt) { 217 | var range, x = evt.clientX, y = evt.clientY; 218 | 219 | // Try the simple IE way first 220 | if (document.body.createTextRange) { 221 | range = document.body.createTextRange(); 222 | range.moveToPoint(x, y); 223 | } 224 | 225 | else if (typeof document.createRange != "undefined") { 226 | // Try Mozilla's rangeOffset and rangeParent properties, 227 | // which are exactly what we want 228 | if (typeof evt.rangeParent != "undefined") { 229 | range = document.createRange(); 230 | range.setStart(evt.rangeParent, evt.rangeOffset); 231 | range.collapse(true); 232 | } 233 | 234 | // Try the standards-based way next 235 | else if (document.caretPositionFromPoint) { 236 | var pos = document.caretPositionFromPoint(x, y); 237 | range = document.createRange(); 238 | range.setStart(pos.offsetNode, pos.offset); 239 | range.collapse(true); 240 | } 241 | 242 | // Next, the WebKit way 243 | else if (document.caretRangeFromPoint) { 244 | range = document.caretRangeFromPoint(x, y); 245 | } 246 | } 247 | 248 | return range; 249 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # adaptive-testing 2 | adaptive-testing uses language models against themselves to build suites of unit tests. It is an interative (and fun!) process between a user and a language model that results in a tree of unit tests specifically adapted to the model you are testing. Fixing any failed tests with fine-tuning then leads to an iterative debugging process similar to traditional software development. See paper for details. 3 | 4 |

5 | adaptive-testing loops
6 | Note, adaptive-testing is currently a beta release so please share any issues you encounter. 7 |

8 | 9 | ## Install 10 | 11 | ``` 12 | pip install adatest 13 | ``` 14 | 15 | ## Sentiment analysis example 16 | 17 | adaptive-testing can test any NLP model you can call with a python function, here we will test a basic open source sentiment analysis model. Since adaptive-testing relies on a generative language model to help you create tests, you need to specify what generative model it will use, here we use GPT-3 from OpenAI or GPT-Neo locally. Tests are organized into a test tree that follows the DataFrame API and is organized like a file system, here we create a new empty tree, but you can also start with a previous test tree that targets a similar task. The core adaptive-testing loop starts when you call the `.adapt()` method on a test tree passing the model(s) you want to test and the backend generator you want to use. The code for all this is below: 18 | 19 | ```python 20 | import transformers 21 | import adaptivetesting 22 | 23 | # create a HuggingFace sentiment analysis model 24 | classifier = transformers.pipeline("sentiment-analysis", return_all_scores=True) 25 | 26 | # specify the backend generator used to help you write tests 27 | generator = adaptivetesting.generators.OpenAI('curie', api_key=OPENAI_API_KEY) 28 | 29 | # ...or you can use an open source generator 30 | #neo = transformers.pipeline('text-generation', model="EleutherAI/gpt-neo-125M") 31 | #generator = adaptivetesting.generators.Transformers(neo.model, neo.tokenizer) 32 | 33 | # create a new test tree 34 | tests = adaptivetesting.TestTree("hotel_reviews.csv") 35 | 36 | # adapt the tests to our model to launch a notebook-based testing interface 37 | # (wrap with adaptivetesting.serve to launch a standalone server) 38 | tests.adapt(classifier, generator, auto_save=True) 39 | ``` 40 | 41 |

42 | adaptive-testing loops 43 |

44 | 45 | Once we have launched a test tree browser, we can use the interface to create new topics and tests. Here we create the topic "/Clear positives/Location" to test how well this model classifies clearly positive statements about a hotel's location. We then add a few starting examples of what we want to see in this topic (clearly positive statements about hotel location): 46 | 47 |

48 | adaptive-testing loops 49 |

50 | 51 | Each test consists of a model input, a model output, a pass/fail label, and a score for the current target model. The input text should fall within the scope of the current topic, which here means it is a clearly positive statement about hotel locations. The output text is what the target model we are testing generated (or it can be manually specified, in which case it turns light grey to show it does not reflect the current model behavior). The label is a pass/fail indicator that denotes if the model output is correct with respect to the aspect being tested in the current topic, in our case the model was correct for all the inputs we entered. The model score represents if the testing model passes or fails and how confident the model is when producing the current output. 52 | 53 | Note that in the above figure all the label indicators are hollow, this means that we have not yet labeled these examples, and adaptive-testing is just guessing that they are correct. They are all correct so can click the checkmarks to confirm and label all these examples. By confirming we teach adaptive-testing more about what we want this topic to test, so it becomes better at predicting future labels, and hence automating the testing process. Once we label these examples we can then click "Suggestions" and adaptive-testing will attempt to write new in-topic examples for us, labeling them and sorting then by score so we can see the most likely failures at the top of the list. 54 | 55 |

56 | adaptive-testing loops 57 |

58 | 59 | Starting at the top of the list we can confirm or change the label for each suggestion and so add them to the current topic (like marking "very convientent for walking" -> "POSITIVE" as correct model behavior), while we reject (or just ignore) examples that don't belong in the current topic (like "Second visit" which is not about a hotel's location). After we have added some new suggestions to the current topic (we normally only bother to look at the top few suggestions) we can repeat the process by clicking "Suggestions" again. Repeating the process a few times allows adaptive-testing to learn from our feedback and hill-climb towards generating better and better suggestions (ones that are more likely to be on-topic and reveal model failures). Doing this for a few rounds reveals lots of bugs in the model related to positive hotel location statements. 60 | 61 |

62 | adaptive-testing loops 63 |

64 | 65 | Once we have testing the location aspect enough we can repeat the process to test a new aspect of model behavior, for example comments about hotel swimming pools or gyms. The space of possible concepts for hotel reviews is large, so to help explore it adaptive-testing can suggest new topics once we have a few examples: 66 | 67 |

68 | adaptive-testing loops 69 |

70 | 71 | After we accept some of these new topic suggestions we can open them and fill them out without ever even writing seed examples. adaptive-testing can suggest new tests inside an empty topic by just using examples other topics and the current topic's name. 72 | 73 |

74 | adaptive-testing loops 75 |

76 | 77 | This is just a short example of how to find bugs in a sentiment analysis model, but the same process can be applied to any NLP model (even ones that generate free form text). Test trees can be adapted to new models and shared with others collaboratively (they are just CSV files). Once you have enough bugs you can fine tune your model against a mixture of your test tree and the original training data to fix all the bugs in the test tree while retaining performance on your original training data (we will share a full demo notebook of this soon). 78 | 79 | 80 | ## Citation 81 | 82 | If you find adaptive-testing or test trees useful in your work feel free to cite our ACL paper: 83 | [Adaptive Testing and Debugging of NLP Models](https://aclanthology.org/2022.acl-long.230) (Ribeiro & Lundberg, ACL 2022) 84 | 85 | ## Contributing 86 | 87 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 88 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 89 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 90 | 91 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 92 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 93 | provided by the bot. You will only need to do this once across all repos using our CLA. 94 | 95 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 96 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 97 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 98 | 99 | ## Trademarks 100 | 101 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 102 | trademarks or logos is subject to and must follow 103 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 104 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 105 | Any use of third-party trademarks or logos are subject to those third-party's policies. 106 | -------------------------------------------------------------------------------- /adaptivetesting/_topic_model.py: -------------------------------------------------------------------------------- 1 | import sklearn 2 | import numpy as np 3 | from sklearn import multioutput 4 | from sklearn import preprocessing 5 | from sklearn.linear_model import RidgeClassifierCV 6 | from sklearn.linear_model import LogisticRegression 7 | from sklearn.svm import LinearSVC 8 | import adaptivetesting 9 | import re 10 | 11 | class ConstantModel(): 12 | def __init__(self, probability): 13 | self.probability = probability 14 | def predict_prob(self, embeddings): 15 | if not hasattr(embeddings[0], "__len__"): 16 | return self.probability 17 | else: 18 | return [self.probability] * len(embeddings) 19 | 20 | class CVModel(): 21 | def __init__(self, embeddings, labels): 22 | self.inner_model = RidgeClassifierCV(class_weight={"pass": 1, "fail": 1}) 23 | self.inner_model.fit(embeddings, labels) 24 | 25 | def predict_prob(self, embeddings): 26 | assert len(self.inner_model.classes_) == 2 27 | d = self.inner_model.decision_function(embeddings) 28 | probs = np.exp(d) / (np.exp(d) + np.exp(-d)) 29 | 30 | return probs 31 | 32 | class OutputNearestNeighborLabelModel(): 33 | def __init__(self, embeddings, labels): 34 | embeddings[:,:embeddings.shape[1]//2] = 0 # zero out the embedding for the input value so we only depend on the output 35 | self.model = sklearn.neighbors.KNeighborsClassifier(1) 36 | self.model.fit(embeddings, labels) 37 | def predict(self, embeddings): 38 | embeddings[:,:embeddings.shape[1]//2] = 0 39 | return self.model.predict(embeddings) 40 | 41 | class TopicLabelingModel: 42 | def __init__(self, topic, test_tree): 43 | self.topic = topic 44 | self.test_tree = test_tree 45 | 46 | # mask out entries that do not have a pass/fail label 47 | valid_mask = ~((test_tree["labeler"] == "imputed") | (test_tree["label"] == "topic_marker") | (test_tree["label"] == "off_topic")) 48 | 49 | # try and select samples from the current topic 50 | topic_mask = (test_tree["topic"] == topic) & valid_mask 51 | 52 | # if we didn't find enough samples then expand to include subtopics 53 | if topic_mask.sum() <= 1: 54 | topic_mask = test_tree["topic"].str.startswith(topic) & valid_mask 55 | 56 | # if we still didn't find enough samples then expand to include parent topics 57 | parts = topic.split("/") 58 | for i in range(len(parts), 0, -1): 59 | prefix = "/".join(parts[:i+1]) 60 | if topic_mask.sum() <= 1: 61 | topic_mask = test_tree["topic"].str.startswith(prefix) & valid_mask 62 | else: 63 | break 64 | 65 | # get our features and labels for fitting a model 66 | strings = list(test_tree["input"][topic_mask]) + list(test_tree["output"][topic_mask]) 67 | labels = list(test_tree["label"][topic_mask]) 68 | unrolled_embeds = adaptivetesting.embed(strings) 69 | embeddings = np.hstack([unrolled_embeds[:len(labels)], unrolled_embeds[len(labels):]]) 70 | 71 | # empty test tree 72 | if len(labels) == 0: 73 | self.model = ConstantModel(0.0) 74 | 75 | # constant label topic 76 | elif len(set(labels)) == 1: 77 | self.model = ConstantModel(0.0 if labels[0] == "pass" else 1.0) 78 | 79 | # enough samples to fit a model 80 | else: 81 | 82 | # we are in a highly overparametrized situation, so we use a linear SVC to get "max-margin" based generalization 83 | # TODO: SML: It seems to me that the SVC seems to do very well as long as there are no "errors" in the data labels. But it will 84 | # do very poorly if there are errors in the data labels since it will fit them exactly. Perhaps we can help this by 85 | # ensembling several SVCs together each trained on a different bootstrap sample? This might add the roubustness (against label mismatches) 86 | # that is lacking with hard-margin SVC fitting (it is also motivated a bit by the connections between SGD and hard-margin SVC fitting, and that 87 | # in practice SGD works on subsamples of the data so it should be less sensitive to label misspecification). 88 | # self.model = LinearSVC() 89 | 90 | # self.model = LogisticRegression(penalty='l2', random_state=0, C=1.0, solver='lbfgs', max_iter=1000) 91 | 92 | # This seemed to be reasonably well calibrated on simple tests, so we use it instead of SVC 93 | self.model = CVModel(embeddings, labels) 94 | 95 | # # add the missing predict_proba method to the base model 96 | # def predict_proba(self, X): 97 | # if len(self.classes_) == 1: 98 | # return np.ones((len(X), 1)) 99 | # d = self.decision_function(X) 100 | # if len(self.classes_) == 2: 101 | # probs = np.exp(d) / (np.exp(d) + np.exp(-d)) 102 | # return np.array([1 - probs, probs]).T 103 | # probs = np.exp(d).T / np.sum(np.exp(d), axis=1) 104 | # return probs.T 105 | # self.model.predict_proba = predict_proba.__get__(self.model, self.model.__class__) 106 | 107 | # self.model.fit(embeddings, labels) 108 | 109 | def __call__(self, input, output): 110 | embeddings = np.hstack(adaptivetesting.embed([input, output])) 111 | if not hasattr(embeddings[0], "__len__"): 112 | return self.model.predict_prob([embeddings])[0] 113 | return self.model.predict_prob(embeddings) 114 | 115 | class TopicMembershipModel: 116 | """ A model that predicts if a given test fits in a given topic. 117 | 118 | Note that this model only depends on the inputs not the output values for a test. 119 | """ 120 | def __init__(self, topic, test_tree): 121 | self.topic = topic 122 | self.test_tree = test_tree 123 | 124 | # mask out entries that do not have a topic membership label 125 | valid_mask = ~((test_tree["labeler"] == "imputed") | (test_tree["label"] == "topic_marker")) 126 | 127 | # try and select samples from the current topic 128 | topic_mask = (test_tree["topic"] == topic) & valid_mask 129 | 130 | # if we didn't find enough samples then expand to include subtopics 131 | if topic_mask.sum() <= 1: 132 | topic_mask = test_tree["topic"].str.startswith(topic) & valid_mask 133 | 134 | # if we still didn't find enough samples then expand to include parent topics 135 | parts = topic.split("/") 136 | for i in range(len(parts), 0, -1): 137 | prefix = "/".join(parts[:i+1]) 138 | if topic_mask.sum() <= 1: 139 | topic_mask = test_tree["topic"].str.startswith(prefix) & valid_mask 140 | else: 141 | break 142 | 143 | # get our features and labels for fitting a model 144 | strings = list(test_tree["input"][topic_mask]) 145 | labels = [l if l == "off_topic" else "on_topic" for l in test_tree["label"][topic_mask]] 146 | embeddings = np.array(adaptivetesting.embed(strings)) 147 | 148 | # empty test tree (default to on-topic) 149 | if len(labels) == 0: 150 | self.model = ConstantModel(1.0) 151 | 152 | # constant label topic 153 | elif len(set(labels)) == 1: 154 | self.model = ConstantModel(0.0 if labels[0] == "off_topic" else 1.0) 155 | 156 | # enough samples to fit a model 157 | else: 158 | 159 | # we are in a highly overparametrized situation, so we use a linear SVC to get "max-margin" based generalization 160 | self.model = CVModel() 161 | self.model.fit(embeddings, labels) 162 | 163 | def __call__(self, input): 164 | embeddings = adaptivetesting.embed([input])[0] 165 | if not hasattr(embeddings[0], "__len__"): 166 | return "on_topic" if self.model.predict_prob([embeddings])[0] > 0.5 else "off_topic" 167 | return ["on_topic" if v > 0.5 else "off_topic" for v in self.model.predict_prob(embeddings)] 168 | 169 | class ChainTopicModel: 170 | def __init__(self, model=None): 171 | if model is None: 172 | self.base_model = RidgeClassifierCV() 173 | else: 174 | self.base_model = model 175 | def fit(self, X, y): 176 | topics = y 177 | max_levels = max([len(x.split('>')) for x in topics]) 178 | self.model = sklearn.multioutput.ClassifierChain(self.base_model, order=list(range(max_levels))) 179 | y = [list(map(str.strip, x.split('>'))) for x in topics] 180 | y = np.array([x + ['-'] * (max_levels - len(x)) for x in y]) 181 | self.encoders = [preprocessing.LabelEncoder() for _ in range(max_levels)] 182 | self.possible_topics = set() 183 | for x in topics: 184 | self.possible_topics.add(x) 185 | a = x.split(' > ') 186 | for i in range(1, len(a)): 187 | self.possible_topics.add(' > '.join(a[:i])) 188 | 189 | self.classes_ = list(self.possible_topics) 190 | new_y = np.zeros(y.shape) 191 | for i in range(y.shape[1]): 192 | self.encoders[i].fit(y[:, i]) 193 | new_y[:, i] = self.encoders[i].transform(y[:, i]) 194 | self.model.fit(X, new_y) 195 | def predict(self, X): 196 | y = self.model.predict(X) 197 | ret = [] 198 | for i in range(y.shape[1]): 199 | ret.append(self.encoders[i].inverse_transform(y[:, i].astype(int))) 200 | y = np.array(ret).T 201 | ret = [] 202 | for x in y: 203 | x = [z for z in x if z != '-'] 204 | a = ' > '.join(x) 205 | while a not in self.possible_topics: 206 | x = x[:-1] 207 | a = ' > '.join(x) 208 | ret.append(a) 209 | return np.array(ret) 210 | 211 | def predict_proba(self, X): 212 | # This is just a fake function for now, puts 1 in the predicted class and 0 elsewhere 213 | y = self.predict(X) 214 | ret = np.zeros((len(X), len(self.classes_))) 215 | for i, r in enumerate(y): 216 | ret[i, self.classes_.index(r)] = 1 217 | return ret 218 | 219 | class StandardTopicModel: 220 | def __init__(self, threshold=0.5): 221 | self.model= sklearn.linear_model.RidgeClassifierCV() 222 | self.threshold=threshold 223 | # add the missing predict_proba method to RidgeClassifierCV 224 | def predict_proba(self, X): 225 | if len(self.classes_) == 1: 226 | return np.ones((len(X), 1)) 227 | d = self.decision_function(X) 228 | if len(self.classes_) == 2: 229 | probs = np.exp(d) / (np.exp(d) + np.exp(-d)) 230 | return np.array([1 - probs, probs]).T 231 | probs = np.exp(d).T / np.sum(np.exp(d), axis=1) 232 | return probs.T 233 | self.model.predict_proba = predict_proba.__get__(self.model, self.model.__class__) 234 | def fit(self, X, y): 235 | self.model.fit(X, y) 236 | def predict_proba(self, X): 237 | return self.model.predict_proba(X) 238 | def predict(self, X): 239 | if self.threshold is None: 240 | return self.model.predict(X) 241 | pps = self.model.predict_proba(X) 242 | zero_index = list(self.model.classes_).index('Not problematic') 243 | ret = [] 244 | for p in pps: 245 | if p[zero_index] >= self.threshold: 246 | ret.append(self.model.classes_[zero_index]) 247 | continue 248 | else: 249 | best = np.argsort(p) 250 | if best[-1] == zero_index: 251 | best = best[:-1] 252 | ret.append(self.model.classes_[best[-1]]) 253 | return np.array(ret) 254 | # return self.model.predict(X) 255 | -------------------------------------------------------------------------------- /adaptivetesting/_prompt_builder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | import re 4 | import urllib.parse 5 | import adaptivetesting 6 | from .embedders import cos_sim 7 | from .utils import is_subtopic 8 | log = logging.getLogger(__name__) 9 | 10 | 11 | class PromptBuilder(): 12 | """ A class to build prompts for the model. 13 | """ 14 | 15 | def __init__(self, prompt_size=7, slot_randomization=0.25, score_randomization=0.05, skip_randomization=0.25, prompt_diversity=True, 16 | subtopic_diversity=True): 17 | """ Initialize the prompt builder. 18 | 19 | Parameters 20 | ---------- 21 | prompt_size : int 22 | The number of test slots to include in the prompt. 23 | 24 | slot_randomization : float 25 | The proportion of slots to make fully random (within the current topic). 26 | 27 | score_randomization : float 28 | The standard deviation of an additive Gaussian randomization factor for the scores. 29 | 30 | skip_randomization : float 31 | The proportion of times we skip over top ranking tests when building the prompt. 32 | 33 | prompt_diversity : bool 34 | Whether to include a diversity term when selecting tests for the prompt. This diversity term is based 35 | on the embeddings of each test. 36 | 37 | subtopic_diversity : bool 38 | If true, we will try and pick tests from a diverse set of subtopics of the current topic (if we are 39 | using subtopic tests and not direct child tests). 40 | """ 41 | 42 | assert skip_randomization < 0.99, "skip_randomization must be less than 1, otherwise everything will always be skipped!" 43 | 44 | self.prompt_size = prompt_size 45 | self.slot_randomization = slot_randomization 46 | self.score_randomization = score_randomization # TODO: make this scale according to the stddev of the top 7 entries? 47 | self.skip_randomization = skip_randomization 48 | self.prompt_diversity = prompt_diversity 49 | self.subtopic_diversity = subtopic_diversity 50 | 51 | def __call__(self, test_tree, topic, score_column, repetitions=1, filter="", suggest_topics=False, working_set_size=100, embeddings=None): 52 | """ This builds a prompt for GPT3 that elicits useful input examples. 53 | 54 | Parameters 55 | ---------- 56 | test_tree : adaptivetesting.TestTree 57 | The test tree to generate prompts from. 58 | 59 | topic : str 60 | The topic to build a prompt for. 61 | 62 | score_column : str 63 | The column to use for scoring the tests. 64 | 65 | repetitions : int 66 | The number of times to repeat the prompt generation process. This is how many prompots we will return. 67 | 68 | filter : str 69 | A filter to apply to the test set before selecting tests to build the prompt. 70 | 71 | suggest_topics : bool 72 | If true, we will create a prompt filled with topic names instead of a list of tests. 73 | 74 | working_set_size : int 75 | How many top tests to consider when doing the full iterative scoring process. Larger values may take longer. 76 | Note that this has no effect as long as we never go more than working_set_size tests deep during the prompt 77 | item selection process. 78 | 79 | embeddings : dict 80 | A dictionary of embeddings to use for the prompt. This is used to compute the prompt_diversity. 81 | """ 82 | 83 | ids = np.array(test_tree.index) 84 | 85 | # return early for an empty test tree 86 | if len(ids) == 0: 87 | return [[] for _ in range(repetitions)] 88 | 89 | # we compute each test's distance from current topic, where distance is measured 90 | # by the length of the topic prefix shared between the test and the current topic 91 | topic_scaling = np.ones(test_tree.shape[0]) 92 | topic_parts = topic.split("/") 93 | for i in range(1, len(topic_parts)): 94 | prefix = "/".join(topic_parts[:i+1]) 95 | if suggest_topics: 96 | prefix += "/" 97 | topic_scaling *= 1 + 99 * np.array([v.startswith(prefix) for v in test_tree["topic"]]) 98 | 99 | # promote direct children over subtopic descendants and filter for topics vs tests 100 | if suggest_topics: 101 | topic_scaling *= 1 + 99 * np.array([v.rsplit('/', 1)[0] == topic for v in test_tree["topic"]]) 102 | topic_scaling *= np.array(test_tree["label"] == "topic_marker") 103 | else: 104 | topic_scaling *= 1 + 99 * np.array([v == topic for v in test_tree["topic"]]) 105 | topic_scaling *= np.array(test_tree["label"] != "topic_marker") 106 | topic_scaling *= np.array(["__suggestions__" not in t for t in test_tree["topic"]]) 107 | 108 | # return early if we have nothing to build a prompt with 109 | if np.sum(topic_scaling) == 0: 110 | return [[] for _ in range(repetitions)] 111 | 112 | topic_scaling /= np.max(topic_scaling) 113 | 114 | # hide rows that don't match the filter 115 | hidden_scaling = np.ones(len(ids)) 116 | if filter != "": 117 | filter_compiled = re.compile(filter) 118 | for i,k in enumerate(ids): 119 | test = test_tree.loc[k] 120 | if filter_compiled.search(test.test_type) is not None: 121 | continue 122 | if hasattr(test, "input") and filter_compiled.search(test.input) is not None: 123 | continue 124 | if hasattr(test, "output") and filter_compiled.search(test.output) is not None: 125 | continue 126 | hidden_scaling[i] = 0.0 127 | 128 | # filter down to a single test type (chosen to match the top scoring test) 129 | if suggest_topics: 130 | # scores currently do not influence topic suggestions 131 | # TODO: can we score topics and topic suggestions? 132 | scores = np.ones(len(ids)) 133 | else: 134 | # compute a positive single value score for each test 135 | scores = np.array([score_max(test_tree.loc[k, score_column], test_tree.loc[k, "label"]) for k in ids]) 136 | 137 | # filter down to just top rows we will use during the iterative scoring process 138 | rank_vals = scores * topic_scaling * hidden_scaling 139 | top_inds = np.argsort(-rank_vals)[:working_set_size] 140 | ids = ids[top_inds] 141 | topic_scaling = topic_scaling[top_inds] 142 | hidden_scaling = hidden_scaling[top_inds] 143 | scores = scores[top_inds] * 1.0 144 | 145 | # build a list of randomized prompts 146 | prompts = [] 147 | for _ in range(repetitions): 148 | 149 | # store tmp versions of things we update during the iteration 150 | scores_curr = scores.copy() 151 | topic_scaling_curr = topic_scaling.copy() 152 | 153 | # score randomization 154 | scores_curr += self.score_randomization * np.random.rand(len(ids)) 155 | 156 | # sim_avoidance is a vector that marks which items (and items related through similarities) 157 | # should be avoided (ranked lower for prompt selection) 158 | if self.prompt_diversity: 159 | sim_avoidance = np.zeros(len(ids)) 160 | if suggest_topics: 161 | embeddings_arr = np.vstack(adaptivetesting.embed( 162 | [urllib.parse.unquote(test_tree.loc[id, "topic"].split("/")[-1]) for id in ids] 163 | )) 164 | else: 165 | embeddings_arr = np.hstack([ 166 | np.vstack(adaptivetesting.embed([test_tree.loc[id, "input"] for id in ids])), 167 | np.vstack(adaptivetesting.embed([test_tree.loc[id, "output"] for id in ids])) 168 | ]) 169 | similarities = cos_sim(embeddings_arr, embeddings_arr) 170 | hard_avoidance = np.zeros(len(ids)) 171 | diversity = np.ones(len(ids)) 172 | 173 | # compute how many greedy and how many random positions we will have 174 | num_random = max(0, min(np.random.binomial(self.prompt_size, self.slot_randomization), len(ids) - self.prompt_size)) 175 | num_greedy = max(0, min(self.prompt_size - num_random, len(ids) - num_random)) 176 | 177 | # iteratively select prompt items 178 | prompt_ids = [] 179 | outside_topics_used = np.ones(len(ids)) 180 | while len(prompt_ids) < num_greedy + num_random: 181 | 182 | # once we get to the random part of the process we scramble the scores 183 | if len(prompt_ids) == num_greedy: 184 | scores_curr = 1 + np.random.rand(len(ids))*0.1 185 | 186 | # find the next bext index 187 | if self.prompt_diversity: 188 | diversity = 1 - (similarities * sim_avoidance).max(1) 189 | rank_vals = scores_curr * topic_scaling_curr * diversity * (1 - hard_avoidance) * hidden_scaling * outside_topics_used 190 | 191 | if np.nanmax(rank_vals) <= 0 and len(prompt_ids) > 0: # stop if we have run out of the current subtree 192 | break 193 | 194 | new_ind = np.nanargmax(rank_vals) 195 | skip_rand = np.random.rand() 196 | 197 | # make it unlikely we will choose the same outside topic twice 198 | new_ind_topic = test_tree.loc[ids[new_ind], "topic"] 199 | if not is_subtopic(topic, new_ind_topic): 200 | outside_topics_used *= 1 - 0.9 * np.array([test_tree.loc[id, "topic"] == new_ind_topic for id in ids]) 201 | 202 | # add or skip this item 203 | if skip_rand >= self.skip_randomization: 204 | prompt_ids.append(ids[new_ind]) 205 | avoidance_level = 1 206 | else: 207 | avoidance_level = 1 - 0.1 208 | 209 | # avoid this IO pair as we select the next pairs 210 | hard_avoidance[new_ind] = avoidance_level 211 | if self.prompt_diversity: 212 | sim_avoidance[new_ind] = avoidance_level 213 | 214 | # lower the weight of the subtopic we just picked from 215 | if self.subtopic_diversity: 216 | new_topic = test_tree.loc[ids[new_ind], "topic"] 217 | if topic != new_topic and is_subtopic(topic, new_topic): 218 | subtopic = topic + "/" + new_topic[(len(topic)+1):].split("/")[0] 219 | subtopic_scaling = np.array([0.001 if is_subtopic(subtopic, test_tree.loc[k, "topic"]) else 1 for k in ids]) 220 | topic_scaling_curr *= subtopic_scaling 221 | 222 | # create the prompt as a list of tuples 223 | prompt = [] 224 | for k in reversed(prompt_ids): 225 | row = test_tree.loc[k] 226 | if suggest_topics: 227 | if row["topic"] == "": 228 | continue # we can't use the root to help suggest topic names 229 | parents,child = row["topic"].rsplit("/", 1) 230 | prompt.append((k, parents, urllib.parse.unquote(child))) 231 | else: 232 | prompt.append((k, row["topic"], row["input"])) 233 | prompts.append(prompt) 234 | 235 | return prompts 236 | 237 | def score_max(s, label): 238 | if s == "" or s is None: 239 | return 1 if label == "fail" else 0 240 | elif isinstance(s, str): 241 | return np.max([convert_float(v) for v in s.split("|")]) 242 | elif np.isnan(s): 243 | return 1 if label == "fail" else 0 244 | else: 245 | return np.max(s) 246 | 247 | def convert_float(s): 248 | try: 249 | f = float(s) 250 | except ValueError: 251 | f = np.nan 252 | return f -------------------------------------------------------------------------------- /test_trees/abstract_capabilities.csv: -------------------------------------------------------------------------------- 1 | ,topic,input,output,label,labeler,description,model score 2 | b72f54bb56ec458c87138faf4ef64f4f,/Semantic Role Labeling,,,topic_marker,adatest_default,, 3 | 3c506e4cbcc14d87ac0c6e3169beeb0f,/Logic,,,topic_marker,adatest_default,, 4 | 2cb3c973d2b446cfb1f4a4a7f40f7253,/Named Entity Recognition,,,topic_marker,adatest_default,, 5 | 9037cb78df8e4b09bab972c5661b735d,/Vocabulary+POS,,,topic_marker,adatest_default,, 6 | 480258ecc6654efb879d29726fc05507,/Fairness,,,topic_marker,adatest_default,, 7 | 87c8b63256d0486fb306b9b13d042ad6,/Negation,,,topic_marker,adatest_default,, 8 | 2c2cccc2ae714d77a33ec72a93339739,/Coreference,,,topic_marker,adatest_default,, 9 | 3eb405c6ab6a4eeb95c57f00c2f8c542,/Fairness/Race,,,topic_marker,adatest_default,, 10 | 6287a53a733e454796a05efd939e0903,/Fairness/Gender and sexuality,,,topic_marker,adatest_default,, 11 | cd722a3a511a4b2cbb211eab3435a968,/Fairness/Gender and sexuality/LGBTQ,,,topic_marker,adatest_default,, 12 | 8f16c69e8fe3490b8be6c9fcafb8dc2f,/Fairness/Politics,,,topic_marker,adatest_default,, 13 | fe40a6e72696481aaf6e4f575868c8b2,/Robustness/Adding or removing irrelevant punctuation,,,topic_marker,adatest_default,, 14 | 683233e699c14d54a6801acdca3b08d3,/Robustness/Spelling errors,,,topic_marker,adatest_default,, 15 | c12be5d95344420a8bf191f52c45d5ce,/Robustness/Adding a random URL or @ mentions,,,topic_marker,adatest_default,, 16 | 045f4ad7e36a4f1694f9260a6d8f32ab,/Fairness/Religion,,,topic_marker,adatest_default,, 17 | 9ce174f1fc544297becdeef49cc061db,/Fairness/Disabilities,,,topic_marker,adatest_default,, 18 | e6a7264a9ce1475091d8fa841a3b758d,/Fairness/Age,,,topic_marker,adatest_default,, 19 | 022942349c854e678b4623f37fdd6cd6,/Logic/Symmetry,,,topic_marker,adatest_default,, 20 | eda7126ae4c2400ea2b22135b19fd5f5,/Logic/Implications,,,topic_marker,adatest_default,, 21 | 5b2621ecc57f4052bc5218049faf825c,/Taxonomy/Synonyms,,,topic_marker,adatest_default,, 22 | 5fe557f4efda4c18956dbf72729e69b2,/Taxonomy/Antonyms,,,topic_marker,adatest_default,, 23 | 45784b5867f24d70bae80cbf60100bc5,/Robustness/Contractions,,,topic_marker,adatest_default,, 24 | 7d00b5c1b43046d28ecfed25860a83b7,/Fairness/Religion/Christianity,,,topic_marker,adatest_default,, 25 | 6f192e1dd87c4441a9121fba0da8534f,/Fairness/Religion/Islam,,,topic_marker,adatest_default,, 26 | 7b3addf57c3c46249e819b624909033f,/Fairness/Religion/Hinduism,,,topic_marker,adatest_default,, 27 | cc309caf440a49c1a4636808ce811e58,/Fairness/Religion/Buddhism,,,topic_marker,adatest_default,, 28 | bc1e8ab4711c4bbfb26876e95bd2ee95,/Fairness/Religion/Judaism,,,topic_marker,adatest_default,, 29 | 14187c4c7c9243a193758455bbaa7776,/Fairness/Religion/Shinto,,,topic_marker,adatest_default,, 30 | 9a12e7319fc3417d85e837daef62449b,/Fairness/Religion/Taoism,,,topic_marker,adatest_default,, 31 | 382da467f7b04ea296f81fdd897e45a0,/Fairness/Religion/Jainism,,,topic_marker,adatest_default,, 32 | 6f1298d713be4fdc9c1f668c31b38b50,/Fairness/Religion/Atheism,,,topic_marker,adatest_default,, 33 | 58c32b204151411f94078d89d4516ea5,/Fairness/Religion/Sikhism,,,topic_marker,adatest_default,, 34 | 936b858cc6684230987ee8d3e12ee8c1,/Fairness/Religion/Spiritism,,,topic_marker,adatest_default,, 35 | 5f1457585d3743a68d0e891ff5d43e51,/Fairness/Religion/Bahai Faith,,,topic_marker,adatest_default,, 36 | c66b61b6c0e9478fbae3b3883d9c6efb,/Fairness/Age/Young,,,topic_marker,adatest_default,, 37 | b88fa13816c7485f8006bd90948dce0a,/Fairness/Age/Old,,,topic_marker,adatest_default,, 38 | 3d97bfa9f086421a8f30683f4e00b653,/Fairness/Disabilities/Autism spectrum disorders,,,topic_marker,adatest_default,, 39 | 0a3a12c82c53428d8b17aa33f6b893c0,/Fairness/Disabilities/Deaf and hard of hearing,,,topic_marker,adatest_default,, 40 | 1464613188b74cae9a64f72bafe9e7e6,/Fairness/Disabilities/Visual impairments,,,topic_marker,adatest_default,, 41 | 3b9c6615468f4699b589f0e0d590b388,/Fairness/Disabilities/Learning disabilities,,,topic_marker,adatest_default,, 42 | 2f3d38e7a0c247e8b1854b78dfb9de70,/Fairness/Disabilities/Mobility disabilities,,,topic_marker,adatest_default,, 43 | 2b69cb088cb64c0dbbf56e67722ce535,/Fairness/Race/East Asian,,,topic_marker,adatest_default,, 44 | bb62f5f3b53447e282f7fbbb77c014e7,/Fairness/Race/European,,,topic_marker,adatest_default,, 45 | eac2219f50954542a5a53d53e5390247,/Fairness/Race/Indigenous American,,,topic_marker,adatest_default,, 46 | 8e3ebca3b259483c837d52e51a9e724c,/Fairness/Race/Sub-Saharan African,,,topic_marker,adatest_default,, 47 | 2312dee91409400f86a75f6992a148f7,/Fairness/Race/North African,,,topic_marker,adatest_default,, 48 | 72dda66a4ea944249ca55f3d502b60f2,/Fairness/Race/West Asian,,,topic_marker,adatest_default,, 49 | 93715bd783e640ca8be8d32d9ddb9ace,/Fairness/Race/Melanesian,,,topic_marker,adatest_default,, 50 | 734d00bbdd664182b96115eeda255c00,/Fairness/Race/Central & South Asian,,,topic_marker,adatest_default,, 51 | 1c856fe4b2f048eab2a9b0302c7c735f,/Fairness/Race/European/Northwestern European,,,topic_marker,adatest_default,, 52 | 6fb01ad5445f42e2b19d390da4e922a8,/Fairness/Race/European/Jewish,,,topic_marker,adatest_default,, 53 | 30d9d929aecd411d86b278c1958e1662,/Fairness/Race/European/Eastern European,,,topic_marker,adatest_default,, 54 | 84d0247422a24e529360a24fdc95944e,/Fairness/Race/European/Southern European,,,topic_marker,adatest_default,, 55 | 05ce2308a4974026ac1ea52be07fbea3,"/Fairness/Race/Arab, Egyptian & Levantine",,,topic_marker,adatest_default,, 56 | b64f34df1c40453196ed5874f8a56a43,/Fairness/Race/European/Northwestern European/British & Irish,,,topic_marker,adatest_default,, 57 | c33650d6aea543c1813e04d8a914f423,/Fairness/Race/European/Northwestern European/Scandinavian,,,topic_marker,adatest_default,, 58 | 990d9443b2d943fb91ed8c5e959e7ba5,/Fairness/Race/European/Northwestern European/French & German,,,topic_marker,adatest_default,, 59 | fa65eedda36b4e0882ff9649b46f30a8,/Fairness/Race/European/Northwestern European/Finnish,,,topic_marker,adatest_default,, 60 | bc1b76925b474d2e9bafda65a0c1e246,/Fairness/Race/European/Southern European/Greek & Balkan,,,topic_marker,adatest_default,, 61 | 2440037191764f3bb1d76601f1560296,/Fairness/Race/European/Southern European/Italian,,,topic_marker,adatest_default,, 62 | e9c73fe0f30e41c6bc489b82e66fec87,/Fairness/Race/European/Southern European/Sardinian,,,topic_marker,adatest_default,, 63 | 872478a8298e439b96e17ff072c330fe,/Fairness/Race/European/Southern European/Spanish & Portuguese,,,topic_marker,adatest_default,, 64 | 18400d3dbec24d5b9524a8f4bd52dd3e,/Fairness/Ethnicity,,,topic_marker,adatest_default,, 65 | eaefb247d21f4d1580d489e045ee3666,/Fairness/Ethnicity/Black or African American,,,topic_marker,adatest_default,, 66 | dc5d4e7ee670443cb6401b629a068c50,/Fairness/Ethnicity/White,,,topic_marker,adatest_default,, 67 | 00b80e0768c94a25acf2cea654b068a8,/Fairness/Ethnicity/Hispanic,,,topic_marker,adatest_default,, 68 | 4d1cfc27dfbf45569e53ed54a1db73ca,/Fairness/Ethnicity/Asian,,,topic_marker,adatest_default,, 69 | 46fe783a83c647058396ea7d5e723f31,/Fairness/Ethnicity/Arab,,,topic_marker,adatest_default,, 70 | ae02ab0695cc474db66b4c1ec3e26cac,/Fairness/Ethnicity/Jewish,,,topic_marker,adatest_default,, 71 | fea6dba1333d4ba8820135c85b13ea7f,/Fairness/Ethnicity/Black or African American/Black or African American,,,topic_marker,adatest_default,, 72 | 86bd9c19b41647519618f94157e57e56,/Fairness/Ethnicity/Nationality,,,topic_marker,adatest_default,, 73 | e6d782ca2f964c1787dfabb070ce9e01,/Fairness/Ethnicity/Nationality/American,,,topic_marker,adatest_default,, 74 | c553bb653f1d42afa0c9a2ae9bc9137b,/Fairness/Ethnicity/Nationality/Chinese,,,topic_marker,adatest_default,, 75 | 0b2e5161d82d40bf8588a896c0a88333,/Fairness/Ethnicity/Nationality/British,,,topic_marker,adatest_default,, 76 | 5c688fa9edd34e998b1a646e5cc92d1e,/Fairness/Ethnicity/Nationality/Irish,,,topic_marker,adatest_default,, 77 | b09935df44a24257bcbf18180ec015b9,/Fairness/Ethnicity/Nationality/Israeli,,,topic_marker,adatest_default,, 78 | 2a8fc16f6f544fd59dd82cca9fac89cb,/Fairness/Ethnicity/Nationality/Egyptian,,,topic_marker,adatest_default,, 79 | 001d29ac38a846eaaeb4b670842d4685,/Fairness/Ethnicity/Nationality/Mexican,,,topic_marker,adatest_default,, 80 | f7fca049c61948fbbbd952616104d59b,/Fairness/Ethnicity/Nationality/Canadian,,,topic_marker,adatest_default,, 81 | 6f0ebdf0d13a4eddb45808ed1c831f0a,/Fairness/Ethnicity/Nationality/Brazilian,,,topic_marker,adatest_default,, 82 | 5aef5b9ecde2464d8bf6e54f34db55b8,/Fairness/Ethnicity/Nationality/Indonesian,,,topic_marker,adatest_default,, 83 | 6062d425b1f745bf97a6331d59925e0a,/Fairness/Ethnicity/Nationality/Russian,,,topic_marker,adatest_default,, 84 | fdb8fe1da4db47e1b1df4696644df6c9,/Fairness/Ethnicity/Nationality/Indian,,,topic_marker,adatest_default,, 85 | c8a4d9a49ca94ff194238d7df9040bb2,/Fairness/Ethnicity/Nationality/Pakistani,,,topic_marker,adatest_default,, 86 | 4fc69a320a60438e936f450f234c18d1,/Fairness/Ethnicity/Nationality/Nigerian,,,topic_marker,adatest_default,, 87 | 65c9db7b6d13432da1cc6f236b7340c5,/Fairness/Ethnicity/Nationality/Bangladeshi,,,topic_marker,adatest_default,, 88 | 4b51ae3043c145de8919dced27724fa4,/Fairness/Ethnicity/Nationality/Japanese,,,topic_marker,adatest_default,, 89 | 39952472ab7e45908dbba6e13c131403,/Fairness/Ethnicity/Nationality/Ethiopian,,,topic_marker,adatest_default,, 90 | 6fcc5a84123a42239a9abbb47251c2b9,/Fairness/Ethnicity/Nationality/Filipino,,,topic_marker,adatest_default,, 91 | 01450c14171d44559113adc39ede0d3f,/Fairness/Ethnicity/Nationality/Vietnamese,,,topic_marker,adatest_default,, 92 | d6f8f0aa6c2a49459af593b2ac34320b,/Fairness/Ethnicity/Nationality/Congolese,,,topic_marker,adatest_default,, 93 | 47c0c7dd5e924010914836602f1c02b6,/Fairness/Ethnicity/Nationality/Iranian,,,topic_marker,adatest_default,, 94 | 79872f5db4ee45b8bde7a54c2f113d58,/Fairness/Ethnicity/Nationality/Turkish,,,topic_marker,adatest_default,, 95 | eb229bf76b624ada96a4f7ecb913016d,/Fairness/Ethnicity/Nationality/German,,,topic_marker,adatest_default,, 96 | 8e2ce2153c4a4ef4bd1e3947c0064fcf,/Fairness/Ethnicity/Nationality/French,,,topic_marker,adatest_default,, 97 | 07affcfaa3f045aebed2861e6df89234,/Fairness/Ethnicity/Nationality/Italian,,,topic_marker,adatest_default,, 98 | b41f2aefe1ca47cebb7955c90b315474,/Fairness/Ethnicity/Nationality/Spanish,,,topic_marker,adatest_default,, 99 | ef25682376c44ba3822eeddf03fcb082,/Fairness/Ethnicity/Nationality/Korean,,,topic_marker,adatest_default,, 100 | a1b7b5aa36444e839987b3c98a6b3940,/Fairness/Ethnicity/Nationality/Australian,,,topic_marker,adatest_default,, 101 | 1f792e47ee404f9182f97e76fda15419,/Literary Devices/Analogies,,,topic_marker,adatest_default,, 102 | c2ab86dd60434b78b33058a9a44b3cd7,/Literary Devices/Analogies,,,topic_marker,adatest_default,, 103 | ef2f42885c98445dbbb2be0fca994d6c,/Literary Devices/Analogies,,,topic_marker,adatest_default,, 104 | cfb77d784eac4cf5abeb11f919e4c64f,/Literary Devices,,,topic_marker,adatest_default,, 105 | 2782002c34c345d6a112665c02e5e9e2,/Literary Devices/Metaphor,,,topic_marker,adatest_default,, 106 | 2d921ac3158748eeb443eaa5880ee5f4,/Literary Devices/Irony+Sarcasm,,,topic_marker,adatest_default,, 107 | e5bdc973808f44b090bc358433976648,/Literary Devices/Hyperbole,,,topic_marker,adatest_default,, 108 | 9465029dfb0844b4b28222633e4ee9da,/Literary Devices/Allusions,,,topic_marker,adatest_default,, 109 | 2a6d5758bf7e487a8b08f40d06e60315,/Literary Devices/Personification,,,topic_marker,adatest_default,, 110 | 97317d629a29480cbcf6a77d9d78d679,/Literary Devices/Imagery,,,topic_marker,adatest_default,, 111 | 5d2ebde64275493da3592b711a53f16d,/Literary Devices/Simile,,,topic_marker,adatest_default,, 112 | 93cd0029f5ff41628474aa46007f21e8,/Literary Devices/Symbolism,,,topic_marker,adatest_default,, 113 | 61cd12099c67425aa2a2a1873a473eb7,/Literary Devices/Puns,,,topic_marker,adatest_default,, 114 | e1bc2b66791c4eb08e2f01b167550110,/Fairness,,,topic_marker,adatest_default,, 115 | fe0f67d4116d4e37940cfd927b59c373,/Fairness,,,topic_marker,adatest_default,, 116 | 28202017462248398be8300204f1cb4d,/Fairness,,,topic_marker,adatest_default,, 117 | fe71895657474d6f9f01602d6ccf3573,/Fairness/Occupation,,,topic_marker,adatest_default,, 118 | ffa7673d58aa454aa92bd5905b7be832,/Fairness/Immigration,,,topic_marker,adatest_default,, 119 | ada79da98a3c4c8eac7dd2c8748dde59,/Vocabulary+POS/Important verbs,,,topic_marker,adatest_default,, 120 | 0b43004d2bf74213b38ef2713233fbc4,/Vocabulary+POS/Important adjectives,,,topic_marker,adatest_default,, 121 | ac574e7cdfdc464c87d86c815cfc9660,/Vocabulary+POS/Important nouns,,,topic_marker,adatest_default,, 122 | c0f21344b01147f986a51d89e47f4fc8,/Vocabulary+POS/Important modifiers,,,topic_marker,adatest_default,, 123 | ca1476723e414a6a9a047cc13c34faf3,/Vocabulary+POS/Irrelevant word changes,,,topic_marker,adatest_default,, 124 | 24c93ae4c4234dce911125b894e260de,/Named Entity Recognition/Person names,,,topic_marker,adatest_default,, 125 | 3dcc45a2c680451db7c00a0c664b79e5,/Named Entity Recognition/Location names,,,topic_marker,adatest_default,, 126 | 704d58451d354a9bb9edf6557383a3d2,/Named Entity Recognition/Person names/Person names or changes that matter for prediction,,,topic_marker,adatest_default,, 127 | ee49e950701742e180ead9fad1d44434,/Named Entity Recognition/Person names/Person names or changes that matter for prediction,,,topic_marker,adatest_default,, 128 | c6a6e0ff746043769e652350de03fb15,/Named Entity Recognition/Person names/Person names or changes that do not matter for prediction,,,topic_marker,adatest_default,, 129 | c94fcc5ab74445f9889db304ca233832,/Named Entity Recognition/Location names/Location names or changes that matter for prediction,,,topic_marker,adatest_default,, 130 | 8954a752d9cc40cd84101fa18786e1d9,/Named Entity Recognition/Location names/Location names or changes that do not matter for prediction,,,topic_marker,adatest_default,, 131 | d3c9c6401ada46b7b9ba0f34785d1fb2,/Named Entity Recognition/Numbers,,,topic_marker,adatest_default,, 132 | e0406c39fc4b4e1cb9d8b92e860001a4,/Named Entity Recognition/Numbers/Numbers or changes that matter for prediction,,,topic_marker,adatest_default,, 133 | 38c621ea185f4bd299d7a32e0bbd0ff5,/Named Entity Recognition/Numbers/Numbers or changes that do not matter for prediction,,,topic_marker,adatest_default,, 134 | 4091e479855a4609ae2469f46a7098c9,/Fairness/Gender and sexuality/Male to Female or vice versa,,,topic_marker,adatest_default,, 135 | 7ef2dd29dc4a46b6a0afc444fa09493c,/Semantic Role Labeling/Agent - Object swap,,,topic_marker,adatest_default,, 136 | d8de9181c54644bc9c8808546a1f3787,/Semantic Role Labeling/Active - passive swap,,,topic_marker,adatest_default,, 137 | 6370cb75a68f4d588e505832fb9e75aa,/Logic/Conjunctions,,,topic_marker,adatest_default,, 138 | 6284e4f2d27a43409c4181ef2505cb08,/Logic/Disjunctions,,,topic_marker,adatest_default,, 139 | d54eb69301924cf1966e302b08f01f6e,/Taxonomy/Subtypes,,,topic_marker,adatest_default,, 140 | -------------------------------------------------------------------------------- /adaptivetesting/_scorer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import re 3 | import logging 4 | import uuid 5 | import itertools 6 | import shap 7 | from ._model import Model 8 | import adaptivetesting 9 | 10 | log = logging.getLogger(__name__) 11 | 12 | class Scorer(): 13 | def __new__(cls, model, *args, **kwargs): 14 | """ If we are wrapping an object that is already a Scorer, we just return it. 15 | """ 16 | if shap.utils.safe_isinstance(model, "adaptivetesting.Scorer"): 17 | return model 18 | else: 19 | return super().__new__(cls) 20 | 21 | def __init__(self, model): 22 | """ Auto detect the model type and subclass to the right scorer object. 23 | """ 24 | 25 | # ensure we have a model of type Model 26 | if shap.utils.safe_isinstance(getattr(self, "model", None), "adaptivetesting.Model") or shap.utils.safe_isinstance(getattr(self, "model", None), "shap.models.Model"): 27 | pass 28 | elif shap.utils.safe_isinstance(model, "adaptivetesting.Model") or shap.utils.safe_isinstance(model, "shap.models.Model"): 29 | self.model = model 30 | else: 31 | self.model = Model(model) 32 | 33 | # If we are in the base class we need to pick the right specialized subclass to become 34 | if self.__class__ is Scorer: 35 | 36 | # finish early if we are wrapping an object that is already a Scorer (__new__ will have already done the work) 37 | if shap.utils.safe_isinstance(model, "adaptivetesting.Scorer"): 38 | return 39 | 40 | # see if we are scoring a generator or a classifier 41 | out = self.model(["string 1", "string 2"]) 42 | if isinstance(out[0], str): 43 | self.__class__ = GeneratorScorer 44 | GeneratorScorer.__init__(self, model) 45 | else: 46 | self.__class__ = ClassifierScorer 47 | ClassifierScorer.__init__(self, model) 48 | 49 | 50 | class DummyScorer(Scorer): 51 | def __init__(self): 52 | self._id = uuid.uuid4().hex 53 | def __call__(self, tests): 54 | out = [] 55 | for k, test in tests.iterrows(): 56 | try: 57 | score = float(test.value2) 58 | except: 59 | score = np.nan 60 | out.append(score) 61 | return np.array(out) 62 | 63 | class ClassifierScorer(Scorer): 64 | """ Wraps a text classification model and defines a callable scorer that returns a score value for any input/output pair. 65 | 66 | Positive scores indicate test failures, positive scores indicate tests that pass. For example if we wrap 67 | a text sentiment classifer the `scorer(TestTree([("this is great!", "should be", "POSITIVE")]))` will return 68 | a large positive value indicating that the model is very likely to correctly produce that output when given 69 | that input. 70 | """ 71 | 72 | def __init__(self, model, top_probs=20, output_names=None): 73 | """ Create a new scorer given a model that returns a probability vector for each input string. 74 | 75 | Parameters: 76 | ----------- 77 | model : callable 78 | A model that is callable with a single argument (which is a list of strings) and returns a matrix of outputs. 79 | 80 | top_probs : int 81 | The number of top output probabilities to consider when scoring tests. This is used to reduce the number of 82 | input/output pairs that are passed to the local topic labeling model (and so save compute). 83 | 84 | output_names : list of strings 85 | A list of strings that correspond to the outputs of the model. If None, model.output_names is used. 86 | """ 87 | super().__init__(model) 88 | 89 | # extract output names from the model if they are not provided directly 90 | if output_names is None and getattr(self, "output_names", None) is None: 91 | self.output_names = self.model.output_names 92 | elif output_names is not None: 93 | self.output_names = output_names 94 | elif not hasattr(self, "output_names"): 95 | self.output_names = None 96 | 97 | if not callable(self.output_names): 98 | self._output_name_to_index = {v: i for i, v in enumerate(self.output_names)} 99 | self.top_probs = top_probs 100 | 101 | def __call__(self, tests, eval_ids): 102 | """ Compute the scores (and model outputs) for the tests matching the given ids. 103 | 104 | Parameters 105 | ---------- 106 | tests : TestTree 107 | A test tree for scoring. Note this should be the full test tree since it defines the local topic label 108 | models used for scoring. 109 | 110 | eval_ids : list of strings 111 | The ids of the tests to score. 112 | """ 113 | 114 | # expand templates in the test tree 115 | eval_inputs = [] 116 | eval_inds = [] 117 | for i, id in enumerate(eval_ids): 118 | test = tests.loc[id] 119 | template_expansions = expand_template(test.input) 120 | for expansion in template_expansions: 121 | eval_inputs.append(expansion) 122 | eval_inds.append(i) 123 | 124 | # run the model 125 | try: 126 | model_out = self.model(eval_inputs) 127 | except Exception as e: 128 | model_out = np.zeros((len(eval_inputs), len(self.model.output_names))) * np.nan # TODO: remove this hack after the user study 129 | log.error(e) 130 | log.error(eval_inputs) 131 | log.error("The model threw an exception when evaluating inputs! We are patching this disaster with np.nan for the sake of the user study!") 132 | 133 | # compute the output strings and probabilites for each output in template form 134 | out_strings = [[] for _ in range(len(eval_ids))] 135 | out_probs = [[] for _ in range(len(eval_ids))] 136 | i = 0 137 | while i < len(model_out): 138 | out_strings[eval_inds[i]].append(self.model.output_names[np.argmax(model_out[i])]) 139 | out_probs[eval_inds[i]].append(model_out[i]) 140 | i += 1 141 | for i in set(eval_inds): 142 | out_strings[i] = "|".join(out_strings[i]) # template outputs are joined by | 143 | out_probs[i] = np.column_stack(out_probs[i]) # the probability of a set of items is the prob of the min item 144 | 145 | # compute the embeddings as a batch (this fills a cache we will use when scoring below) 146 | adaptivetesting.embed(list(tests.loc[eval_ids, "input"])) 147 | 148 | # score all the tests 149 | scores = [] 150 | outputs = [] 151 | for i, ind in enumerate(eval_inds): 152 | outputs.append(out_strings[ind]) 153 | scores.append(self._score_test(tests, eval_ids[ind], out_probs[ind], self.top_probs)) 154 | 155 | return outputs,scores 156 | 157 | def _score_test(self, tests, id, probs, top_probs): 158 | test = tests.loc[id] 159 | total_fail_prob = 0 160 | total_pass_prob = 0 161 | 162 | # if this is not a templated test 163 | if probs.shape[1] == 1: 164 | inds = np.argsort(probs[:,0])[::-1] 165 | for ind in inds[:top_probs]: 166 | 167 | # Scott: we could use any manually given labels when possible, but then that would make the score depend on the label 168 | # and so we would either need to save the full output of the model or recompute every time 169 | # if self.model.output_names[ind] == test["output"] and test["labeler"] != "imputed": 170 | # label = test["label"] 171 | 172 | # we use the local topic model to predict the label 173 | fail_prob = tests.topic_labeling_model(test.topic)(test.input, self.model.output_names[ind]) 174 | total_fail_prob += probs[ind, 0] * fail_prob 175 | total_pass_prob += probs[ind, 0] * (1 - fail_prob) 176 | 177 | if not (total_fail_prob + total_pass_prob > 0): 178 | return np.nan 179 | else: 180 | return total_fail_prob / (total_pass_prob + total_fail_prob) 181 | else: 182 | raise NotImplementedError("TODO: implement classifer scoring for templated tests") 183 | 184 | class GeneratorScorer(Scorer): 185 | """ Wraps a text generation model as a callable scorer that can be applied to a test tree. 186 | """ 187 | 188 | def __init__(self, model): 189 | """ Create a new scorer for a generative text model. 190 | 191 | Parameters: 192 | ----------- 193 | model : callable 194 | A model that is callable with a single argument (which is a list of strings) and returns a list of strings. 195 | """ 196 | super().__init__(model) 197 | 198 | # we don't want to re-init a class if init has alrady been done (this can happen when Scorer(maybe_scorer) is called) 199 | if hasattr(self, "_id"): 200 | return # already initialized 201 | 202 | def __call__(self, tests, eval_ids): 203 | """ Score a set of tests. 204 | 205 | Parameters 206 | ---------- 207 | tests : TestTree or DataFrame 208 | A dataframe of tests. 209 | 210 | eval_ids : list of strings 211 | The evaluation IDs to use. 212 | """ 213 | 214 | # determine which rows we need to evaluate 215 | eval_inputs = [] 216 | eval_inds = [] 217 | for i, id in enumerate(eval_ids): 218 | template_expansions = expand_template(tests.loc[id, "input"]) 219 | for expansion in template_expansions: 220 | eval_inputs.append(expansion) 221 | eval_inds.append(i) 222 | 223 | # run the model on the rows we need to evaluate 224 | try: 225 | model_out = self.model(eval_inputs) 226 | except Exception as e: 227 | model_out = [""] * len(eval_inputs) # TODO: remove this hack after the user study 228 | log.error(e) 229 | log.error(eval_inputs) 230 | log.error("The model threw an exception when evaluating inputs! We are patching this disaster with np.nan for the sake of the user study!") 231 | 232 | # compute the output strings for each output 233 | out_strings = [[] for _ in range(len(eval_ids))] 234 | i = 0 235 | while i < len(model_out): 236 | out_strings[eval_inds[i]].append(str(model_out[i])) 237 | i += 1 238 | for i in set(eval_inds): 239 | out_strings[i] = "|".join(out_strings[i]) # template outputs are joined by | 240 | 241 | scores = [] 242 | outputs = [] 243 | for i, ind in enumerate(eval_inds): 244 | outputs.append(out_strings[ind]) 245 | scores.append(self._score_test(tests, eval_ids[ind], out_strings[ind])) 246 | 247 | return outputs,scores 248 | 249 | def _score_test(self, tests, id, output): 250 | test = tests.loc[id] 251 | 252 | fail_prob = tests.topic_labeling_model(test.topic)(test.input, output) 253 | 254 | return fail_prob 255 | 256 | class RawScorer(Scorer): 257 | """ Wraps a model that directly outputs a score each input as a callable scorer. 258 | 259 | The score from the model should be in the range [0,1] with higher scores indicating failures 260 | (or just more interesting behavior). 261 | """ 262 | 263 | def __init__(self, model): 264 | """ Create a new scorer given a model that returns a bounded real value for each input string. 265 | 266 | Parameters: 267 | ----------- 268 | model : callable 269 | A model that is callable with a single argument (which is a list of strings) and returns a vector of score in the range [0,1]. 270 | """ 271 | super().__init__(model) 272 | 273 | def __call__(self, tests, eval_ids): 274 | """ Compute the scores (and model outputs) for the tests matching the given ids. 275 | 276 | Parameters 277 | ---------- 278 | tests : TestTree 279 | A test tree for scoring. Note this should be the full test tree since it defines the local topic label 280 | models used for scoring. 281 | 282 | eval_ids : list of strings 283 | The ids of the tests to score. 284 | """ 285 | 286 | # expand templates in the test tree 287 | eval_inputs = [] 288 | eval_inds = [] 289 | for i, id in enumerate(eval_ids): 290 | test = tests.loc[id] 291 | template_expansions = expand_template(test.input) 292 | for expansion in template_expansions: 293 | eval_inputs.append(expansion) 294 | eval_inds.append(i) 295 | 296 | # run the model 297 | try: 298 | model_out = self.model(eval_inputs) 299 | except Exception as e: 300 | model_out = np.zeros(len(eval_inputs)) * np.nan # TODO: remove this hack after the user study 301 | log.error(e) 302 | log.error(eval_inputs) 303 | log.error("The model threw an exception when evaluating inputs! We are patching this disaster with np.nan for the sake of the user study!") 304 | 305 | # compute the output strings and scores for each output in template form 306 | out_strings = [[] for _ in range(len(eval_ids))] 307 | out_scores = [[] for _ in range(len(eval_ids))] 308 | i = 0 309 | while i < len(model_out): 310 | out_strings[eval_inds[i]].append(str(np.round(model_out[i], 6))) # convert float to string with precision of 6 311 | out_scores[eval_inds[i]].append(model_out[i]) 312 | i += 1 313 | for i in eval_inds: 314 | out_strings[i] = "|".join(out_strings[i]) # template outputs are joined by | 315 | out_scores[i] = np.max(out_scores[i]) # the score of a set of items is the score of the max item 316 | 317 | # score all the tests 318 | scores = [] 319 | outputs = [] 320 | for i, ind in enumerate(eval_inds): 321 | outputs.append(out_strings[ind]) 322 | scores.append(out_scores[ind]) 323 | 324 | return outputs,scores 325 | 326 | def expand_template(s, keep_braces=False): 327 | """ Expand a template string into a list of strings. 328 | """ 329 | # parts = [] 330 | # for s in strings: 331 | matches = re.findall("{[^}]*}", s) 332 | s = re.sub("{[^}]*}", "{}", s) 333 | template_groups = [str(m)[1:-1].split("|") for m in matches] 334 | try: 335 | if keep_braces: 336 | return [s.format(*['{{{p}}}' for p in parts]) for parts in itertools.product(*template_groups)] 337 | else: 338 | return [s.format(*parts) for parts in itertools.product(*template_groups)] 339 | except ValueError: 340 | return [s] # we return the template not filled in if it is invalid 341 | 342 | def clean_template(s): 343 | """ This removes duplicate template entries. 344 | """ 345 | matches = re.findall("{[^}]*}", s) 346 | s = re.sub("{[^}]*}", "{}", s) 347 | template_groups = [str(m)[1:-1].split("|") for m in matches] 348 | clean_groups = ["{"+"|".join(list({v: None for v in g}.keys()))+"}" for g in template_groups] 349 | try: 350 | return s.format(*clean_groups) 351 | except ValueError: 352 | return s # we return the template not cleaned in if it is invalid 353 | -------------------------------------------------------------------------------- /adaptivetesting/_server.py: -------------------------------------------------------------------------------- 1 | import json 2 | import functools 3 | import uuid 4 | import pathlib 5 | import logging 6 | import os 7 | 8 | 9 | import asyncio 10 | import nest_asyncio 11 | nest_asyncio.apply() 12 | 13 | import aiohttp 14 | from aiohttp import web 15 | import aiohttp_session 16 | import aiohttp_session.cookie_storage 17 | import aiohttp_security 18 | #from aiohttp_session import SimpleCookieStorage, session_middleware 19 | from aiohttp_security import check_permission, \ 20 | is_anonymous, remember, forget, \ 21 | setup as setup_security, SessionIdentityPolicy 22 | from aiohttp_security.abc import AbstractAuthorizationPolicy 23 | import cryptography.fernet 24 | from . import TestTree 25 | import functools 26 | 27 | log = logging.getLogger(__name__) 28 | 29 | 30 | def serve(test_tree_browsers, host="localhost", port=8080, static_dir=None, authenticate=lambda user, password: True, 31 | authorize=lambda user,location: True, auth_duration=60 * 60 * 8, ssl_crt=None, ssl_key=None): 32 | """ Serves the interface at the given host and port. 33 | """ 34 | log.debug(f"serve(test_tree_browsers={test_tree_browsers})") 35 | 36 | if isinstance(test_tree_browsers, TestTree): 37 | raise Exception("You cannot serve a TestTree directly! You need to call it with a scorer like test_tree(scorer).") 38 | 39 | if isinstance(authenticate, dict): 40 | auth_dict = authenticate 41 | def check_pass(user, password): 42 | return auth_dict.get(user, object()) == password 43 | authenticate = check_pass 44 | 45 | loop = asyncio.get_event_loop() 46 | 47 | if not hasattr(test_tree_browsers, "interface_event") and callable(test_tree_browsers): 48 | test_tree_browsers = functools.lru_cache(maxsize=None)(test_tree_browsers) 49 | 50 | id = uuid.uuid4().hex 51 | 52 | async def send_ws_data(ws, str_data): 53 | await ws.send_str(str_data) 54 | 55 | async def topic_handler(request): 56 | log.debug(f"topic_handler({request})") 57 | logged_in = not await aiohttp_security.is_anonymous(request) 58 | 59 | if not logged_in: 60 | user = request.rel_url.query.get("user", "anonymous") 61 | if authenticate(user, None): 62 | redirect_response = web.HTTPFound(str(request.rel_url)) 63 | await remember(request, redirect_response, user) 64 | return redirect_response 65 | else: 66 | raise web.HTTPFound(f'/_login?user={user}&sendback={str(request.rel_url)}') 67 | else: 68 | user = await aiohttp_security.authorized_userid(request) 69 | if hasattr(test_tree_browsers, "interface_event"): 70 | prefix = "" 71 | test_tree_browser = test_tree_browsers 72 | test_tree_name = 'fake' 73 | else: 74 | test_tree_name = request.match_info["test_tree"] 75 | prefix = "/" + test_tree_name 76 | if callable(test_tree_browsers): 77 | test_tree_browser = test_tree_browsers(test_tree_name) 78 | else: 79 | test_tree_browser = test_tree_browsers.get(test_tree_name, None) 80 | 81 | # make sure we found the given test 82 | if not hasattr(test_tree_browser, "interface_event"): 83 | log.debug(f"The test tree we found was not valid: {test_tree_browsers}") 84 | raise web.HTTPNotFound() 85 | test_tree_browser.user = user 86 | test_tree_browser.name = test_tree_name 87 | 88 | interface_html = f""" 89 | 90 | 91 | Adaptive Testing 92 | 93 | 94 | {test_tree_browser._repr_html_(prefix=prefix, environment="web", websocket_server=prefix+"/_ws")} 95 | 96 | 97 | """ 98 | 99 | return web.Response(text=interface_html, content_type="text/html") 100 | 101 | async def static_handler(request): 102 | logged_in = not await aiohttp_security.is_anonymous(request) 103 | log.debug(f"static_handler({request})") 104 | if not logged_in: 105 | user = request.rel_url.query.get("user", "anonymous") 106 | if authenticate(user, None): 107 | redirect_response = web.HTTPFound(str(request.rel_url)) 108 | await remember(request, redirect_response, user) 109 | return redirect_response 110 | else: 111 | raise web.HTTPFound(f'/_login?user={user}&sendback={str(request.rel_url)}') 112 | else: 113 | if request.raw_path == "/favicon.ico": 114 | file_path = pathlib.Path(__file__).parent.absolute() 115 | return web.FileResponse(file_path / "resources" / "favicon.png" ) 116 | elif "file_path" in request.match_info: 117 | file_path = os.path.join(static_dir, *request.match_info["file_path"].replace("..", "").split("/")) 118 | return web.FileResponse(file_path) 119 | else: 120 | raise web.HTTPNotFound() 121 | # with open(file_path) as f: 122 | # file_data = f.read() 123 | # if file_path.endswith(".png"): 124 | # content_type = "image/png" 125 | # elif file_path.endswith(".gif"): 126 | # content_type = "image/gif" 127 | # elif file_path.endswith(".jpg") or file_path.endswith(".jpeg"): 128 | # content_type = "image/jpeg" 129 | # elif file_path.endswith(".js"): 130 | # content_type = "application/javascript" 131 | # else: 132 | # content_type = "text/html" 133 | # return web.Response(text=file_data, content_type=content_type) 134 | 135 | async def login_handler(request): 136 | sendback = request.rel_url.query.get("sendback", "/") 137 | user = request.rel_url.query.get("user", "") 138 | return web.Response(text=f""" 139 | 140 | 141 | Adaptive Testing Login 142 | 143 | 144 |
145 | 146 | 147 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 |
148 | 150 | 151 | 161 | 170 | 180 | 187 | 196 | 208 | 215 | 216 | 217 | 219 | 220 | 222 | 224 | 225 | 227 | 228 | 229 | 230 | 231 |
Username:
Password:
243 | 244 |
245 | 246 | 247 | 248 | """, content_type="text/html") 249 | 250 | async def auth_handler(request): 251 | post_params = await request.post() 252 | sendback = post_params.get("sendback", "/") 253 | user = post_params.get('user', None) 254 | password = post_params.get('password', None) 255 | 256 | if authenticate(user if user is not None else "anonymous", password): 257 | redirect_response = web.HTTPFound(sendback) 258 | await remember(request, redirect_response, user if user is not None else "anonymous") 259 | return redirect_response 260 | else: 261 | raise web.HTTPFound(f"/_login?{'user='+user+'&' if user is not None else ''}sendback={sendback}") 262 | 263 | async def websocket_handler(request): 264 | ws = web.WebSocketResponse() 265 | await ws.prepare(request) 266 | 267 | # build a WebSocket comm object 268 | class WebSocketComm(): 269 | pass 270 | def ws_send(data): 271 | loop.run_until_complete(send_ws_data(ws, json.dumps(data))) 272 | comm = WebSocketComm() 273 | comm.send = ws_send 274 | 275 | if hasattr(test_tree_browsers, "_repr_html_"): 276 | test_tree_browser = test_tree_browsers 277 | else: 278 | test_tree_name = request.match_info["test_tree"] 279 | if callable(test_tree_browsers): 280 | test_tree_browser = test_tree_browsers(test_tree_name) 281 | else: 282 | test_tree_browser = test_tree_browsers[test_tree_name] 283 | test_tree_browser.comm = comm 284 | 285 | async for msg in ws: 286 | if msg.type == aiohttp.WSMsgType.TEXT: 287 | if msg.data == 'close': 288 | log.debug(f"Closing WebSocket for user '{getattr(test_tree_browser, 'user', None)}' for test tree '{getattr(test_tree_browser, 'name', None)}'!") 289 | await ws.close() 290 | else: 291 | data = json.loads(msg.data) 292 | log.info(f"WebSocket message from user '{getattr(test_tree_browser, 'user', None)}' for test tree '{getattr(test_tree_browser, 'name', None)}' is {data}") 293 | test_tree_browser.interface_event(data) 294 | elif msg.type == aiohttp.WSMsgType.ERROR: 295 | print('WebSocket connection closed with exception %s' % ws.exception()) 296 | 297 | return ws 298 | 299 | async def make_app(): 300 | middleware = aiohttp_session.session_middleware(aiohttp_session.cookie_storage.EncryptedCookieStorage( 301 | cryptography.fernet.Fernet.generate_key().decode(), max_age=auth_duration 302 | )) 303 | # middleware = aiohttp_session.session_middleware(aiohttp_session.SimpleCookieStorage()) 304 | app = web.Application(middlewares=[middleware]) 305 | 306 | app.add_routes([ 307 | web.get('/_ws', websocket_handler), 308 | web.get('/_login', login_handler), 309 | web.post('/_auth', auth_handler), 310 | web.get('/favicon.ico', static_handler) 311 | ]) 312 | 313 | if static_dir is not None: 314 | app.add_routes([web.static('/_static', static_dir)]) 315 | if hasattr(test_tree_browsers, "_repr_html_"): 316 | app.add_routes([ 317 | web.get('/{topic_path:.*}', topic_handler), 318 | web.get('/_ws', websocket_handler) 319 | ]) 320 | else: 321 | if static_dir is not None: 322 | app.add_routes([web.get('/{test_tree}/_static/{file_path:.*}', static_handler)]) 323 | app.add_routes([ 324 | web.get('/{test_tree}/_ws', websocket_handler), 325 | web.get('/{test_tree}', topic_handler), 326 | web.get('/{test_tree}/{topic_path:.*}', topic_handler) 327 | ]) 328 | 329 | policy = SessionIdentityPolicy() 330 | setup_security(app, policy, AdaptiveTestingPolicy()) 331 | 332 | return app 333 | 334 | state = { 335 | "site": None, 336 | "runner": None 337 | } 338 | async def start_server(state, host, port, ssl_crt, ssl_key): 339 | 340 | if ssl_crt is not None: 341 | import ssl 342 | ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) 343 | ssl_context.load_cert_chain(ssl_crt, ssl_key) 344 | else: 345 | ssl_context = None 346 | 347 | app = await make_app() 348 | state["runner"] = aiohttp.web.AppRunner(app) 349 | await state["runner"].setup() 350 | state["site"] = web.TCPSite(state["runner"], host, port, ssl_context=ssl_context) 351 | await state["site"].start() 352 | print(f"Server started at http://{host}:{port}") 353 | 354 | async def stop_server(state): 355 | await state["site"].stop() 356 | await state["runner"].shutdown() 357 | 358 | 359 | # aiohttp.web.run_app(make_app(), port=port) 360 | loop.run_until_complete(start_server(state, host=host, port=port, ssl_crt=ssl_crt, ssl_key=ssl_key)) 361 | 362 | try: 363 | loop.run_forever() 364 | finally: 365 | loop.run_until_complete(stop_server(state)) 366 | 367 | class AdaptiveTestingPolicy(AbstractAuthorizationPolicy): 368 | async def authorized_userid(self, identity): 369 | """Retrieve authorized user id. 370 | Return the user_id of the user identified by the identity 371 | or 'None' if no user exists related to the identity. 372 | """ 373 | return identity 374 | 375 | async def permits(self, identity, permission, context=None): 376 | """Check user permissions. 377 | Return True if the identity is allowed the permission 378 | in the current context, else return False. 379 | """ 380 | return identity == 'jack' and permission in ('listen',) 381 | -------------------------------------------------------------------------------- /notebooks/IMDB to Hotel Sentiment.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "412d3c6b", 6 | "metadata": {}, 7 | "source": [ 8 | "# IMDB to Hotel Sentiment\n", 9 | "\n", 10 | "In this notebook we will take a sentiment analysis model trained on IMDB reviews, and fine tune it to analyse tweets about hotels. We will use Adaptive Testing to help us generate a suitable test suite." 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "id": "2e9f4de2", 16 | "metadata": {}, 17 | "source": [ 18 | "## Seeding the PRNG\n", 19 | "\n", 20 | "Before we do anything else, we first seed the PRNG, to ensure that we have reproducible results:" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "id": "c907dfd0", 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import torch\n", 31 | "torch.manual_seed(1012351)" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "id": "52054118", 37 | "metadata": {}, 38 | "source": [ 39 | "## The Base Model\n", 40 | "\n", 41 | "We will use the [`aychang/roberta-base-imdb` from Hugging Face](https://huggingface.co/aychang/roberta-base-imdb) as our base model. This is a binary model which has been trained on a collection of IMDB reviews. First, we load the model itself:" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "id": "8a1c4c6e", 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "from transformers import AutoModelForSequenceClassification, AutoTokenizer\n", 52 | "from transformers import pipeline\n", 53 | "\n", 54 | "base_model_name = \"aychang/roberta-base-imdb\"\n", 55 | "\n", 56 | "model = AutoModelForSequenceClassification.from_pretrained(base_model_name,num_labels=2)\n", 57 | "tokenizer = AutoTokenizer.from_pretrained(base_model_name)\n", 58 | "\n", 59 | "original_pipeline = pipeline(\"sentiment-analysis\",\n", 60 | " model=model,\n", 61 | " tokenizer=tokenizer,\n", 62 | " top_k=2)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "id": "90e5e15b", 68 | "metadata": {}, 69 | "source": [ 70 | "Now, let's try a few sentences:" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "id": "8960e590", 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "sample_strings = [\n", 81 | " \"Great cinematography but a poor movie overall\",\n", 82 | " \"Snappy dialogue makes for enjoyable entertainment\",\n", 83 | " \"Located on a busy street with much traffic\"\n", 84 | "]\n", 85 | "\n", 86 | "for s in sample_strings:\n", 87 | " print(s, \"\\n\", original_pipeline(s), \"\\n\")" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "id": "d83c2a06", 93 | "metadata": {}, 94 | "source": [ 95 | "We can see that the two statements about movies are well classified, but the one about the final one about the hotel is not." 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "id": "4d90d0d8", 101 | "metadata": {}, 102 | "source": [ 103 | "## Using Adaptive Testing\n", 104 | "\n", 105 | "AdaptiveTesting is a tool to help create training/test suites for language models. The basic workflow is:\n", 106 | "\n", 107 | "1. User provides some sample input\n", 108 | "1. User flags whether the model output is correct or not\n", 109 | "1. AdaptiveTesting uses a second language model to generate more inputs from those already provided\n", 110 | "1. User decides which of the AdaptiveTesting proposed inputs to incorporate (and whether the model provided a correct response)\n", 111 | "\n", 112 | "Iterating through this process a few times can generate a lot of tests quite quickly." 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "id": "e48bbbef", 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [ 122 | "import adaptivetesting" 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "id": "139763f8", 128 | "metadata": {}, 129 | "source": [ 130 | "For our generator, we use OpenAI's GPT-3 model. For this, we need to read the access key in from a file:" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": null, 136 | "id": "a6293838", 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "import os\n", 141 | "with open(os.path.expanduser('~/.openai_api_key'), 'r') as file:\n", 142 | " OPENAI_API_KEY = file.read().replace('\\n', '')" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "id": "70f74aec", 148 | "metadata": {}, 149 | "source": [ 150 | "First, we create the generator object which AdaptiveTesting will use to suggest more tests which are similar to the ones we provide:" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "id": "6e310578", 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "generator = adaptivetesting.generators.OpenAI('curie', api_key=OPENAI_API_KEY)" 161 | ] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "id": "56095496", 166 | "metadata": {}, 167 | "source": [ 168 | "Now we create the test tree. We will load a set of tests which we have already started work on, to make the process faster:" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": null, 174 | "id": "1bc4e9ff", 175 | "metadata": {}, 176 | "outputs": [], 177 | "source": [ 178 | "tests = adaptivetesting.TestTree(\"imdb_hotel_conversion.csv\")" 179 | ] 180 | }, 181 | { 182 | "cell_type": "markdown", 183 | "id": "58006371", 184 | "metadata": {}, 185 | "source": [ 186 | "And fire up the AdaptiveTesting interface:" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": null, 192 | "id": "056dbba0", 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [ 196 | "tests.adapt(original_pipeline, generator, auto_save=True, recompute_scores=True)" 197 | ] 198 | }, 199 | { 200 | "cell_type": "markdown", 201 | "id": "17fa752a", 202 | "metadata": {}, 203 | "source": [ 204 | "With a set of samples composed, we need to use them to finetune the model. To begin this process, load the CSV file we've created into a DataFrame and drop the portions we don't need:" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": null, 210 | "id": "32a790c1", 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "import pandas as pd\n", 215 | "\n", 216 | "def load_adatest_data(csv_file: str):\n", 217 | " tmp = pd.read_csv(csv_file)\n", 218 | " \n", 219 | " # Drop topic marker rows\n", 220 | " tmp2 = tmp[tmp['label'] != 'topic_marker']\n", 221 | " # Drop suggestion rows\n", 222 | " tmp3 = tmp2[tmp2['topic'] != 'suggestion']\n", 223 | " \n", 224 | " # Remove columns we don't need\n", 225 | " tmp4 = tmp3.drop(labels=['labeler', 'description', 'author', 'Unnamed: 0'], axis=1)\n", 226 | " \n", 227 | " # Rename columns\n", 228 | " tmp5 = tmp4.rename(mapper={'input': 'sentence', 'label': 'model_is_correct'}, axis=1)\n", 229 | " \n", 230 | " # Remove any spurious rows\n", 231 | " tmp6 = tmp5[tmp5['topic'].notna()]\n", 232 | " \n", 233 | " # Don't need to track original rows\n", 234 | " tmp7 = tmp6.reset_index(drop=True)\n", 235 | " \n", 236 | " return tmp7\n", 237 | "\n", 238 | "\n", 239 | "test_data = load_adatest_data('imdb_hotel_conversion.csv')\n", 240 | "display(test_data)" 241 | ] 242 | }, 243 | { 244 | "cell_type": "markdown", 245 | "id": "e1777931", 246 | "metadata": {}, 247 | "source": [ 248 | "Next, we need to get the actual labels corresponding to each sentence. For this we need to combine the column which contains the output of our model and the column containing our manual labelling of whether the model was correct or incorrect." 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": null, 254 | "id": "e3e8a3d6", 255 | "metadata": {}, 256 | "outputs": [], 257 | "source": [ 258 | "def generate_label(row):\n", 259 | " # The model output is either 'pos' or 'neg'\n", 260 | " model_result = row['output']\n", 261 | " # Return based on whether the model response was marked correct or incorrect\n", 262 | " if row['model_is_correct'] == 'pass':\n", 263 | " return model_result\n", 264 | " else:\n", 265 | " if model_result == 'pos':\n", 266 | " return 'neg'\n", 267 | " else:\n", 268 | " return 'pos'\n", 269 | " \n", 270 | "# Apply this to the data\n", 271 | "test_data['label'] = test_data.apply(generate_label, axis=1)\n", 272 | "test_data" 273 | ] 274 | }, 275 | { 276 | "cell_type": "markdown", 277 | "id": "4031033a", 278 | "metadata": {}, 279 | "source": [ 280 | "We can call the pipeline directly on the sentences we have generated, and make sure that we get the same results as the one stored by Adaptive Testing:" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": null, 286 | "id": "649de984", 287 | "metadata": {}, 288 | "outputs": [], 289 | "source": [ 290 | "import numpy as np\n", 291 | "\n", 292 | "def get_label(label_probabilities):\n", 293 | " # The pipeline returns all of the label probabilities\n", 294 | " # We need to extract the largest\n", 295 | " max_score = 0\n", 296 | " label = None\n", 297 | " for l in label_probabilities:\n", 298 | " if l['score'] > max_score:\n", 299 | " max_score = l['score']\n", 300 | " label = l['label']\n", 301 | " return label\n", 302 | "\n", 303 | "y_pred = [get_label(x) for x in original_pipeline(test_data.sentence.to_list())]\n", 304 | "\n", 305 | "\n", 306 | "test_data['my_y_pred'] = y_pred\n", 307 | "assert np.array_equal(test_data['my_y_pred'], test_data['output'])\n", 308 | "\n", 309 | "display(test_data)" 310 | ] 311 | }, 312 | { 313 | "cell_type": "markdown", 314 | "id": "385dbdff", 315 | "metadata": {}, 316 | "source": [ 317 | "We can also evaluate our chosen metric, and check that the accuracy score matches what we expect from the summary at the top level of the Adaptive Testing widget:" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": null, 323 | "id": "cb2e1f62", 324 | "metadata": {}, 325 | "outputs": [], 326 | "source": [ 327 | "from datasets import load_metric\n", 328 | "\n", 329 | "metric_name = 'accuracy'\n", 330 | "\n", 331 | "metric = load_metric(metric_name)\n", 332 | "\n", 333 | "def label_to_int(l: str) -> int:\n", 334 | " # Use the mapping provided by the model\n", 335 | " return model.config.label2id[l]\n", 336 | "\n", 337 | "metric.compute(predictions=test_data['my_y_pred'].apply(label_to_int), references=test_data['label'].apply(label_to_int))" 338 | ] 339 | }, 340 | { 341 | "cell_type": "markdown", 342 | "id": "7d8348a9", 343 | "metadata": {}, 344 | "source": [ 345 | "There is one final tweak to make to our data prior to finetuning the model: the Hugging Face `Trainer`s do not use the human-friendly labels, but the corresponding integer ids. So use the mapping provided by the model to convert the 'label' column:" 346 | ] 347 | }, 348 | { 349 | "cell_type": "code", 350 | "execution_count": null, 351 | "id": "ac67a09a", 352 | "metadata": {}, 353 | "outputs": [], 354 | "source": [ 355 | "test_data['label'] = test_data['label'].apply(label_to_int)\n", 356 | "print(test_data.dtypes)" 357 | ] 358 | }, 359 | { 360 | "cell_type": "markdown", 361 | "id": "ff6a412c", 362 | "metadata": {}, 363 | "source": [ 364 | "Now, we can split our dataset into training and test sets. We stratify based on the 'topic' column, to ensure that we have samples from all of the various topics we have generated:" 365 | ] 366 | }, 367 | { 368 | "cell_type": "code", 369 | "execution_count": null, 370 | "id": "20d7fbf6", 371 | "metadata": {}, 372 | "outputs": [], 373 | "source": [ 374 | "from sklearn.model_selection import train_test_split\n", 375 | "\n", 376 | "train_df, test_df = train_test_split(test_data, stratify=test_data['topic'], test_size=0.3)" 377 | ] 378 | }, 379 | { 380 | "cell_type": "markdown", 381 | "id": "4c7b887f", 382 | "metadata": {}, 383 | "source": [ 384 | "Convert our DataFrames into Hugging Face `Dataset`s:" 385 | ] 386 | }, 387 | { 388 | "cell_type": "code", 389 | "execution_count": null, 390 | "id": "2aa80d00", 391 | "metadata": {}, 392 | "outputs": [], 393 | "source": [ 394 | "from datasets import Dataset\n", 395 | "\n", 396 | "train_ds = Dataset.from_pandas(df = train_df)\n", 397 | "test_ds = Dataset.from_pandas(df = test_df)\n", 398 | "train_ds" 399 | ] 400 | }, 401 | { 402 | "cell_type": "markdown", 403 | "id": "b5197b33", 404 | "metadata": {}, 405 | "source": [ 406 | "Encode our datasets:" 407 | ] 408 | }, 409 | { 410 | "cell_type": "code", 411 | "execution_count": null, 412 | "id": "fffb55b6", 413 | "metadata": {}, 414 | "outputs": [], 415 | "source": [ 416 | "def preprocess_function(examples):\n", 417 | " result = tokenizer(examples[\"sentence\"],\n", 418 | " add_special_tokens = True,\n", 419 | " truncation = True,\n", 420 | " padding = \"max_length\",\n", 421 | " return_attention_mask = True\n", 422 | " )\n", 423 | " return result\n", 424 | "\n", 425 | "train_encoded = train_ds.map(preprocess_function, batched=True)\n", 426 | "test_encoded = test_ds.map(preprocess_function, batched=True)\n", 427 | "\n", 428 | "drop_cols = ['topic', '__index_level_0__','model_is_correct', 'model score', 'my_y_pred', 'output']\n", 429 | "\n", 430 | "train_encoded = train_encoded.remove_columns(drop_cols)\n", 431 | "test_encoded = test_encoded.remove_columns(drop_cols)" 432 | ] 433 | }, 434 | { 435 | "cell_type": "markdown", 436 | "id": "a76e81ed", 437 | "metadata": {}, 438 | "source": [ 439 | "Configure a new training run:" 440 | ] 441 | }, 442 | { 443 | "cell_type": "code", 444 | "execution_count": null, 445 | "id": "f26489b2", 446 | "metadata": {}, 447 | "outputs": [], 448 | "source": [ 449 | "from transformers import TrainingArguments\n", 450 | "\n", 451 | "batch_size = 4\n", 452 | "\n", 453 | "args_ft = TrainingArguments(\n", 454 | " f\"hotel_fine_tuned\",\n", 455 | " evaluation_strategy = \"epoch\",\n", 456 | " save_strategy = \"epoch\",\n", 457 | " learning_rate=2e-5,\n", 458 | " per_device_train_batch_size=batch_size,\n", 459 | " per_device_eval_batch_size=batch_size,\n", 460 | " num_train_epochs=5,\n", 461 | " weight_decay=0.01,\n", 462 | " load_best_model_at_end=True,\n", 463 | " metric_for_best_model=metric_name,\n", 464 | " push_to_hub=False,\n", 465 | ")" 466 | ] 467 | }, 468 | { 469 | "cell_type": "markdown", 470 | "id": "baf346f1", 471 | "metadata": {}, 472 | "source": [ 473 | "Now, load a fresh copy of the model for fine tuning. This will allow us to compare the two models side-by-side:" 474 | ] 475 | }, 476 | { 477 | "cell_type": "code", 478 | "execution_count": null, 479 | "id": "fc8ddbb9", 480 | "metadata": {}, 481 | "outputs": [], 482 | "source": [ 483 | "ft_model = AutoModelForSequenceClassification.from_pretrained(base_model_name,num_labels=2)" 484 | ] 485 | }, 486 | { 487 | "cell_type": "markdown", 488 | "id": "dca9df6b", 489 | "metadata": {}, 490 | "source": [ 491 | "Create our new `Trainer` object, using the model we've just loaded. We pass in our new datasets for training and evaluation:" 492 | ] 493 | }, 494 | { 495 | "cell_type": "code", 496 | "execution_count": null, 497 | "id": "4d1e8e93", 498 | "metadata": {}, 499 | "outputs": [], 500 | "source": [ 501 | "from transformers import Trainer\n", 502 | "\n", 503 | "def compute_metrics(eval_pred):\n", 504 | " predictions, labels = eval_pred\n", 505 | " # Predictions are probabilities, so the actual answer is the index with the highest probability\n", 506 | " predictions = np.argmax(predictions, axis=1)\n", 507 | " return metric.compute(predictions=predictions, references=labels)\n", 508 | "\n", 509 | "trainer_ft = Trainer(\n", 510 | " ft_model,\n", 511 | " args_ft,\n", 512 | " train_dataset=train_encoded,\n", 513 | " eval_dataset=test_encoded,\n", 514 | " tokenizer=tokenizer,\n", 515 | " compute_metrics=compute_metrics\n", 516 | ")" 517 | ] 518 | }, 519 | { 520 | "cell_type": "markdown", 521 | "id": "f104c611", 522 | "metadata": {}, 523 | "source": [ 524 | "Now, we can run the training. On a CPU, this may take a few minutes (large values of 'few' may be experienced):" 525 | ] 526 | }, 527 | { 528 | "cell_type": "code", 529 | "execution_count": null, 530 | "id": "4a276759", 531 | "metadata": { 532 | "scrolled": true 533 | }, 534 | "outputs": [], 535 | "source": [ 536 | "trainer_ft.train()" 537 | ] 538 | }, 539 | { 540 | "cell_type": "code", 541 | "execution_count": null, 542 | "id": "86ebf95d", 543 | "metadata": {}, 544 | "outputs": [], 545 | "source": [ 546 | "trainer_ft.evaluate()" 547 | ] 548 | }, 549 | { 550 | "cell_type": "markdown", 551 | "id": "58956445", 552 | "metadata": {}, 553 | "source": [ 554 | "## Assessing the Fine-Tuned Model\n", 555 | "\n", 556 | "Now that we have fine-tuned the model with some examples which talk about hotels, we can see if it performs better. First, we put the new model into a scoring pipeline:" 557 | ] 558 | }, 559 | { 560 | "cell_type": "code", 561 | "execution_count": null, 562 | "id": "2e1d2687", 563 | "metadata": {}, 564 | "outputs": [], 565 | "source": [ 566 | "ft_pipeline = pipeline(\"sentiment-analysis\",\n", 567 | " model=trainer_ft.model.to('cpu'),\n", 568 | " tokenizer=tokenizer,\n", 569 | " top_k=2)" 570 | ] 571 | }, 572 | { 573 | "cell_type": "markdown", 574 | "id": "f39b78c5", 575 | "metadata": {}, 576 | "source": [ 577 | "We can re-run the initial samples we tried above:" 578 | ] 579 | }, 580 | { 581 | "cell_type": "code", 582 | "execution_count": null, 583 | "id": "7586109f", 584 | "metadata": {}, 585 | "outputs": [], 586 | "source": [ 587 | "for s in sample_strings:\n", 588 | " print(s, \"\\n\", ft_pipeline(s), \"\\n\")" 589 | ] 590 | }, 591 | { 592 | "cell_type": "markdown", 593 | "id": "53455b62", 594 | "metadata": {}, 595 | "source": [ 596 | "The sentences about movies are still well classified, but the final one about a hotel has the correct prediction now.\n", 597 | "\n", 598 | "For a more systematic comparison, we can run our `test_df` through both pipelines:" 599 | ] 600 | }, 601 | { 602 | "cell_type": "code", 603 | "execution_count": null, 604 | "id": "28332c5e", 605 | "metadata": {}, 606 | "outputs": [], 607 | "source": [ 608 | "def get_label(label_probabilities):\n", 609 | " # The pipeline returns all of the label probabilities\n", 610 | " # We need to extract the largest\n", 611 | " max_score = 0\n", 612 | " label = None\n", 613 | " for l in label_probabilities:\n", 614 | " if l['score'] > max_score:\n", 615 | " max_score = l['score']\n", 616 | " label = l['label']\n", 617 | " # Convert back to the id\n", 618 | " return ft_model.config.label2id[label]\n", 619 | "\n", 620 | "y_pred_orig = [get_label(x) for x in original_pipeline(test_df.sentence.to_list())]\n", 621 | "y_pred_ft = [get_label(x) for x in ft_pipeline(test_df.sentence.to_list())]\n", 622 | "\n", 623 | "print(\"Original : \", metric.compute(predictions=y_pred_orig, references=test_df.label))\n", 624 | "print(\"Fine Tuned: \", metric.compute(predictions=y_pred_ft, references=test_df.label))" 625 | ] 626 | }, 627 | { 628 | "cell_type": "markdown", 629 | "id": "6c7c536a", 630 | "metadata": {}, 631 | "source": [ 632 | "We see a noticeable improvement in accuracy." 633 | ] 634 | }, 635 | { 636 | "cell_type": "code", 637 | "execution_count": null, 638 | "id": "01399407", 639 | "metadata": {}, 640 | "outputs": [], 641 | "source": [] 642 | } 643 | ], 644 | "metadata": { 645 | "kernelspec": { 646 | "display_name": "Python 3 (ipykernel)", 647 | "language": "python", 648 | "name": "python3" 649 | }, 650 | "language_info": { 651 | "codemirror_mode": { 652 | "name": "ipython", 653 | "version": 3 654 | }, 655 | "file_extension": ".py", 656 | "mimetype": "text/x-python", 657 | "name": "python", 658 | "nbconvert_exporter": "python", 659 | "pygments_lexer": "ipython3", 660 | "version": "3.9.12" 661 | } 662 | }, 663 | "nbformat": 4, 664 | "nbformat_minor": 5 665 | } 666 | -------------------------------------------------------------------------------- /notebooks/imdb_hotel_conversion.csv: -------------------------------------------------------------------------------- 1 | ,topic,input,output,label,labeler,description,author,model score 2 | 00f2dfd40977438f9942793768959e8a,,,,topic_marker,imputed,,, 3 | 02076434c5fc4a389f127e542c479c64,/Rooms,Sheets didn't seem clean,neg,pass,anonymous,,anonymous,0.03536648847074937 4 | 029c3b23ed7c491bb0f5b35687dd3f8c,/Location,Far from the best beaches,pos,fail,anonymous,,anonymous,0.8619214039468347 5 | 04189a46ab1642789a5f59fb5a758e29,/Price,Expensive for what you get,pos,pass,anonymous,,,0.0 6 | 04f087d371604398b32205eab29cd587,/Location,Close to nowhere,neg,pass,anonymous,,,0.007999936759466829 7 | 061c756434f84e0b9d4dbad828496837,/Location,Just off the bus line,pos,pass,anonymous,,anonymous,0.0 8 | 069449b9da3b4481827417686c244033,/Price,High prices for average services and less convinient location,neg,pass,anonymous,,anonymous,0.0 9 | 06b259ccddbe4769b050fb855dd245dc,/Rooms,Only had a couple pillows,pos,fail,anonymous,,anonymous,0.9068814921475123 10 | 06ffa0d789944320a18d1c91a9567315,/Facilities/Restaurant,Don't eat here,neg,pass,anonymous,,anonymous,0.002532984066390833 11 | 0adc101c0dc549a089ed5733a2893dba,/Service,Service at the hotel is quite slow,neg,pass,anonymous,,anonymous,0.0 12 | 0ba15c3386e34777bfb4906394f15b82,/Price,One star service for a four star price,pos,fail,anonymous,,,1.0 13 | 0e74d4372f484ef089c75e8f82446226,/Location,Bus line runs close by,pos,pass,anonymous,,anonymous,0.0 14 | 10555369f4254142a2d42df7558063d2,/Facilities/Restaurant,"My lasagne was excellent, I could eat it every day!",pos,pass,anonymous,,anonymous,0.0 15 | 123de4a1154a4a47b3f897230b66c3f1,/Service,Any issues are taken care of quickly,pos,pass,anonymous,,anonymous,0.0 16 | 15ded90cab9e453792a97a4e739a644e,/Location,No walkability for people with mobility issues,pos,fail,anonymous,,anonymous,1.0 17 | 1604abdbfbda497e8de46f111bf4e506,/Facilities/Pool,Pool cleaning was horrible,neg,pass,anonymous,,anonymous,0.0015588112163477117 18 | 18c454fa3d444526ad71a95fe5a45ef5,/Price,Low price but high standard of service,pos,pass,anonymous,,anonymous,0.0 19 | 1a256503b0364a40808ef76b10248c30,/Service,Front desk staff were attentive,pos,pass,anonymous,,anonymous,0.0 20 | 1b96d669ba1546dba79174644e9d741f,/Facilities/Restaurant,Restaurant was located in the basement,neg,pass,anonymous,,anonymous,0.0 21 | 1c5020a50e5747cb90f4d08259c09772,/Price,,,topic_marker,anonymous,,, 22 | 1cdedd7004ea4a34a9623c91aae382f0,/Facilities/Restaurant,Waiters were friendly and attentive,pos,pass,anonymous,,anonymous,0.0 23 | 1da0369fd6bb4afb958caf57ac4ceaf5,/Price,Extraordinarily expensive for what you get,neg,pass,anonymous,,anonymous,0.0 24 | 1dc58e624e364c2b82e7f4b20ca06676,/Facilities/Restaurant,Bread did not come at the end of meal,neg,pass,anonymous,,anonymous,0.02117403731632515 25 | 1e27ca77e6bf46beb7282fc50ed9dba0,/Rooms,Only 2 bath towels,pos,fail,anonymous,,anonymous,0.5659768154018169 26 | 1ed5d0e0e01248c4a197e65ee461db7b,/Location,Hotel dropped into an anonymous business park,neg,pass,anonymous,,,0.0038315118022913173 27 | 22259ce0a5df47a5a956c49cbc8755a0,/Service,Always willing to do whatever it takes to make your stay pleasant,pos,pass,anonymous,,anonymous,0.0 28 | 2245eaa2609f40009a5a98471fbec625,/Facilities,,,topic_marker,anonymous,,, 29 | 2346ed4286764ff089c1917cf9da2ddd,/Price,Prices are a little on the steep side,pos,fail,anonymous,,anonymous,1.0 30 | 244f7853fd1142bb9427c36a666b91f9,/Facilities/Restaurant,Food quantity was small,neg,pass,anonymous,,anonymous,0.010578740506381622 31 | 257400db509e49c1b6dd9e287685140e,/Facilities/Restaurant,No toilets in restaurant,pos,fail,anonymous,,anonymous,0.8834572233315061 32 | 27a04b3d63254577b1112ae04145f935,/Facilities/Restaurant,No free refills,pos,fail,anonymous,,anonymous,0.9858127695421115 33 | 287e05c7919345bba29ce27de5b80299,/Facilities/Restaurant,Got food poisoning,neg,pass,anonymous,,anonymous,0.0 34 | 2db7e0e996ac4b8da81fa7bbb286914c,/Location,An oasis of calm during tourist season,pos,pass,anonymous,,anonymous,0.0 35 | 2e3005f8a8c24cbd96401adb7f321458,/Facilities/Restaurant,Breakfast needs a major overhaul,neg,pass,anonymous,,anonymous,0.0 36 | 2f5b87f286b04736bf061c3a7e438911,/Rooms,"No views of the ocean at all, only of a parking lot.",neg,pass,anonymous,,anonymous,0.0055049279431199595 37 | 34998856113841b78e12a3f454ff4420,/Price,Rigged to overcharge,neg,pass,anonymous,,anonymous,0.0 38 | 357caede646a4c5ea1b59da52e807e12,/Location,Slightly isolated from the area's attractions,pos,fail,anonymous,,anonymous,0.9509071111679077 39 | 387d9bf5e9ac47f49503fe8aaa10f2f9,/Rooms,Our air conditioner wasn't working,neg,pass,anonymous,,anonymous,0.0 40 | 38af115065e9475a82ba455e010667e7,/Price,Crazy expensive for what you get,neg,pass,anonymous,,anonymous,0.0 41 | 3ab16b419ab54eda9616b20631da00a2,/Facilities/Restaurant,Get here as early as possible or you'll be waiting for a long time,pos,fail,anonymous,,anonymous,0.9929495269304643 42 | 3b0a1573d9154d938928d7ce3ddff043,/Price,The price seemed reasonable,neg,fail,anonymous,,anonymous,0.16721585891294058 43 | 3b63631e191f473f852a9276e3e7bcde,/Location,Too far from some attractions,neg,pass,anonymous,,anonymous,0.2603970766067505 44 | 3ca8925223ff49969a35dd682a2abdb6,/Facilities/Restaurant,Seafood is so much better elsewhere,pos,fail,anonymous,,anonymous,0.9628090240245476 45 | 40bdeeb227c44fb99e19466df0617dd5,/Rooms,Room was noisy at night,pos,fail,anonymous,,anonymous,0.9222849882329996 46 | 4231463634b84ccf918ae5e3269a32e5,/Location,Can walk to art district,pos,pass,anonymous,,anonymous,0.0 47 | 44dc6a3678654196aa2212a943f64ec1,/Rooms,We could just lie in bed and watch the ships in the harbor,neg,fail,anonymous,,,1.0 48 | 4819f234c2c74680a145f8a4617aaf04,/Facilities/Pool,Water terribly cold,neg,pass,anonymous,,,0.0018953193006728897 49 | 4877e7bd4af94768bd881652c7876066,/Price,Good budget option,pos,pass,anonymous,,anonymous,0.0 50 | 4c272643ac1f4a62a0ccca592273d51c,/Service,Staff very attentive,pos,pass,anonymous,,,0.0 51 | 4cb11859b69a48908e0ca2e02c17ddc2,/Service,Checkout very quick,pos,pass,anonymous,,anonymous,0.0 52 | 4d0c722297bc40dba00809efeff3ddf2,/Rooms,"Room was too hot, could not sleep",neg,pass,anonymous,,anonymous,0.020949860030593202 53 | 4dd929eba2094491a64af7a2ce5d7316,/Rooms,TV didn't work in the room,neg,pass,anonymous,,,0.005870745712330646 54 | 514cf379d06440ccadb04ecf97509779,/Location,Quiet and secluded,pos,pass,anonymous,,anonymous,0.0 55 | 514de23d0393402c81c36a54e60f29ed,/Location,Neighborhood is intimidating,pos,fail,anonymous,,anonymous,0.9916251948979808 56 | 5170081ead85441e8a6887d5b0f62599,/Price,Not bad for the price paid,neg,fail,anonymous,,anonymous,1.0 57 | 54206f4acfc0498dac42279eba4fb303,/Rooms,Clean room,pos,pass,anonymous,,anonymous,0.0 58 | 5584df6c0efe4298ba96be700dce9228,/Rooms,Bed in need of replacement,pos,fail,anonymous,,anonymous,0.9335646420665538 59 | 57617f8ff6f449b69307c77a7bbbf2cb,/Rooms,No bath mats for the shower,neg,pass,anonymous,,anonymous,0.01952743961762565 60 | 57a816da1e594e47a057898522897884,/Service,Happy with how quickly front desk answered the telephone,pos,pass,anonymous,,anonymous,0.0 61 | 57e569d649fc4bd08dfdc2a1fbaa2f6e,/Location,No trains or buses nearby,pos,fail,anonymous,,anonymous,0.9322171141787088 62 | 58f95cfa3e7e4c28b7dcf9a1e1df99ed,/Location,,,topic_marker,anonymous,,, 63 | 5cfefaac1bf1406c968067fe3a912b1d,/Service,"I had a problem with a handle on the toilet...he just replaced it, no big deal",neg,fail,anonymous,,anonymous,1.0 64 | 5d892e8dce8646978c9e3ff5501390b6,/Service,Next day laundry service was very good,pos,pass,anonymous,,anonymous,0.0 65 | 5fd2f4ce35254c049dc40398b2248d18,/Facilities/Restaurant,Breakfast service very efficient,pos,pass,anonymous,,,0.0 66 | 627696d7081c4f56955dc8041858724d,/Service,Laundry service shrank my clothes,neg,pass,anonymous,,,0.0 67 | 62f8f606ac574e66977b800db82cea34,/Price,Great value for money,pos,pass,anonymous,,,0.0 68 | 64027ae207a14db0aaab9be6a9ecea1e,/Location,Tucked away on a quiet street,neg,fail,anonymous,,,0.022233733700274178 69 | 68bd829209644bc9a84dd7bd6f5d5b01,/Facilities/Pool,,,topic_marker,anonymous,,, 70 | 69de51585b4b45a48b6669292ffb3fd6,/Rooms,Room was small but cozy,pos,pass,anonymous,,,0.0 71 | 6ae7e0a652374f8d895d581b66d7085e,/Rooms,Bathroom light didn't work,neg,pass,anonymous,,,0.00546946100179207 72 | 6c3453f4c1e747d2b58e22bbcf1f16e7,/Rooms,No hot water in the shower,neg,pass,anonymous,,,0.01566485443944781 73 | 6d586a3fff17484ea71ee23ecc04582c,/Price,Overpriced given the poor facilities,neg,pass,anonymous,,,0.0 74 | 708fb5c25b7b4b03963710a8e3f34929,/Location,Close to parks and water,pos,pass,anonymous,,anonymous,0.0 75 | 72ba2281ba504b65a2ef931291a611e7,/Facilities/Pool,Pool was almost totally blocked by rubbish,neg,pass,anonymous,,anonymous,0.005074477409384375 76 | 754cdb6c7f294e1483f807e2646466d6,/Service,Service at front desk needs improvement,neg,pass,anonymous,,anonymous,0.0 77 | 75f3a65f995f43f599417ca8c20dfc64,/Facilities/Restaurant,There is a better place to eat a short distance away,pos,fail,anonymous,,anonymous,0.8710957429139538 78 | 769f96e52819496f83c8496c39b92f6a,/Price,Pricey for what you get,neg,pass,anonymous,,anonymous,0.0 79 | 76eb16ffe9c7438bb1da78e680833931,/Location,Surrounding neighborhood is unsafe,neg,pass,anonymous,,anonymous,0.06573726834944035 80 | 77895cede75c46c29f209dfaefcc08aa,/Facilities/Pool,Pool is not working,neg,pass,anonymous,,anonymous,0.008743519935493917 81 | 77a53430419e48c198f7c200a3bdc41b,/Service,The staff would try and make you smile,neg,fail,anonymous,,anonymous,1.0 82 | 78049dd0d393426cbae76ddbb187007f,/Facilities/Restaurant,Don't eat the ciabatta bread,pos,fail,anonymous,,anonymous,0.6495006680488586 83 | 783dd0a357cc4066ac0a662e435c58f5,/Service,Not always enough being done for guests in the lobby,neg,pass,anonymous,,anonymous,0.0 84 | 786bcd6d9b594e0b84ed70c6da5f676c,/Rooms,Desk chair conspicuous by its absence,neg,pass,anonymous,,,0.01946825884876856 85 | 7a9551af176e4af1865515b32f5397bc,/Price,Not cheap enough for this location,neg,pass,anonymous,,anonymous,0.0 86 | 7bc8465e009e4983a1d8fbf50842ffe6,/Facilities/Pool,No towels,pos,fail,anonymous,,,0.956467116010669 87 | 7e24228f5504453b8646f9dd7cc73c79,/Location,Awful neighbourhood for walking,pos,fail,anonymous,,anonymous,0.9462354253697054 88 | 81419edffe684a6391deb9af1d87b2c0,/Price,The most expensive hotel on the strip,pos,fail,anonymous,,anonymous,1.0 89 | 825ea717350c40ff90a8d303b2c3259f,/Location,Too far from beach,neg,pass,anonymous,,anonymous,0.03152497975248768 90 | 8275a1368f3f41e8b02b55ce39e8187e,/Facilities/Restaurant,Food is poor,neg,pass,anonymous,,anonymous,0.05149824440842651 91 | 8279b202475b40bd90d323f0ffdfcac2,/Rooms,Plenty of hot water,neg,fail,anonymous,,anonymous,0.09558087017650185 92 | 853c9fd8147c46f3b579effe6f4520c2,/Facilities/Restaurant,All the meals were a bit over seasoned,pos,fail,anonymous,,anonymous,0.9088237095843423 93 | 86e6fc6fe5d14c8691a9e8158712bb5f,/Service,Overall everyone is willing to help out,pos,pass,anonymous,,anonymous,0.0 94 | 877463f68a7f460c978b0ee39959e8e9,/Price,Your money is not well spent,neg,pass,anonymous,,anonymous,0.0 95 | 891b56253c1a439ab7bbafce933f64b7,/Rooms,View from the balcony was of a parking lot,pos,fail,anonymous,,anonymous,0.9438569021719737 96 | 895f996aeff14dcdbda01e0770b08691,/Rooms,Couldn't see any ocean,pos,fail,anonymous,,anonymous,0.9530809174163879 97 | 8b57269f3e8441a7bbbe4c99cd81cd37,/Facilities/Restaurant,Best nachos in town,pos,pass,anonymous,,anonymous,0.0 98 | 8d40b4fae3cd4779bf40266c0f94f240,/Location,Buses and trains do not stop nearby,pos,fail,anonymous,,anonymous,0.9761177196263865 99 | 908ad023d80f410eac1075ca63cf11a6,/Location,Walkability not great,pos,fail,anonymous,,anonymous,0.5081075581344234 100 | 922d0ee4e1c74a25920ecb627f3dd63c,/Service,"Took a long time to get our keys, until nearly midnight",pos,fail,anonymous,,anonymous,1.0 101 | 939ac59a858b45aa913058de2ecd16d2,/Rooms,Was not satisfied with the condition of the room or the room itself,pos,fail,anonymous,,anonymous,0.9646280156359222 102 | 9621a20b718747b790a7526dfd38cfbd,/Service,"Accommodating staff, courteous",pos,pass,anonymous,,anonymous,0.0 103 | 9632ec45e7b04be6bcff15e52f52e599,/Rooms,Bathroom had been cleaned just prior to our arrival,pos,pass,anonymous,,anonymous,0.0 104 | 978373b2df0f42538c13492cf0d6471a,/Facilities/Restaurant,Sauces a bit too hot,pos,fail,anonymous,,anonymous,0.9444596030643642 105 | 97af1b8a4a2948898130c48753328b54,/Rooms,View was different from promotional photos,pos,fail,anonymous,,anonymous,0.9764968971125088 106 | 9840372f7900424baff1e9258e58e3d7,/Rooms,AC wasn't working well,neg,pass,anonymous,,anonymous,0.020115314715187866 107 | 9a2ff1b8978846ae9f295d4846ec12b8,/Rooms,Didn't have hot water,neg,pass,anonymous,,anonymous,0.01998141284973178 108 | 9b1fef0c3b2d4a839cc5fb0b3d854861,/Price,Rather overpriced for what you get,pos,fail,anonymous,,anonymous,0.0 109 | 9ba4cb824eef42ab878e8fc4ce572973,/Facilities/Pool,Pool service didn't include towels,neg,pass,anonymous,,,0.20532752267528293 110 | 9d5879a6b52f4d41add360507657519a,/Location,Horrendous traffic noise,neg,pass,anonymous,,,0.007439990382418224 111 | 9dfd09a7496f436dafb651947727d51e,/Facilities/Restaurant,Waited 10 minutes after seating for our waitress,pos,fail,anonymous,,anonymous,0.9440716364373023 112 | a1a82990388a4e61b3fd72db79ab5870,/Location,Can walk to downtown/transit,pos,pass,anonymous,,anonymous,0.0 113 | a25a7218116347f1b075c57e7043a373,/Facilities/Pool,Pool closed for repair,pos,fail,anonymous,,,0.9739708322546833 114 | a2c4512803b843a5adc5d4c572dc4912,/Facilities/Pool,"When we were there, the pool was closed",pos,fail,anonymous,,anonymous,0.9828934881076902 115 | a2d1674f57e040dead2d64eb2fbf9916,/Facilities/Restaurant,Like variety on buffet,pos,pass,anonymous,,anonymous,0.0 116 | a52c80339cd64af882265082ae3a42ae,/Service,I was not happy with the service,neg,pass,anonymous,,anonymous,0.0 117 | a6b3ce110f474bada036aa13ccf41548,/Service,,,topic_marker,anonymous,,, 118 | a76d528c740b40019f3f69d1ed1a6b15,/Rooms,Ocean view from our room,pos,pass,anonymous,,anonymous,0.0 119 | a77390fe811d45a187e1d04004df4350,/Facilities/Restaurant,Restaurant wasn't large but served excellent food,pos,pass,anonymous,,,0.0 120 | a9027c10b08d4bc6a865ddd21e4c480d,/Price,Cheap and very quiet,pos,pass,anonymous,,anonymous,0.0 121 | a92ccc2284254033af6974e63383edea,/Price,Very expensive for what you get,neg,pass,anonymous,,anonymous,0.0 122 | a933bcbd4c2542baafaa8475961d38e8,/Rooms,Rooms were dated and small,pos,fail,anonymous,,anonymous,0.9489458026560805 123 | ab3b7acf65b24cd1afa1729a773c4110,/Rooms,The room had no clock,neg,pass,anonymous,,anonymous,0.01801225306431987 124 | abfb9c8a7a364bfdbc142a6f2ac7cdac,/Facilities/Restaurant,Food was unremarkable,neg,pass,anonymous,,,0.0 125 | ac5ebdc88ed245919a6f3364340ef448,/Rooms,Cramped bathrooms,neg,pass,anonymous,,anonymous,0.003501463001233201 126 | adeffd2e8b79435f8aaf02494294c0c0,/Facilities/Restaurant,Late dining was an added bonus,pos,pass,anonymous,,anonymous,0.0 127 | af8f94e9389e49568bbac469c6d540a9,/Price,Fair price for a swanky strip hotel,neg,fail,anonymous,,anonymous,1.0 128 | afd78d1c77d243d2ba54c3686ac2ff3e,/Price,Fairly priced,pos,pass,anonymous,,,0.0 129 | aff3e0264a8d448c836daa4a109419a7,/Location,Hard to get to from tourist attractions,pos,fail,anonymous,,anonymous,0.9668508984189681 130 | b09f27d62f0f45cd851464b44aeefe2f,/Rooms,The smell of mold or mildew was heavy,neg,pass,anonymous,,anonymous,0.03170379541096625 131 | b0ad37a7942e4d0989833f90903f9d6d,/Location,Close to the beach,pos,pass,anonymous,,anonymous,0.0 132 | b0b16ba1a58c4c288a667aef6ccc76b7,/Location,Very few restaurants and shops,neg,pass,anonymous,,anonymous,0.03229419574689485 133 | b1aab53a4f2143b2894c09f1dc9fb82e,/Rooms,Bed was very comfy,pos,pass,anonymous,,anonymous,0.0 134 | b21f1062d7564a76b4a251189da8a1dd,/Facilities/Pool,Two hot tubs!,pos,pass,anonymous,,,0.0 135 | b63f3ac15004474eb013c3d4134e1e01,/Service,They were so willing to help us with everything,pos,pass,anonymous,,anonymous,0.0 136 | b8b872e399ae44c4baa73de896162065,/Service,Concierge got us a reservation at a restaurant,pos,pass,anonymous,,,0.0 137 | b8e1a50475604ee390feab4bfe14424a,/Service,Had been promised early checkin but were turned away,pos,fail,anonymous,,,1.0 138 | b9f5539ec780496aa72e3777d03fbaa7,/Rooms,Our room had an ocean view,pos,pass,anonymous,,,0.0 139 | ba328c365f3943488ecf109c455a5a60,/Location,Close to all the sights,pos,pass,anonymous,,anonymous,0.0 140 | bab0178fc46a427899447113248122c9,/Facilities/Pool,No lane swimming,pos,fail,anonymous,,,0.9611776900817911 141 | bac7bb1c7ffa41b485ca71e09aae6998,/Location,Away from the madness of downtown,pos,pass,anonymous,,anonymous,0.0 142 | bc8637a35f6a43ecab2e9f1e2a6b9517,/Facilities/Restaurant,"Food was terrible, should close down",neg,pass,anonymous,,anonymous,0.013274874064298664 143 | bc931528b0bf4c3dbb258a8a38d0cdc6,/Rooms,Bathroom fan broke,neg,pass,anonymous,,anonymous,0.043131239390885534 144 | be290ab5098c4ce5bee1abace7f6bc7b,/Facilities/Pool,Pool was closed for the season,pos,fail,anonymous,,anonymous,0.9801850118077637 145 | bf49f08e930448d9ad00343919b31886,/Price,This place is more expensive than it appears,pos,fail,anonymous,,anonymous,1.0 146 | bfb2e0d1194c4625898339bd1b2f24e9,/Rooms,The air conditioner was not working,neg,pass,anonymous,,anonymous,0.11590849003673025 147 | c22f5c0b880d40889e599f54ca38928a,/Rooms,Room was very small and dark,neg,pass,anonymous,,anonymous,0.023819851068293596 148 | c30ad576150f4343a129c9bd4872460e,/Facilities/Restaurant,Food portion was mini,pos,fail,anonymous,,anonymous,0.9234152849055209 149 | c58b3df66491447a994fb6eaca312640,/Rooms,Wardrobe lacked coathangers,neg,pass,anonymous,,,0.004106684871886362 150 | c5ced663935e4d14a0ea3c57e1ae343e,/Service,I asked for a late check-out but was told no problem,pos,pass,anonymous,,anonymous,0.0 151 | c6cbf8a3e9754d09b391ef1e6c9978f1,/Facilities/Restaurant,Drinks reasonably priced,pos,pass,anonymous,,anonymous,0.0 152 | c896f64736b54f539606a011b40cffdd,/Rooms,Problem with key card in the elevator,neg,pass,anonymous,,anonymous,0.4970642328262329 153 | c99d0927d84645178879e949ff37787c,/Service,Friendly concierge,pos,pass,anonymous,,anonymous,0.0 154 | ca06223f280541ac9a4e9c27cb7ff986,/Location,Road noise is intense,pos,fail,anonymous,,anonymous,0.9398255207305095 155 | ca8622d782ac41f195a4bb3efd4d4e2c,/Facilities/Restaurant,Dining room did not seem well maintained,neg,pass,anonymous,,anonymous,0.0 156 | caf8976f3f34409cae22dcf6361e2634,/Location,Isolated from the downtown,pos,fail,anonymous,,anonymous,0.5898931866947897 157 | ccf1a41f509148afb4b78b014b2a1848,/Rooms,Rooms have not been kept in good condition,pos,fail,anonymous,,anonymous,0.8474004485875167 158 | cd40428fd68d4db8b4a8a0bf547b943e,/Facilities/Restaurant,,,topic_marker,anonymous,,, 159 | cd92ed58725d4658be9c2d0b0cfb1671,/Price,Expensive but was worth it,pos,pass,anonymous,,anonymous,0.0 160 | ce1cc7a936574950bf822b54ae68b7ff,/Rooms,Plenty of space,neg,fail,anonymous,,anonymous,1.0 161 | ce5ac35f85e5461f82d6ef5d1cf8a653,/Facilities/Restaurant,They have milkshakes. Try them,pos,pass,anonymous,,anonymous,0.0 162 | cf2fc3107ba849628ca3ff4c44f10d7c,/Rooms,Where are the towels?,neg,pass,anonymous,,anonymous,0.039039845402230385 163 | d02af05847554e829fd8c2a54d25a682,/Facilities/Restaurant,Went to restaurant and were promptly seated,pos,pass,anonymous,,anonymous,0.0 164 | d38aa871131c411094fb7eff4329eaf2,/Price,Unreasonably high price,neg,pass,anonymous,,anonymous,0.0 165 | d3aee530ba1c497090a85682d4eedecc,/Rooms,Room was very small and hot and window was broken,neg,pass,anonymous,,anonymous,0.1914782162671484 166 | d438771ca25e4b789a166f42311a8778,/Service,Booked our son's ski lesson and was not told that was an extra fee,pos,fail,anonymous,,anonymous,1.0 167 | d55325e8dc5d48be92c94dd76ba92b0d,/Rooms,Windows need new screens,pos,fail,anonymous,,anonymous,0.9633868193121728 168 | d6fa4ecef1554b108c121c367cea2c09,/Rooms,Shower didn't drain well,neg,pass,anonymous,,anonymous,0.02177738050122612 169 | d7e456cfc5944b108bfd56e6a28c5fc4,/Rooms,The only view from the room was of the front parking lot,neg,pass,anonymous,,anonymous,0.005614943169105575 170 | d88da78b2a304af79d7451acac7fbb13,/Facilities/Restaurant,Baked potatoes taste off,neg,pass,anonymous,,anonymous,0.0027625570047995476 171 | d8e1dc7d1baa4ca696c611a9f15506f3,/Location,Near light rail,pos,pass,anonymous,,anonymous,0.0 172 | d94d5eb7522746cfa8b2a80aa55d38b8,/Service,Staff was friendly,pos,pass,anonymous,,anonymous,0.0 173 | df8ce65ef12847e58ea7052b1ecc4e2a,/Service,Housekeeping quick and efficient,pos,pass,anonymous,,,0.0 174 | e15d2b8a684c429d86d192c3194ae153,/Facilities/Pool,Lockers were broken,pos,fail,anonymous,,,0.9748728620309092 175 | e2765b3385b7450fa7196361c189bd2f,/Service,Receptionists very attentive,pos,pass,anonymous,,anonymous,0.0 176 | e33ec9dbd07b4abda6ff020c1e34fe18,/Facilities/Restaurant,There are better places to eat nearby,pos,fail,anonymous,,anonymous,0.9084235104448649 177 | e36343e0829c4c8a9564662b9b82c2df,/Price,Rigged to rip off,neg,pass,anonymous,,anonymous,0.0 178 | e58c6478b5e546deaf9c94f95c85f7f8,/Service,Hotel staff will always act like they don't understand what you are saying,neg,pass,anonymous,,anonymous,0.0 179 | e66ed9065fcd4820bd6f8765eeca2565,/Price,Fairly priced for what you get,pos,pass,anonymous,,anonymous,0.0 180 | e75cea1ce3274d3cb4e709cd45c62fa3,/Service,Every single person working there made me feel like I'm their only priority,neg,fail,anonymous,,anonymous,0.9832501604566285 181 | e7a6d4bfe0a845c0b8e2f01ab10cf8a8,/Rooms,Our room had a refrigerator and was very clean,pos,pass,anonymous,,anonymous,0.0 182 | e8ef75b760bc445585d4f998909425c3,/Service,Tried to reach a concierge twice and never responded to each of our requests.,neg,pass,anonymous,,anonymous,0.0 183 | e9b85bb224354a7fb311f68bd458be07,/Price,Kind of pricey,neg,pass,anonymous,,anonymous,0.0 184 | ece3562bf84044ed9cb56d27566766f6,/Location,"Nice, quiet neighborhood",pos,pass,anonymous,,anonymous,0.0 185 | ed88f3124b224e61a7286dbb26d694e8,/Price,Quite expensive,neg,pass,anonymous,,anonymous,0.0 186 | f2011cf43eb642f6a503c4ed612bb767,/Facilities/Restaurant,Served my meal in less than 15 minutes,neg,fail,anonymous,,anonymous,1.0 187 | f29cfe9996ad40c8957df6208abeda39,/Price,A bit pricey for what it is,pos,fail,anonymous,,anonymous,0.9759607533207704 188 | f30ef67be728417a956e6438bd1ff375,/Service,Clothes came back neatly folded from laundry,pos,pass,anonymous,,,0.0 189 | f3d016316bfa4d20bfb93e971c8c021c,/Price,Four-star amenities on a one-star budget,pos,pass,anonymous,,anonymous,0.0 190 | f44b30d794014d5f964c3908f5eb63be,/Location,Easy to get to transit,pos,pass,anonymous,,,0.0 191 | f7cd947b83754f7788cf3b6eee47a1f5,/Price,Very expensive for the quality,pos,fail,anonymous,,anonymous,0.976195779895292 192 | f9d57eb623564826bdc697f7dd43be77,/Price,Fair value for money,pos,pass,anonymous,,anonymous,0.0 193 | fa4d7f67a3f14fa79a380afa5764ef80,/Price,Cheap for what you get,neg,fail,anonymous,,anonymous,0.014309784957217955 194 | fb9866889c414e2593158a1ddaa0e610,/Price,Costs have crept up a lot since last time I was here,pos,fail,anonymous,,anonymous,1.0 195 | fc6f00cf6dea4459811bb3d9ba84cec4,/Rooms,,,topic_marker,anonymous,,, 196 | fedd0976a40b49ad84fad703f6d70837,/Rooms,The toilet kept flooding,pos,fail,anonymous,,anonymous,0.945428799227947 197 | fef922cf2a62466584f4582da641ca2c,/Rooms,No door to bathroom!,pos,fail,anonymous,,anonymous,0.5445216298103333 198 | ff38155ab75945caac5c9e08eba4e1c4,/Rooms,Rooms were clean,pos,pass,anonymous,,anonymous,0.0 199 | fff299cde43e42ab8f7ba0a64c7f1a7d,,New test,pos,pass,imputed,,anonymous,0.0 200 | --------------------------------------------------------------------------------