├── .github └── workflows │ └── publish-to-pypi.yml ├── .gitignore ├── LICENSE ├── README.md ├── dev-requirements.txt ├── pyproject.toml ├── setup.py ├── sketch ├── __init__.py ├── core.py ├── metrics.py ├── pandas_extension.py ├── references.py └── sketches.py └── tests ├── conftest.py ├── test_calculate_sketches.py ├── test_metrics_from_sketchpads.py └── test_pandas_extension.py /.github/workflows/publish-to-pypi.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python 🐍 distributions 📦 to PyPI 2 | 3 | on: push 4 | jobs: 5 | tests: 6 | name: Test package 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | # python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] 11 | python-version: ['3.8', '3.9', '3.10', '3.11'] 12 | steps: 13 | - uses: actions/checkout@master 14 | - name: Set up Python ${{ matrix.python-version }} 15 | uses: actions/setup-python@v4 16 | with: 17 | python-version: ${{ matrix.python-version }} 18 | - name: Install dependencies 19 | run: | 20 | python -m pip install --upgrade pip 21 | python -m pip install tox tox-gh-actions 22 | - name: Test with tox 23 | run: tox 24 | build-n-publish: 25 | name: Build and publish Python 🐍 distributions 📦 to PyPI and TestPyPI 26 | runs-on: ubuntu-latest 27 | needs: [tests] 28 | steps: 29 | - uses: actions/checkout@master 30 | - name: Set up Python 3.11 31 | uses: actions/setup-python@v3 32 | with: 33 | python-version: "3.11" 34 | - name: Install pypa/build 35 | run: >- 36 | python -m 37 | pip install 38 | build 39 | --user 40 | - name: Build a binary wheel and a source tarball 41 | run: >- 42 | python -m 43 | build 44 | --sdist 45 | --wheel 46 | --outdir dist/ 47 | - name: Publish distribution 📦 to PyPI 48 | if: startsWith(github.ref, 'refs/tags') 49 | uses: pypa/gh-action-pypi-publish@release/v1 50 | with: 51 | password: ${{ secrets.PYPI_API_TOKEN }} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .vscode/** 132 | 133 | .DS_Store 134 | .DS_Store? 135 | ._* 136 | .Spotlight-V100 137 | .Trashes 138 | ehthumbs.db 139 | Thumbs.db 140 | 141 | *.db 142 | *.db-wal 143 | *.db-shm 144 | *.index 145 | 146 | *.parquet 147 | 148 | *.csv 149 | *.xlsx 150 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Justin Waugh, Mike Biven 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [](https://discord.gg/kW9nBQErGe) 2 | 3 | # sketch 4 | 5 | Sketch is an AI code-writing assistant for pandas users that understands the context of your data, greatly improving the relevance of suggestions. Sketch is usable in seconds and doesn't require adding a plugin to your IDE. 6 | 7 | ```bash 8 | pip install sketch 9 | ``` 10 | 11 | ## Demo 12 | 13 | Here we follow a "standard" (hypothetical) data-analysis workflow, showing a Natural Language interface that successfully navigates many tasks in the data stack landscape. 14 | 15 | - Data Catalogging: 16 | - General tagging (eg. PII identification) 17 | - Metadata generation (names and descriptions) 18 | - Data Engineering: 19 | - Data cleaning and masking (compliance) 20 | - Derived feature creation and extraction 21 | - Data Analysis: 22 | - Data questions 23 | - Data visualization 24 | 25 | https://user-images.githubusercontent.com/916073/212602281-4ebd090f-09c4-495d-b48d-0b4c37b9f665.mp4 26 | 27 | Try it out in colab: [](https://colab.research.google.com/gist/bluecoconut/410a979d94613ea2aaf29987cf0233bc/sketch-demo.ipynb) 28 | 29 | ## How to use 30 | 31 | It's as simple as importing sketch, and then using the `.sketch` extension on any pandas dataframe. 32 | 33 | ```python 34 | import sketch 35 | ``` 36 | 37 | Now, any pandas dataframe you have will have an extension registered to it. Access this new extension with your dataframes name `.sketch` 38 | 39 | ### `.sketch.ask` 40 | 41 | Ask is a basic question-answer system on sketch, this will return an answer in text that is based off of the summary statistics and description of the data. 42 | 43 | Use ask to get an understanding of the data, get better column names, ask hypotheticals (how would I go about doing X with this data), and more. 44 | 45 | ```python 46 | df.sketch.ask("Which columns are integer type?") 47 | ``` 48 | 49 | ### `.sketch.howto` 50 | 51 | Howto is the basic "code-writing" prompt in sketch. This will return a code-block you should be able to copy paste and use as a starting point (or possibly ending!) for any question you have to ask of the data. Ask this how to clean the data, normalize, create new features, plot, and even build models! 52 | 53 | ```python 54 | df.sketch.howto("Plot the sales versus time") 55 | ``` 56 | 57 | ### `.sketch.apply` 58 | 59 | apply is a more advanced prompt that is more useful for data generation. Use it to parse fields, generate new features, and more. This is built directly on [lambdaprompt](https://github.com/approximatelabs/lambdaprompt). In order to use this, you will need to set up a free account with OpenAI, and set an environment variable with your API key. `OPENAI_API_KEY=YOUR_API_KEY` 60 | 61 | ```python 62 | df['review_keywords'] = df.sketch.apply("Keywords for the review [{{ review_text }}] of product [{{ product_name }}] (comma separated):") 63 | ``` 64 | 65 | ```python 66 | df['capitol'] = pd.DataFrame({'State': ['Colorado', 'Kansas', 'California', 'New York']}).sketch.apply("What is the capitol of [{{ State }}]?") 67 | ``` 68 | 69 | ## Sketch currently uses `prompts.approx.dev` to help run with minimal setup 70 | 71 | You can also directly use a few pre-built hugging face models (right now `MPT-7B` and `StarCoder`), which will run entirely locally (once you download the model weights from HF). 72 | Do this by setting environment 3 variables: 73 | 74 | ```python 75 | os.environ['LAMBDAPROMPT_BACKEND'] = 'StarCoder' 76 | os.environ['SKETCH_USE_REMOTE_LAMBDAPROMPT'] = 'False' 77 | os.environ['HF_ACCESS_TOKEN'] = 'your_hugging_face_token' 78 | ``` 79 | 80 | You can also directly call OpenAI directly (and not use our endpoint) by using your own API key. To do this, set 2 environment variables. 81 | 82 | (1) `SKETCH_USE_REMOTE_LAMBDAPROMPT=False` 83 | (2) `OPENAI_API_KEY=YOUR_API_KEY` 84 | 85 | ## How it works 86 | 87 | Sketch uses efficient approximation algorithms (data sketches) to quickly summarize your data, and feed that information into language models. Right now it does this by summarizing the columns and writing these summary statistics as additional context to be used by the code-writing prompt. In the future we hope to feed these sketches directly into custom made "data + language" foundation models to get more accurate results. 88 | 89 | -------------------------------------------------------------------------------- /dev-requirements.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | tox 3 | isort 4 | black 5 | flake8 6 | pytest-asyncio 7 | pytest-mock -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=45", "wheel", "setuptools_scm>=6.2"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name="sketch" 7 | description="Compute, store and operate on data sketches" 8 | readme = "README.md" 9 | requires-python = ">=3.8" 10 | keywords = ["data", "sketch", "model", "etl", "automatic", "join", "ai", "embedding", "profiling"] 11 | license = {file = "LICENSE"} 12 | classifiers = [ 13 | "Programming Language :: Python :: 3", 14 | ] 15 | dependencies = [ 16 | "pandas>=1.3.0", 17 | "datasketch>=1.5.8", 18 | "datasketches>=4.0.0", 19 | "ipython", 20 | "lambdaprompt>=0.6.1", 21 | "packaging" 22 | ] 23 | urls = {homepage = "https://github.com/approximatelabs/sketch"} 24 | dynamic = ["version"] 25 | 26 | [project.optional-dependencies] 27 | local = ["lambdaprompt[local]"] 28 | all = ["sketch[local]"] 29 | 30 | [tool.setuptools_scm] 31 | 32 | [tool.tox] 33 | legacy_tox_ini = """ 34 | [tox] 35 | envlist = py38, py39, py310, py311 36 | 37 | [gh-actions] 38 | python = 39 | 3.8: py38 40 | 3.9: py39 41 | 3.10: py310 42 | 3.11: py311 43 | 44 | [testenv] 45 | deps= -rdev-requirements.txt 46 | commands = python -m pytest tests 47 | """ 48 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup() 4 | -------------------------------------------------------------------------------- /sketch/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import Portfolio, SketchPad # noqa 2 | from .pandas_extension import SketchHelper # noqa 3 | -------------------------------------------------------------------------------- /sketch/core.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import heapq 3 | import logging 4 | import os 5 | import sqlite3 6 | import uuid 7 | 8 | import pandas as pd 9 | from packaging import version 10 | 11 | from .metrics import binary_metrics, strings_from_sketchpad_sketches, unary_metrics 12 | from .references import ( 13 | PandasDataframeColumn, 14 | Reference, 15 | SqliteColumn, 16 | WikipediaTableColumn, 17 | ) 18 | from .sketches import SketchBase 19 | 20 | SKETCHCACHE = "~/.cache/sketch/" 21 | 22 | # TODO: These object models are possibly different than the ones in api models 23 | # and those are different than the ones in data models... need to rectify. 24 | # either use a single source of truth, or have good robust tests. 25 | # -- These feel more useful for the client, utility methods 26 | 27 | 28 | # TODO: consider sketchpad having the same "interface" as a sketch.. 29 | # maybe that's the "abstraction" here... 30 | class SketchPad: 31 | version = "0.0.1" 32 | sketch_classes = SketchBase.all_sketches() 33 | 34 | def __init__(self, reference, context=None, initialize_sketches=True): 35 | self.version = "0.0.1" 36 | self.id = str(uuid.uuid4()) 37 | self.metadata = { 38 | "id": self.id, 39 | "creation_start": datetime.datetime.utcnow().isoformat(), 40 | } 41 | self.reference = reference 42 | self.context = context or {} 43 | if initialize_sketches: 44 | self.sketches = [skcls.empty() for skcls in self.sketch_classes] 45 | else: 46 | self.sketches = [] 47 | 48 | def get_sketch_by_name(self, name): 49 | sketches = [sk for sk in self.sketches if sk.name == name] 50 | if len(sketches) == 1: 51 | return sketches[0] 52 | return None 53 | 54 | def get_sketchdata_by_name(self, name): 55 | sketch = self.get_sketch_by_name(name) 56 | return sketch.data if sketch else None 57 | 58 | def minhash_jaccard(self, other): 59 | self_minhash = self.get_sketchdata_by_name("MinHash") 60 | other_minhash = other.get_sketchdata_by_name("MinHash") 61 | if self_minhash is None or other_minhash is None: 62 | return None 63 | return self_minhash.jaccard(other_minhash) 64 | 65 | def compute_sketches(self, data): 66 | # data is assumed to be an iterable 67 | for row in data: 68 | for sk in self.sketches: 69 | sk.add_row(row) 70 | # freeze sketches 71 | for sk in self.sketches: 72 | sk.freeze() 73 | 74 | def to_dict(self): 75 | return { 76 | "version": self.version, 77 | "metadata": self.metadata, 78 | "reference": self.reference.to_dict(), 79 | "sketches": [s.to_dict() for s in self.sketches], 80 | "context": self.context, 81 | } 82 | 83 | @classmethod 84 | def from_series(cls, series: pd.Series, reference: Reference = None) -> "SketchPad": 85 | if reference is None: 86 | reference = PandasDataframeColumn("df", series.name) 87 | sp = cls(reference, initialize_sketches=False) 88 | for skcls in cls.sketch_classes: 89 | sp.sketches.append(skcls.from_series(series)) 90 | sp.metadata["creation_end"] = datetime.datetime.utcnow().isoformat() 91 | return sp 92 | 93 | @classmethod 94 | def from_dict(cls, data): 95 | assert data["version"] == cls.version 96 | sp = cls(Reference.from_dict(data["reference"])) 97 | sp.id = data["metadata"]["id"] 98 | sp.metadata = data["metadata"] 99 | sp.context = data["context"] 100 | sp.sketches = [SketchBase.from_dict(s) for s in data["sketches"]] 101 | return sp 102 | 103 | def get_metrics(self): 104 | return unary_metrics(self) 105 | 106 | def get_cross_metrics(self, other): 107 | return binary_metrics(self, other) 108 | 109 | def string_value_representation(self): 110 | return strings_from_sketchpad_sketches(self) 111 | 112 | 113 | class Portfolio: 114 | def __init__(self, sketchpads=None): 115 | self.sketchpads = {sp.id: sp for sp in (sketchpads or [])} 116 | 117 | @classmethod 118 | def from_dataframe(cls, df, dfname="df"): 119 | return cls().add_dataframe(df, dfname=dfname) 120 | 121 | def add_dataframe(self, df, dfname="df"): 122 | for col in df.columns: 123 | reference = PandasDataframeColumn(dfname, col) 124 | sp = SketchPad.from_series(df[col], reference) 125 | self.add_sketchpad(sp) 126 | return self 127 | 128 | @classmethod 129 | def from_dataframes(cls, dfs): 130 | return cls().add_dataframes(dfs) 131 | 132 | def add_dataframes(self, dfs): 133 | # in general, this method is poor because of name tracking 134 | for df in dfs: 135 | self.add_dataframe(df) 136 | return self 137 | 138 | @classmethod 139 | def from_sqlite(cls, sqlite_db_path): 140 | return cls().add_sqlite(sqlite_db_path) 141 | 142 | def add_wikitable(self, page, id, headers, pandas_df): 143 | for col in pandas_df.columns: 144 | reference = WikipediaTableColumn(page, id, headers, col) 145 | sp = SketchPad.from_series(pandas_df[col], reference) 146 | self.add_sketchpad(sp) 147 | 148 | def get_sketchpad_by_reference_id(self, reference_id): 149 | for sketchpad in self.sketchpads.values(): 150 | if sketchpad.reference.id == reference_id: 151 | return sketchpad 152 | return None 153 | 154 | def add_sqlite(self, sqlite_db_path): 155 | if sqlite_db_path.startswith("http"): 156 | os.system(f"wget -nc {sqlite_db_path} --directory-prefix={SKETCHCACHE} -q") 157 | path = os.path.join(SKETCHCACHE, os.path.split(sqlite_db_path)[1]) 158 | else: 159 | path = sqlite_db_path 160 | conn = sqlite3.connect(path) 161 | conn.text_factory = lambda b: b.decode(errors="ignore") 162 | # TODO: Consider using a cursor to avoid the need for this 163 | meta_name = ( 164 | "sqlite_master" 165 | if version.parse(sqlite3.sqlite_version) < version.Version("3.33.0") 166 | else "sqlite_schema" 167 | ) 168 | tables = pd.read_sql( 169 | f"SELECT name FROM {meta_name} WHERE type='table' ORDER BY name;", conn 170 | ) 171 | logging.info(f"Found {len(tables)} tables in file {sqlite_db_path}") 172 | for i, table in enumerate(tables.name): 173 | for column in pd.read_sql(f'PRAGMA table_info("{table}")', conn).name: 174 | query = f'SELECT "{column}" FROM "{table}"' 175 | reference = SqliteColumn(sqlite_db_path, query, column) 176 | # consider iterator here 177 | sp = SketchPad.from_series( 178 | pd.read_sql(query, conn)[f"{column}"], 179 | reference, 180 | ) 181 | self.add_sketchpad(sp) 182 | return self 183 | 184 | @classmethod 185 | def from_sketchpad(cls, sketchpad): 186 | return cls().add_sketchpad(sketchpad) 187 | 188 | def add_sketchpad(self, sketchpad): 189 | self.sketchpads[sketchpad.id] = sketchpad 190 | return self 191 | 192 | def get_approx_pk_sketchpads(self): 193 | # is an estimated unique_key if unique count estimate 194 | # is > 97% the number of rows 195 | pf = Portfolio() 196 | for sketchpad in self.sketchpads.values(): 197 | uq = sketchpad.get_sketchdata_by_name("HyperLogLog").count() 198 | rows = int(sketchpad.get_sketchdata_by_name("Rows")) 199 | if uq > 0.97 * rows: 200 | pf.add_sketchpad(sketchpad) 201 | return pf 202 | 203 | def closest_overlap(self, sketchpad, n=5): 204 | scores = [] 205 | for sp in self.sketchpads.values(): 206 | score = sketchpad.minhash_jaccard(sp) 207 | heapq.heappush(scores, (score, sp.id)) 208 | top_n = heapq.nlargest(n, scores, key=lambda x: x[0]) 209 | return [(s, self.sketchpads[i]) for s, i in top_n] 210 | -------------------------------------------------------------------------------- /sketch/metrics.py: -------------------------------------------------------------------------------- 1 | import datasketches 2 | import numpy as np 3 | 4 | 5 | def strings_from_sketchpad_sketches(sketchpad): 6 | # FI and VO are the two 7 | output = "" 8 | ds = sketchpad.get_sketchdata_by_name("DS_FI") 9 | # consider showing the counts of frequent items?? Might be useful information. 10 | output += " ".join( 11 | [ 12 | x[0] 13 | for x in ds.get_frequent_items( 14 | datasketches.frequent_items_error_type.NO_FALSE_POSITIVES 15 | ) 16 | ] 17 | ) 18 | output += "\n" 19 | output += " ".join( 20 | [ 21 | x[0] 22 | for x in ds.get_frequent_items( 23 | datasketches.frequent_items_error_type.NO_FALSE_NEGATIVES 24 | ) 25 | ] 26 | ) 27 | output += "\n" 28 | ds = sketchpad.get_sketchdata_by_name("DS_VO") 29 | output += " ".join([x[0] for x in ds.get_samples()]) 30 | return output 31 | 32 | 33 | def unary_metrics(sketchpad): 34 | # get metrics for a single sketchpad 35 | # return a vector of metrics 36 | metrics = {} 37 | 38 | metrics["rows"] = sketchpad.get_sketchdata_by_name("Rows") 39 | metrics["count"] = sketchpad.get_sketchdata_by_name("Count") 40 | 41 | ds = sketchpad.get_sketchdata_by_name("DS_HLL") 42 | 43 | metrics["hll_lower_bound_2"] = ds.get_lower_bound(2) 44 | metrics["hll_upper_bound_2"] = ds.get_upper_bound(2) 45 | metrics["hll_estimate"] = ds.get_estimate() 46 | 47 | ds = sketchpad.get_sketchdata_by_name("DS_CPC") 48 | metrics["cpc_lower_bound_2"] = ds.get_lower_bound(2) 49 | metrics["cpc_upper_bound_2"] = ds.get_upper_bound(2) 50 | metrics["cpc_estimate"] = ds.get_estimate() 51 | 52 | ds = sketchpad.get_sketchdata_by_name("DS_THETA") 53 | metrics["theta_lower_bound_2"] = ds.get_lower_bound(2) 54 | metrics["theta_upper_bound_2"] = ds.get_upper_bound(2) 55 | metrics["theta_estimate"] = ds.get_estimate() 56 | 57 | ds = sketchpad.get_sketchdata_by_name("DS_FI") 58 | # likely can't use these, as they are more... values of data than metrics 59 | # metrics["fi_no_false_pos"] = ds.get_frequent_items(datasketches.frequent_items_error_type.NO_FALSE_POSITIVES) 60 | # metrics["fi_no_false_neg"] = ds.get_frequent_items(datasketches.frequent_items_error_type.NO_FALSE_NEGATIVES) 61 | 62 | ds = sketchpad.get_sketchdata_by_name("DS_KLL") 63 | # pts = ds.get_quantiles([0.01, 0.1, 0.25, 0.5, 0.75, 0.9, 0.99]) 64 | metrics["kll_quantile_0.01"] = ds.get_quantile(0.01) 65 | metrics["kll_quantile_0.1"] = ds.get_quantile(0.1) 66 | metrics["kll_quantile_0.25"] = ds.get_quantile(0.25) 67 | metrics["kll_quantile_0.5"] = ds.get_quantile(0.5) 68 | metrics["kll_quantile_0.75"] = ds.get_quantile(0.75) 69 | metrics["kll_quantile_0.9"] = ds.get_quantile(0.9) 70 | metrics["kll_quantile_0.99"] = ds.get_quantile(0.99) 71 | 72 | ds = sketchpad.get_sketchdata_by_name("DS_Quantiles") 73 | metrics["quantiles_quantile_0.01"] = ds.get_quantile(0.01) 74 | metrics["quantiles_quantile_0.1"] = ds.get_quantile(0.1) 75 | metrics["quantiles_quantile_0.25"] = ds.get_quantile(0.25) 76 | metrics["quantiles_quantile_0.5"] = ds.get_quantile(0.5) 77 | metrics["quantiles_quantile_0.75"] = ds.get_quantile(0.75) 78 | metrics["quantiles_quantile_0.9"] = ds.get_quantile(0.9) 79 | metrics["quantiles_quantile_0.99"] = ds.get_quantile(0.99) 80 | 81 | ds = sketchpad.get_sketchdata_by_name("DS_REQ") 82 | metrics["req_min_value"] = ds.get_min_value() 83 | metrics["req_max_value"] = ds.get_max_value() 84 | # not sure, should i include quantiles or specific "rank" get values? 85 | 86 | # VO Sketch has failed 87 | # ds = wow.get_sketchdata_by_name("DS_VO") 88 | # print("=VO=".ljust(12, " "), ds.to_string(True)) 89 | 90 | ds = sketchpad.get_sketchdata_by_name("UnicodeMatches") 91 | metrics.update({f"unicode_{k}": v for k, v in ds.items()}) 92 | 93 | return metrics 94 | 95 | 96 | def max_delta(x1, y1, x2, y2): 97 | f1 = np.interp(np.concatenate([x1, x2]), x2, y2) 98 | f2 = np.interp(np.concatenate([x1, x2]), x1, y1) 99 | return np.max(np.abs(f1 - f2)) 100 | 101 | 102 | def get_CDF(s, N=100): 103 | yvals = [x / N for x in range(N + 1)] 104 | xvals = s.get_quantiles(yvals) 105 | return xvals, yvals 106 | 107 | 108 | def ks_estimate(s1, s2): 109 | # Need to do a smarter job of handling nulls or something 110 | x1, y1 = get_CDF(s1) 111 | x2, y2 = get_CDF(s2) 112 | return max_delta(x1, y1, x2, y2) 113 | 114 | 115 | def binary_metrics(sketchpad1, sketchpad2): 116 | metrics = {} 117 | 118 | ds1 = sketchpad1.get_sketchdata_by_name("DS_THETA") 119 | ds2 = sketchpad2.get_sketchdata_by_name("DS_THETA") 120 | 121 | lower, estimate, upper = datasketches.theta_jaccard_similarity.jaccard(ds1, ds2) 122 | metrics["theta_jaccard_lower_bound"] = lower 123 | metrics["theta_jaccard_upper_bound"] = upper 124 | metrics["theta_jaccard_estimate"] = estimate 125 | metrics["theta_exactly_equal"] = int( 126 | datasketches.theta_jaccard_similarity.exactly_equal(ds1, ds2) 127 | ) 128 | theta_1_not_2 = datasketches.theta_a_not_b().compute(ds1, ds2) 129 | metrics["theta_1_not_2"] = theta_1_not_2.get_estimate() 130 | theta_2_not_1 = datasketches.theta_a_not_b().compute(ds2, ds1) 131 | metrics["theta_2_not_1"] = theta_2_not_1.get_estimate() 132 | intersect = datasketches.theta_intersection() 133 | intersect.update(ds1) 134 | intersect.update(ds2) 135 | metrics["theta_intersection_estimate"] = intersect.get_result().get_estimate() 136 | 137 | # Share same frequent items 138 | ds1 = sketchpad1.get_sketchdata_by_name("DS_FI") 139 | ds2 = sketchpad2.get_sketchdata_by_name("DS_FI") 140 | 141 | fi1 = ds1.get_frequent_items( 142 | datasketches.frequent_items_error_type.NO_FALSE_POSITIVES 143 | ) 144 | fi2 = ds2.get_frequent_items( 145 | datasketches.frequent_items_error_type.NO_FALSE_POSITIVES 146 | ) 147 | fi1 = [x[0] for x in fi1] 148 | fi2 = [x[0] for x in fi2] 149 | metrics["fi_intersection"] = len(set(fi1).intersection(set(fi2))) 150 | metrics["fi_1_not_2"] = len(set(fi1).difference(set(fi2))) 151 | metrics["fi_2_not_1"] = len(set(fi2).difference(set(fi1))) 152 | 153 | # KS test 154 | ds1 = sketchpad1.get_sketchdata_by_name("DS_KLL") 155 | ds2 = sketchpad2.get_sketchdata_by_name("DS_KLL") 156 | 157 | metrics["ks_test_0.9"] = int(datasketches.ks_test(ds1, ds2, 0.9)) 158 | metrics["ks_test_0.5"] = int(datasketches.ks_test(ds1, ds2, 0.5)) 159 | metrics["ks_test_0.1"] = int(datasketches.ks_test(ds1, ds2, 0.1)) 160 | metrics["ks_test_0.01"] = int(datasketches.ks_test(ds1, ds2, 0.01)) 161 | metrics["ks_test_0.001"] = int(datasketches.ks_test(ds1, ds2, 0.001)) 162 | # if metrics["ks_test_0.5"]: 163 | # metrics["kll_ks_score"] = ks_estimate(ds1, ds2) 164 | # else: 165 | # metrics["kll_ks_score"] = 1.0 166 | return metrics 167 | -------------------------------------------------------------------------------- /sketch/pandas_extension.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import base64 3 | import importlib 4 | import inspect 5 | import json 6 | import logging 7 | import os 8 | import uuid 9 | 10 | import datasketches 11 | import numpy as np 12 | import pandas as pd 13 | import requests 14 | from IPython.display import HTML, display 15 | 16 | import lambdaprompt 17 | import sketch 18 | 19 | 20 | def retrieve_name(var): 21 | callers_local_vars = inspect.currentframe().f_back.f_back.f_back.f_locals.items() 22 | return [var_name for var_name, var_val in callers_local_vars if var_val is var] 23 | 24 | 25 | def strtobool(val): 26 | """Convert a string representation of truth to true (1) or false (0). 27 | True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values 28 | are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if 29 | 'val' is anything else. 30 | """ 31 | val = val.lower() 32 | if val in ("y", "yes", "t", "true", "on", "1"): 33 | return 1 34 | elif val in ("n", "no", "f", "false", "off", "0"): 35 | return 0 36 | else: 37 | raise ValueError("invalid truth value %r" % (val,)) 38 | 39 | 40 | def string_repr_truncated(val, size=100): 41 | result = str(val) 42 | if len(result) > size: 43 | result = result[: (size - 3)] + "..." 44 | return result 45 | 46 | 47 | def get_top_n(ds, n=5, size=100, reject_all_1=True): 48 | top_n = [ 49 | (count, string_repr_truncated(val, size=size)) 50 | for val, count, *_ in ds.get_frequent_items( 51 | datasketches.frequent_items_error_type.NO_FALSE_POSITIVES 52 | ) 53 | ][:n] 54 | top_n = [] if (reject_all_1 and all([c <= 1 for c, _ in top_n])) else top_n 55 | return {"counts": [c for c, _ in top_n], "values": [v for _, v in top_n]} 56 | 57 | 58 | def get_distribution(ds, n=5): 59 | if ds.is_empty(): 60 | return {} 61 | percents = np.linspace(0, 1, n) 62 | return {p: v for p, v in zip(percents, ds.get_quantiles(percents))} 63 | 64 | 65 | def get_description_of_sketchpad(sketchpad): 66 | description = {} 67 | for sk in sketchpad.sketches: 68 | if sk.name == "Rows": 69 | description["rows"] = sk.data 70 | elif sk.name == "Count": 71 | description["count"] = sk.data 72 | elif sk.name == "DS_THETA": 73 | description["uniqecount-est"] = sk.data.get_estimate() 74 | elif sk.name == "UnicodeMatches": 75 | description["unicode"] = sk.data 76 | elif sk.name == "DS_FI": 77 | description["top-n"] = get_top_n(sk.data) 78 | elif sk.name == "DS_KLL": 79 | description["quantiles"] = get_distribution(sk.data) 80 | return description 81 | 82 | 83 | def get_description_from_parts( 84 | column_names, data_types, extra_information, index_col_name=None 85 | ): 86 | descriptions = [] 87 | for colname, dtype, extra in zip(column_names, data_types, extra_information): 88 | description = { 89 | "column-name": colname, 90 | "type": dtype, 91 | "index": colname == index_col_name, 92 | } 93 | if not isinstance(extra, sketch.SketchPad): 94 | # try and load it as a sketchpad 95 | try: 96 | if "version" in extra: 97 | extra = sketch.SketchPad.from_dict(extra) 98 | except: 99 | pass 100 | if isinstance(extra, sketch.SketchPad): 101 | extra = get_description_of_sketchpad(extra) 102 | description.update(extra) 103 | descriptions.append(description) 104 | return descriptions 105 | 106 | 107 | def get_parts_from_df(df, useSketches=False): 108 | index_col_name = df.index.name 109 | df = df.reset_index() 110 | column_names = [str(x) for x in df.columns] 111 | data_types = [str(x) for x in df.dtypes] 112 | if useSketches: 113 | extras = list(sketch.Portfolio.from_dataframe(df).sketchpads.values()) 114 | # extras = [get_description_of_sketchpad(sketchpad) for sketchpad in sketchpads] 115 | else: 116 | extras = [] 117 | for col in df.columns: 118 | extra = { 119 | "rows": len(df[col]), 120 | "count": int(df[col].count()), 121 | "uniquecount": int(df[col].apply(str).nunique()), 122 | "head-sample": str( 123 | [string_repr_truncated(x) for x in df[col].head(5).tolist()] 124 | ), 125 | } 126 | # if column is numeric, get quantiles 127 | if df[col].dtype in [np.float64, np.int64]: 128 | extra["quantiles"] = str( 129 | df[col].quantile([0, 0.25, 0.5, 0.75, 1]).tolist() 130 | ) 131 | extras.append(extra) 132 | return column_names, data_types, extras, index_col_name 133 | 134 | 135 | def to_b64(data): 136 | return base64.b64encode(json.dumps(data).encode("utf-8")).decode("utf-8") 137 | 138 | 139 | def from_b64(data): 140 | return json.loads(base64.b64decode(data.encode("utf-8")).decode("utf-8")) 141 | 142 | 143 | def call_prompt_on_dataframe(df, prompt, **kwargs): 144 | names = retrieve_name(df) 145 | name = "df" if len(names) == 0 else names[0] 146 | column_names, data_types, extras, index_col_name = get_parts_from_df(df) 147 | max_columns = int(os.environ.get("SKETCH_MAX_COLUMNS", "20")) 148 | if len(column_names) > max_columns: 149 | raise ValueError( 150 | f"Too many columns ({len(column_names)}), max is {max_columns} in current version (set SKETCH_MAX_COLUMNS to override)" 151 | ) 152 | prompt_kwargs = dict( 153 | dfname=name, 154 | column_names=to_b64(column_names), 155 | data_types=to_b64(data_types), 156 | extras=to_b64(extras), 157 | index_col_name=index_col_name, 158 | **kwargs, 159 | ) 160 | # We now have all of our vars, let's decide if we use an external service or local prompt 161 | if strtobool(os.environ.get("SKETCH_USE_REMOTE_LAMBDAPROMPT", "True")): 162 | url = os.environ.get("SKETCH_ENDPOINT_URL", "https://prompts.approx.dev") 163 | try: 164 | response = requests.get( 165 | f"{url}/prompt/{prompt.name}", 166 | params=prompt_kwargs, 167 | ) 168 | response.raise_for_status() 169 | text_to_copy = response.json() 170 | except Exception as e: 171 | print( 172 | f"""Failed to use remote {url}.. {str(e)}. 173 | Consider setting SKETCH_USE_REMOTE_LAMBDAPROMPT=False 174 | and run with your own open-ai key 175 | """ 176 | ) 177 | text_to_copy = f"SKETCH ERROR - see print logs for full error" 178 | else: 179 | # using local version 180 | text_to_copy = prompt(**prompt_kwargs) 181 | return text_to_copy 182 | 183 | 184 | howto_prompt = lambdaprompt.Completion( 185 | """ 186 | For the pandas dataframe ({{ dfname }}) the user wants code to solve a problem. 187 | Summary statistics and descriptive data of dataframe [`{{ dfname }}`]: 188 | ``` 189 | {{ data_description }} 190 | ``` 191 | The dataframe is loaded and in memory, and currently named [ {{ dfname }} ]. 192 | 193 | Code to solve [ {{ how }} ]?: 194 | ```python 195 | {% if previous_answer is defined %} 196 | {{ previous_answer }} 197 | ``` 198 | {{ previous_error }} 199 | 200 | Fixing for error, and trying again... 201 | Code to solve [ {{ how }} ]?: 202 | ``` 203 | {% endif %} 204 | """, 205 | stop=["```"], 206 | # model_name="code-davinci-002", 207 | ) 208 | 209 | 210 | @lambdaprompt.prompt 211 | def howto_from_parts( 212 | dfname, column_names, data_types, extras, how, index_col_name=None 213 | ): 214 | column_names = from_b64(column_names) 215 | data_types = from_b64(data_types) 216 | extras = from_b64(extras) 217 | description = get_description_from_parts( 218 | column_names, data_types, extras, index_col_name 219 | ) 220 | description = pd.json_normalize(description).to_csv(index=False) 221 | code = howto_prompt(dfname=dfname, data_description=description, how=how) 222 | try: 223 | ast.parse(code) 224 | except SyntaxError as e: 225 | # if we get a syntax error, try again, but include the error message 226 | # only do 1 retry 227 | code = howto_prompt( 228 | dfname=dfname, 229 | data_description=description, 230 | how=how, 231 | previous_answer=code, 232 | previous_error=str(e), 233 | ) 234 | return code 235 | 236 | 237 | ask_prompt = lambdaprompt.Completion( 238 | """ 239 | For the pandas dataframe ({{ dfname }}) the user wants an answer to a question about the data. 240 | Summary statistics and descriptive data of dataframe [`{{ dfname }}`]: 241 | ``` 242 | {{ data_description }} 243 | ``` 244 | 245 | {{ question }} 246 | Answer: 247 | ``` 248 | """, 249 | stop=["```"], 250 | ) 251 | 252 | 253 | @lambdaprompt.prompt 254 | def ask_from_parts( 255 | dfname, column_names, data_types, extras, question, index_col_name=None 256 | ): 257 | column_names = from_b64(column_names) 258 | data_types = from_b64(data_types) 259 | extras = from_b64(extras) 260 | description = get_description_from_parts( 261 | column_names, data_types, extras, index_col_name 262 | ) 263 | description = pd.json_normalize(description).to_csv(index=False) 264 | return ask_prompt(dfname=dfname, data_description=description, question=question) 265 | 266 | 267 | def get_import_modules_from_codestring(code): 268 | """ 269 | Given a code string, return a list of import module 270 | 271 | eg `from sklearn import linear_model` would return `["sklearn"]` 272 | eg. `print(3)` would return `[]` 273 | eg. `import pandas as pd; import matplotlib.pyplot as plt` would return `["pandas", "matplotlib"]` 274 | """ 275 | # use ast to parse the code 276 | tree = ast.parse(code) 277 | # get all the import statements 278 | import_statements = [node for node in tree.body if isinstance(node, ast.Import)] 279 | # get all the import from statements 280 | import_from_statements = [ 281 | node for node in tree.body if isinstance(node, ast.ImportFrom) 282 | ] 283 | # get all the module names 284 | import_modules = [] 285 | for node in import_statements: 286 | for alias in node.names: 287 | import_modules.append(alias.name) 288 | import_modules += [node.module for node in import_from_statements] 289 | # only take parent module (eg. `matplotlib.pyplot` -> `matplotlib`) 290 | import_modules = [module.split(".")[0] for module in import_modules] 291 | return import_modules 292 | 293 | 294 | def validate_pycode_result(result): 295 | try: 296 | modules = get_import_modules_from_codestring(result) 297 | for module in modules: 298 | temp = importlib.util.find_spec(module) 299 | if temp is None: 300 | logging.warning( 301 | f"Module {module} not found, but part of suggestion. May need to pip install..." 302 | ) 303 | except SyntaxError: 304 | logging.warning("Syntax error in suggestion -- might not work directly") 305 | 306 | 307 | @pd.api.extensions.register_dataframe_accessor("sketch") 308 | class SketchHelper: 309 | def __init__(self, pandas_obj): 310 | self._obj = pandas_obj 311 | 312 | def howto(self, how, call_display=True): 313 | result = call_prompt_on_dataframe(self._obj, howto_from_parts, how=how) 314 | validate_pycode_result(result) 315 | if not call_display: 316 | return result 317 | # output text in a
, also on the side (on top) include a `copy` button that puts it onto clipboard 318 | uid = uuid.uuid4() 319 | b64_encoded_result = to_b64(result) 320 | display( 321 | HTML( 322 | f"""323 |""" 326 | ) 327 | ) 328 | 329 | def ask(self, question, call_display=True): 330 | result = call_prompt_on_dataframe(self._obj, ask_from_parts, question=question) 331 | if not call_display: 332 | return result 333 | display(HTML(f"""{result}""")) 334 | 335 | def apply(self, prompt_template_string, **kwargs): 336 | row_limit = int(os.environ.get("SKETCH_ROW_OVERRIDE_LIMIT", "10")) 337 | if len(self._obj) > row_limit: 338 | raise RuntimeError( 339 | f"Too many rows for apply \n (SKETCH_ROW_OVERRIDE_LIMIT: {row_limit}, Actual: {len(self._obj)})" 340 | ) 341 | new_gpt3_prompt = lambdaprompt.Completion(prompt_template_string) 342 | named_args = new_gpt3_prompt.get_named_args() 343 | known_args = set(self._obj.columns) | set(kwargs.keys()) 344 | needed_args = set(named_args) 345 | if needed_args - known_args: 346 | raise RuntimeError( 347 | f"Missing: {needed_args - known_args}\nKnown: {known_args}" 348 | ) 349 | 350 | def apply_func(row): 351 | row_dict = row.to_dict() 352 | row_dict.update(kwargs) 353 | return new_gpt3_prompt(**row_dict) 354 | 355 | return self._obj.apply(apply_func, axis=1) 356 | 357 | # # Async version 358 | 359 | # new_gpt3_prompt = lambdaprompt.AsyncGPT3Prompt(prompt_template_string) 360 | # named_args = new_gpt3_prompt.get_named_args() 361 | # known_args = set(self._obj.columns) | set(kwargs.keys()) 362 | # needed_args = set(named_args) 363 | # if needed_args - known_args: 364 | # raise RuntimeError( 365 | # f"Missing: {needed_args - known_args}\nKnown: {known_args}" 366 | # ) 367 | 368 | # ind, vals = [], [] 369 | # for i, row in self._obj.iterrows(): 370 | # ind.append(i) 371 | # row_dict = row.to_dict() 372 | # row_dict.update(kwargs) 373 | # vals.append(new_gpt3_prompt(**row_dict)) 374 | 375 | # # gather the results 376 | # vals = asyncio.run(asyncio.gather(*vals)) 377 | 378 | # return pd.Series(vals, index=ind) 379 | -------------------------------------------------------------------------------- /sketch/references.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import json 3 | import os 4 | 5 | try: 6 | from functools import cache 7 | except ImportError: 8 | from functools import lru_cache 9 | 10 | cache = lru_cache(maxsize=None) 11 | from typing import Dict 12 | 13 | 14 | def get_id_for_object(obj): 15 | serialized = json.dumps(obj, sort_keys=True) 16 | return hashlib.sha256(serialized.encode("utf-8")).hexdigest() 17 | 18 | 19 | class Reference: 20 | def __init__(self, **data): 21 | self.data = data 22 | self.id = get_id_for_object(self.data) 23 | self.type = self.__class__.__name__ 24 | 25 | def to_pyscript(self): 26 | raise NotImplementedError(f"{self.__class__}.to_usable_script") 27 | 28 | def to_searchable_string(self): 29 | raise NotImplementedError(f"{self.__class__}.to_searchable_string") 30 | 31 | def to_dict(self): 32 | return { 33 | "id": self.id, 34 | "type": self.type, 35 | "data": self.data, 36 | } 37 | 38 | def to_json(self): 39 | return json.dumps(self.to_dict()) 40 | 41 | @classmethod 42 | @property 43 | @cache 44 | def subclass_lookup(cls) -> Dict[str, "Reference"]: 45 | subclasses = {} 46 | for subclass in cls.__subclasses__(): 47 | subclasses[subclass.__name__] = subclass 48 | subclasses.update(subclass.subclass_lookup) 49 | return subclasses 50 | 51 | @classmethod 52 | def from_dict(cls, data): 53 | subclass = cls.subclass_lookup[data["type"]] 54 | new_obj = subclass(**data["data"]) 55 | assert new_obj.id == data["id"] 56 | return new_obj 57 | 58 | @classmethod 59 | def from_json(cls, json_str): 60 | data = json.loads(json_str) 61 | return cls.from_dict(data) 62 | 63 | @property 64 | def short_id(self): 65 | return int.from_bytes(bytes.fromhex(self.id[:16]), "big", signed=True) 66 | 67 | 68 | # TODO: make the subclasses of Reference have smarter args 69 | # possibly make them a dataclass 70 | 71 | # TODO: eventually consider a Sqlite Query reference (full tuple) 72 | # might replace this entire single column concept 73 | class SqliteColumn(Reference): 74 | def __init__(self, path, query, column, friendly_name=None): 75 | data = { 76 | "path": path, 77 | "column": column, 78 | "query": query, 79 | "friendly_name": friendly_name, 80 | } 81 | super().__init__(**data) 82 | 83 | def to_searchable_string(self): 84 | base = f"{self.data['query']} {self.data['column']}" 85 | base += f" {self.data['friendly_name']}" if self.data["friendly_name"] else "" 86 | base += f" {self.data['path']}" 87 | return base 88 | 89 | def to_pyscript(self): 90 | commands = ["import os", "import pandas as pd", "import sqlite3"] 91 | if self.data["path"].startswith("http"): 92 | # assuming this is a downloadable path 93 | commands.append( 94 | f"""os.system("wget -nc '{self.data['path']}' -P ~/.cache/sketch/")""" # noqa 95 | ) 96 | base = os.path.split(self.data["path"])[1] 97 | localpath = f"~/.cache/sketch/{base}" 98 | else: 99 | localpath = self.data["path"] 100 | commands.append(f"conn = sqlite3.connect('{localpath}')") 101 | commands.append(f"df = pd.read_sql_query('{self.data['query']}', conn)") 102 | commands.append(f"df = df['{self.data['column']}']") 103 | return "\n".join(commands) 104 | 105 | 106 | class PandasDataframeColumn(Reference): 107 | def __init__(self, column, dfname, **dfextra): 108 | super().__init__(dfname=dfname, column=column, **dfextra) 109 | 110 | def to_searchable_string(self): 111 | base = " ".join([self.data["dfname"], self.data["column"]]) 112 | base += " ".join([f"{k}={v}" for k, v in self.data.get("extra", {}).items()]) 113 | return base 114 | 115 | def to_pyscript(self): 116 | commands = [] 117 | commands.append(f'df = {self.data["dfname"]}') 118 | commands.append(f'df = df[["{self.data["column"]}"]]') 119 | return "\n".join(commands) 120 | 121 | 122 | class WikipediaTableColumn(Reference): 123 | def __init__(self, page, id, headers, column): 124 | super().__init__(page=page, id=id, headers=headers, column=column) 125 | 126 | def to_searchable_string(self): 127 | base = " ".join( 128 | [ 129 | self.data["page"], 130 | str(self.data["id"]), 131 | self.data["headers"], 132 | str(self.data["column"]), 133 | ] 134 | ) 135 | return base 136 | 137 | @property 138 | def url(self): 139 | return f"https://en.wikipedia.org/wiki/{self.data['page'].replace(' ', '_')}" 140 | 141 | def to_pyscript(self): 142 | commands = [] 143 | commands.append(f"import pandas as pd") 144 | commands.append(f'df = pd.read_html({self.data["page"]})[{self.data["id"]}]') 145 | commands.append(f'df = df[["{self.data["column"]}"]]') 146 | return "\n".join(commands) 147 | -------------------------------------------------------------------------------- /sketch/sketches.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | 4 | import datasketch 5 | import datasketches 6 | 7 | 8 | def active(func): 9 | def wrapper(self, *args, **kwargs): 10 | assert self.active, "Sketchpad is not active, cannot add a row" 11 | return func(self, *args, **kwargs) 12 | 13 | return wrapper 14 | 15 | 16 | class SketchBase: 17 | def __init__(self, data, active=False): 18 | self.name = self.__class__.__name__ 19 | self.data = data 20 | self.active = active 21 | 22 | @active 23 | def add_row(self, row): 24 | raise NotImplementedError(f"{self.__class__.__name__}.add_row") 25 | 26 | @classmethod 27 | def from_series(cls, series): 28 | result = cls(data=cls.empty_data(), active=True) 29 | for d in series: 30 | result.add_row(d) 31 | return result 32 | 33 | def pack(self): 34 | return self.data 35 | 36 | @classmethod 37 | def unpack(cls, data): 38 | return data 39 | 40 | def to_dict(self): 41 | return {"name": self.__class__.__name__, "data": self.pack()} 42 | 43 | @classmethod 44 | def empty_data(cls): 45 | raise NotImplementedError(f"{cls.__name__}.empty_data") 46 | 47 | @classmethod 48 | def from_dict(cls, data): 49 | tcls = cls 50 | if data["name"] != cls.__name__: 51 | for subclass in cls.all_sketches(): 52 | if subclass.__name__ == data["name"]: 53 | tcls = subclass 54 | return tcls(data=tcls.unpack(data["data"])) 55 | 56 | @classmethod 57 | def all_sketches(cls): 58 | subclasses = cls.__subclasses__() 59 | for subclass in list(subclasses): 60 | subclasses.extend(subclass.all_sketches()) 61 | # filter 62 | subclasses = [s for s in subclasses if s.__name__ != "DataSketchesSketchBase"] 63 | return subclasses 64 | 65 | @classmethod 66 | def empty(cls): 67 | return cls(data=cls.empty_data(), active=True) 68 | 69 | def freeze(self): 70 | self.active = False 71 | 72 | def merge(self, sketch): 73 | raise NotImplementedError(f"{self.__class__.__name__}.merge") 74 | 75 | 76 | class Rows(SketchBase): 77 | @active 78 | def add_row(self, row): 79 | self.data += 1 80 | 81 | @classmethod 82 | def from_series(cls, series): 83 | return cls(data=int(series.size)) 84 | 85 | @classmethod 86 | def empty_data(cls): 87 | return 0 88 | 89 | 90 | class Count(SketchBase): 91 | @active 92 | def add_row(self, row): 93 | self.data += 1 if row is not None else 0 94 | 95 | @classmethod 96 | def from_series(cls, series): 97 | return cls(data=int(series.count())) 98 | 99 | @classmethod 100 | def empty_data(cls): 101 | return 0 102 | 103 | 104 | class MinHash(SketchBase): 105 | @active 106 | def add_row(self, row): 107 | # TODO: ensure row is 'bytes' 108 | self.data.update(str(row).encode("utf-8")) 109 | 110 | @classmethod 111 | def from_series(cls, series): 112 | minhash = datasketch.MinHash() 113 | minhash.update_batch([str(x).encode("utf-8") for x in series]) 114 | lmh = datasketch.LeanMinHash(minhash) 115 | return cls(data=lmh) 116 | 117 | def pack(self): 118 | if self.active: 119 | raise RuntimeError("Cannot pack an active MinHash") 120 | buf = bytearray(self.data.bytesize()) 121 | self.data.serialize(buf) 122 | return base64.b64encode(buf).decode("utf-8") 123 | 124 | @classmethod 125 | def unpack(cls, data): 126 | return datasketch.LeanMinHash.deserialize(base64.b64decode(data)) 127 | 128 | @classmethod 129 | def empty_data(cls): 130 | return datasketch.MinHash() 131 | 132 | def freeze(self): 133 | self.data = datasketch.LeanMinHash(self.data) 134 | super().freeze() 135 | 136 | 137 | class HyperLogLog(SketchBase): 138 | @active 139 | def add_row(self, row): 140 | # TODO: ensure row is 'bytes' 141 | self.data.update(str(row).encode("utf-8")) 142 | 143 | @classmethod 144 | def from_series(cls, series): 145 | hllpp = datasketch.HyperLogLogPlusPlus() 146 | for d in series: 147 | hllpp.update(str(d).encode("utf-8")) 148 | return cls(data=hllpp) 149 | 150 | def pack(self): 151 | buf = bytearray(self.data.bytesize()) 152 | self.data.serialize(buf) 153 | return base64.b64encode(buf).decode("utf-8") 154 | 155 | @classmethod 156 | def unpack(cls, data): 157 | return datasketch.HyperLogLogPlusPlus.deserialize(base64.b64decode(data)) 158 | 159 | @classmethod 160 | def empty_data(cls): 161 | return datasketch.HyperLogLogPlusPlus() 162 | 163 | 164 | class DataSketchesSketchBase(SketchBase): 165 | sketch_class = None 166 | init_args = () 167 | 168 | @active 169 | def add_row(self, row): 170 | self.data.update(str(row).encode("utf-8")) 171 | 172 | def pack(self): 173 | return base64.b64encode(self.data.serialize()).decode("utf-8") 174 | 175 | @classmethod 176 | def unpack(cls, data): 177 | return cls.sketch_class.deserialize(base64.b64decode(data)) 178 | 179 | @classmethod 180 | def empty_data(cls): 181 | return cls.sketch_class(*cls.init_args) 182 | 183 | 184 | class DS_HLL(DataSketchesSketchBase): 185 | sketch_class = datasketches.hll_sketch 186 | init_args = (12, datasketches.tgt_hll_type.HLL_8) 187 | 188 | def pack(self): 189 | return base64.b64encode(self.data.serialize_compact()).decode("utf-8") 190 | 191 | 192 | class DS_CPC(DataSketchesSketchBase): 193 | sketch_class = datasketches.cpc_sketch 194 | init_args = (12,) 195 | 196 | 197 | class DS_FI(DataSketchesSketchBase): 198 | sketch_class = datasketches.frequent_strings_sketch 199 | init_args = (10,) 200 | 201 | 202 | class DS_KLL(DataSketchesSketchBase): 203 | sketch_class = datasketches.kll_floats_sketch 204 | init_args = (160,) 205 | 206 | @active 207 | def add_row(self, row): 208 | if isinstance(row, (int, float)): 209 | self.data.update(row) 210 | 211 | 212 | class DS_Quantiles(DataSketchesSketchBase): 213 | sketch_class = datasketches.quantiles_floats_sketch 214 | init_args = (128,) 215 | 216 | @active 217 | def add_row(self, row): 218 | if isinstance(row, (int, float)): 219 | self.data.update(row) 220 | 221 | 222 | class DS_REQ(DataSketchesSketchBase): 223 | sketch_class = datasketches.req_floats_sketch 224 | init_args = (12,) 225 | 226 | @active 227 | def add_row(self, row): 228 | if isinstance(row, (int, float)): 229 | self.data.update(row) 230 | 231 | 232 | class DS_THETA(DataSketchesSketchBase): 233 | sketch_class = datasketches.update_theta_sketch 234 | init_args = (12,) 235 | 236 | def pack(self): 237 | try: 238 | return base64.b64encode(self.data.compact().serialize()).decode("utf-8") 239 | except AttributeError: 240 | return base64.b64encode(self.data.serialize()).decode("utf-8") 241 | 242 | @classmethod 243 | def unpack(cls, data): 244 | return datasketches.compact_theta_sketch.deserialize(base64.b64decode(data)) 245 | 246 | 247 | class PyUnicodeStringsSerDe(datasketches.PyObjectSerDe): 248 | def get_size(self, item): 249 | return int(4 + len(item.encode("utf-8"))) 250 | 251 | def to_bytes(self, item: str): 252 | b = bytearray() 253 | b.extend(len(item.encode("utf-8")).to_bytes(4, "little")) 254 | b.extend(item.encode("utf-8")) 255 | return bytes(b) 256 | 257 | def from_bytes(self, data: bytes, offset: int): 258 | num_chars = int.from_bytes(data[offset : offset + 3], "little") 259 | if num_chars < 0 or num_chars > offset + len(data): 260 | raise IndexError( 261 | f"num_chars read must be non-negative and not larger than the buffer. Found {num_chars}" 262 | ) 263 | str = data[offset + 4 : offset + 4 + num_chars].decode("utf-8") 264 | return (str, 4 + num_chars) 265 | 266 | 267 | class DS_VO(DataSketchesSketchBase): 268 | sketch_class = datasketches.var_opt_sketch 269 | init_args = (50,) 270 | 271 | @active 272 | def add_row(self, row): 273 | self.data.update(str(row)) 274 | 275 | def pack(self): 276 | return base64.b64encode(self.data.serialize(PyUnicodeStringsSerDe())).decode( 277 | "utf-8" 278 | ) 279 | 280 | @classmethod 281 | def unpack(cls, data): 282 | return cls.sketch_class.deserialize( 283 | base64.b64decode(data), PyUnicodeStringsSerDe() 284 | ) 285 | 286 | 287 | class UnicodeMatches(SketchBase): 288 | unicode_ranges = { 289 | "emoticon": (0x1F600, 0x1F64F), 290 | "control": (0x00, 0x1F), 291 | "digits": (0x30, 0x39), 292 | "latin-lower": (0x41, 0x5A), 293 | "latin-upper": (0x61, 0x7A), 294 | "basic-latin": (0x00, 0x7F), 295 | "extended-latin": (0x0080, 0x02AF), 296 | "UNKNOWN": (0x00, 0x00), 297 | } 298 | 299 | @active 300 | def add_row(self, row): 301 | if isinstance(row, str): 302 | for c in row: 303 | found = False 304 | for name, (start, end) in self.unicode_ranges.items(): 305 | if start <= ord(c) <= end: 306 | self.data[name] += 1 307 | found = True 308 | if not found: 309 | self.data["UNKNOWN"] += 1 310 | 311 | def pack(self): 312 | return base64.b64encode(json.dumps(self.data).encode("utf-8")).decode("utf-8") 313 | 314 | @classmethod 315 | def unpack(cls, data): 316 | return json.loads(base64.b64decode(data)) 317 | 318 | @classmethod 319 | def empty_data(cls): 320 | return {name: 0 for name in cls.unicode_ranges} 321 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pytest 4 | 5 | 6 | @pytest.fixture 7 | def df(): 8 | return pd.DataFrame(np.random.randint(0, 100, size=(15, 4)), columns=list("ABCD")) 9 | -------------------------------------------------------------------------------- /tests/test_calculate_sketches.py: -------------------------------------------------------------------------------- 1 | import sketch 2 | 3 | 4 | def test_calculate_sketches(df): 5 | p = sketch.Portfolio.from_dataframe(df) 6 | assert len(p.sketchpads) == 4 7 | 8 | 9 | def test_calculate_sketchpad(df): 10 | s1 = df["A"] 11 | sp = sketch.SketchPad.from_series(s1) 12 | assert len(sp.sketches) > 2 13 | -------------------------------------------------------------------------------- /tests/test_metrics_from_sketchpads.py: -------------------------------------------------------------------------------- 1 | import sketch 2 | 3 | 4 | def test_calculate_unary_metrics(df): 5 | p = sketch.Portfolio.from_dataframe(df) 6 | for s in p.sketchpads.values(): 7 | metrics = s.get_metrics() 8 | 9 | 10 | def test_calculate_cross_metrics(df): 11 | p = sketch.Portfolio.from_dataframe(df) 12 | for s1 in p.sketchpads.values(): 13 | for s2 in p.sketchpads.values(): 14 | metrics = s1.get_cross_metrics(s2) 15 | -------------------------------------------------------------------------------- /tests/test_pandas_extension.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | import sketch # noqa 4 | 5 | 6 | class FakeResponse: 7 | def __init__(self, data): 8 | self.data = data 9 | 10 | def json(self): 11 | return self.data 12 | 13 | def raise_for_status(self): 14 | pass 15 | 16 | 17 | def test_sketch(mocker): 18 | mocker.patch("requests.get", return_value=FakeResponse("Hello World")) 19 | df = pd.DataFrame( 20 | { 21 | "a": [1, 2, 3], 22 | "b": ["a", "b", "c"], 23 | "c": [None, 4.1, 3], 24 | "d": ["010222", "010222", "010222"], 25 | "e": [[1, 2, 3], [3, 1], []], 26 | } 27 | ) 28 | result = df.sketch.ask("What is in column e?", call_display=False) 29 | --------------------------------------------------------------------------------{result}324 | 325 |