├── .gitignore ├── README.md ├── demo.html ├── flake.lock ├── flake.nix ├── hammock ├── __init__.py ├── anki.py ├── cache.py ├── calibre │ ├── core.py │ └── detect_core.py ├── cluster.py ├── color.py ├── core.py ├── embedding.py ├── gunicorn.py ├── plot.py ├── util.py └── web.py ├── poetry.lock ├── pyproject.toml ├── pyrightconfig.json ├── screenshot.png ├── stubs ├── InstructorEmbedding │ └── __init__.pyi ├── arxiv │ └── __init__.pyi ├── flask_compress │ └── __init__.pyi ├── gunicorn │ ├── app │ │ └── wsgiapp.pyi │ └── config.pyi ├── gutenbergpy │ ├── __init__.py │ ├── gutenbergcache.pyi │ └── textget.pyi ├── hdbscan │ └── __init__.pyi ├── networkx │ └── __init__.pyi ├── nltk │ ├── __init__.pyi │ └── tokenize.pyi ├── numba │ └── core │ │ └── errors.pyi ├── numpy │ ├── __init__.pyi │ ├── lib │ │ └── stride_tricks.pyi │ └── linalg.pyi ├── plotly │ ├── __init__.pyi │ ├── colors │ │ ├── __init__.pyi │ │ ├── qualitative.pyi │ │ └── sequential.pyi │ ├── graph_objects │ │ ├── __init__.pyi │ │ └── layout │ │ │ ├── __init__.pyi │ │ │ └── scene.pyi │ └── subplots.pyi ├── scipy │ ├── __init__.pyi │ └── spatial │ │ └── distance.pyi ├── sentence_transformers │ ├── __init__.pyi │ └── models.pyi ├── sklearn │ ├── __init__.pyi │ ├── _base.pyi │ ├── linear_model.pyi │ ├── metrics.pyi │ └── preprocessing.pyi ├── sklearn_extra │ └── cluster.pyi ├── transformers │ └── __init__.pyi ├── umap │ └── __init__.pyi └── wikipedia │ └── __init__.pyi ├── templates ├── books.html ├── index.html ├── main.html └── plotly.js └── tsconfig.json /.gitignore: -------------------------------------------------------------------------------- 1 | .direnv/ 2 | dist/ 3 | output/ 4 | result 5 | cache/ 6 | *.bak 7 | texts/ 8 | .envrc 9 | *__pycache__* 10 | gutenbergindex.db 11 | private-topics.txt 12 | .venv/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Visualize and compare embeddings for text sequences 2 | 3 | This is a small project for visualizing embeddings. 4 | 5 | ![A screenshot of the interactive output visualizing embeddings for Anki cards](screenshot.png) 6 | 7 | See the screenshot above or the interactive demo output [here](https://raw.githack.com/colehaus/hammock-public/main/demo.html). 8 | 9 | The overall flow is: 10 | 11 | 1. One or more text sequences is split into sentences or paragraphs. 12 | 2. Each resulting text fragment is embedded using the requested embedding model. 13 | 3. The embeddings are reduced to two or three dimensions (as requested) via [UMAP](https://umap-learn.readthedocs.io/en/latest/index.html). 14 | 4. The low dimensionality embeddings are optionally clustered with [hdbscan](https://hdbscan.readthedocs.io/en/latest/index.html). A single set of embeddings can be clustered at multiple granularities. 15 | 5. Clusters are optionally summarized using the requested language model. 16 | 17 | The resulting points and cluster info are plotted in an interactive 3D scatter plot. 18 | - Plotly provides a number of default handlers for interaction. 19 | - Left and right arrow keys step through text fragments in original text order (i.e. you could, in theory, read a book by stepping through its fragments and simultaneously see how the fragments relate to each other in embedding space). 20 | - Up and down arrow keys step through clustering granularities. (i.e. you can see a very high-level "table" of contents and then "zoom in" to more and more granular "tables" of contents) 21 | 22 | There's also some support and integration with a few existing text sources. A small web interface (`python -m hammock.gunicorn`) is provided for: 23 | 24 | - Visualizing arbitrary text submitted in a textarea 25 | - Fetching one or more books from Project Gutenberg by title and visualizing them 26 | - Fetching one or more articles from Wikipedia by title and visualizing them 27 | 28 | Modules can be accessed from the command line to: 29 | 30 | - Batch process epubs from [Calibre](https://calibre-ebook.com/) (`python -m hammock.calibre.core`) 31 | - Fetch cards from an [Anki](https://apps.ankiweb.net/) database and visualize them (`python -m hammock.anki -d -p `) 32 | 33 | If you have `nix`, you can simply do `nix run` from the project directory to set up the project launch the web server on localhost. `nix develop` will dump you into a shell with all the dependencies set up. 34 | -------------------------------------------------------------------------------- /flake.lock: -------------------------------------------------------------------------------- 1 | { 2 | "nodes": { 3 | "flake-utils": { 4 | "inputs": { 5 | "systems": "systems" 6 | }, 7 | "locked": { 8 | "lastModified": 1687709756, 9 | "narHash": "sha256-Y5wKlQSkgEK2weWdOu4J3riRd+kV/VCgHsqLNTTWQ/0=", 10 | "owner": "numtide", 11 | "repo": "flake-utils", 12 | "rev": "dbabf0ca0c0c4bce6ea5eaf65af5cb694d2082c7", 13 | "type": "github" 14 | }, 15 | "original": { 16 | "owner": "numtide", 17 | "repo": "flake-utils", 18 | "type": "github" 19 | } 20 | }, 21 | "flake-utils_2": { 22 | "inputs": { 23 | "systems": "systems_2" 24 | }, 25 | "locked": { 26 | "lastModified": 1687709756, 27 | "narHash": "sha256-Y5wKlQSkgEK2weWdOu4J3riRd+kV/VCgHsqLNTTWQ/0=", 28 | "owner": "numtide", 29 | "repo": "flake-utils", 30 | "rev": "dbabf0ca0c0c4bce6ea5eaf65af5cb694d2082c7", 31 | "type": "github" 32 | }, 33 | "original": { 34 | "owner": "numtide", 35 | "repo": "flake-utils", 36 | "type": "github" 37 | } 38 | }, 39 | "nixpkgs": { 40 | "locked": { 41 | "lastModified": 1688590700, 42 | "narHash": "sha256-ZF055rIUP89cVwiLpG5xkJzx00gEuuGFF60Bs/LM3wc=", 43 | "owner": "NixOS", 44 | "repo": "nixpkgs", 45 | "rev": "f292b4964cb71f9dfbbd30dc9f511d6165cd109b", 46 | "type": "github" 47 | }, 48 | "original": { 49 | "owner": "NixOS", 50 | "ref": "nixos-unstable", 51 | "repo": "nixpkgs", 52 | "type": "github" 53 | } 54 | }, 55 | "poetry2nix": { 56 | "inputs": { 57 | "flake-utils": "flake-utils_2", 58 | "nixpkgs": [ 59 | "nixpkgs" 60 | ] 61 | }, 62 | "locked": { 63 | "lastModified": 1688732421, 64 | "narHash": "sha256-fy5CYRNkwcjEBeh9oJpNtKHj1BzitMku87OG6LzdH7A=", 65 | "owner": "nix-community", 66 | "repo": "poetry2nix", 67 | "rev": "02e4a29cb4ec64f2f5e8989084b80951df2bbb64", 68 | "type": "github" 69 | }, 70 | "original": { 71 | "owner": "nix-community", 72 | "repo": "poetry2nix", 73 | "type": "github" 74 | } 75 | }, 76 | "root": { 77 | "inputs": { 78 | "flake-utils": "flake-utils", 79 | "nixpkgs": "nixpkgs", 80 | "poetry2nix": "poetry2nix" 81 | } 82 | }, 83 | "systems": { 84 | "locked": { 85 | "lastModified": 1681028828, 86 | "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", 87 | "owner": "nix-systems", 88 | "repo": "default", 89 | "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", 90 | "type": "github" 91 | }, 92 | "original": { 93 | "owner": "nix-systems", 94 | "repo": "default", 95 | "type": "github" 96 | } 97 | }, 98 | "systems_2": { 99 | "locked": { 100 | "lastModified": 1681028828, 101 | "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", 102 | "owner": "nix-systems", 103 | "repo": "default", 104 | "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", 105 | "type": "github" 106 | }, 107 | "original": { 108 | "owner": "nix-systems", 109 | "repo": "default", 110 | "type": "github" 111 | } 112 | } 113 | }, 114 | "root": "root", 115 | "version": 7 116 | } 117 | -------------------------------------------------------------------------------- /flake.nix: -------------------------------------------------------------------------------- 1 | { 2 | description = "Application packaged using poetry2nix"; 3 | 4 | inputs.flake-utils.url = "github:numtide/flake-utils"; 5 | inputs.nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable"; 6 | inputs.poetry2nix = { 7 | url = "github:nix-community/poetry2nix"; 8 | inputs.nixpkgs.follows = "nixpkgs"; 9 | }; 10 | 11 | outputs = { self, nixpkgs, flake-utils, poetry2nix }: 12 | flake-utils.lib.eachDefaultSystem (system: 13 | let 14 | inherit (poetry2nix.legacyPackages.${system}) 15 | mkPoetryApplication mkPoetryEnv overrides; 16 | pkgs = nixpkgs.legacyPackages.${system}; 17 | python = pkgs.python311; 18 | myOverrides = overrides.withDefaults (final: prev: 19 | # Missing setup tools dependency declarations 20 | (pkgs.lib.genAttrs [ 21 | "httpsproxy-urllib2" 22 | "instructorembedding" 23 | "wikipedia" 24 | ] (name: 25 | prev.${name}.overridePythonAttrs (old: { 26 | buildInputs = (old.buildInputs or [ ]) ++ [ prev.setuptools ]; 27 | }))) // 28 | # Miscellaneous build problems that are most easily fixed by using wheels 29 | (pkgs.lib.genAttrs [ 30 | "cmake" 31 | "ruff" 32 | "safetensors" 33 | "tokenizers" 34 | "pybind11" 35 | "scipy" 36 | "urllib3" 37 | ] (name: prev.${name}.override { preferWheel = true; }))); 38 | poetryAttrs = { 39 | projectDir = ./.; 40 | preferWheels = false; 41 | python = python; 42 | overrides = myOverrides; 43 | }; 44 | in rec { 45 | formatter = pkgs.nixfmt; 46 | defaultApp = mkPoetryApplication poetryAttrs; 47 | devShells.default = (mkPoetryEnv poetryAttrs).env.overrideAttrs 48 | (final: prev: { 49 | nativeBuildInputs = (prev.nativeBuildInputs or [ ]) ++ [ 50 | poetry2nix.packages.${system}.poetry 51 | pkgs.typescript 52 | pkgs.nodePackages.prettier 53 | ]; 54 | }); 55 | }); 56 | } 57 | -------------------------------------------------------------------------------- /hammock/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /hammock/anki.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import re 4 | import sqlite3 5 | from typing import Mapping, NamedTuple, NewType 6 | 7 | import html2text 8 | 9 | from hammock.plot import Source 10 | 11 | from .core import plot_single 12 | from .cluster import CCColorAndSummarize 13 | from .embedding import instructor_large 14 | 15 | 16 | def strip_html(text: str) -> str: 17 | h = html2text.HTML2Text() 18 | h.ignore_links = True 19 | h.ignore_images = True 20 | h.ignore_tables = True 21 | h.ignore_emphasis = True 22 | # Do it twice because some cards about HTML turn into valid HTML after the first pass! 23 | return h.handle(h.handle(text)).strip() 24 | 25 | 26 | class StripClozeResults(NamedTuple): 27 | text: str 28 | had_cloze_deletions: bool 29 | 30 | 31 | def strip_cloze_deletions(text: str) -> StripClozeResults: 32 | out = re.sub(r"{{c\d+::(.*?)(::.*?)?}}", r"\1", text, flags=re.DOTALL) 33 | return StripClozeResults(out, text != out) 34 | 35 | 36 | def compress_spaces(text: str) -> str: 37 | return re.sub(r"\s+", " ", text) 38 | 39 | 40 | NoteID = NewType("NoteID", int) 41 | 42 | 43 | def extract_text_from_anki(private_path: Path, db_path: Path) -> Mapping[NoteID, str]: 44 | with open(private_path) as f: 45 | private_topics = [line.strip() for line in f.readlines()] 46 | with sqlite3.connect(db_path) as conn: 47 | cursor = conn.cursor() 48 | cursor.execute("SELECT id, flds FROM notes") 49 | field_separator = "\x1f" 50 | return { 51 | NoteID(note_id): stripped.text 52 | for note_id, stripped in [ 53 | (row[0], strip_cloze_deletions(compress_spaces(strip_html(row[1].split(field_separator)[0])))) 54 | for row in cursor.fetchall() 55 | ] 56 | if stripped.had_cloze_deletions 57 | # and " 4 and all(sub not in stripped.text for sub in private_topics) 59 | } 60 | 61 | 62 | summary_prompt = ( 63 | "Please choose an academic subfield as topic to label the following cluster of snippets. " 64 | "The topic should cover as many of the snippets as reasonably possible. " 65 | "Make sure to look at ALL snippets and include them ALL in your analysis. " 66 | "A topic should typically be a noun phrase. " 67 | "Snippets begin here: " 68 | ) 69 | embedding_instruction = "Represent the Academic paragraph for clustering by topic: " 70 | 71 | if __name__ == "__main__": 72 | parser = argparse.ArgumentParser(description="Little helper for generating plots of Anki cards.") 73 | parser.add_argument("-d", "--db-path", type=Path, help="Path to Anki database.") 74 | parser.add_argument( 75 | "-p", 76 | "--private-path", 77 | type=Path, 78 | help=( 79 | "Path to file containing private topics to exclude from visualization. " 80 | "Format is one substring per line." 81 | ), 82 | ) 83 | args = parser.parse_args() 84 | notes = extract_text_from_anki(args.private_path, args.db_path) 85 | plot_single( 86 | instructor_large, 87 | embedding_instruction, 88 | dimensions=3, 89 | cluster_control=CCColorAndSummarize("google/flan-t5-large", summary_prompt, [60, 30]), 90 | source=Source("Anki Cards", list(notes.values())), 91 | include_labels="include_labels", 92 | ) 93 | -------------------------------------------------------------------------------- /hammock/cache.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from functools import wraps 4 | from hashlib import sha1 5 | from inspect import getfullargspec 6 | import json 7 | import pickle 8 | from pathlib import Path 9 | from typing import ( 10 | IO, 11 | TYPE_CHECKING, 12 | Callable, 13 | Any, 14 | Generic, 15 | Literal, 16 | NamedTuple, 17 | ParamSpec, 18 | TypeVar, 19 | TypedDict, 20 | cast, 21 | overload, 22 | ) 23 | 24 | import numpy as np 25 | 26 | from .util import clean_filename 27 | 28 | if TYPE_CHECKING: 29 | from _typeshed import SupportsWrite 30 | 31 | P = ParamSpec("P") 32 | T = TypeVar("T") 33 | 34 | 35 | def _stringify(x: Any) -> str: # pylint: disable=too-many-return-statements 36 | """Generate a cache key for any value we use. This is a little ad-hoc, but it works for now.""" 37 | digest_prefix_length = 7 38 | match x: 39 | case int(): 40 | return str(x) 41 | case float(): 42 | return clean_filename(f"{x:.2f}") 43 | case str(): 44 | return clean_filename(x) if len(x) < 30 else sha1(x.encode()).hexdigest()[:digest_prefix_length] 45 | case Path(): 46 | return x.as_posix() 47 | case dict(): 48 | return sha1("".join(str(k) + str(v) for k, v in cast(dict[Any, Any], x).items()).encode()).hexdigest()[ 49 | :digest_prefix_length 50 | ] 51 | case list(): 52 | return sha1("".join(str(y) for y in cast(list[Any], x)).encode()).hexdigest()[:digest_prefix_length] 53 | case np.ndarray(): 54 | return sha1(cast(bytes, x)).hexdigest()[:digest_prefix_length] 55 | # NamedTuple 56 | case _ if hasattr(x, "_asdict"): 57 | return clean_filename("-".join(f"{k}-{_stringify(v)}" for k, v in x._asdict().items())) 58 | case _: 59 | raise ValueError(f"Cannot stringify value of type {type(x)}: {x}") 60 | 61 | 62 | def _mk_file_name(func_name: str, *args: Any, **kwargs: Any) -> str: 63 | """Choose a file name for the cached result. The file name acts as a cache key and reflects 64 | the function name and arguments.""" 65 | args_fragment = [clean_filename("-".join(_stringify(a) for a in args))] if args else [] 66 | kwargs_fragment = ( 67 | [clean_filename("-".join(f"{_stringify(k)}-{_stringify(v)}" for k, v in kwargs.items()))] if kwargs else [] 68 | ) 69 | return f"{'-'.join([func_name] + args_fragment + kwargs_fragment)}" 70 | 71 | 72 | class JsonCache(TypedDict): 73 | format: Literal["json"] 74 | ext: Literal["json"] 75 | 76 | 77 | json_cache: JsonCache = {"format": "json", "ext": "json"} 78 | 79 | 80 | class PickleCache(TypedDict): 81 | format: Literal["pickle"] 82 | ext: Literal["pickle"] 83 | 84 | 85 | pickle_cache: PickleCache = {"format": "pickle", "ext": "pickle"} 86 | 87 | 88 | class BytesCache(TypedDict): 89 | format: Literal["bytes"] 90 | ext: str 91 | 92 | 93 | CacheType = JsonCache | PickleCache | BytesCache 94 | 95 | 96 | class Serde(NamedTuple, Generic[T]): 97 | load: Callable[[IO[bytes]], T] 98 | dump: Callable[[T, SupportsWrite[bytes | str]], None] 99 | # For use with `open` 100 | mode: Literal["b", ""] 101 | 102 | 103 | class NumpyEncoder(json.JSONEncoder): 104 | """Custom JSON encoder which also handles Numpy ints.""" 105 | 106 | def encode(self, o: Any): 107 | if isinstance(o, dict): 108 | return super().encode( 109 | { 110 | f"np.int64:{k}" if isinstance(k, np.int64) else k: v 111 | for k, v in o.items() # pyright: ignore[reportUnknownVariableType] 112 | } 113 | ) 114 | else: 115 | return super().encode(o) 116 | 117 | def iterencode(self, o: Any, _one_shot: bool = False): 118 | if isinstance(o, dict): 119 | return super().iterencode( 120 | { 121 | f"np.int64:{k}" if isinstance(k, np.int64) else k: v 122 | for k, v in o.items() # pyright: ignore[reportUnknownVariableType] 123 | }, 124 | _one_shot, 125 | ) 126 | else: 127 | return super().iterencode(o, _one_shot) 128 | 129 | 130 | class NumpyDecoder(json.JSONDecoder): 131 | def decode(self, s: Any, _w: Any = ...) -> Any: 132 | obj = super().decode(s) 133 | if isinstance(obj, dict): 134 | return { 135 | ( 136 | np.int64(k.split(":")[1]) if isinstance(k, str) and k.startswith("np.int64:") else cast(Any, k) 137 | ): v 138 | for k, v in obj.items() # pyright: ignore[reportUnknownVariableType] 139 | } 140 | return obj 141 | 142 | 143 | @overload 144 | def _format_str_to_serde(x: Literal["json"]) -> Serde[Any]: 145 | ... 146 | 147 | 148 | @overload 149 | def _format_str_to_serde(x: Literal["pickle"]) -> Serde[Any]: 150 | ... 151 | 152 | 153 | @overload 154 | def _format_str_to_serde(x: Literal["bytes"]) -> Serde[bytes]: 155 | ... 156 | 157 | 158 | def _format_str_to_serde(x: Literal["json", "pickle", "bytes"]) -> Serde[Any]: 159 | match x: 160 | case "json": 161 | return Serde( 162 | lambda fp: json.load(fp, cls=NumpyDecoder), 163 | lambda data, handle: json.dump(data, handle, indent=2, cls=NumpyEncoder), 164 | "", 165 | ) 166 | case "pickle": 167 | return Serde(pickle.load, pickle.dump, "b") 168 | case "bytes": 169 | 170 | def writer(data: bytes, handle: SupportsWrite[bytes | str]) -> None: 171 | """Just a little wrapper that returns `None`.""" 172 | handle.write(data) 173 | return None 174 | 175 | return Serde(lambda handle: handle.read(), writer, "b") 176 | 177 | 178 | class CacheResult(NamedTuple, Generic[T]): 179 | cache_result: T 180 | cache_path: Path 181 | 182 | 183 | def cache_with_path( 184 | cache_dir: Path, cache_type: CacheType = pickle_cache 185 | ) -> Callable[[Callable[P, T]], Callable[P, CacheResult[T]]]: 186 | """Decorator to cache the result of a function call to disk. Returns result of function and path to cache 187 | (which is computed using the function name and arguments as cache keys). 188 | Note that the `bytes` format should only be used with functions that return `T = bytes`. 189 | Unfortunately, encoding this requirement with `@overload` is at least very hairy and maybe impossible 190 | due to covariance and contravariance issues.""" 191 | 192 | def cache_decorator(func: Callable[P, T]) -> Callable[P, CacheResult[T]]: 193 | @wraps(func) 194 | def wrapper(*args: P.args, **kwargs: P.kwargs) -> CacheResult[T]: 195 | arg_names = getfullargspec(func).args 196 | file_name = _mk_file_name(func.__name__, **(kwargs | dict(zip(arg_names, args)))) 197 | file_path = cache_dir / f"{file_name}.{cache_type['ext']}" 198 | serde = _format_str_to_serde(cache_type["format"]) 199 | if file_path.exists(): 200 | print(f"Reading from cache at {file_path}") 201 | with open(file_path, f"r{serde.mode}") as f: 202 | result = cast(T, serde.load(f)) 203 | else: 204 | print(f"Writing to cache at {file_path}") 205 | result = func(*args, **kwargs) 206 | if cache_type["format"] == "bytes": 207 | assert type(result) == bytes 208 | with open(file_path, f"w{serde.mode}") as f: 209 | # pyright can't figure out that the type of `result` is matched to the type of `serde.dump` 210 | serde.dump(cast(Any, result), cast(Any, f)) 211 | return CacheResult(result, file_path) 212 | 213 | return wrapper 214 | 215 | return cache_decorator 216 | 217 | 218 | def cache(cache_dir: Path, cache_type: CacheType = pickle_cache) -> Callable[[Callable[P, T]], Callable[P, T]]: 219 | """Decorator to cache the result of a function call to disk. Returns result of function. 220 | Should be transparent to caller.""" 221 | 222 | def cache_decorator(func: Callable[P, T]) -> Callable[P, T]: 223 | @wraps(func) 224 | def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: 225 | return cache_with_path(cache_dir, cache_type)(func)(*args, **kwargs).cache_result 226 | 227 | return wrapper 228 | 229 | return cache_decorator 230 | -------------------------------------------------------------------------------- /hammock/calibre/core.py: -------------------------------------------------------------------------------- 1 | import curses 2 | import json 3 | import os 4 | from pathlib import Path 5 | import re 6 | import subprocess 7 | from typing import Callable, Literal, Mapping, Sequence, TypedDict, cast 8 | 9 | os.environ["OMP_NUM_THREADS"] = "4" 10 | 11 | from nltk.tokenize import sent_tokenize, word_tokenize 12 | 13 | from ..cache import cache, json_cache 14 | from ..cluster import CCColorAndSummarize 15 | from ..core import TextUnit, clip_string, para_tokenize, plot_single, tokenize 16 | from ..embedding import instructor_large 17 | from ..plot import Source 18 | from ..util import clean_filename 19 | from .detect_core import get_core 20 | 21 | 22 | class RawBook(TypedDict): 23 | title: str 24 | authors: str 25 | rating: int 26 | formats: list[str] 27 | 28 | 29 | class EpubBook(TypedDict): 30 | title: str 31 | authors: str 32 | rating: int 33 | path: Path 34 | 35 | 36 | def get_books_list(library_path: Path) -> Sequence[EpubBook]: 37 | """Fetch collection of epubs known by Calibre.""" 38 | result = subprocess.run( 39 | [ 40 | "calibredb", 41 | "list", 42 | "--library-path", 43 | library_path, 44 | "--for-machine", 45 | "--fields", 46 | "title,authors,rating,formats", 47 | ], 48 | capture_output=True, 49 | text=True, 50 | check=True, 51 | ) 52 | raw: list[RawBook] = json.loads(result.stdout) 53 | epubs = [book for book in raw if any(path.endswith("epub") for path in book["formats"])] 54 | return [ 55 | cast( 56 | EpubBook, {**book, "path": Path(next(iter(path for path in book["formats"] if path.endswith("epub"))))} 57 | ) 58 | for book in epubs 59 | ] 60 | 61 | 62 | error_list: list[Path] = [] 63 | 64 | 65 | def convert_book_to_text(book_path: Path, output_dir: Path): 66 | convert_destination = output_dir / f"{clean_filename(book_path.stem)}.txt" 67 | if not convert_destination.exists(): 68 | res = subprocess.run(["ebook-convert", book_path.as_posix(), convert_destination.as_posix()], check=False) 69 | if res.returncode != 0: 70 | error_list.append(book_path) 71 | 72 | 73 | def convert_epubs(): 74 | books = get_books_list(Path(os.path.expanduser("~/calibre-library"))) 75 | for book in books: 76 | convert_book_to_text(book["path"], Path("cache/calibre-conversions")) 77 | for error in error_list: 78 | print(error) 79 | 80 | 81 | # A = TypeVar("A") 82 | def summary_prompt(book_descr: str, embedding_area: str) -> str: 83 | return ( 84 | "These are clustered paragraphs from a book. " 85 | # f"These are clustered paragraphs from {book_descr}. " 86 | "Please provide a topic summarizing and describing the cluster. " 87 | # "The topic should cover as many of the paragraphs as reasonably possible. " 88 | # "Make sure to look at ALL paragraphs and include them ALL in your analysis. " 89 | # f"Your response should NOT just be {embedding_area}. " 90 | # "A topic should typically be a noun phrase. " 91 | "Paragraphs begin here: " 92 | ) 93 | 94 | 95 | def handle_book( 96 | book_path: Path, start_anchor: str, end_anchor: str, title: str, embedding_area: str, summary_book_descr: str 97 | ): 98 | """Produce visualization for given book.""" 99 | with open(book_path, encoding="utf-8") as f: 100 | text = f.read() 101 | core_text = clip_string(text, start_anchor=start_anchor, end_anchor=end_anchor) 102 | assert core_text is not None 103 | paras = tokenize( 104 | TextUnit("paragraph", para_newline_type="double"), 105 | always_bad_sentence_predicates=[ 106 | lambda sent: sent.startswith("Source"), 107 | ], 108 | standalone_bad_sentence_predicates=[ 109 | lambda sent: bool(re.search(r"^\d+\.", sent)), # Corresponds to questions usually 110 | lambda sent: len(sent.split()) < 4, 111 | lambda sent: bool(re.search(r" p. ", sent)), 112 | lambda sent: bool(re.search(r" pp. ", sent)), 113 | ], 114 | bad_para_predicates=[ 115 | lambda para: all(s.endswith("?") for s in sent_tokenize(para)), 116 | lambda para: para.lower().startswith("chapter") or para.lower().startswith("section"), 117 | lambda para: para.lower().startswith("table") or para.lower().startswith("figure"), 118 | lambda para: para.lower().startswith("exercise"), 119 | lambda para: len(word_tokenize(para)) < 12, 120 | lambda para: len(word_tokenize(para)) < 24 121 | and (bool(re.search(r" p. ", para)) or bool(re.search(r" pp. ", para))), 122 | ], 123 | text=core_text, 124 | ) 125 | return plot_single( 126 | instructor_large, 127 | f"Represent the {embedding_area} paragraph for clustering: ", 128 | dimensions=3, 129 | # cluster_control=CCColor(min_cluster_sizes=[20, 5]), 130 | cluster_control=CCColorAndSummarize( 131 | "google/flan-t5-large", summary_prompt(summary_book_descr, embedding_area), [20, 12, 8, 5] 132 | ), 133 | source=Source(title, paras), 134 | include_labels="include_labels", 135 | ) 136 | 137 | 138 | BookArgs = TypedDict( 139 | "BookArgs", 140 | { 141 | "start_anchor": str, 142 | "end_anchor": str, 143 | "title": str, 144 | "embedding_area": str, 145 | "summary_book_descr": str, 146 | }, 147 | ) 148 | 149 | 150 | def choose_snippet(snippets: list[str]) -> str: 151 | """Interactively choose text as start or end anchor demarcating pre- and postamble from main content.""" 152 | selected_index = 0 153 | 154 | stdscr = curses.initscr() 155 | curses.curs_set(0) 156 | stdscr.nodelay(True) 157 | stdscr.timeout(100) 158 | 159 | win = curses.newwin(curses.LINES, curses.COLS, 0, 0) 160 | 161 | try: 162 | while True: 163 | win.erase() 164 | 165 | win.addstr(0, 0, "Choose a snippet with j/J/k/K, press q to finalize") 166 | 167 | # Display snippets with the selected snippet highlighted 168 | context_length = 25 169 | start_index = max(0, selected_index - context_length) 170 | for i, s in enumerate( 171 | snippets[ 172 | start_index : min(len(snippets), max(selected_index + context_length, 2 * context_length)) 173 | ], 174 | start_index, 175 | ): 176 | if i == selected_index: 177 | win.attron(curses.A_REVERSE) 178 | win.addstr(i - start_index + 1, 0, s) 179 | win.attroff(curses.A_REVERSE) 180 | 181 | win.refresh() 182 | 183 | key = stdscr.getch() 184 | 185 | if key == ord("q"): 186 | curses.endwin() 187 | return snippets[selected_index] 188 | elif key == ord("k") and selected_index > 0: 189 | selected_index -= 1 190 | elif key == ord("K") and selected_index > 50: 191 | selected_index -= 50 192 | elif key == ord("j") and selected_index < len(snippets) - 1: 193 | selected_index += 1 194 | elif key == ord("J") and selected_index < len(snippets) - 50: 195 | selected_index += 50 196 | finally: 197 | curses.endwin() 198 | 199 | 200 | book_args_dir = Path("cache/book-args") 201 | 202 | 203 | @cache(book_args_dir, cache_type=json_cache) 204 | def collect_book_args(anchor_method: Literal["manual", "automatic"], book_path: Path, title: str) -> BookArgs: 205 | match anchor_method: 206 | case "automatic": 207 | start_anchor, end_anchor = get_core(book_path) 208 | pruned_title = re.sub(r"\([^)]*\)", "", title).split(":")[0] 209 | embedding_area = pruned_title 210 | summary_book_descr = "a book titled " + pruned_title 211 | case "manual": 212 | with open(book_path, encoding="utf-8") as f: 213 | text = f.read() 214 | start_anchor = choose_snippet(para_tokenize("double", text)) 215 | end_anchor = choose_snippet(list(reversed(para_tokenize("double", text)))) 216 | print("Start: ", repr(start_anchor)) 217 | print("End: ", repr(end_anchor)) 218 | print("Title: ", title) 219 | embedding_area = input("Embedding area (e.g. 'Anthropology'): ") 220 | summary_book_descr = input("Summary book description (e.g. 'an anthropology textbook'): ") 221 | 222 | return BookArgs( 223 | start_anchor=start_anchor, 224 | end_anchor=end_anchor, 225 | title=title, 226 | embedding_area=embedding_area, 227 | summary_book_descr=summary_book_descr, 228 | ) 229 | 230 | 231 | completed: Mapping[str, str] = {} 232 | 233 | 234 | def run_epub(anchor_method: Literal["manual", "automatic"], book_path: Path, title: str): 235 | book_args = collect_book_args(anchor_method, book_path, title) 236 | completed[title] = handle_book(book_path, **book_args).as_posix() 237 | with open("output/calibre.json", "w") as f: 238 | json.dump(completed, f, indent=2) 239 | 240 | 241 | def paths_to_titles() -> Callable[[Path], str]: 242 | """Convert a path to a title by looking up the path in the calibre library. 243 | Return a `Callable` so that we keep the book list in memory instead of 244 | shelling out and hitting the disk each time.""" 245 | books = get_books_list(Path(os.path.expanduser("~/calibre-library"))) 246 | return lambda x: next(book["title"] for book in books if clean_filename(book["path"].stem) == x.stem) 247 | 248 | 249 | # Batch process books from Calibre library 250 | if __name__ == "__main__": 251 | ptt = paths_to_titles() 252 | for book_path in Path("cache/calibre-conversions").glob("*Very-Short*.txt"): 253 | run_epub("automatic", book_path, ptt(book_path)) 254 | -------------------------------------------------------------------------------- /hammock/calibre/detect_core.py: -------------------------------------------------------------------------------- 1 | from itertools import combinations 2 | from pathlib import Path 3 | from typing import Callable, Generic, Literal, NamedTuple, Sequence, TypeVar 4 | 5 | from nltk.tokenize import word_tokenize 6 | import numpy as np 7 | from numpy.lib.stride_tricks import sliding_window_view 8 | from numpy import ndarray 9 | import plotly.graph_objects as go 10 | import plotly.subplots as sp 11 | from sklearn.linear_model import LinearRegression 12 | from sklearn.metrics import mean_squared_error 13 | 14 | from ..core import para_tokenize 15 | 16 | 17 | def get_core(book_path: Path) -> tuple[str, str]: 18 | """Books generally have a preamble and a postamble that we're not really interested in. 19 | We want to automatically prune that content. 20 | We do that heuristically by looking at the number of words per paragraph since 21 | the preamble and postamble tend to have shorter paragraphs. 22 | We pick a breakpoint in the distribution of paragraph lengths and 23 | then find the longest contiguous section of paragraphs that are above that breakpoint.""" 24 | with open(book_path, encoding="utf-8") as f: 25 | text = f.read() 26 | paras = para_tokenize("double", text) 27 | word_counts = np.array([len(word_tokenize(para)) for para in paras]) 28 | percentiles = np.linspace(0, 100, 101) 29 | percentile_values = np.percentile(word_counts, percentiles) 30 | # Sometimes the very ends of the percentile graphs get erroneously registered as break points. 31 | # So we clip them. 32 | percentile = find_breakpoint(percentiles[2:-2], percentile_values[2:-2]) 33 | percentile_margin = 1 34 | threshold = np.percentile(word_counts, percentile + percentile_margin) 35 | start, end = longest_contiguous_above_threshold( 36 | word_counts, threshold, supermajority=0.4, window_size=20, end_margin=0.05 37 | ) 38 | return paras[start], paras[end] 39 | # print(book_path.stem, start, end) 40 | # plot_diagnostic( 41 | # word_counts, 42 | # paras, 43 | # AnalysisResults( 44 | # percentiles=percentiles, 45 | # percentile_values=percentile_values, 46 | # chosen_percentile=percentile + percentile_margin, 47 | # start=start, 48 | # end=end, 49 | # threshold=threshold, 50 | # ), 51 | # Path(f"tmp-out/{book_path.stem}.html"), 52 | # ) 53 | 54 | 55 | NPercentiles = TypeVar("NPercentiles") 56 | 57 | 58 | class AnalysisResults(NamedTuple, Generic[NPercentiles]): 59 | percentiles: ndarray[float, NPercentiles] 60 | percentile_values: ndarray[float, NPercentiles] 61 | chosen_percentile: float 62 | start: int 63 | end: int 64 | threshold: float 65 | 66 | 67 | def plot_diagnostic( 68 | word_counts: Sequence[int], paras: Sequence[str], results: AnalysisResults[NPercentiles], out_path: Path 69 | ): 70 | """Visualizes multiple pieces of data to diagnose our heuristics. 71 | - Plot showing the word counts by paragraph, in order 72 | - Plot showing the percentile values by percentile 73 | - Vertical line on percentile plot showing the percentile chosen by breakpoint detection 74 | - Horizontal lines on word plot showing the correspoding inferred threshold of words per paragraph 75 | - Vertical lines on word plot showing the start and end of the longest contiguous region above the threshold 76 | """ 77 | fig = sp.make_subplots(rows=1, cols=2, subplot_titles=["Word counts", "Percentiles"]) 78 | fig.add_trace( 79 | go.Scatter(x=list(range(len(word_counts))), y=word_counts, text=paras, mode="lines"), 80 | row=1, 81 | col=1, 82 | ) 83 | fig.add_trace(go.Scatter(x=results.percentiles, y=results.percentile_values, mode="markers"), row=1, col=2) 84 | fig.update_layout( 85 | shapes=[ 86 | go.layout.Shape( 87 | type="line", 88 | x0=results.chosen_percentile, 89 | x1=results.chosen_percentile, 90 | y0=0, 91 | y1=max(word_counts), 92 | xref="x2", 93 | yref="y2", 94 | ), 95 | go.layout.Shape( 96 | type="line", 97 | x0=results.start, 98 | x1=results.start, 99 | y0=0, 100 | y1=word_counts[results.start], 101 | xref="x1", 102 | yref="y1", 103 | line=dict(color="red"), 104 | ), 105 | go.layout.Shape( 106 | type="line", 107 | x0=results.end, 108 | x1=results.end, 109 | y0=0, 110 | y1=word_counts[results.end], 111 | xref="x1", 112 | yref="y1", 113 | line=dict(color="red"), 114 | ), 115 | go.layout.Shape( 116 | type="line", 117 | x0=0, 118 | x1=len(word_counts), 119 | y0=results.threshold, 120 | y1=results.threshold, 121 | xref="x1", 122 | yref="y1", 123 | ), 124 | ] 125 | ) 126 | fig.write_html(file=out_path) 127 | 128 | 129 | A = TypeVar("A") 130 | X = TypeVar("X") 131 | WinSize = TypeVar("WinSize", bound=int) 132 | 133 | 134 | def in_window_satisfying_predicate( 135 | values: ndarray[A, X], window_size: WinSize, predicate: Callable[[ndarray[A, X, WinSize]], ndarray[bool, X]] 136 | ) -> ndarray[bool, X]: 137 | """Returns a boolean array indicating whether each element is in any rolling window 138 | where that window satisfies the predicate.""" 139 | 140 | windows = sliding_window_view(values, window_size) 141 | bools = predicate(windows) 142 | result = np.zeros(len(values), dtype=bool) 143 | for i in range(window_size): 144 | result[i : -(window_size - i - 1) if window_size - i - 1 > 0 else None] |= bools 145 | return result 146 | 147 | 148 | NParas = TypeVar("NParas") 149 | 150 | 151 | def find_breakpoint(x: ndarray[float, NParas], y: ndarray[float, NParas]) -> float: 152 | """Fit two linear regressions to the data, one for each side of the breakpoint. 153 | Find the breakpoint that minimizes the sum of the MSEs of the two regressions.""" 154 | 155 | def mse_for_breakpoint(i: int) -> float: 156 | model1 = LinearRegression[Literal[1]]().fit(x[:i].reshape((-1, 1)), y[:i]) 157 | model2 = LinearRegression[Literal[1]]().fit(x[i:].reshape((-1, 1)), y[i:]) 158 | mse1 = mean_squared_error(y[:i], model1.predict(x[:i].reshape((-1, 1)))) 159 | mse2 = mean_squared_error(y[i:], model2.predict(x[i:].reshape((-1, 1)))) 160 | return mse1 + mse2 161 | 162 | return x[np.argmin([mse_for_breakpoint(i) for i in range(1, len(x) - 1)]) + 1] 163 | 164 | 165 | def segment_average_from_cumulative(segment: tuple[int, int], cumulative_counts: ndarray[int, NParas]) -> float: 166 | """Faster way to compute the average in a segment.""" 167 | return (cumulative_counts[segment[1]] - cumulative_counts[segment[0] - 1]) / (segment[1] - segment[0] + 1) 168 | 169 | 170 | def longest_contiguous_above_threshold( 171 | word_counts: ndarray[int, NParas], 172 | threshold: float, 173 | supermajority: float, 174 | window_size: int, 175 | end_margin: float, 176 | ) -> tuple[int, int]: 177 | """Roughly, Find the longest contiguous segment of paragraphs where the word count is above the threshold. 178 | Parameters 179 | ---------- 180 | counts : Sequence[int] 181 | The number of words in each paragraph. 182 | threshold : float 183 | The value tending to distinguish pre- and postamble from the main body. Unit is words per paragraph. 184 | supermajority : float 185 | The fraction of paragraphs in the whole segment that must be above the threshold. 186 | We can't require /all/ paragraphs to be above the threshold because 187 | there will occasionally be short paragraphs. But we want a large fraction above the threshold. 188 | window_size : int 189 | The size of the window used to smooth the evaluation. 190 | We pick candidate segment endpoints based on whether their windows satisfy certain criteria. 191 | end_margin : float 192 | Our candidate segments might end up containing clusters of low word count paragraphs 193 | near the beginning or end of the overall content. If these clusters are too close to the ends, 194 | we reject the candidate segment on the theory that the low clusters are pre- or postamble 195 | and our candidate endpoints are actually anomalous pre- or postamble content. 196 | """ 197 | in_average_above_threshold_window = in_window_satisfying_predicate( 198 | word_counts, window_size, lambda windows: np.mean(windows, axis=1) >= threshold 199 | ) 200 | in_supermajority_above_threshold_window = in_window_satisfying_predicate( 201 | word_counts, 202 | window_size, 203 | lambda windows: np.mean(windows >= threshold, axis=1) >= supermajority, 204 | ) 205 | # Paragraphs that are candidate endpoints of a contiguous segment satisfy three criteria: 206 | # 1. The average word count of their window is above the threshold. 207 | # 2. The supermajority of their window is above the threshold. 208 | # 3. Their word count is above the threshold. 209 | all_segments = combinations( 210 | [ 211 | i 212 | for i in np.nonzero(in_average_above_threshold_window & in_supermajority_above_threshold_window)[0] 213 | if word_counts[i] >= threshold 214 | ], 215 | 2, 216 | ) 217 | 218 | def has_low_cluster_near_end(segment: tuple[int, int]) -> bool: 219 | """If the candidate segment has a cluster of short paragraphs near the beginning or end 220 | of the overall content, we reject it. 221 | This suggests that the candidate segment is actually including pre- or postamble and 222 | the longer paragraphs that are our segment endpoints are just anomolous pre- or postamble content.""" 223 | if segment[1] - segment[0] < window_size: 224 | return True 225 | else: 226 | low_clusters = np.nonzero( 227 | in_window_satisfying_predicate( 228 | word_counts[segment[0] : segment[1]], 229 | window_size, 230 | lambda windows: ~(np.mean(windows >= threshold, axis=1) >= supermajority), 231 | ) 232 | )[0] 233 | if low_clusters.size == 0: 234 | return False 235 | else: 236 | earliest_low = low_clusters[0] + segment[0] 237 | latest_low = low_clusters[-1] + segment[0] 238 | return earliest_low / len(word_counts) < end_margin or latest_low / len(word_counts) > ( 239 | 1 - end_margin 240 | ) 241 | 242 | cumulative_counts = np.cumsum(word_counts) 243 | cumulative_threshold_counts = np.cumsum(np.where(word_counts >= threshold, 1, 0)) 244 | valid_segments = ( 245 | segment 246 | for segment in all_segments 247 | if segment_average_from_cumulative(segment, cumulative_counts) >= threshold and 248 | # Ensure that most values are above the threshold 249 | segment_average_from_cumulative(segment, cumulative_threshold_counts) >= supermajority 250 | ) 251 | 252 | longest_segments = reversed(sorted(valid_segments, key=lambda segment: segment[1] - segment[0])) 253 | for s in longest_segments: 254 | # Special handling for this predicate because it's relatively slow 255 | if not has_low_cluster_near_end(s): 256 | return s 257 | raise RuntimeError("No valid segment found") 258 | 259 | 260 | # Doesn't really work. At least with 'google/flan-t5-base' or 'google/flan-t5-large'. 261 | # def identify_preamble(model_name: str): 262 | # batch_size = 1 263 | # prompt = ( 264 | # "Is the focused paragraph in the input the first paragraph that's not part of a book's preamble? " 265 | # "i.e. not the table of contents, preface, copyright, etc. " 266 | # "Just answer with a percentage confidence that it is like 0.2, 0.4, 0.6, or 0.8 " 267 | # "The paragraphs follow: " 268 | # ) 269 | -------------------------------------------------------------------------------- /hammock/cluster.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections import defaultdict 4 | import functools 5 | from pathlib import Path 6 | from typing import ( 7 | TYPE_CHECKING, 8 | Any, 9 | Callable, 10 | Generic, 11 | Literal, 12 | Mapping, 13 | NamedTuple, 14 | NewType, 15 | Sequence, 16 | TypeAlias, 17 | TypeVar, 18 | cast, 19 | ) 20 | 21 | from hdbscan import HDBSCAN 22 | from sklearn_extra.cluster import KMedoids 23 | import numpy as np 24 | from numpy import ndarray 25 | from plotly import colors 26 | import torch 27 | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, BatchEncoding 28 | 29 | from .cache import cache, json_cache 30 | from .color import cluster_colors_by_graph 31 | from .util import flatmap, time_segment 32 | 33 | if TYPE_CHECKING: 34 | from plotly.colors import RGBStr, Tuple1, Tuple255 35 | 36 | profile = False 37 | 38 | A = TypeVar("A") 39 | 40 | ClusterID = NewType("ClusterID", np.int64) 41 | 42 | NPoints = TypeVar("NPoints") 43 | SmallD = TypeVar("SmallD", bound=Literal[2, 3]) 44 | NClusterPoints = TypeVar("NClusterPoints") 45 | 46 | 47 | class ClusterResult(NamedTuple, Generic[A, NPoints, NClusterPoints, SmallD]): 48 | cluster_by_point: ndarray[ClusterID, NPoints] 49 | # `NClusterPoints` is slightly sketchy because it's not the case that 50 | # each cluster has the same number of points. But this is at least better than `Any`. 51 | points_by_cluster: Mapping[ClusterID, tuple[list[A], ndarray[float, NClusterPoints, SmallD]]] 52 | 53 | 54 | def _cluster( 55 | min_cluster_size: int, items: list[A], embeddings: ndarray[float, NPoints, SmallD] 56 | ) -> ClusterResult[A, NPoints, NClusterPoints, SmallD]: 57 | """Each item `A` in `items` should correspond to a row in `embeddings`. 58 | Returns a mapping from cluster labels to a tuple of the items in that cluster and their embeddings.""" 59 | 60 | cluster_ids = cast( 61 | ndarray[ClusterID, NPoints], 62 | (HDBSCAN(min_cluster_size=min_cluster_size, min_samples=1).fit_predict(embeddings)), 63 | ) 64 | items_by_cluster: dict[ClusterID, list[A]] = defaultdict(list) 65 | for label, item in zip(cluster_ids, items, strict=True): 66 | items_by_cluster[label].append(item) 67 | return ClusterResult( 68 | cluster_ids, 69 | { 70 | label: ( 71 | items_for_cluster, 72 | embeddings[cluster_ids == label, :], 73 | ) 74 | for label, items_for_cluster in items_by_cluster.items() 75 | }, 76 | ) 77 | 78 | 79 | NMedoids = TypeVar("NMedoids", bound=int) 80 | 81 | 82 | def _search_for_threshold( 83 | first_guess: int, 84 | lower_bound: int, 85 | upper_bound: int, 86 | fn: Callable[[int], A], 87 | pred: Callable[[A], bool], 88 | ) -> A: 89 | """Binary search for the threshold where `pred(fn(i))` is false and `pred(fn(i+1))` is true. 90 | If `pred` can't be satisfied, returns largest value seen.""" 91 | current_guess = first_guess 92 | smallest_saturater: tuple[int, A] | None = None 93 | largest_non_saturater: A | None = None 94 | while lower_bound <= upper_bound: 95 | a = fn(current_guess) 96 | if pred(a): 97 | # `lower_bound - 1` must not have saturated so the current value is a solution 98 | if lower_bound == current_guess: 99 | return a 100 | else: 101 | smallest_saturater = current_guess, a 102 | upper_bound = current_guess - 1 103 | current_guess = (lower_bound + upper_bound) // 2 104 | else: 105 | if smallest_saturater is not None and smallest_saturater[0] == current_guess + 1: 106 | return smallest_saturater[1] 107 | else: 108 | largest_non_saturater = a 109 | lower_bound = current_guess + 1 110 | current_guess = (lower_bound + upper_bound) // 2 111 | match largest_non_saturater: 112 | case None: 113 | raise RuntimeError("Should be impossible. In `_search_for_threshold`") 114 | case a: 115 | return a 116 | 117 | 118 | def choose_reps_within_length( 119 | model_name: str, 120 | prompt: str, 121 | max_length: int, 122 | texts: Sequence[str], 123 | embeddings: ndarray[float, NClusterPoints, SmallD], 124 | ) -> tuple[Sequence[str], ndarray[float, NMedoids, SmallD]]: 125 | """The summarization model has a maximum or effective maximum input length. 126 | We want to choose good representatives from the cluster that total up to that length. 127 | We do that via k-medoids clustering. 128 | (k-medoids is like k-means, but the cluster centers are actual data points which is essential here.)""" 129 | 130 | def k_medoids(num_reps: int) -> KMedoids[NMedoids, SmallD]: 131 | km: KMedoids[NMedoids, SmallD] = KMedoids( 132 | n_clusters=cast(NMedoids, min(len(texts), num_reps)), method="pam", init="k-medoids++" 133 | ) 134 | return km.fit(embeddings) 135 | 136 | def len_of_reps(kmeds: KMedoids[NMedoids, SmallD]) -> int: 137 | reps = [texts[idx] for idx in kmeds.medoid_indices_] 138 | ret = len(batch_encode_clusters(model_name, prompt, [reps], max_length=max_length)["input_ids"][0]) 139 | return ret 140 | 141 | with time_segment(f"Choosing reps for sequence of length {len(texts)}", active=profile): 142 | avg_fragment_length = sum(len(text) for text in texts) / len(texts) 143 | first_guess = int(max_length / avg_fragment_length) - 1 144 | meds = _search_for_threshold( 145 | first_guess, 146 | lower_bound=2, 147 | upper_bound=int(max_length / avg_fragment_length * 10), 148 | fn=k_medoids, 149 | pred=lambda x: len_of_reps(x) == max_length, 150 | ) 151 | return [texts[idx] for idx in meds.medoid_indices_], meds.cluster_centers_ 152 | 153 | 154 | def _cluster_centers( 155 | embeddings: ndarray[float, NPoints, SmallD], cluster_ids: ndarray[ClusterID, NPoints] 156 | ) -> Mapping[ClusterID, ndarray[float, SmallD]]: 157 | """Each row in `embeddings` should correspond to a cluster label in `cluster_labels`.""" 158 | 159 | cluster_label_array = cluster_ids 160 | return { 161 | cluster_id: np.mean(embeddings[cluster_label_array == cluster_id, :], axis=0) 162 | for cluster_id in set(cluster_ids) 163 | } 164 | 165 | 166 | noise_cluster_id = np.int64(-1) 167 | 168 | summary_dir = Path("cache/summary") 169 | 170 | 171 | def batch_encode_clusters( 172 | model_name: str, prompt: str, clusters: Sequence[Sequence[str]], max_length: int 173 | ) -> BatchEncoding: 174 | tokenizer = AutoTokenizer.from_pretrained(model_name) 175 | return tokenizer( 176 | [f"{prompt}\n" + "\n".join(cluster) for cluster in clusters], 177 | truncation=True, 178 | max_length=max_length, 179 | padding=True, 180 | return_tensors="pt", 181 | ) 182 | 183 | 184 | @functools.cache 185 | def _summary_model(model_name: SummaryModelName) -> tuple[AutoModelForSeq2SeqLM, AutoTokenizer]: 186 | return ( 187 | AutoModelForSeq2SeqLM.from_pretrained(model_name, low_cpu_mem_usage=True), 188 | AutoTokenizer.from_pretrained(model_name), 189 | ) 190 | 191 | 192 | def _capwords_preserve_case(s: str): 193 | """Capitalize the first letter of each word in a string. 194 | (`.capwords()` will lowercase the rest of the word which doesn't work for e.g. 'AI' or 'NP'.)""" 195 | return " ".join(word[0].upper() + word[1:] for word in s.split()) 196 | 197 | 198 | @cache(summary_dir, cache_type=json_cache) 199 | def _summarize_clusters( 200 | model_name: SummaryModelName, 201 | max_length: int, 202 | prompt: str, 203 | clusters: Mapping[ClusterID, tuple[Sequence[str], ndarray[float, NClusterPoints, SmallD]]], 204 | ) -> Mapping[ClusterID, str]: 205 | """Passes text fragments from cluster to summarization model and returns summaries. 206 | In `clusters`, we expect each row of the embeddings array to correspend to one text fragment `str` in the list. 207 | We need the embeddings so that we can do k-medoids clustering and choose good representatives for each cluster. 208 | (We can't just summarize /every/ text fragment in the cluster because 209 | the summarization model has a maximum input length.) 210 | """ 211 | print(f"Num clusters: {len(clusters.keys())}") 212 | 213 | model, tokenizer = _summary_model(model_name) 214 | tokenized_input = batch_encode_clusters( 215 | model_name, 216 | prompt, 217 | [ 218 | choose_reps_within_length(model_name, prompt, max_length, *v)[0] 219 | for k, v in clusters.items() 220 | if k != noise_cluster_id 221 | ], 222 | max_length=max_length, 223 | ) 224 | batch_size = 1 225 | batch_count = 0 226 | 227 | def generate(input_ids: torch.Tensor) -> Sequence[str]: 228 | nonlocal batch_count 229 | batch_count += 1 230 | # print(f"Batch: {batch_count}") 231 | # for input_ in tokenizer.batch_decode(input_ids, skip_special_tokens=True): 232 | # print(input_) 233 | # print("=======================================") 234 | 235 | with time_segment(f"summary batch. Size of {batch_size}", active=profile): 236 | batch_summary_ids = model.generate( 237 | input_ids, 238 | min_length=2, 239 | max_length=20, 240 | length_penalty=0.8, 241 | repetition_penalty=3.0, 242 | ) 243 | batch_summary_texts = [ 244 | _capwords_preserve_case(t.strip()) 245 | for t in tokenizer.batch_decode(batch_summary_ids, skip_special_tokens=True) 246 | ] 247 | print(batch_summary_texts) 248 | return batch_summary_texts 249 | 250 | return dict( 251 | zip( 252 | list(clusters.keys()), 253 | flatmap(generate, torch.split(tokenized_input["input_ids"], batch_size)), 254 | strict=True, 255 | ) 256 | ) 257 | 258 | 259 | # These two summarization models are much worse than t2t: 260 | # "facebook/bart-large-cnn" 261 | # "sshleifer/distilbart-cnn-12-6" 262 | # Too big to run consistently 263 | # "google/flan-t5-xl" 264 | 265 | SummaryModelName: TypeAlias = Literal["google/flan-t5-large", "google/flan-t5-base"] 266 | 267 | # Callers may request different clustering functionality. 268 | # They can describe the requested functionality via the `ClusterControl` types. 269 | # We then "interpret" the clustering specification into different `ClusterResult` types via `handle_clustering`. 270 | 271 | 272 | class CCColorAndSummarize(NamedTuple): 273 | model_name: SummaryModelName 274 | summary_prompt: str 275 | min_cluster_sizes: Sequence[int] 276 | 277 | 278 | class CCNeither(NamedTuple): 279 | ... 280 | 281 | 282 | class CCColor(NamedTuple): 283 | min_cluster_sizes: Sequence[int] 284 | 285 | 286 | ClusterControl: TypeAlias = CCColorAndSummarize | CCNeither | CCColor 287 | 288 | # Color for each point 289 | PointColoring = NewType("PointColoring", "Sequence[RGBStr]") 290 | 291 | 292 | class CRColor(NamedTuple): 293 | # Color for each point 294 | colors: PointColoring 295 | 296 | 297 | class ClusterData(NamedTuple, Generic[SmallD]): 298 | center: ndarray[float, SmallD] 299 | label: str 300 | color: RGBStr 301 | 302 | 303 | class CRColorAndSummarize(NamedTuple, Generic[SmallD]): 304 | clusters: Sequence[ClusterData[SmallD]] 305 | # Color for each point according to its cluster 306 | # (so the color info here isn't duplicative with the info in `clusters` which is only one per cluster) 307 | colors: PointColoring 308 | 309 | 310 | ID = TypeVar("ID") 311 | 312 | 313 | def compute_cluster_info( 314 | min_cluster_size: int, 315 | texts: list[str], 316 | reduced_embeddings: ndarray[float, NPoints, SmallD], 317 | ) -> tuple[ 318 | ndarray[ClusterID, NPoints], 319 | Mapping[ClusterID, tuple[Sequence[str], ndarray[float, Any, SmallD]]], 320 | Mapping[ClusterID, ndarray[float, SmallD]], 321 | Mapping[ClusterID, Tuple1], 322 | ]: 323 | clusters: ClusterResult[str, NPoints, Any, SmallD] = _cluster(min_cluster_size, texts, reduced_embeddings) 324 | centers = { 325 | k: v 326 | for k, v in _cluster_centers(reduced_embeddings, clusters.cluster_by_point).items() 327 | if k != noise_cluster_id 328 | } 329 | colors = cluster_colors_by_graph(centers) 330 | points_by_cluster = {k: v for k, v in clusters.points_by_cluster.items() if k != noise_cluster_id} 331 | return ( 332 | clusters.cluster_by_point, 333 | points_by_cluster, 334 | centers, 335 | colors, 336 | ) 337 | 338 | 339 | def handle_clustering( 340 | cluster_control: ClusterControl, texts: list[str], reduced_embeddings: ndarray[float, NPoints, SmallD] 341 | ) -> Sequence[CRColorAndSummarize[SmallD]] | Sequence[CRColor] | None: 342 | """One cluster result per `min_cluster_size`. 343 | This allows us to view the plot at different clustering granularities.""" 344 | 345 | def color_by_point_from_cluster( 346 | cluster_by_point: ndarray[ClusterID, NPoints], cluster_to_color: Mapping[ClusterID, Tuple1] 347 | ) -> PointColoring: 348 | return PointColoring( 349 | [ 350 | colors.label_rgb( 351 | cast("Tuple255", (255, 255, 255)) 352 | if cluster_id == noise_cluster_id 353 | else colors.convert_to_RGB_255(cluster_to_color[cluster_id]) 354 | ) 355 | for cluster_id in cluster_by_point 356 | ] 357 | ) 358 | 359 | match cluster_control: 360 | case CCColor(min_cluster_size): 361 | 362 | def _c_inner(min_cluster_size: int): 363 | cluster_by_point, _, _, cluster_to_color = compute_cluster_info( 364 | min_cluster_size, texts, reduced_embeddings 365 | ) 366 | return CRColor(colors=color_by_point_from_cluster(cluster_by_point, cluster_to_color)) 367 | 368 | return [_c_inner(mcs) for mcs in min_cluster_size] 369 | case CCColorAndSummarize(summary_model_name, summary_prompt, min_cluster_size): 370 | 371 | def _cs_inner(min_cluster_size: int): 372 | cluster_by_point, clusters, cluster_to_center, cluster_to_color = compute_cluster_info( 373 | min_cluster_size, texts, reduced_embeddings 374 | ) 375 | summaries = _summarize_clusters( 376 | summary_model_name, 377 | max_length=2048, 378 | prompt=summary_prompt, 379 | clusters=clusters, 380 | ) 381 | assert cluster_to_center.keys() == summaries.keys(), (cluster_to_center.keys(), summaries.keys()) 382 | assert cluster_to_center.keys() == cluster_to_color.keys(), ( 383 | cluster_to_center.keys(), 384 | cluster_to_color.keys(), 385 | ) 386 | return CRColorAndSummarize( 387 | [ 388 | ClusterData( 389 | cluster_to_center[k], 390 | summaries[k], 391 | colors.label_rgb(colors.convert_to_RGB_255(cluster_to_color[k])), 392 | ) 393 | for k in cluster_to_center.keys() 394 | ], 395 | color_by_point_from_cluster(cluster_by_point, cluster_to_color), 396 | ) 397 | 398 | return [_cs_inner(mcs) for mcs in min_cluster_size] 399 | case CCNeither(): 400 | return None 401 | -------------------------------------------------------------------------------- /hammock/color.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from itertools import combinations 4 | from typing import TYPE_CHECKING, Mapping, NewType, Sequence, TypeVar, cast 5 | 6 | from networkx import Graph, equitable_color 7 | import numpy as np 8 | from numpy import ndarray 9 | from plotly import colors 10 | from scipy.spatial.distance import pdist, squareform 11 | from sklearn.preprocessing import minmax_scale 12 | 13 | from .util import time_segment 14 | 15 | if TYPE_CHECKING: 16 | from networkx import WeightDict 17 | from plotly.colors import HexStr, PlotlyScales, RGBStr, Tuple1, Tuple255 18 | 19 | profile = False 20 | noise_cluster_id = -1 21 | 22 | ID = TypeVar("ID") 23 | 24 | # 2D or 3D 25 | SmallD = TypeVar("SmallD") 26 | 27 | 28 | def graph_from_centers(cluster_centers: Mapping[ID, ndarray[float, SmallD]], percentile: float) -> Graph[ID]: 29 | """Given a collection of points representing the centers of clusters, 30 | create an undirected graph connecting those points. 31 | Two cluster centers are connected if the (Euclidean) distance between them is 32 | less than the given percentile of all pairwise distances.""" 33 | 34 | nodes: list[ID] = list(cluster_centers.keys()) 35 | node_pairs: list[tuple[ID, ID]] = list(combinations(nodes, 2)) 36 | all_distances: list[float] = [np.linalg.norm(cluster_centers[u] - cluster_centers[v]) for u, v in node_pairs] 37 | threshold: float = np.percentile(np.array(all_distances), percentile) 38 | edges: list[tuple[ID, ID, WeightDict]] = [ 39 | (u, v, {"weight": d}) for (u, v), d in zip(node_pairs, all_distances, strict=True) if d < threshold 40 | ] 41 | graph: Graph[ID] = Graph() 42 | graph.add_nodes_from(nodes) 43 | graph.add_edges_from(edges) 44 | return graph 45 | 46 | 47 | def cluster_colors_by_graph(cluster_centers: Mapping[ID, ndarray[float, SmallD]]) -> Mapping[ID, Tuple1]: 48 | """Given a collection of points representing the centers of clusters, 49 | choose colors for those clusters by coloring the nodes of the corresponding graph 50 | such that no two adjacent nodes have the same color.""" 51 | 52 | graph = graph_from_centers(cluster_centers, percentile=10) 53 | max_degree = max(d for _, d in graph.degree()) 54 | with time_segment("coloring", active=profile): 55 | cluster_label_to_color_index: Mapping[ID, int] = equitable_color(graph, max(max_degree + 1, 10)) 56 | color_scale = colors.sample_colorscale( 57 | "Phase", len(set(cluster_label_to_color_index.values())), colortype="tuple" 58 | ) 59 | return {k: color_scale[v] for k, v in cluster_label_to_color_index.items()} 60 | 61 | 62 | NPoints = TypeVar("NPoints") 63 | NDims = TypeVar("NDims") 64 | NColors = NewType("NColors", int) 65 | EmbeddingDim = TypeVar("EmbeddingDim") 66 | 67 | 68 | def _colors_from_scale( 69 | scale_name: PlotlyScales, values: ndarray[float, NPoints] 70 | ) -> ndarray[float, NPoints, Tuple1]: 71 | """Plotly sequential color scales have a finite list of colors. 72 | We transform that into a truly continuous color scale by interpolating. 73 | Also, it's vectorized because this turned out to be a performance bottleneck. 74 | But morally, we want 75 | `_colors_from_scale(scale_name: PlotlyScales) -> Callable[[float], tuple[float, float, float]]`.""" 76 | 77 | clamped = np.clip(values, 0, 1) 78 | colorscale = colors.convert_colors_to_same_type(getattr(colors.sequential, scale_name), "tuple")[0] 79 | colorscale_array = cast(ndarray[float, NColors, Tuple1], np.array(colorscale)) 80 | # Our actual fractional 'index' 81 | index: ndarray[float, NPoints] = clamped * (len(colorscale) - 1) 82 | # Closest integral indices to pick from provided colors in scale and interpolate between 83 | index_low = np.floor(index) 84 | index_high = np.ceil(index) 85 | interp = index - index_low 86 | return np.floor( 87 | ( 88 | colorscale_array[index_low.astype(int), :] * (1 - interp)[:, np.newaxis] 89 | + colorscale_array[index_high.astype(int), :] * interp[:, np.newaxis] 90 | ) 91 | ) 92 | 93 | 94 | def _compute_distances(data: ndarray[float, NPoints, NDims]) -> ndarray[float, NPoints]: 95 | return np.mean(squareform(pdist(data, metric="euclidean")), axis=0) 96 | 97 | 98 | def distance_colors(embeddings: ndarray[float, NPoints, EmbeddingDim]) -> Sequence[RGBStr]: 99 | """Color points using a continuous scale based on their average distance to all other points.""" 100 | norm_distances = minmax_scale(_compute_distances(embeddings)) 101 | return [ 102 | colors.label_rgb(colors.convert_to_RGB_255(cast("Tuple1", c))) 103 | for c in _colors_from_scale("thermal", norm_distances) 104 | ] 105 | 106 | 107 | def qualitative_colors(num_colors: int) -> Sequence[RGBStr]: 108 | def _to_rgb(x: str) -> Tuple255: 109 | match x: 110 | case x if x.startswith("rgb"): 111 | return colors.unlabel_rgb(cast("RGBStr", x)) 112 | case x if x.startswith("#"): 113 | return colors.hex_to_rgb(cast("HexStr", x)) 114 | case _: 115 | raise ValueError(f"Unknown color format: {x}") 116 | 117 | return [colors.label_rgb(_to_rgb(c)) for c in colors.qualitative.Plotly[:num_colors]] 118 | -------------------------------------------------------------------------------- /hammock/core.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | import re 5 | from typing import ( 6 | Callable, 7 | Literal, 8 | NamedTuple, 9 | NewType, 10 | Optional, 11 | Sequence, 12 | TypeAlias, 13 | TypeVar, 14 | cast, 15 | ) 16 | 17 | from gutenbergpy.gutenbergcache import GutenbergCache 18 | from gutenbergpy.textget import get_text_by_id, strip_headers 19 | import nltk 20 | from nltk.tokenize import sent_tokenize 21 | import wikipedia 22 | 23 | from .cache import cache 24 | from .cluster import ClusterControl, CCNeither 25 | from .embedding import EmbeddingModelName 26 | from .plot import Source, plot_single, plot_multiple 27 | from .util import Either, Failure, Success, flatmap, thunkify, time_segment, unzip 28 | 29 | nltk.download("punkt", quiet=True) 30 | 31 | profile = True 32 | 33 | EmbeddingDim = TypeVar("EmbeddingDim") 34 | 35 | 36 | class TextUnit(NamedTuple): 37 | unit: Literal["paragraph", "sentence"] 38 | # Even sentences get a `para_newline_type` because 39 | # `sent_tokenize` sometimes considers double newlines to be part of a single paragraph. 40 | # So we manually split on paragraphs first. 41 | para_newline_type: NewlineType 42 | 43 | 44 | NewlineType: TypeAlias = Literal["double", "single"] 45 | 46 | 47 | def para_tokenize(newline_type: NewlineType, text: str) -> list[str]: 48 | """Split a text block into paragraphs. 49 | Some sources use double newlines to separate paragraphs and some use single newlines.""" 50 | return [ 51 | paragraph.strip() 52 | for paragraph in text.split("\n\n" if newline_type == "double" else "\n") 53 | if paragraph.strip() != "" 54 | ] 55 | 56 | 57 | def tokenize( 58 | text_unit: TextUnit, 59 | always_bad_sentence_predicates: list[Callable[[str], bool]], 60 | standalone_bad_sentence_predicates: list[Callable[[str], bool]], 61 | bad_para_predicates: list[Callable[[str], bool]], 62 | text: str, 63 | ) -> list[str]: 64 | """Some sentences we always want to filter out (e.g. chapter declarations) 65 | while some are okay in the context of a paragraph but not as standalone sentences 66 | (e.g. really short sentences). 67 | We control this filtering via the predicate arguments here.""" 68 | 69 | # We paragraph tokenize first because 70 | # `sent_tokenize` sometimes considers double newlines to be part of a single sentence. 71 | paras = para_tokenize(text_unit.para_newline_type, text) 72 | match text_unit.unit: 73 | case "sentence": 74 | return [ 75 | sent 76 | for sent in flatmap(sent_tokenize, paras) 77 | if not any(p(sent) for p in always_bad_sentence_predicates + standalone_bad_sentence_predicates) 78 | ] 79 | case "paragraph": 80 | 81 | def sentence_preds_for_para(para: list[str]) -> list[Callable[[str], bool]]: 82 | match len(para): 83 | case 0 | 1: 84 | return always_bad_sentence_predicates + standalone_bad_sentence_predicates 85 | case _: 86 | return always_bad_sentence_predicates 87 | 88 | sentences_for_paras: list[list[str]] = [sent_tokenize(p) for p in paras] 89 | filtered_sentences_for_paras: list[list[str]] = [ 90 | [ 91 | sent 92 | for sent in sentences_for_para 93 | if not any(p(sent) for p in sentence_preds_for_para(sentences_for_para)) 94 | ] 95 | for para, sentences_for_para, in zip(paras, sentences_for_paras, strict=True) 96 | if not any(p(para) for p in bad_para_predicates) 97 | ] 98 | return [" ".join(sentences) for sentences in filtered_sentences_for_paras if len(sentences) > 0] 99 | 100 | 101 | def clip_string(s: str, start_anchor: str, end_anchor: str) -> Optional[str]: 102 | """Some texts have preamble and postamble that we want to ignore. 103 | We do this by clipping the text to the specified anchors.""" 104 | match re.compile(f"({re.escape(start_anchor)}.*?{re.escape(end_anchor)})", re.DOTALL).search(s): 105 | case None: 106 | return None 107 | case res: 108 | return res.group(1) 109 | 110 | 111 | # Freeform 112 | 113 | 114 | def plot_freeform( 115 | embedding_model_name: EmbeddingModelName[EmbeddingDim], 116 | embedding_instruction: str, 117 | texts: list[str], 118 | cluster_control: ClusterControl = CCNeither(), 119 | include_labels: Literal["include_labels", "exclude_labels"] = "include_labels", 120 | ) -> Path: 121 | return plot_single( 122 | embedding_model_name, 123 | embedding_instruction, 124 | dimensions=3, 125 | cluster_control=cluster_control, 126 | source=Source("Freeform", texts), 127 | include_labels=include_labels, 128 | ) 129 | 130 | 131 | # Project Gutenberg 132 | 133 | 134 | def gutenberg_embedding_instruction(text_unit: TextUnit): 135 | return f"Represent the Fiction {text_unit.unit} for clustering by theme: " 136 | 137 | 138 | class GutenbergArgs(NamedTuple): 139 | title: str 140 | start_anchor: str 141 | end_anchor: str 142 | 143 | 144 | GutenbergID = NewType("GutenbergID", int) 145 | 146 | 147 | def get_gutenberg_text(gutenberg_id: GutenbergID) -> str: 148 | return strip_headers(get_text_by_id(gutenberg_id)).decode("utf-8") 149 | 150 | 151 | gutenberg_cache = thunkify(GutenbergCache.get_cache) 152 | 153 | 154 | def fetch_gutenberg(title: str, text_unit: TextUnit, start_anchor: str, end_anchor: str) -> Either[str, list[str]]: 155 | ids = cast(list[GutenbergID], gutenberg_cache().query(titles=[title])) 156 | text = None 157 | gutenberg_id = None 158 | for i in ids: 159 | try: 160 | text = get_gutenberg_text(i) 161 | gutenberg_id = i 162 | break 163 | except Exception as e: 164 | print( 165 | f"Failed while fetching Project Guterberg book {title} at ID {i}:\n {e}\n" 166 | f"Continuing with IDs from {ids}" 167 | ) 168 | continue 169 | if text is None or gutenberg_id is None: 170 | return Failure(f"Could not find {title}") 171 | clipped = clip_string(text, start_anchor, end_anchor) 172 | if clipped is None: 173 | return Failure(f"Could not find {start_anchor} or/and {end_anchor} in {title}") 174 | always_bad_sentence_predicates: list[Callable[[str], bool]] = [ 175 | lambda sent: "chapter" in sent.lower(), 176 | lambda sent: "book" in sent.lower(), 177 | ] 178 | standalone_bad_sentence_predicates: list[Callable[[str], bool]] = [ 179 | lambda sent: len(sent.split()) < 3, 180 | lambda sent: not any(char.islower() for char in sent), 181 | ] 182 | return Success( 183 | tokenize( 184 | text_unit, 185 | always_bad_sentence_predicates=always_bad_sentence_predicates, 186 | standalone_bad_sentence_predicates=standalone_bad_sentence_predicates, 187 | bad_para_predicates=[], 188 | text=clipped, 189 | ) 190 | ) 191 | 192 | 193 | def plot_single_gutenberg( 194 | embedding_model_name: EmbeddingModelName[EmbeddingDim], 195 | text_unit: TextUnit, 196 | cluster_control: ClusterControl, 197 | title: str, 198 | start_anchor: str, 199 | end_anchor: str, 200 | ) -> Either[str, Path]: 201 | match fetch_gutenberg(title, text_unit, start_anchor, end_anchor): 202 | case Failure() as f: 203 | return f 204 | case Success(texts): 205 | return Success( 206 | plot_single( 207 | embedding_model_name, 208 | gutenberg_embedding_instruction(text_unit), 209 | dimensions=2, 210 | cluster_control=cluster_control, 211 | source=Source(title, texts), 212 | include_labels="include_labels", 213 | ) 214 | ) 215 | 216 | 217 | def plot_multiple_gutenberg( 218 | embedding_model_name: EmbeddingModelName[EmbeddingDim], 219 | text_unit: TextUnit, 220 | cluster_control: ClusterControl, 221 | gutenberg_argss: Sequence[GutenbergArgs], 222 | ) -> Path: 223 | titles, _, _ = unzip(gutenberg_argss) 224 | textss = [ 225 | x.value 226 | for x in [fetch_gutenberg(text_unit=text_unit, **args._asdict()) for args in gutenberg_argss] 227 | if not isinstance(x, Failure) 228 | ] 229 | return plot_multiple( 230 | embedding_model_name, 231 | gutenberg_embedding_instruction(text_unit), 232 | dimensions=2, 233 | cluster_control=cluster_control, 234 | sources=[Source(title, texts) for title, texts in zip(titles, textss, strict=True)], 235 | include_labels="include_labels", 236 | ) 237 | 238 | 239 | # Wikipedia 240 | 241 | 242 | def wiki_embedding_instruction(text_unit: TextUnit): 243 | return f"Represent the Wikipedia {text_unit.unit} for clustering by topic: " 244 | 245 | 246 | def wiki_summary_prompt(text_unit: TextUnit, titles: Sequence[str]) -> str: 247 | return ( 248 | f"These are clustered {text_unit.unit}s from Wikipedia. " 249 | "Please provide a topic summarizing and describing the cluster. " 250 | "The topic should cover as many of the paragraphs as reasonably possible. " 251 | "Make sure to look at ALL paragraphs and include them ALL in your analysis. " 252 | "A topic should typically be a noun phrase. " 253 | f"Make sure the topic you choose is NOT just one of: {'; '.join(titles)}. " 254 | f"{text_unit.unit}s begin here: " 255 | ) 256 | 257 | 258 | text_dir = Path("cache/texts") 259 | 260 | 261 | @cache(text_dir) 262 | def fetch_wiki(wiki_title: str, text_unit: TextUnit) -> list[str]: 263 | with time_segment("Wiki download", active=profile): 264 | page = wikipedia.page(title=wiki_title, auto_suggest=False).content 265 | page_body = re.sub(r"== See also ==.*", "", page, flags=re.DOTALL) 266 | no_headers = re.sub(r"(=+ [^=]+? =+)", "", page_body, flags=re.DOTALL) 267 | return tokenize( 268 | text_unit, 269 | always_bad_sentence_predicates=[], 270 | standalone_bad_sentence_predicates=[ 271 | lambda sent: len(sent.split()) > 4, 272 | ], 273 | bad_para_predicates=[], 274 | text=no_headers, 275 | ) 276 | 277 | 278 | def plot_single_wiki( 279 | embedding_model_name: EmbeddingModelName[EmbeddingDim], 280 | text_unit: TextUnit, 281 | cluster_control: ClusterControl, 282 | title: str, 283 | ) -> Path: 284 | return plot_single( 285 | embedding_model_name, 286 | wiki_embedding_instruction(text_unit), 287 | source=Source(title, fetch_wiki(title, text_unit)), 288 | dimensions=3, 289 | cluster_control=cluster_control, 290 | include_labels="include_labels", 291 | ) 292 | 293 | 294 | def plot_multiple_wiki( 295 | embedding_model_name: EmbeddingModelName[EmbeddingDim], 296 | text_unit: TextUnit, 297 | cluster_control: ClusterControl, 298 | titles: Sequence[str], 299 | ) -> Path: 300 | textss = [fetch_wiki(title, text_unit) for title in titles] 301 | return plot_multiple( 302 | embedding_model_name, 303 | wiki_embedding_instruction(text_unit), 304 | dimensions=3, 305 | cluster_control=cluster_control, 306 | sources=[Source(t, ss) for (t, ss) in zip(titles, textss, strict=True)], 307 | include_labels="include_labels", 308 | ) 309 | -------------------------------------------------------------------------------- /hammock/embedding.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import functools 4 | from pathlib import Path 5 | from typing import Generic, Literal, NamedTuple, Protocol, TypeVar, cast 6 | import warnings 7 | 8 | from InstructorEmbedding import INSTRUCTOR 9 | from numba.core.errors import NumbaDeprecationWarning 10 | from numpy import ndarray 11 | from scipy.spatial.distance import pdist, squareform 12 | from sentence_transformers import SentenceTransformer 13 | 14 | with warnings.catch_warnings(): 15 | warnings.filterwarnings("ignore", category=NumbaDeprecationWarning) 16 | from umap import UMAP 17 | 18 | from .cache import cache 19 | from .util import time_segment 20 | 21 | profile = True 22 | 23 | embeddings_dir = Path("cache/embeddings") 24 | 25 | NTexts = TypeVar("NTexts") 26 | EmbeddingDim = TypeVar("EmbeddingDim") 27 | SmallD = TypeVar("SmallD", bound=Literal[2, 3]) 28 | 29 | 30 | # Just some convenience protocols to make our `_embedding_model` type signature a little more readable 31 | class Embedder(Protocol[EmbeddingDim]): 32 | def __call__(self, text_fragments: list[str]) -> ndarray[float, NTexts, EmbeddingDim]: 33 | ... 34 | 35 | 36 | class MkEmbedder(Protocol[EmbeddingDim]): 37 | def __call__(self, embedding_instruction: str) -> Embedder[EmbeddingDim]: 38 | ... 39 | 40 | 41 | @functools.cache 42 | def _embedding_model(model_desc: EmbeddingModelName[EmbeddingDim]) -> MkEmbedder[EmbeddingDim]: 43 | """The type signature is a little clumsy here but we want to ensure our `@functools.cache` shares 44 | the same model instance across all calls with a given `model_name`.""" 45 | if model_desc.name.startswith("hkunlp/instructor"): 46 | instructor: INSTRUCTOR[EmbeddingDim] = INSTRUCTOR(model_desc.name) 47 | return lambda embedding_instruction: lambda text_fragments: instructor.encode( 48 | [[embedding_instruction, t] for t in text_fragments], 49 | batch_size=2, 50 | show_progress_bar=True, 51 | ) 52 | else: 53 | st: SentenceTransformer[EmbeddingDim] = SentenceTransformer(model_desc.name) 54 | return lambda embedding_instruction: lambda text_fragments: st.encode( 55 | text_fragments, batch_size=2, show_progress_bar=True 56 | ) 57 | 58 | 59 | @cache(embeddings_dir) 60 | def get_embeddings( 61 | model_name: EmbeddingModelName[EmbeddingDim], 62 | embedding_instruction: str, 63 | text_fragments: list[str], 64 | ) -> ndarray[float, NTexts, EmbeddingDim]: 65 | with time_segment("embedding", active=profile): 66 | # Have to cast to tie EmbeddingDim@get_embeddings to EmbeddingDim@_embedding_model 67 | # In theory, we could attach the `EmbeddingDim` to `EmbeddingModelName` as a phantom type, 68 | # but that seems like it's more work than it's worth ATM. 69 | return cast( 70 | "ndarray[float, NTexts, EmbeddingDim]", 71 | _embedding_model(model_name)(embedding_instruction)(text_fragments), 72 | ) 73 | 74 | 75 | # Can't load these because they're too big for my laptop 76 | # "hkunlp/instructor-xl" 77 | # "sentence-t5-xxl" 78 | # Doesn't seem that great? 79 | # i.e. Pareto-dominated by "instructor" line which seem get better performance at smaller size. 80 | # "sentence-t5-xl" 81 | 82 | 83 | class EmbeddingModelName(NamedTuple, Generic[EmbeddingDim]): 84 | # hkunlp/instructors are clearly better but notably slower. The other two can be nice for quick testing. 85 | name: Literal["hkunlp/instructor-large", "hkunlp/instructor-base", "all-mpnet-base-v2", "all-MiniLM-L6-v2"] 86 | 87 | 88 | instructor_large = EmbeddingModelName[Literal["instructor-large-dim"]]("hkunlp/instructor-large") 89 | instructor_base = EmbeddingModelName[Literal["instructor-base-dim"]]("hkunlp/instructor-base") 90 | all_MP_net_base = EmbeddingModelName[Literal["mpnet-base-dim"]]("all-mpnet-base-v2") 91 | all_mini_LM = EmbeddingModelName[Literal["minilm-dim"]]("all-MiniLM-L6-v2") 92 | 93 | 94 | @cache(embeddings_dir) 95 | def reduce_embeddings( 96 | dimensions: Literal[2, 3], n_neighbors: int, embeddings: ndarray[float, NTexts, EmbeddingDim] 97 | ) -> ndarray[float, NTexts, SmallD]: 98 | """Use UMAP to reduce the dimensionality of the embeddings to 2 or 3 dimensions. 99 | `n_neighbors` seems to be the most important parameter for our use case: 100 | https://umap-learn.readthedocs.io/en/latest/parameters.html#n-neighbors 101 | """ 102 | with time_segment(f"UMAP {dimensions}D", active=profile): 103 | # segfaults if we don't precompute for values larger than 4096 (on some data) 104 | with warnings.catch_warnings(): 105 | warnings.filterwarnings( 106 | "ignore", message="using precomputed metric; inverse_transform will be unavailable" 107 | ) 108 | return UMAP( 109 | n_neighbors=n_neighbors, 110 | min_dist=0, 111 | n_components=cast(SmallD, dimensions), 112 | metric="precomputed", 113 | verbose=True, 114 | ).fit_transform(squareform(pdist(embeddings, metric="euclidean"))) 115 | -------------------------------------------------------------------------------- /hammock/gunicorn.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from gunicorn.app.wsgiapp import WSGIApplication 4 | 5 | 6 | class GunicornApp(WSGIApplication): 7 | def init(self, parser: Any, opts: Any, args: list[Any]): 8 | # Tell gunicorn which app to run 9 | # Like running `gunicorn hammock.web:app` from the command line 10 | args.insert(0, "hammock.web:app") 11 | super().init(parser, opts, args) 12 | self.cfg.set("timeout", 600) 13 | # Auto-reload 14 | self.cfg.set("reload", True) 15 | 16 | 17 | # Need this as a separate function so that it can be specified in 18 | # the `tool.poetry.scripts` section of `pyproject.toml` 19 | def main(): 20 | GunicornApp().run() 21 | 22 | 23 | if __name__ == "__main__": 24 | main() 25 | -------------------------------------------------------------------------------- /hammock/plot.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from io import StringIO 4 | from pathlib import Path 5 | import textwrap 6 | from typing import ( 7 | TYPE_CHECKING, 8 | Any, 9 | Generic, 10 | Literal, 11 | Mapping, 12 | NamedTuple, 13 | NewType, 14 | Optional, 15 | Sequence, 16 | TypeAlias, 17 | TypeVar, 18 | cast, 19 | ) 20 | import typing 21 | 22 | import numpy as np 23 | from numpy import ndarray 24 | from plotly.graph_objects import Scatter3d 25 | import plotly.graph_objects as go 26 | 27 | from .cache import BytesCache, cache_with_path 28 | from .cluster import ( 29 | ClusterControl, 30 | ClusterData, 31 | CRColorAndSummarize, 32 | CRColor, 33 | PointColoring, 34 | handle_clustering, 35 | ) 36 | from .color import distance_colors, qualitative_colors 37 | from .embedding import EmbeddingModelName, get_embeddings, reduce_embeddings 38 | from .util import flatmap, flatten, transpose, unstack, unzip 39 | 40 | if TYPE_CHECKING: 41 | from plotly.colors import RGBStr 42 | 43 | profile = True 44 | 45 | output_dir = Path("output") 46 | 47 | MarkerType: TypeAlias = Literal["circle", "square", "circle-open", "diamond", "square-open", "diamond-open"] 48 | marker_types: Sequence[MarkerType] = typing.get_args(MarkerType) 49 | 50 | NPoints = TypeVar("NPoints") 51 | SmallD = TypeVar("SmallD", bound=Literal[2, 3]) 52 | 53 | 54 | class PlotData(NamedTuple, Generic[NPoints, SmallD]): 55 | embeddings: ndarray[float, NPoints, SmallD] 56 | text_fragments: Sequence[str] 57 | title: str 58 | marker_colors: Sequence[RGBStr] | RGBStr 59 | marker: MarkerType 60 | # Used when we want to toggle between different levels of clustering 61 | alt_colors: Optional[Sequence[PointColoring]] = None 62 | 63 | 64 | def _cluster_marker_trace(clustering: Sequence[ClusterData[SmallD]], visible: bool, index: int) -> Scatter3d: 65 | """Add "passive" trace for cluster markers to the figure. 66 | Most will be hidden but we can use JS to toggle them on/off.""" 67 | center_array = np.array([c.center for c in clustering]) 68 | return go.Scatter3d( 69 | x=center_array[:, 0], 70 | y=center_array[:, 1], 71 | z=center_array[:, 2], 72 | text=[c.label for c in clustering], 73 | mode="markers+text", 74 | showlegend=False, 75 | hovertemplate="", 76 | visible=visible, 77 | textfont=dict( 78 | size=16, 79 | family="Arial Black", 80 | ), 81 | marker=dict( 82 | size=4, 83 | color=[c.color for c in clustering], 84 | showscale=False, 85 | symbol="x", 86 | ), 87 | name=f"cluster_markers_{index}", 88 | ) 89 | 90 | 91 | def _marker_type_size_adjustment(marker_type: MarkerType) -> float: 92 | """Not all markers have the same visual impact. 93 | We adjust the size depending on the marker type to make them more visually similar.""" 94 | 95 | match marker_type: 96 | case "circle" | "circle-open": 97 | return 1.0 98 | case "square" | "square-open": 99 | return 0.8 100 | case "diamond" | "diamond-open": 101 | return 0.7 102 | 103 | 104 | def _point_traces_for_single_source( 105 | plot_data: PlotData[NPoints, SmallD], 106 | include_labels: Literal["include_labels", "exclude_labels"], 107 | ) -> Sequence[Scatter3d]: 108 | """A plot may have multiple sources (e.g. different books). 109 | This function handles all the traces that are exclusive to a single source.""" 110 | traces: Sequence[Scatter3d] = [] 111 | match plot_data.embeddings.shape[1]: 112 | case 3: 113 | data = plot_data.embeddings 114 | case 2: 115 | # If we have 2D embeddings, we make the third dimension a "chronological" dimension 116 | # reflecting the original text order. 117 | data = np.insert( 118 | plot_data.embeddings, 0, np.arange(plot_data.embeddings.shape[0], dtype=float), axis=1 119 | ) 120 | case _: 121 | raise ValueError(f"Unsupported number of dimensions: {plot_data.embeddings.shape[1]}") 122 | # Outer sequence is one element per point 123 | # Inner sequence is one element per alternative cluster coloring scheme 124 | alt_colors_as_custom_data: Sequence[Mapping[str, Sequence[RGBStr]]] = ( 125 | [] if plot_data.alt_colors is None else [{"colors": c} for c in transpose(plot_data.alt_colors)] 126 | ) 127 | traces.append( 128 | go.Scatter3d( 129 | customdata=alt_colors_as_custom_data, 130 | x=data[:, 0], 131 | y=data[:, 1], 132 | z=data[:, 2], 133 | mode="markers", 134 | showlegend=False, 135 | text=["
".join(textwrap.wrap(frag, width=120)) for frag in plot_data.text_fragments] 136 | if include_labels == "include_labels" 137 | else None, 138 | marker=dict( 139 | size=6 * _marker_type_size_adjustment(plot_data.marker), 140 | color=plot_data.marker_colors, 141 | showscale=False, 142 | opacity=0.4, 143 | symbol=plot_data.marker, 144 | line_width=2 if "-open" in plot_data.marker else None, 145 | line_color=plot_data.marker_colors, 146 | ), 147 | **cast(Any, 148 | dict(hovertemplate="%{text}") 149 | if include_labels == "include_labels" 150 | else dict(hoverinfo="none") 151 | ), 152 | ) 153 | ) 154 | # If we have linear embeddings, add a little marker highlighting the first text fragment 155 | if plot_data.embeddings.shape[1] == 2: 156 | traces.append( 157 | go.Scatter3d( 158 | x=[data[0, 0]], 159 | y=[data[0, 1]], 160 | z=[data[0, 2]], 161 | mode="markers", 162 | showlegend=False, 163 | text=["Start"], 164 | hovertemplate="%{text}", 165 | marker=dict(symbol="circle-open", size=10, color="yellow", line=dict(color="yellow", width=2)), 166 | ) 167 | ) 168 | return traces 169 | 170 | 171 | # HTML in `bytes` form 172 | HTMLBytes = NewType("HTMLBytes", bytes) 173 | output_dir = Path("output") 174 | # pyright just types this as `dict[str, str]` if we declare it inline 175 | bytes_html: BytesCache = {"format": "bytes", "ext": "html"} 176 | 177 | 178 | @cache_with_path(output_dir, bytes_html) 179 | def _plot( 180 | plot_datas: Sequence[PlotData[NPoints, SmallD]], 181 | clusterss: Sequence[Sequence[ClusterData[SmallD]]], 182 | include_script: Literal["include_script", "exclude_script"] = "include_script", 183 | include_labels: Literal["include_labels", "exclude_labels"] = "include_labels", 184 | ) -> HTMLBytes: 185 | fig = go.Figure() 186 | 187 | titles = [pt.title for pt in plot_datas] 188 | for trace in flatmap(lambda p: _point_traces_for_single_source(p, include_labels=include_labels), plot_datas): 189 | fig.add_trace(trace) 190 | for i, clustering in enumerate(clusterss): 191 | fig.add_trace(_cluster_marker_trace(clustering, visible=(i == 0), index=i)) 192 | 193 | background_color = "rgb(168, 168, 192)" 194 | 195 | fig.add_annotation( 196 | text=( 197 | ( 198 | "Use the left and right keys on the keyboard to step through fragments.
" 199 | if include_labels == "include_labels" 200 | else "" 201 | ) 202 | + "Use the up and down keys to change clustering level." 203 | ), 204 | xref="paper", 205 | yref="paper", 206 | x=0, 207 | y=0, 208 | showarrow=False, 209 | ) 210 | 211 | invisible_axis = dict( 212 | showticklabels=False, 213 | showspikes=False, 214 | backgroundcolor=background_color, 215 | gridcolor=background_color, 216 | zerolinecolor=background_color, 217 | ) 218 | 219 | # Make the plot a a vast, formless void 220 | fig.update_layout( 221 | paper_bgcolor=background_color, 222 | plot_bgcolor=background_color, 223 | scene=dict( 224 | xaxis_title="", 225 | yaxis_title="", 226 | zaxis_title="", 227 | xaxis=invisible_axis, 228 | yaxis=invisible_axis, 229 | zaxis=invisible_axis, 230 | ), 231 | title="
".join(titles), 232 | ) 233 | 234 | with StringIO() as s: 235 | match include_script: 236 | case "include_script": 237 | with open("templates/plotly.js", encoding="utf-8") as f: 238 | fig.write_html(s, include_plotlyjs="cdn", post_script=f.read()) 239 | case "exclude_script": 240 | fig.write_html(s, include_plotlyjs="cdn") 241 | return HTMLBytes(s.getvalue().encode("utf-8")) 242 | 243 | 244 | class Source(NamedTuple): 245 | title: str 246 | text_fragments: list[str] 247 | 248 | 249 | EmbeddingDim = TypeVar("EmbeddingDim") 250 | 251 | 252 | def _mk_plot_and_cluster_data( 253 | unstacked_reduced_embeddings: Sequence[ndarray[float, NPoints, SmallD]], 254 | original_embeddings: Optional[ndarray[float, NPoints, EmbeddingDim]], 255 | cluster_control: ClusterControl, 256 | sources: Sequence[Source], 257 | ) -> tuple[Sequence[PlotData[NPoints, SmallD]], Sequence[Sequence[ClusterData[SmallD]]]]: 258 | titles, textss = unzip(sources) 259 | split_indices = np.cumsum([e.shape[0] for e in unstacked_reduced_embeddings][:-1]) 260 | match handle_clustering(cluster_control, flatten(textss), np.vstack(unstacked_reduced_embeddings)): 261 | case None: 262 | return [ 263 | PlotData(*t) 264 | for t in zip( 265 | unstacked_reduced_embeddings, 266 | textss, 267 | titles, 268 | # Either qualitatively by source or 269 | # by distance in the original (i.e non-reduced) embedding dimension 270 | qualitative_colors(len(unstacked_reduced_embeddings)) 271 | if original_embeddings is None 272 | else unstack(distance_colors(original_embeddings), split_indices.tolist()), 273 | marker_types[: len(unstacked_reduced_embeddings)], 274 | strict=True, 275 | ) 276 | ], [] 277 | case [CRColorAndSummarize(), *_] as result: 278 | return ( 279 | [ 280 | PlotData(*t) 281 | for t in zip( 282 | unstacked_reduced_embeddings, 283 | textss, 284 | titles, 285 | unstack(result[0].colors, split_indices.tolist()), 286 | marker_types[: len(unstacked_reduced_embeddings)], 287 | transpose([unstack(s.colors, split_indices.tolist()) for s in result]), 288 | strict=True, 289 | ) 290 | ], 291 | [s.clusters for s in result], 292 | ) 293 | case [CRColor(), *_] as colors: 294 | return ( 295 | [ 296 | PlotData(*t) 297 | for t in zip( 298 | unstacked_reduced_embeddings, 299 | textss, 300 | titles, 301 | unstack(colors[0].colors, split_indices.tolist()), 302 | marker_types[: len(unstacked_reduced_embeddings)], 303 | transpose([unstack(c.colors, split_indices.tolist()) for c in colors]), 304 | strict=True, 305 | ) 306 | ], 307 | [], 308 | ) 309 | case []: 310 | raise RuntimeError("Unexpected empty list in plot_multiple") 311 | case _: 312 | raise AssertionError("Pyright can't tell this is already exhaustive") 313 | 314 | 315 | def plot_single( 316 | embedding_model_name: EmbeddingModelName[EmbeddingDim], 317 | embedding_instruction: str, 318 | dimensions: Literal[2, 3], 319 | cluster_control: ClusterControl, 320 | source: Source, 321 | include_labels: Literal["include_labels", "exclude_labels"] = "include_labels", 322 | ) -> Path: 323 | embeddings = get_embeddings(embedding_model_name, embedding_instruction, source.text_fragments) 324 | reduced_embeddings = reduce_embeddings(dimensions, n_neighbors=n_neighbors, embeddings=embeddings) 325 | plot_data, cluster_data = _mk_plot_and_cluster_data( 326 | [reduced_embeddings], embeddings, cluster_control, [source] 327 | ) 328 | return _plot( 329 | plot_data, cluster_data, include_script="include_script", include_labels=include_labels 330 | ).cache_path 331 | 332 | 333 | n_neighbors = 10 334 | 335 | 336 | def plot_multiple( 337 | embedding_model_name: EmbeddingModelName[EmbeddingDim], 338 | embedding_instruction: str, 339 | dimensions: Literal[2, 3], 340 | cluster_control: ClusterControl, 341 | sources: Sequence[Source], 342 | include_labels: Literal["include_labels", "exclude_labels"] = "include_labels", 343 | ) -> Path: 344 | embeddings = [ 345 | get_embeddings(embedding_model_name, embedding_instruction, source.text_fragments) for source in sources 346 | ] 347 | stacked_embeddings = reduce_embeddings(dimensions, n_neighbors=n_neighbors, embeddings=np.vstack(embeddings)) 348 | split_indices = np.cumsum([e.shape[0] for e in embeddings][:-1]) 349 | unstacked_embeddings = np.vsplit(stacked_embeddings, split_indices) 350 | plot_data, cluster_data = _mk_plot_and_cluster_data(unstacked_embeddings, None, cluster_control, sources) 351 | return _plot( 352 | plot_data, cluster_data, include_script="include_script", include_labels=include_labels 353 | ).cache_path 354 | -------------------------------------------------------------------------------- /hammock/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from contextlib import contextmanager 4 | from dataclasses import dataclass 5 | from datetime import timedelta 6 | import re 7 | import time 8 | from typing import ( 9 | Any, 10 | Callable, 11 | Generic, 12 | Sequence, 13 | TypeAlias, 14 | TypeVar, 15 | overload, 16 | ) 17 | 18 | A = TypeVar("A") 19 | B = TypeVar("B") 20 | C = TypeVar("C") 21 | 22 | 23 | def clean_filename(filename: str) -> str: 24 | return re.sub(r'[\s./\*\?\[\]\'"|(){}<>!@#$%^&:;~`_,;]+', "-", filename) 25 | 26 | 27 | def declare(type_: type[A], value: A) -> A: 28 | """Declare a type inline""" 29 | return value 30 | 31 | 32 | @overload 33 | def unzip(x: Sequence[tuple[A, B, C]]) -> tuple[Sequence[A], Sequence[B], Sequence[C]]: 34 | ... 35 | 36 | 37 | @overload 38 | def unzip(x: Sequence[tuple[A, B]]) -> tuple[Sequence[A], Sequence[B]]: 39 | ... 40 | 41 | 42 | def unzip(x: Sequence[tuple[Any, ...]]) -> tuple[Sequence[Any], ...]: 43 | return tuple(zip(*x)) 44 | 45 | 46 | @contextmanager 47 | def time_segment(name: str, active: bool = True): 48 | """Quick and dirty profiling helper. Reports time spent in block.""" 49 | if active: 50 | print(f"Entering {name}") 51 | ts = time.time() 52 | yield 53 | print(f"Exiting {name} after {timedelta(seconds=time.time() - ts)}") 54 | else: 55 | yield 56 | 57 | 58 | def transpose(lists: Sequence[Sequence[A]]) -> Sequence[Sequence[A]]: 59 | return [list(sub_list) for sub_list in zip(*lists)] 60 | 61 | 62 | def thunkify(fn: Callable[[], A]) -> Callable[[], A]: 63 | """Memoize a function with no arguments""" 64 | x = None 65 | 66 | def inner(): 67 | nonlocal x 68 | if x is None: 69 | x = fn() 70 | return x 71 | else: 72 | return x 73 | 74 | return inner 75 | 76 | 77 | X_co = TypeVar("X_co", covariant=True) 78 | Y_co = TypeVar("Y_co", covariant=True) 79 | 80 | 81 | @dataclass(frozen=True) 82 | class Failure(Generic[X_co]): 83 | value: X_co 84 | 85 | 86 | @dataclass(frozen=True) 87 | class Success(Generic[X_co]): 88 | value: X_co 89 | 90 | 91 | Either: TypeAlias = Failure[X_co] | Success[Y_co] 92 | 93 | 94 | def flatmap(fn: Callable[[A], Sequence[B]], input_list: Sequence[A]) -> Sequence[B]: 95 | return [item for sublist in (fn(x) for x in input_list) for item in sublist] 96 | 97 | 98 | def flatten(x: Sequence[Sequence[A]]) -> list[A]: 99 | return [item for sublist in x for item in sublist] 100 | 101 | 102 | def unstack(x: Sequence[A], indices: list[int]) -> Sequence[Sequence[A]]: 103 | """Split a sequence into subsequences at the given indices""" 104 | return [x[start:end] for start, end in zip([0] + indices, indices + [len(x)])] 105 | -------------------------------------------------------------------------------- /hammock/web.py: -------------------------------------------------------------------------------- 1 | import json 2 | from flask import Flask, render_template, redirect, request, url_for 3 | from flask_compress import Compress 4 | 5 | from .core import ( 6 | GutenbergArgs, 7 | TextUnit, 8 | plot_multiple_gutenberg, 9 | plot_multiple_wiki, 10 | plot_freeform, 11 | plot_single_gutenberg, 12 | plot_single_wiki, 13 | ) 14 | from .cluster import CCColor, SummaryModelName 15 | from .embedding import instructor_base 16 | from .util import Failure, Success 17 | 18 | app = Flask(__name__, template_folder="../templates", static_folder="../output") 19 | Compress(app) 20 | 21 | 22 | @app.route("/books", methods=["GET"]) 23 | def books(): 24 | with open("output/calibre.json", "r") as f: 25 | return render_template("books.html", books=json.load(f)) 26 | 27 | 28 | @app.route("/", methods=["GET", "POST"]) 29 | def index(): 30 | return render_template("index.html") 31 | 32 | 33 | @app.route("/freeform", methods=["GET", "POST"]) 34 | def freeform(): 35 | match request.method: 36 | case "POST": 37 | return redirect( 38 | plot_freeform( 39 | default_embedding_model, 40 | "Represent the text for clustering:", 41 | request.form["input_field"].splitlines(), 42 | ).as_posix() 43 | ) 44 | case "GET": 45 | return render_template( 46 | "main.html", 47 | title="custom sentences", 48 | label="Sentences separated by newlines (minimum of 7)", 49 | placeholder="Sphinx of black quartz, judge my vow.", 50 | action=url_for("freeform"), 51 | ) 52 | case method: 53 | raise ValueError(f"Unexpected method: {method}") 54 | 55 | 56 | @app.route("/gutenberg", methods=["GET", "POST"]) 57 | def gutenberg(): 58 | match request.method: 59 | case "POST": 60 | match [ 61 | GutenbergArgs(*s) for s in [line.split("|") for line in request.form["input_field"].splitlines()] 62 | ]: 63 | case [gutenberg_args]: 64 | match plot_single_gutenberg( 65 | default_embedding_model, 66 | TextUnit("paragraph", para_newline_type="double"), 67 | CCColor(min_cluster_sizes=[20, 5]), 68 | gutenberg_args.title, 69 | gutenberg_args.start_anchor, 70 | gutenberg_args.end_anchor, 71 | ): 72 | case Failure(err): 73 | raise ValueError(err) 74 | case Success(out_file): 75 | return redirect(out_file.as_posix()) 76 | case gutenberg_argss: 77 | return redirect( 78 | plot_multiple_gutenberg( 79 | default_embedding_model, 80 | TextUnit("paragraph", para_newline_type="double"), 81 | CCColor(min_cluster_sizes=[20, 5]), 82 | gutenberg_argss, 83 | ).as_posix() 84 | ) 85 | case "GET": 86 | return render_template( 87 | "main.html", 88 | title="Project Gutenberg books", 89 | label="Exact title of Project Gutenberg books (one per line (or just one)." 90 | "
Include a fragment after '|' to start with that fragment" 91 | " and another fragment after another '|' to mark the end.", 92 | placeholder=( 93 | "Frankenstein; Or, The Modern Prometheus|" 94 | "You will rejoice to hear that no disaster|lost in darkness and distance." 95 | ), 96 | action=url_for("gutenberg"), 97 | ) 98 | case method: 99 | raise ValueError(f"Unexpected method: {method}") 100 | 101 | 102 | default_embedding_model = instructor_base 103 | default_summary_model: SummaryModelName = "google/flan-t5-base" 104 | 105 | 106 | @app.route("/wiki", methods=["GET", "POST"]) 107 | def wiki(): 108 | match request.method: 109 | case "POST": 110 | match request.form["input_field"].splitlines(): 111 | case [title]: 112 | return redirect( 113 | plot_single_wiki( 114 | default_embedding_model, 115 | TextUnit("paragraph", para_newline_type="single"), 116 | CCColor( 117 | min_cluster_sizes=[8, 3], 118 | ), 119 | title, 120 | ).as_posix() 121 | ) 122 | case titles: 123 | return redirect( 124 | plot_multiple_wiki( 125 | default_embedding_model, 126 | TextUnit("paragraph", para_newline_type="single"), 127 | CCColor(min_cluster_sizes=[8, 3]), 128 | titles, 129 | ).as_posix() 130 | ) 131 | case "GET": 132 | return render_template( 133 | "main.html", 134 | title="Wikipedia articles", 135 | label="Exact title of Wikipedia articles (one per line (or just one))", 136 | placeholder="New York City", 137 | action=url_for("wiki"), 138 | ) 139 | case method: 140 | raise ValueError(f"Unexpected method: {method}") 141 | 142 | 143 | if __name__ == "__main__": 144 | app.run() 145 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "hammock" 3 | version = "0.1.0" 4 | description = "Visualize sentence-by-sentence embeddings for a variety of texts" 5 | authors = ["Cole Haus"] 6 | license = "AGPL-3.0-or-later" 7 | 8 | include = ["templates/*"] 9 | 10 | [tool.poetry.dependencies] 11 | python = "^3.11" 12 | # Required for `low_cpu_mem_usage=True` 13 | accelerate = "^0.20.3" 14 | flask = { extras = ["async"], version = "^2.2.3" } 15 | flask-compress = "^1.13" 16 | gunicorn = "^20.1.0" 17 | gutenbergpy = "^0.3.5" 18 | hdbscan = "^0.8.29" 19 | html2text = "^2020.1.16" 20 | instructorembedding = "^1.0.0" 21 | networkx = "^3.1" 22 | plotly = "^5.14.1" 23 | scikit-learn-extra = "^0.3.0" 24 | # https://github.com/UKPLab/sentence-transformers/issues/1590 25 | sentence-transformers = { git = "https://github.com/UKPLab/sentence-transformers.git", rev = "3e1929fddef16df94f8bc6e3b10598a98f46e62d" } 26 | # Even though we don't directly depend on torch, we specify it here so we can version lock it to 2.0.0 27 | # 2.0.1 fails to include a bunch of nvidia cuda stuff as dependencies 28 | torch = "2.0.0" 29 | umap-learn = "^0.5.3" 30 | wikipedia = "^1.4.0" 31 | 32 | [tool.poetry.group.dev.dependencies] 33 | black = "^23.3.0" 34 | pyright = "^1.1.316" 35 | ruff = "^0.0.263" 36 | 37 | [build-system] 38 | requires = ["poetry-core>=1.0.0"] 39 | build-backend = "poetry.core.masonry.api" 40 | 41 | [tool.poetry.scripts] 42 | hammock = 'hammock.gunicorn:main' 43 | 44 | [tool.black] 45 | line-length = 115 46 | include = '\.pyi?$' 47 | 48 | [tool.ruff] 49 | line-length = 115 50 | ignore = [ 51 | # We'd rather rely on pyright for these 52 | "F403", 53 | "F405", 54 | "F821" 55 | ] -------------------------------------------------------------------------------- /pyrightconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "typeCheckingMode": "strict", 3 | "useLibraryCodeForTypes": true, 4 | "venv": "7s3vm4swksym5jh7bizribq0gd6cs6pa-python3-3.11.4-env", 5 | "venvPath": "/nix/store", 6 | "stubPath": "stubs", 7 | "reportInvalidStubStatement": false, 8 | "reportInvalidTypeVarUse": false, 9 | "reportUnnecessaryIsInstance": false, 10 | "reportUnnecessaryTypeIgnoreComment": true, 11 | "reportUnknownLambdaType": false, 12 | "reportUnusedFunction": "warning", 13 | "ignore": ["tmp", "result"] 14 | } 15 | -------------------------------------------------------------------------------- /screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/colehaus/hammock-public/eec733e509eddc775304f570bb4cbdf408561b2c/screenshot.png -------------------------------------------------------------------------------- /stubs/InstructorEmbedding/__init__.pyi: -------------------------------------------------------------------------------- 1 | from typing import TypeVar 2 | 3 | from sentence_transformers import SentenceTransformer 4 | from transformers import T5TokenizerFast 5 | 6 | EmbeddingDim = TypeVar("EmbeddingDim") 7 | 8 | class INSTRUCTOR(SentenceTransformer[EmbeddingDim]): 9 | tokenizer: T5TokenizerFast 10 | -------------------------------------------------------------------------------- /stubs/arxiv/__init__.pyi: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | from enum import Enum 4 | from typing import Generator 5 | 6 | class SortCriterion(Enum): 7 | Relevance = "relevance" 8 | LastUpdatedDate = "lastUpdatedDate" 9 | SubmittedDate = "submittedDate" 10 | 11 | class Result: 12 | summary: str 13 | title: str 14 | pdf_url: str 15 | def download_pdf(self, filename: str) -> None: ... 16 | 17 | class Search: 18 | def __init__(self, query: str, max_results: int, sort_by: SortCriterion) -> None: ... 19 | def results(self) -> Generator[Result, None, None]: ... 20 | -------------------------------------------------------------------------------- /stubs/flask_compress/__init__.pyi: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | from flask import Flask 4 | 5 | class Compress: 6 | def __init__(self, app: Flask) -> None: ... 7 | -------------------------------------------------------------------------------- /stubs/gunicorn/app/wsgiapp.pyi: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | from typing import Any 4 | 5 | from gunicorn.config import Config 6 | 7 | class WSGIApplication: 8 | def init(self, parser: Any, opts: Any, args: list[Any]) -> None: ... 9 | def run(self) -> None: ... 10 | cfg: Config 11 | -------------------------------------------------------------------------------- /stubs/gunicorn/config.pyi: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | from typing import Any 4 | 5 | class Config: 6 | def set(self, key: str, value: Any) -> None: ... 7 | -------------------------------------------------------------------------------- /stubs/gutenbergpy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/colehaus/hammock-public/eec733e509eddc775304f570bb4cbdf408561b2c/stubs/gutenbergpy/__init__.py -------------------------------------------------------------------------------- /stubs/gutenbergpy/gutenbergcache.pyi: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | class GutenbergCache: 4 | @staticmethod 5 | def get_cache() -> GutenbergCache: ... 6 | def query(self, titles: list[str]) -> list[int]: ... 7 | -------------------------------------------------------------------------------- /stubs/gutenbergpy/textget.pyi: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | def get_text_by_id(index: int) -> bytes: ... 4 | def strip_headers(text: bytes) -> bytes: ... 5 | -------------------------------------------------------------------------------- /stubs/hdbscan/__init__.pyi: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | from typing import Optional, TypeVar 4 | 5 | import numpy as np 6 | from numpy import ndarray 7 | 8 | NSamples = TypeVar("NSamples") 9 | NFeatures = TypeVar("NFeatures") 10 | 11 | class HDBSCAN: 12 | def __init__(self, min_cluster_size: int = 5, min_samples: Optional[int] = None) -> None: ... 13 | def fit_predict(self, X: ndarray[float, NSamples, NFeatures]) -> ndarray[np.int64, NSamples]: ... 14 | -------------------------------------------------------------------------------- /stubs/networkx/__init__.pyi: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | from typing import Generic, Iterable, Iterator, TypeVar, TypedDict 4 | 5 | N = TypeVar("N") 6 | 7 | WeightDict = TypedDict("WeightDict", {"weight": float}) 8 | 9 | class DegreesView(Generic[N]): 10 | def __getitem__(self, index: int) -> N: ... 11 | def __iter__(self) -> Iterator[tuple[N, int]]: ... 12 | def __len__(self) -> int: ... 13 | 14 | class Graph(Generic[N]): 15 | def add_nodes_from(self, nodes_for_adding: Iterable[N]) -> None: ... 16 | def add_edges_from(self, edges_for_adding: Iterable[tuple[N, N, WeightDict]]) -> None: ... 17 | def degree(self) -> DegreesView[N]: ... 18 | 19 | def equitable_color(G: Graph[N], num_colors: int) -> dict[N, int]: ... 20 | -------------------------------------------------------------------------------- /stubs/nltk/__init__.pyi: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | def download(info_or_id: str, quiet: bool = False) -> None: ... 4 | -------------------------------------------------------------------------------- /stubs/nltk/tokenize.pyi: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | def sent_tokenize(text: str) -> list[str]: ... 4 | def word_tokenize(text: str) -> list[str]: ... 5 | -------------------------------------------------------------------------------- /stubs/numba/core/errors.pyi: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | class NumbaWarning(Warning): ... 4 | class NumbaDeprecationWarning(NumbaWarning): ... 5 | -------------------------------------------------------------------------------- /stubs/numpy/__init__.pyi: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | from typing import Any, Generic, Iterator, Literal, Sequence, TypeVar, TypeVarTuple, overload 4 | 5 | from numpy import linalg as linalg 6 | 7 | DType = TypeVar("DType") 8 | DType2 = TypeVar("DType2") 9 | Shape = TypeVarTuple("Shape") 10 | Dim1 = TypeVar("Dim1") 11 | Dim2 = TypeVar("Dim2") 12 | Dim3 = TypeVar("Dim3") 13 | 14 | class ndarray(Generic[DType, *Shape]): 15 | def __and__(self: ndarray[bool, *Shape], other: ndarray[bool, *Shape]) -> ndarray[bool, *Shape]: ... 16 | def __invert__(self: ndarray[bool, *Shape]) -> ndarray[bool, *Shape]: ... 17 | def __len__(self) -> int: ... 18 | def __eq__(self, other: DType) -> ndarray[bool, *Shape]: ... 19 | def __ior__(self: ndarray[bool, *Shape], other: ndarray[bool, *Shape]) -> ndarray[bool, *Shape]: ... 20 | @overload 21 | def __ge__(self, other: DType) -> ndarray[bool, *Shape]: ... 22 | @overload 23 | def __ge__(self: ndarray[int, *Shape], other: float) -> ndarray[bool, *Shape]: ... 24 | def __add__(self, other: ndarray[DType, *Shape]) -> ndarray[DType, *Shape]: ... 25 | def __sub__(self, other: ndarray[DType, *Shape]) -> ndarray[DType, *Shape]: ... 26 | def __rsub__(self, other: DType) -> ndarray[DType, *Shape]: ... 27 | @overload 28 | def __mul__(self, other: DType) -> ndarray[DType, *Shape]: ... 29 | @overload 30 | def __mul__(self, other: ndarray[DType, *Shape]) -> ndarray[DType, *Shape]: ... 31 | def astype(self, dtype: type[DType2]) -> ndarray[DType2, *Shape]: ... 32 | @overload 33 | def __getitem__(self: ndarray[DType, Dim1], key: int) -> DType: ... 34 | @overload 35 | def __getitem__( 36 | self: ndarray[DType, Dim1, Dim2], key: tuple[ndarray[int, Dim3], slice] 37 | ) -> ndarray[DType, Any, Dim2]: ... 38 | @overload 39 | def __getitem__(self: ndarray[DType, Dim1], key: tuple[slice, None]) -> ndarray[DType, Dim1, Any]: ... 40 | @overload 41 | def __getitem__( 42 | self: ndarray[DType, Dim1, Dim2], key: tuple[ndarray[bool, Dim1], slice] 43 | ) -> ndarray[DType, Any, Dim2]: ... 44 | @overload 45 | def __getitem__(self: ndarray[DType, Dim1, Dim2], key: tuple[slice, int]) -> ndarray[DType, Dim1]: ... 46 | @overload 47 | def __getitem__(self: ndarray[DType, Dim1, Dim2], key: tuple[int, int]) -> DType: ... 48 | @overload 49 | def __getitem__(self: ndarray[DType, Dim1], key: slice) -> ndarray[DType, Dim1]: ... 50 | def __setitem__(self, key: slice, value: ndarray[DType, *Shape]) -> None: ... 51 | def __iter__(self) -> Iterator[DType]: ... 52 | @property 53 | def shape(self: ndarray[DType, Dim1, Dim2]) -> tuple[int, int]: ... 54 | def tolist(self: ndarray[DType, Dim1]) -> list[DType]: ... 55 | def reshape( 56 | self: ndarray[DType, Dim1], shape: tuple[Literal[-1], Literal[1]] 57 | ) -> ndarray[DType, Dim1, Literal[1]]: ... 58 | size: int 59 | 60 | class integer: 61 | def __int__(self) -> int: ... 62 | 63 | class signedinteger(integer): ... 64 | 65 | class int64(signedinteger): 66 | def __init__(self, value: int | str) -> None: ... 67 | 68 | Scalar = TypeVar("Scalar", int, float) 69 | 70 | @overload 71 | def percentile(a: ndarray[float, *Shape] | ndarray[int, *Shape], q: float) -> float: ... 72 | @overload 73 | def percentile( 74 | a: ndarray[float | int, *Shape] | ndarray[int, *Shape], q: ndarray[float, Dim1] 75 | ) -> ndarray[float, Dim1]: ... 76 | @overload 77 | def array(object: Sequence[ndarray[DType, *Shape]]) -> ndarray[DType, Any, *Shape]: ... 78 | @overload 79 | def array(object: Sequence[Scalar]) -> ndarray[Scalar, Any]: ... 80 | @overload 81 | def array(object: Sequence[tuple[float, ...]]) -> ndarray[float, Any, Any]: ... 82 | def clip(a: ndarray[float, *Shape], a_min: float, a_max: float) -> ndarray[float, *Shape]: ... 83 | def floor(x: ndarray[float, *Shape]) -> ndarray[float, *Shape]: ... 84 | def ceil(x: ndarray[float, *Shape]) -> ndarray[float, *Shape]: ... 85 | @overload 86 | def mean(a: ndarray[float, Dim1, *Shape], axis: Literal[0]) -> ndarray[float, *Shape]: ... 87 | @overload 88 | def mean(a: ndarray[float, Dim1, Dim2, *Shape], axis: Literal[1]) -> ndarray[float, Dim1, *Shape]: ... 89 | @overload 90 | def mean(a: ndarray[bool, Dim1, Dim2, *Shape], axis: Literal[1]) -> ndarray[float, Dim1, *Shape]: ... 91 | @overload 92 | def mean(a: ndarray[int, Dim1, Dim2, *Shape], axis: Literal[1]) -> ndarray[float, Dim1, *Shape]: ... 93 | def arange(stop: int, dtype: type[DType]) -> ndarray[DType, Any]: ... 94 | def insert( 95 | arr: ndarray[DType, Dim1, Dim2], obj: int, values: ndarray[DType, Dim1], axis: Literal[1] 96 | ) -> ndarray[DType, Dim1, Dim2]: ... 97 | def vstack(tup: Sequence[ndarray[DType, *Shape]]) -> ndarray[DType, *Shape]: ... 98 | @overload 99 | def cumsum(a: ndarray[DType, Dim1]) -> ndarray[DType, Dim1]: ... 100 | @overload 101 | def cumsum(a: Sequence[DType]) -> ndarray[DType, Any]: ... 102 | def vsplit( 103 | ary: ndarray[DType, *Shape], indices_or_sections: ndarray[int, Any] 104 | ) -> list[ndarray[DType, *Shape]]: ... 105 | def linspace(start: float, stop: float, num: int) -> ndarray[float, Any]: ... 106 | def zeros(shape: int, dtype: type[DType]) -> ndarray[DType, Any]: ... 107 | def nonzero(a: ndarray[bool, Dim1]) -> tuple[ndarray[int, Dim1]]: ... 108 | def argmin(a: Sequence[DType]) -> int: ... 109 | def where(condition: ndarray[bool, Dim1], x: DType, y: DType) -> ndarray[DType, Dim1]: ... 110 | 111 | newaxis: None 112 | -------------------------------------------------------------------------------- /stubs/numpy/lib/stride_tricks.pyi: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | from typing import TypeVar 4 | from numpy import ndarray 5 | 6 | X = TypeVar("X") 7 | Dim1 = TypeVar("Dim1") 8 | NWindowShape = TypeVar("NWindowShape", bound=int) 9 | 10 | # The output is actually a little shorter than `Dim1`. How much shorter depends on the size of the window. 11 | def sliding_window_view(x: ndarray[X, Dim1], window_shape: NWindowShape) -> ndarray[X, Dim1, NWindowShape]: ... 12 | -------------------------------------------------------------------------------- /stubs/numpy/linalg.pyi: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | from typing import Any 4 | 5 | from numpy import ndarray 6 | 7 | def norm(x: ndarray[float, Any]) -> float: ... 8 | -------------------------------------------------------------------------------- /stubs/plotly/__init__.pyi: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | -------------------------------------------------------------------------------- /stubs/plotly/colors/__init__.pyi: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | from typing import Literal, NewType, TypeAlias, overload 4 | 5 | from .sequential import * 6 | from .qualitative import * 7 | 8 | PlotlyScales: TypeAlias = Literal[ 9 | "Edge", "HSV", "Icefire", "Phase", "Rainbow", "Turbo", "Viridis", "haline", "mrybm", "thermal" 10 | ] 11 | 12 | HexStr = NewType("HexStr", str) 13 | # Values in string should be out 255 14 | RGBStr = NewType("RGBStr", str) 15 | # Values range from 0 to 1 16 | Tuple1 = NewType("Tuple1", tuple[float, float, float]) 17 | # Values range from 0 to 255 18 | Tuple255 = NewType("Tuple255", tuple[int, int, int]) 19 | 20 | def unlabel_rgb(colors: RGBStr) -> Tuple255: ... 21 | def label_rgb(colors: Tuple255) -> RGBStr: ... 22 | @overload 23 | def convert_colors_to_same_type( 24 | colors: PlotlyScales, colortype: Literal["tuple"] 25 | ) -> tuple[list[Tuple1], None]: ... 26 | @overload 27 | def convert_colors_to_same_type(colors: PlotlyScales, colortype: Literal["rgb"]) -> tuple[list[RGBStr], None]: ... 28 | def convert_to_RGB_255(colors: Tuple1) -> Tuple255: ... 29 | def hex_to_rgb(hex: HexStr) -> Tuple255: ... 30 | @overload 31 | def sample_colorscale( 32 | colorscale: PlotlyScales, samplepoints: int, colortype: Literal["tuple"] 33 | ) -> list[Tuple1]: ... 34 | @overload 35 | def sample_colorscale( 36 | colorscale: PlotlyScales, samplepoints: int, colortype: Literal["rgb"] = "rgb" 37 | ) -> list[RGBStr]: ... 38 | def find_intermediate_color(lowcolor: Tuple1, highcolor: Tuple1, intermed: float) -> Tuple1: ... 39 | -------------------------------------------------------------------------------- /stubs/plotly/colors/qualitative.pyi: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | # Hex 4 | Plotly: list[str] 5 | # RGB 6 | Antique: list[str] 7 | # RGB 8 | Pastel: list[str] 9 | -------------------------------------------------------------------------------- /stubs/plotly/colors/sequential.pyi: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | -------------------------------------------------------------------------------- /stubs/plotly/graph_objects/__init__.pyi: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | from io import StringIO 4 | from pathlib import Path 5 | from typing import Any, Literal, Mapping, Optional, Sequence, TypeVar, overload 6 | 7 | from numpy import ndarray 8 | from plotly.graph_objects.layout.scene import Annotation 9 | 10 | from .layout import * 11 | 12 | class Trace: ... 13 | 14 | NSamples = TypeVar("NSamples") 15 | 16 | class Scatter(Trace): 17 | def __init__( 18 | self, 19 | x: Sequence[float] | ndarray[float, NSamples], 20 | y: Sequence[float] | ndarray[float, NSamples], 21 | mode: Literal["markers", "lines"], 22 | text: Optional[Sequence[str]] = None, 23 | ) -> None: ... 24 | 25 | class Scatter3d(Trace): 26 | def __init__( 27 | self, 28 | x: ndarray[float, NSamples] | Sequence[float], 29 | y: ndarray[ 30 | float, 31 | NSamples, 32 | ] 33 | | Sequence[float], 34 | z: ndarray[float, NSamples] | Sequence[float], 35 | mode: Literal["markers+lines", "markers", "markers+text"], 36 | showlegend: bool, 37 | marker: dict[str, Any], 38 | textfont: Optional[dict[str, Any]] = None, 39 | visible: bool = True, 40 | name: Optional[str] = None, 41 | text: Optional[Sequence[str]] = None, 42 | hovertemplate: Optional[str] = None, 43 | hoverinfo: Optional[Literal["text", "none"]] = None, 44 | line: Optional[Mapping[str, Any]] = None, 45 | customdata: Optional[Sequence[Any]] = None, 46 | ) -> None: ... 47 | 48 | class Figure: 49 | def __init__(self) -> None: ... 50 | def add_trace(self, trace: Trace, row: Optional[int] = None, col: Optional[int] = None) -> Figure: ... 51 | @overload 52 | def add_annotation(self, arg: Annotation) -> Figure: ... 53 | @overload 54 | def add_annotation( 55 | self, 56 | text: str, 57 | xref: Literal["paper"], 58 | yref: Literal["paper"], 59 | x: float, 60 | y: float, 61 | showarrow: Literal[False], 62 | ) -> Figure: ... 63 | def update_layout( 64 | self, 65 | scene: Optional[Mapping[str, Any]] = None, 66 | title: Optional[str] = None, 67 | hovermode: Optional[Literal["closest"]] = None, 68 | hoverlabel: Optional[Mapping[str, Any]] = None, 69 | paper_bgcolor: Optional[str] = None, 70 | plot_bgcolor: Optional[str] = None, 71 | shapes: Optional[Sequence[Shape]] = None, 72 | ) -> Figure: ... 73 | def write_html( 74 | self, 75 | file: str | Path | StringIO, 76 | include_plotlyjs: bool | Literal["cdn"] = True, 77 | div_id: Optional[str] = None, 78 | post_script: Optional[str] = None, 79 | ) -> None: ... 80 | -------------------------------------------------------------------------------- /stubs/plotly/graph_objects/layout/__init__.pyi: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | from typing import Any, Literal, Mapping, Optional 4 | 5 | class Shape: 6 | def __init__( 7 | self, 8 | type: Literal["line"], 9 | x0: float, 10 | y0: float, 11 | x1: float, 12 | y1: float, 13 | xref: str, 14 | yref: str, 15 | line: Optional[Mapping[str, Any]] = None, 16 | ) -> None: ... 17 | -------------------------------------------------------------------------------- /stubs/plotly/graph_objects/layout/scene.pyi: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | from typing import Optional 4 | 5 | class Annotation: 6 | def __init__( 7 | self, 8 | x: float, 9 | y: float, 10 | z: float, 11 | text: str, 12 | showarrow: bool = True, 13 | bgcolor: Optional[str] = None, 14 | borderpad: Optional[float] = 1, 15 | visible: bool = True, 16 | ) -> None: ... 17 | -------------------------------------------------------------------------------- /stubs/plotly/subplots.pyi: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | from plotly.graph_objects import Figure 4 | 5 | def make_subplots(rows: int, cols: int, subplot_titles: list[str]) -> Figure: ... 6 | -------------------------------------------------------------------------------- /stubs/scipy/__init__.pyi: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | -------------------------------------------------------------------------------- /stubs/scipy/spatial/distance.pyi: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | from typing import Literal, TypeVar 4 | 5 | from numpy import ndarray 6 | 7 | NSamples = TypeVar("NSamples") 8 | NFeatures = TypeVar("NFeatures") 9 | 10 | # The return type isn't quite rigth because `pdist` actually returns a condensed array but it's good enough for now 11 | def pdist(X: ndarray[float, NSamples, NFeatures], metric: Literal["euclidean"]) -> ndarray[float, NSamples]: ... 12 | def squareform(X: ndarray[float, NSamples]) -> ndarray[float, NSamples, NSamples]: ... 13 | -------------------------------------------------------------------------------- /stubs/sentence_transformers/__init__.pyi: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | from typing import Any, Generic, Optional, Sequence, TypeVar 4 | 5 | from numpy import ndarray 6 | import torch 7 | 8 | EmbeddingDim = TypeVar("EmbeddingDim") 9 | 10 | class SentenceTransformer(Generic[EmbeddingDim]): 11 | def __init__( 12 | self, model_name_or_path: Optional[str] = None, modules: Optional[Sequence[torch.nn.Module]] = None 13 | ) -> None: ... 14 | def encode( 15 | self, 16 | sentences: list[str] | list[list[str]], 17 | batch_size: int = 32, 18 | show_progress_bar: bool = False, 19 | ) -> ndarray[float, Any, EmbeddingDim]: ... 20 | -------------------------------------------------------------------------------- /stubs/sentence_transformers/models.pyi: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | from torch.nn import Module 4 | 5 | class Transformer(Module): 6 | def __init__(self, model_name_or_path: str) -> None: ... 7 | -------------------------------------------------------------------------------- /stubs/sklearn/__init__.pyi: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | -------------------------------------------------------------------------------- /stubs/sklearn/_base.pyi: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | from typing import Generic, TypeVar, TypeVarTuple 4 | 5 | from numpy import ndarray 6 | 7 | Shape = TypeVarTuple("Shape") 8 | 9 | NSamples = TypeVar("NSamples") 10 | NFeatures = TypeVar("NFeatures") 11 | 12 | class LinearModel(Generic[NFeatures]): 13 | def fit( 14 | self, X: ndarray[float, NSamples, NFeatures], y: ndarray[float, NSamples] 15 | ) -> LinearModel[NFeatures]: ... 16 | def predict(self, X: ndarray[float, NSamples, NFeatures]) -> ndarray[float, NSamples]: ... 17 | -------------------------------------------------------------------------------- /stubs/sklearn/linear_model.pyi: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | from typing import TypeVar 4 | 5 | from ._base import LinearModel 6 | 7 | NFeatures = TypeVar("NFeatures") 8 | 9 | class LinearRegression(LinearModel[NFeatures]): 10 | def __init__(self, fit_intercept: bool = True) -> None: ... 11 | -------------------------------------------------------------------------------- /stubs/sklearn/metrics.pyi: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | from typing import TypeVarTuple 4 | 5 | from numpy import ndarray 6 | 7 | Shape = TypeVarTuple("Shape") 8 | 9 | def mean_squared_error(y_true: ndarray[float, *Shape], y_pred: ndarray[float, *Shape]) -> float: ... 10 | -------------------------------------------------------------------------------- /stubs/sklearn/preprocessing.pyi: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | from typing import TypeVarTuple 4 | from numpy import ndarray 5 | 6 | Shape = TypeVarTuple("Shape") 7 | 8 | def minmax_scale( 9 | X: ndarray[float, *Shape], axis: int = 0, feature_range: tuple[float, float] = (0, 1) 10 | ) -> ndarray[float, *Shape]: ... 11 | -------------------------------------------------------------------------------- /stubs/sklearn_extra/cluster.pyi: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | from typing import Generic, Literal, TypeVar 4 | 5 | from numpy import ndarray 6 | 7 | NClusters = TypeVar("NClusters") 8 | NSamples = TypeVar("NSamples") 9 | NFeatures = TypeVar("NFeatures") 10 | 11 | class KMedoids(Generic[NClusters, NFeatures]): 12 | def __init__(self, n_clusters: NClusters, method: Literal["pam"], init: Literal["k-medoids++"]) -> None: ... 13 | def fit(self, X: ndarray[float, NSamples, NFeatures]) -> KMedoids[NClusters, NFeatures]: ... 14 | medoid_indices_: ndarray[int, NClusters] 15 | cluster_centers_: ndarray[float, NClusters, NFeatures] 16 | -------------------------------------------------------------------------------- /stubs/transformers/__init__.pyi: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Literal, Optional 6 | 7 | from torch import Tensor 8 | 9 | class T5TokenizerFast: 10 | def tokenize(self, text: str) -> list[str]: ... 11 | 12 | class BatchEncoding: 13 | def __getitem__(self, key: str) -> Tensor: ... 14 | 15 | class AutoTokenizer: 16 | @staticmethod 17 | def from_pretrained(pretrained_model_name_or_path: str, local_files_only: bool = False) -> AutoTokenizer: ... 18 | def __call__( 19 | self, 20 | text: str | list[str], 21 | max_length: Optional[int] = None, 22 | truncation: Optional[bool] = None, 23 | return_tensors: Optional[Literal["pt"]] = None, 24 | padding: Optional[Literal["max_length", True]] = None, 25 | ) -> BatchEncoding: ... 26 | def decode(self, token_ids: Tensor, skip_special_tokens: bool = False) -> str: ... 27 | def batch_decode(self, token_ids: Tensor, skip_special_tokens: bool = False) -> list[str]: ... 28 | 29 | class PretrainedConfig: 30 | hidden_size: int 31 | 32 | class AutoModelForSeq2SeqLM: 33 | @staticmethod 34 | def from_pretrained( 35 | pretrained_model_name_or_path: str, 36 | low_cpu_mem_usage: bool = False, 37 | local_files_only: bool = False, 38 | ) -> AutoModelForSeq2SeqLM: ... 39 | def generate( 40 | self, 41 | input_ids: Tensor, 42 | min_length: int = 0, 43 | max_length: int = 20, 44 | length_penalty: float = 1.0, 45 | do_sample: bool = False, 46 | num_beams: int = 1, 47 | num_return_sequences: int = 1, 48 | repetition_penalty: float = 1.0, 49 | ) -> Tensor: ... 50 | config: PretrainedConfig 51 | -------------------------------------------------------------------------------- /stubs/umap/__init__.pyi: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | from typing import Generic, TypeVar 4 | 5 | from numpy import ndarray 6 | 7 | NSamples = TypeVar("NSamples") 8 | NFeatures = TypeVar("NFeatures") 9 | NComponents = TypeVar("NComponents", bound=int) 10 | 11 | class UMAP(Generic[NComponents]): 12 | def __init__( 13 | self, 14 | n_neighbors: int = 15, 15 | n_components: NComponents = 2, 16 | min_dist: float = 0.1, 17 | metric: str = "euclidean", 18 | random_state: int = 0, 19 | verbose: bool = False, 20 | ) -> None: ... 21 | def fit_transform(self, X: ndarray[float, NSamples, NSamples]) -> ndarray[float, NSamples, NComponents]: ... 22 | -------------------------------------------------------------------------------- /stubs/wikipedia/__init__.pyi: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | 3 | class WikipediaPage: 4 | content: str 5 | 6 | def page(title: str, auto_suggest: bool = True) -> WikipediaPage: ... 7 | -------------------------------------------------------------------------------- /templates/books.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Visualize text with sentence embeddings 6 | 7 | 8 | {% for title, path in books.items() %} 9 | {{ title }}
10 | {% endfor %} 11 | 12 | 13 | -------------------------------------------------------------------------------- /templates/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Visualize text with sentence embeddings 6 | 7 | 8 | View prerendered books
9 | Visualize Wikipedia articles
10 | Visualize Project Gutenberg books
11 | Visualize freeform sentences 12 | 13 | 14 | -------------------------------------------------------------------------------- /templates/main.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Visualize {{title}} 6 | 7 | 8 |
9 |
10 |
11 | 12 |
13 | 14 | 15 | -------------------------------------------------------------------------------- /templates/plotly.js: -------------------------------------------------------------------------------- 1 | // We use JSDoc syntax here so we can use TypeScript 2 | // without having to set up a whole build system for this one file 3 | 4 | // Replaced with a generated ID by plotly when including this file from python 5 | const plotId = '{plot_id}'; 6 | 7 | // Just to tell TS about it 8 | // @ts-ignore 9 | const Plotly = /** @type {any} */ (window.Plotly); 10 | 11 | /** @typedef {{ 12 | * x: number, 13 | * y: number, 14 | * z: number, 15 | * text: string, 16 | * focalAnnotation: boolean, 17 | * showarrow: boolean, 18 | * bgcolor: string, 19 | * borderpad: number, 20 | * font: {color: string}, 21 | * xanchor: "left" | "center" | "right", 22 | * yanchor: "bottom" | "middle" | "top", 23 | * }} Annotation 24 | */ 25 | 26 | /** 27 | * @typedef {{ 28 | * scene: { 29 | * camera: {eye: Point3D}, 30 | * annotations: Array 31 | * } 32 | * }} Layout 33 | */ 34 | 35 | /** 36 | * @typedef {{ 37 | * x: Array, 38 | * y: Array, 39 | * z: Array, 40 | * marker: { 41 | * color: Array 42 | * }, 43 | * hovertemplate: string, 44 | * }} CoreTrace 45 | */ 46 | 47 | /** 48 | * Trace data containing individual text markers (i.e. representing paragraphs or sentences) 49 | * @typedef {{ 50 | * mode: "markers", 51 | * type: "scatter3d", 52 | * customdata: Array<{ 53 | * colors: Array 54 | * }> 55 | * } & CoreTrace} MarkerTrace, 56 | 57 | /** @typedef { MarkerTrace & {text: Array}} TextTrace */ 58 | 59 | /** 60 | * Trace data containing cluster labels 61 | * @typedef {{ 62 | * mode: "markers+text", 63 | * type: "scatter3d", 64 | * visible: boolean, 65 | * name: string, 66 | * } & CoreTrace} ClusterTrace, 67 | */ 68 | 69 | // We cheat on the type declarations here because we initialize these ASAP 70 | /** @type {Array} */ 71 | let textTraces = /** @type {any} */ (null); 72 | /** @type {Array} */ 73 | let clusterTraces = /** @type {any} */ (null); 74 | 75 | // Initialize camera 76 | Plotly.update( 77 | plotId, 78 | {}, 79 | { 80 | 'scene.camera.eye': { 81 | x: -1.25, 82 | y: 1.25, 83 | z: 1.25, 84 | }, 85 | } 86 | ); 87 | 88 | /** 89 | * Rescale to [0, 1] and center around 0 90 | * @type {(axis: Array, index: number) => number} 91 | */ 92 | function minmaxScale(axis, index) { 93 | const max = Math.max(...axis); 94 | const min = Math.min(...axis); 95 | const range = max - min; 96 | const mid = (max + min) / 2; 97 | return ((axis[index] - mid) / range) * 2; 98 | } 99 | 100 | /** @type {(backgroundColor: string) => boolean} */ 101 | function isLightColor(backgroundColor) { 102 | // Parse the background color string (assuming it's in the format "rgba(r, g, b, a)") 103 | const match = backgroundColor.match(/rgba\(\s*([\d.]+)\s*,\s*([\d.]+)\s*,\s*([\d.]+)\s*,\s*([\d.]+)\s*\)/i); 104 | if (match) { 105 | const [r, g, b, a] = Array(...match) 106 | .slice(1) 107 | .map(parseFloat); 108 | const luminance = (0.299 * r + 0.587 * g + 0.114 * b) * a; 109 | return luminance > 128; 110 | } else { 111 | return true; 112 | } 113 | } 114 | 115 | /** 116 | * @type {(x: number, power: number) => number} 117 | * We expect both `x` and the return value to be in [0, 1] 118 | */ 119 | function easeInOutPower(x, power) { 120 | return x < 0.5 ? 0.5 * Math.pow(2 * x, power) : 1 - 0.5 * Math.pow(-2 * x + 2, power); 121 | } 122 | 123 | /** @typedef {{x: number, y: number, z: number}} Point3D */ 124 | 125 | /** 126 | * Ease the camera in and out from `currentEye` to `targetEye` over `duration` 127 | * @type { 128 | (timestamp: number, startTime: number | null, duration: number, currentEye: Point3D, targetEye: Point3D) => void 129 | * } 130 | */ 131 | function animateCamera(timestamp, startTime, duration, currentEye, targetEye) { 132 | if (!startTime) startTime = timestamp; 133 | 134 | const elapsedTime = timestamp - startTime; 135 | const frac = easeInOutPower(elapsedTime / duration, 4); 136 | 137 | const view_update = { 138 | 'scene.camera.eye': Object.fromEntries( 139 | /** @type {const} */ (['x', 'y', 'z']).map((axis) => [ 140 | axis, 141 | currentEye[axis] + frac * (targetEye[axis] - currentEye[axis]), 142 | ]) 143 | ), 144 | 'scene.camera.center': Object.fromEntries(['x', 'y', 'z'].map((axis) => [axis, 0])), 145 | }; 146 | Plotly.update(plotId, {}, view_update); 147 | 148 | if (elapsedTime < duration) { 149 | requestAnimationFrame((newTimestamp) => 150 | animateCamera(newTimestamp, startTime, duration, currentEye, targetEye) 151 | ); 152 | } 153 | } 154 | 155 | /** @type {(trace: TextTrace, index: MarkerIndex) => Annotation} */ 156 | function focalAnnotationFromTraceInfo(trace, index) { 157 | // TS has trouble with `fromEntries` 158 | const fromTrace = 159 | /** @type {Point3D & {text: string}} */ 160 | (Object.fromEntries(/** @type {const} */ (['x', 'y', 'z', 'text']).map((prop) => [prop, trace[prop][index]]))); 161 | return { 162 | ...fromTrace, 163 | // We want some trace of what type of annotation this is so we 164 | // can filter out the old focal annotation before creating the new one 165 | focalAnnotation: true, 166 | showarrow: true, 167 | bgcolor: trace.marker.color[index], 168 | borderpad: 4, 169 | font: { 170 | color: isLightColor(trace.marker.color[index]) ? 'black' : 'white', 171 | }, 172 | xanchor: 'left', 173 | yanchor: 'bottom', 174 | }; 175 | } 176 | 177 | /** @param {ColorIndex} colorIndex */ 178 | function setTraceVisibility(colorIndex) { 179 | clusterTraces.forEach((trace, i) => { 180 | trace.visible = i === colorIndex; 181 | }); 182 | Plotly.react( 183 | plotId, 184 | textTraces.concat(/** @type {any} */ (clusterTraces)), 185 | /** @type any */ (document.getElementById(plotId)).layout 186 | ); 187 | } 188 | 189 | /** 190 | * Newtype distinguishing indices into alternative color schemes from arbitrary numbers 191 | * @typedef {number & {_brand: "ColorIndex"}} ColorIndex 192 | */ 193 | 194 | /** @param {ColorIndex} colorIndex */ 195 | function swapColors(colorIndex) { 196 | textTraces.forEach((trace, i) => { 197 | // Colors are strings in format like "rgba(255, 255, 255, 0.1)" 198 | const colors = trace.customdata.map((x) => x.colors[colorIndex]); 199 | Plotly.restyle( 200 | plotId, 201 | { 202 | 'marker.color': [colors], 203 | 'line.color': [colors], 204 | }, 205 | [i] 206 | ); 207 | }); 208 | } 209 | 210 | /** @type {[() => ColorIndex, () => ColorIndex]} */ 211 | const crementColorIndex = (() => { 212 | let currentColorIndex = 0; 213 | return [ 214 | () => { 215 | // Can't float this out because `textTraces` isn't defined by the time the constructor runs 216 | // We assume all text traces and all points in each text trace have the same number of colors 217 | // (should be true by construction) 218 | const numColorings = textTraces[0].customdata[0].colors.length; 219 | currentColorIndex = (currentColorIndex + 1) % numColorings; 220 | return /** @type {ColorIndex} */ (currentColorIndex); 221 | }, 222 | () => { 223 | const numColorings = textTraces[0].customdata[0].colors.length; 224 | currentColorIndex = (currentColorIndex - 1 + numColorings) % numColorings; 225 | return /** @type {ColorIndex} */ (currentColorIndex); 226 | }, 227 | ]; 228 | })(); 229 | const [incrementColorIndex, decrementColorIndex] = crementColorIndex; 230 | 231 | /** 232 | * @typedef {number & {_brand: "MarkerIndex"}} MarkerIndex 233 | * @typedef {number & {_brand: "TextTraceIndex"}} TextTraceIndex 234 | */ 235 | 236 | /** 237 | * Adds annotation for chosen marker and starts camera animation zooming to marker 238 | * @type {(textTraceIndex: TextTraceIndex, markerIndex: MarkerIndex) => void} 239 | */ 240 | const focusOnMarker = (() => { 241 | /** @type {number | null} */ 242 | let timeoutHandle = null; 243 | return (textTraceIndex, markerIndex) => { 244 | const trace = textTraces[textTraceIndex]; 245 | const currentLayout = /** @type Layout */ (/** @type any */ (document.getElementById(plotId)).layout); 246 | 247 | const targetEye = /** @type {Point3D} */ ( 248 | Object.fromEntries( 249 | /** @type {const} */ (['x', 'y', 'z']).map((axis) => [axis, minmaxScale(trace[axis], markerIndex)]) 250 | ) 251 | ); 252 | const currentEye = currentLayout.scene.camera.eye; 253 | const distance = Math.sqrt( 254 | /** @type {const} */ (['x', 'y', 'z']) 255 | .map((axis) => Math.pow(targetEye[axis] - currentEye[axis], 2)) 256 | .reduce((acc, el) => acc + el) 257 | ); 258 | // `cancelAnimationFrame` actually handles `null` gracefully, but TS doesn't know that 259 | cancelAnimationFrame(/** @type {number} */ (timeoutHandle)); 260 | const duration = Math.min(4000, Math.max(2000, Math.sqrt(distance) * 3000)); 261 | timeoutHandle = requestAnimationFrame((timestamp) => 262 | animateCamera(timestamp, null, duration, currentEye, targetEye) 263 | ); 264 | 265 | const updatedAnnotations = (currentLayout.scene.annotations ?? []) 266 | .filter((x) => !x.focalAnnotation) 267 | .concat([focalAnnotationFromTraceInfo(trace, markerIndex)]); 268 | Plotly.react(plotId, textTraces.concat(/** @type {any} */ (clusterTraces)), { 269 | ...currentLayout, 270 | scene: { 271 | ...currentLayout.scene, 272 | annotations: updatedAnnotations, 273 | }, 274 | }); 275 | }; 276 | })(); 277 | 278 | /** @type {[() => [TextTraceIndex, MarkerIndex], () => [TextTraceIndex, MarkerIndex]]} */ 279 | const crementMarkerIndex = (() => { 280 | let currentTraceIndex = 0; 281 | let currentMarkerIndex = 0; 282 | return [ 283 | () => { 284 | if (currentMarkerIndex === textTraces[currentTraceIndex].x.length - 1) { 285 | currentMarkerIndex = 0; 286 | if (currentTraceIndex === textTraces.length - 1) { 287 | currentTraceIndex = 0; 288 | } else { 289 | currentTraceIndex += 1; 290 | } 291 | } else { 292 | currentMarkerIndex += 1; 293 | } 294 | return /** @type {[TextTraceIndex, MarkerIndex]} */ ([currentTraceIndex, currentMarkerIndex]); 295 | }, 296 | () => { 297 | if (currentMarkerIndex === 0) { 298 | if (currentTraceIndex === 0) { 299 | currentTraceIndex = textTraces.length - 1; 300 | } else { 301 | currentTraceIndex -= 1; 302 | } 303 | currentMarkerIndex = textTraces[currentTraceIndex].x.length - 1; 304 | } else { 305 | currentMarkerIndex -= 1; 306 | } 307 | return /** @type {[TextTraceIndex, MarkerIndex]} */ ([currentTraceIndex, currentMarkerIndex]); 308 | }, 309 | ]; 310 | })(); 311 | const [incrementMarkerIndex, decrementMarkerIndex] = crementMarkerIndex; 312 | 313 | /** @type {(html: string) => DocumentFragment} */ 314 | function elementFromHtml(html) { 315 | const template = document.createElement('template'); 316 | template.innerHTML = html; 317 | return template.content; 318 | } 319 | 320 | /** @type {(array: Array, predicate: (el: T) => boolean) => Array} */ 321 | function findIndices(array, predicate) { 322 | return array.map((el, i) => (predicate(el) ? i : -1)).filter((x) => x !== -1); 323 | } 324 | 325 | function searchHandler(event) { 326 | event.preventDefault(); 327 | const text = /** @type HTMLInputElement */ (document.getElementById('marker-search-input')).value; 328 | const textTraceIndices = findIndices(textTraces, (trace) => trace.text.some((str) => str.includes(text))); 329 | if (textTraceIndices.length === 0) { 330 | alert(`Found 0 markers found matching "${text}"`); 331 | return; 332 | } 333 | const indexPairs = /** @type {Array<[TextTraceIndex, MarkerIndex]>} */ ( 334 | textTraceIndices.flatMap((textTraceIndex) => 335 | findIndices(textTraces[textTraceIndex].text, (str) => str.includes(text)).map((markerIndex) => [ 336 | textTraceIndex, 337 | markerIndex, 338 | ]) 339 | ) 340 | ); 341 | if (indexPairs.length === 1) { 342 | focusOnMarker(...indexPairs[0]); 343 | } else if (indexPairs.length == 0) { 344 | alert(`Found 0 markers matching "${text}"`); 345 | } else { 346 | const matchingMarkers = indexPairs.map( 347 | ([textTraceIndex, markerIndex]) => textTraces[textTraceIndex].text[markerIndex] 348 | ); 349 | alert(`Found ${indexPairs.length} markers matching "${text}":\n${matchingMarkers.join('\n')}`); 350 | } 351 | } 352 | 353 | function addSearchHandler() { 354 | const button = /** @type HTMLElement */ (document.getElementById('marker-search-submit')); 355 | button.addEventListener('click', searchHandler); 356 | const input = /** @type HTMLElement */ (document.getElementById('marker-search-input')); 357 | input.addEventListener('keydown', (event) => { 358 | if (event.key === 'Enter') { 359 | searchHandler(event); 360 | } 361 | }); 362 | } 363 | 364 | document.addEventListener('touchstart', function (event) { 365 | focusOnMarker(...incrementMarkerIndex()); 366 | }); 367 | document.addEventListener('keydown', function (event) { 368 | if (event.key === 'ArrowRight') { 369 | focusOnMarker(...incrementMarkerIndex()); 370 | } else if (event.key === 'ArrowLeft') { 371 | focusOnMarker(...decrementMarkerIndex()); 372 | } else if (event.key === 'ArrowDown') { 373 | const newIndex = decrementColorIndex(); 374 | swapColors(newIndex); 375 | setTraceVisibility(newIndex); 376 | } else if (event.key === 'ArrowUp') { 377 | const newIndex = incrementColorIndex(); 378 | swapColors(newIndex); 379 | setTraceVisibility(newIndex); 380 | } 381 | }); 382 | document.addEventListener('DOMContentLoaded', function () { 383 | const plot = /** @type any */ (document.getElementById(plotId)); 384 | // @ts-ignore 385 | textTraces = plot.data.filter((trace) => !('name' in trace)); 386 | // @ts-ignore 387 | clusterTraces = plot.data.filter((trace) => 'name' in trace && trace.name.startsWith('cluster_markers_')); 388 | console.log(plotId); 389 | // Note that we're sort of cheating here by typing the `textTraces` as `Array` and 390 | // not `Array` 391 | if ('text' in textTraces[0]) { 392 | const markerSearchHtml = 393 | ''; 394 | document.body.insertBefore(elementFromHtml(markerSearchHtml), document.body.firstChild); 395 | addSearchHandler(); 396 | } 397 | }); 398 | -------------------------------------------------------------------------------- /tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "strict": true, 4 | "noEmit": true, 5 | "checkJs": true, 6 | "lib": ["ES2022", "DOM"] 7 | }, 8 | "include": [ 9 | "templates/plotly.js" 10 | ] 11 | } 12 | --------------------------------------------------------------------------------