├── .gitignore ├── wise_topic ├── topic │ ├── __init__.py │ ├── topic_extractor.py │ ├── classifier.py │ └── greedy.py ├── classifier │ ├── __init__.py │ ├── prompts.py │ └── classifier.py ├── __init__.py ├── cluster │ ├── __init__.py │ ├── embedders.py │ ├── plot.py │ ├── sample.py │ ├── preprocessing.py │ ├── balanced_sampling.py │ └── clustering.py ├── parallel.py └── drilldown_tree.py ├── .github └── CODEOWNERS ├── README.md ├── requirements.txt ├── __init__.py ├── examples ├── classifier_scores.py ├── topic_extraction.py ├── Getting probability scores out of LLM classification.ipynb └── Topic extraction using LLMs only.ipynb ├── setup.py └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wise_topic/topic/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wise_topic/classifier/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @transferwise/data-scientists 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # wise-topic 2 | LLM-only topic extraction and classification 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | langchain 2 | langchain_core 3 | langchain_openai 4 | numpy 5 | scikit_learn 6 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from wise_topic.topic.greedy import greedy_topic_tree, tree_summary 2 | from wise_topic.classifier.classifier import ( 3 | llm_classifier_binary, 4 | llm_classifier_multiple, 5 | ) 6 | -------------------------------------------------------------------------------- /wise_topic/__init__.py: -------------------------------------------------------------------------------- 1 | from wise_topic.topic.greedy import greedy_topic_tree, tree_summary 2 | from wise_topic.classifier.classifier import ( 3 | llm_classifier_binary, 4 | llm_classifier_multiple, 5 | ) 6 | from wise_topic.drilldown_tree import display_drilldown 7 | -------------------------------------------------------------------------------- /wise_topic/cluster/__init__.py: -------------------------------------------------------------------------------- 1 | from .balanced_sampling import cluster_docs # noqa: F401 2 | from .plot import plot_clustering # noqa: F401 3 | from .sample import sample_docs # noqa: F401 4 | from .preprocessing import anonymize_string, preprocess_nps_data # noqa: F401 5 | -------------------------------------------------------------------------------- /wise_topic/cluster/embedders.py: -------------------------------------------------------------------------------- 1 | # All the imports are inside the functions to not have to install them all 2 | def tfidf(): 3 | from sklearn.decomposition import TruncatedSVD 4 | from sklearn.feature_extraction.text import TfidfVectorizer 5 | from sklearn.pipeline import make_pipeline 6 | 7 | pipe = make_pipeline(TfidfVectorizer(), TruncatedSVD(100)) 8 | return pipe 9 | -------------------------------------------------------------------------------- /wise_topic/cluster/plot.py: -------------------------------------------------------------------------------- 1 | from matplotlib import pyplot as plt 2 | import numpy as np 3 | 4 | 5 | def plot_clustering(X_red, labels, title=None): 6 | x_min, x_max = np.min(X_red, axis=0), np.max(X_red, axis=0) 7 | X_red = (X_red - x_min) / (x_max - x_min) 8 | 9 | plt.figure(figsize=(6, 4)) 10 | for label in np.unique(labels): 11 | plt.scatter( 12 | *X_red[labels == label].T, 13 | marker=f"${label}$", 14 | s=50, 15 | c=plt.cm.nipy_spectral(labels[labels == label] / 10), 16 | alpha=0.5, 17 | ) 18 | 19 | plt.xticks([]) 20 | plt.yticks([]) 21 | if title is not None: 22 | plt.title(title, size=17) 23 | plt.axis("off") 24 | plt.tight_layout(rect=[0, 0.03, 1, 0.95]) 25 | -------------------------------------------------------------------------------- /wise_topic/cluster/sample.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import numpy as np 4 | 5 | 6 | def sample_docs(documents: np.ndarray, labels: np.ndarray, docs_per_label: int = 10) -> Dict[int, List[str]]: 7 | out = {} 8 | for i in np.unique(labels): 9 | d = documents[labels == i] 10 | out[i] = np.random.choice(d, min(docs_per_label, len(d)), replace=False) 11 | return out 12 | 13 | 14 | def sample_docs_from_proba(documents: np.ndarray, p: np.ndarray, n: int = 10): 15 | out = {} 16 | # square the probabilities to accentuate high-probability topics 17 | re_p = (p * p) / ((p * p).sum(axis=1, keepdims=True)) 18 | for i in range(p.shape[1]): 19 | out[i] = np.random.choice(documents, size=n, replace=False, p=re_p[:, i] / re_p[:, i].sum()) 20 | return out 21 | -------------------------------------------------------------------------------- /examples/classifier_scores.py: -------------------------------------------------------------------------------- 1 | from langchain_openai import ChatOpenAI 2 | 3 | from wise_topic import llm_classifier_binary, llm_classifier_multiple 4 | 5 | 6 | llm = ChatOpenAI(model="gpt-4-turbo", temperature=0) 7 | question1 = "Consider a very friendly pet with a fluffy tail. You know it's a cat or a dog. Is it a cat?" 8 | question2 = "Consider a very friendly pet with a waggy tail. You know it's a cat or a dog. Is it a cat?" 9 | for question in [question1, question2]: 10 | out = llm_classifier_binary(llm, question) 11 | print(question) 12 | print(out) 13 | 14 | 15 | question3 = "Consider a very friendly pet with a long tail. You know it's a cat, a dog, or a dragon. Which is it?" 16 | out = llm_classifier_multiple(llm, question3, ["cat", "dog", "dragon"]) 17 | print(question3) 18 | print(out) 19 | 20 | print("done!") 21 | -------------------------------------------------------------------------------- /wise_topic/parallel.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Iterable 2 | import logging 3 | import concurrent.futures 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | def process_batch_parallel( 9 | function: Callable, 10 | batched_args: Iterable, 11 | max_workers: int, 12 | ) -> list: 13 | results = [] 14 | 15 | with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: 16 | future_to_batch = { 17 | executor.submit(function, *args): args for args in batched_args 18 | } 19 | 20 | for future in concurrent.futures.as_completed(future_to_batch): 21 | args = future_to_batch[future] 22 | arg_str = ",".join(map(str, args)) 23 | try: 24 | logger.info(f"Running a task in parallel {arg_str}") 25 | data = future.result() 26 | results.append((data, future_to_batch[future])) 27 | 28 | except Exception as ex: 29 | logger.error(f"Error in running task '{arg_str}': {ex}") 30 | 31 | return results 32 | -------------------------------------------------------------------------------- /wise_topic/classifier/prompts.py: -------------------------------------------------------------------------------- 1 | binary_prompt = """ 2 | {question} 3 | Return the digit 1 for a positive answer, and 0 for a negative answer. 4 | Return just the one character digit, nothing else. 5 | Take a deep breath and think carefully before you make your reply. 6 | """ 7 | 8 | 9 | def multi_choice_prompt(include_other: bool): 10 | out = ( 11 | """I am about to give you a numbered list of options. 12 | Then I will pass to you a message (possibly, but not necessarily, a question), 13 | after the word MESSAGE. 14 | Return an integer that is the number of the option that best fits that message, 15 | or if the message is a question, the number of the option that best answers the question. 16 | """ 17 | + ( 18 | """ 19 | If no option fits the message, return 0. 20 | """ 21 | if include_other 22 | else "" 23 | ) 24 | + """ 25 | Return only the number, without additional text. 26 | {categories} 27 | MESSAGE: 28 | {question} 29 | Take a deep breath and think carefully before you make your reply. 30 | BEST MATCH OPTION NUMBER:""" 31 | ) 32 | return out 33 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | with open('README.md') as f: 4 | long_description = f.read() 5 | 6 | setup( 7 | name="wise-topic", 8 | version="0.0.1", 9 | description="LLM-only topic extraction and classification", 10 | long_description=long_description, 11 | long_description_content_type='text/markdown', 12 | author="Wise", 13 | url='https://github.com/transferwise/wise-topic', 14 | classifiers=[ 15 | 'Programming Language :: Python :: 3 :: Only', 16 | 'Programming Language :: Python :: 3.7', 17 | 'Programming Language :: Python :: 3.8', 18 | 'Programming Language :: Python :: 3.9', 19 | 'Programming Language :: Python :: 3.10', 20 | 'Programming Language :: Python :: 3.11', 21 | 'Programming Language :: Python :: 3.12', 22 | ], 23 | install_requires=[ 24 | "langchain", 25 | "langchain_core", 26 | "langchain_openai", 27 | "numpy", 28 | "scikit_learn" 29 | ], 30 | extras_require={ 31 | "test": [ 32 | "flake8", 33 | "pytest", 34 | "pytest-cov" 35 | ], 36 | }, 37 | packages=find_packages( 38 | include=[ 39 | 'wise_topic', 40 | 'wise_topic.*' 41 | ], 42 | exclude=['tests*'], 43 | ), 44 | include_package_data=True, 45 | keywords='wise-topic', 46 | ) 47 | -------------------------------------------------------------------------------- /wise_topic/cluster/preprocessing.py: -------------------------------------------------------------------------------- 1 | def preprocess_nps_data(text: str) -> str: 2 | # remove the most common non-informative substrings 3 | text = ( 4 | str(text) 5 | .replace("Good morning", "") 6 | .replace("Good afternoon", "") 7 | .replace("Good evening", "") 8 | .replace("probable", "likely") 9 | .replace("Very likely", "") 10 | .replace("Hello", "") 11 | .replace("Wise", "") 12 | .replace("WISE", "") 13 | .replace("wise", "") # They're all about Wise anyway 14 | ) 15 | return text 16 | 17 | 18 | try: 19 | from presidio_analyzer import AnalyzerEngine 20 | from presidio_anonymizer import AnonymizerEngine 21 | 22 | try: 23 | import spacy 24 | 25 | spacy.load("en_core_web_lg") 26 | except: 27 | ImportError("run python -m spacy download 'en_core_web_lg' first") 28 | 29 | analyzer = AnalyzerEngine() 30 | anonymizer = AnonymizerEngine() 31 | 32 | def anonymize_string(text: str) -> str: 33 | # https://microsoft.github.io/presidio/supported_entities/ 34 | results = analyzer.analyze( 35 | text=text, 36 | entities=[ 37 | "PHONE_NUMBER", 38 | "CREDIT_CARD", 39 | "EMAIL_ADDRESS", 40 | "IBAN_CODE", 41 | "IP_ADDRESS", 42 | "PERSON", 43 | ], 44 | language="en", 45 | ) 46 | anonymized_text = anonymizer.anonymize(text=text, analyzer_results=results).text 47 | out = anonymized_text.split("\n")[0].replace("", "") 48 | return out 49 | 50 | except ImportError: 51 | # a soft fail: if we don't actually call the function, this will run fine 52 | def anonymize_string(text: str): 53 | raise ImportError("Please install presidio_analyzer and presidio_anonymizer first") 54 | -------------------------------------------------------------------------------- /examples/topic_extraction.py: -------------------------------------------------------------------------------- 1 | from pprint import pprint 2 | 3 | from langchain_openai import ChatOpenAI 4 | 5 | from wise_topic import greedy_topic_tree, tree_summary 6 | 7 | docs = [ 8 | "The summer sun blazed high in the sky, bringing warmth to the sandy beaches.", 9 | "During summer, the days are long and the nights are warm and inviting.", 10 | "Ice cream sales soar as people seek relief from the summer heat.", 11 | "Families often choose summer for vacations to take advantage of the sunny weather.", 12 | "Many festivals and outdoor concerts are scheduled in the summer months.", 13 | "Winter brings the joy of snowfall and the excitement of skiing.", 14 | "The cold winter nights are perfect for sipping hot chocolate by the fire.", 15 | "Winter storms can transform the landscape into a snowy wonderland.", 16 | "Heating bills tend to rise as winter's chill sets in.", 17 | "Many animals hibernate or migrate to cope with the harsh winter conditions.", 18 | "Fish swim in schools to protect themselves from predators.", 19 | "Salmon migrate upstream during spawning season, a remarkable journey.", 20 | "Tropical fish add vibrant color and life to coral reefs.", 21 | "Overfishing threatens many species of fish with extinction.", 22 | "Fish have a diverse range of habitats from deep oceans to shallow streams.", 23 | ] 24 | 25 | topic_llm = ChatOpenAI( 26 | model="gpt-4-turbo", 27 | temperature=0, 28 | model_kwargs={"response_format": {"type": "json_object"}}, 29 | ) 30 | # Do topic extraction on sampled data 31 | topic_tree = greedy_topic_tree( 32 | docs, 33 | initial_topics=["Winter", "Summer"], 34 | topic_llm=topic_llm, 35 | max_depth=1, 36 | num_topics_per_update=1, 37 | max_unclassified_messages=2, 38 | ) 39 | 40 | pprint(tree_summary(topic_tree)) 41 | pprint("**************") 42 | pprint(topic_tree) 43 | print("yay!") 44 | -------------------------------------------------------------------------------- /wise_topic/drilldown_tree.py: -------------------------------------------------------------------------------- 1 | import ipywidgets as widgets 2 | 3 | 4 | def create_collapsible_item(item, level=0): 5 | """ 6 | Create a collapsible widget to display a list or a dictionary. 7 | """ 8 | if isinstance(item, dict): 9 | return display_drilldown(item, level) 10 | elif isinstance(item, list): 11 | return create_collapsible_list(item, level) 12 | else: 13 | return widgets.Label(f"{item}") 14 | 15 | 16 | def display_drilldown(d, level=0): 17 | """ 18 | Recursively create a collapsible widget to display a nested dictionary. 19 | """ 20 | items = [] 21 | for key, value in d.items(): 22 | if isinstance(value, (dict, list)): 23 | # Create a collapsible widget for nested dictionaries or lists 24 | sub_items = create_collapsible_item(value, level + 1) 25 | accordion = widgets.Accordion(children=[sub_items]) 26 | accordion.set_title(0, key) 27 | items.append(accordion) 28 | else: 29 | # Display key-value pairs as labels 30 | items.append(widgets.Label(f"{key}: {value}")) 31 | 32 | return widgets.VBox(items) 33 | 34 | 35 | def create_collapsible_list(lst, level=0): 36 | """ 37 | Create a collapsible widget to display a list without "Item X" labels. 38 | """ 39 | items = [] 40 | for value in lst: 41 | if isinstance(value, (dict, list)): 42 | # Create a collapsible widget for nested dictionaries or lists within the list 43 | sub_items = create_collapsible_item(value, level + 1) 44 | accordion = widgets.Accordion(children=[sub_items]) 45 | accordion.set_title(0, f"List item") 46 | items.append(accordion) 47 | else: 48 | # Display list items directly without labels 49 | items.append(widgets.Label(f"{value}")) 50 | 51 | return widgets.VBox(items) 52 | 53 | 54 | # This is how you'd use this in a notebook: 55 | # # Example usage with a nested dictionary 56 | # nested_dict = { 57 | # "level1": { 58 | # "level2a": { 59 | # "level3a": {"value1": 1, "value2": ["blah1", "blah2"]}, 60 | # "level3b": {"value3": 3}, 61 | # }, 62 | # "level2b": {"level3c": {"value4": ["Foo", "bar"]}}, 63 | # } 64 | # } 65 | # 66 | # 67 | # 68 | # from IPython.display import display 69 | # # Create and display the collapsible dictionary widget 70 | # collapsible_widget = display_drilldown(nested_dict) 71 | # display(collapsible_widget) 72 | -------------------------------------------------------------------------------- /wise_topic/cluster/balanced_sampling.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Sequence, Optional, Dict 2 | 3 | import numpy as np 4 | 5 | from bertopic.backend._utils import select_backend 6 | from flair.embeddings import TransformerDocumentEmbeddings 7 | 8 | from phate import PHATE 9 | from umap import UMAP 10 | 11 | from llm_patterns.utils.caching_query import caching_query 12 | from .clustering import kmeans_clustering, hdbscan_clustering, bayesian_gaussian_mixture, agglomerative_clustering 13 | from .embedders import tfidf 14 | 15 | 16 | def embedder_template(emb_model, caching_fn: str) -> Callable: 17 | return lambda x: caching_query(caching_fn, lambda: select_backend(emb_model).embed_documents(x)) 18 | 19 | 20 | def transformer_template(model) -> Callable: 21 | return lambda x, fn: caching_query(fn, lambda: model.fit_transform(x)) 22 | 23 | 24 | embedders = { 25 | "instructor_large": lambda: embedder_template( 26 | TransformerDocumentEmbeddings("hkunlp/instructor-large"), 27 | "instructor-large.pklz", 28 | ), 29 | "sbert": lambda: embedder_template("all-mpnet-base-v2", "sbert.pklz"), 30 | "sbert_mini": lambda: embedder_template("all-MiniLM-L6-v2", "sbert_mini.pklz"), 31 | "tfidf": embedder_template(tfidf(), "tfidf.pklz"), 32 | } 33 | 34 | dim_reducers = { 35 | "umap": lambda dim: transformer_template(UMAP(n_neighbors=15, n_components=dim, min_dist=0.0, metric="cosine")), 36 | "phate": lambda dim: transformer_template(PHATE(n_components=dim)), 37 | } 38 | 39 | clusterers = { 40 | "kmeans": lambda x, kwargs: kmeans_clustering(x, **{"n_clusters": 10, **kwargs}), 41 | "bgm": lambda x, kwargs: bayesian_gaussian_mixture(x, **{"prior": 0.01, **kwargs}), 42 | "hdbscan": lambda x, kwargs: hdbscan_clustering(x, **kwargs), 43 | "agglomerative": lambda x, kwargs: agglomerative_clustering(x, **{"n_clusters": 10, **kwargs}), 44 | } 45 | 46 | 47 | def cluster_docs( 48 | docs: Sequence[str], 49 | embedder: str = "sbert", 50 | dim_reducer: str = "umap", 51 | cluster_algo: str = "kmeans", 52 | cluster_dim: int = 5, 53 | cluster_kwargs: Optional[Dict] = None, 54 | ): 55 | cluster_kwargs = cluster_kwargs or {} 56 | embeddings = embedders[embedder]()(np.array(docs)) 57 | reduced = dim_reducers[dim_reducer.lower()](dim=cluster_dim)(embeddings, f"{embedder}_{dim_reducer}.pklz") 58 | reduced2 = dim_reducers[dim_reducer](dim=2)(embeddings, f"{embedder}_{dim_reducer}2.pklz") 59 | 60 | cluster_fit = clusterers[cluster_algo](reduced, cluster_kwargs) 61 | labels = cluster_fit.labels_ 62 | 63 | out = {"labels": labels, "embeddings": embeddings, "reduced": reduced, "reduced2d": reduced2} 64 | return out 65 | -------------------------------------------------------------------------------- /wise_topic/cluster/clustering.py: -------------------------------------------------------------------------------- 1 | import hdbscan 2 | import numpy as np 3 | from sklearn import mixture 4 | from sklearn.cluster import AgglomerativeClustering, KMeans 5 | from sklearn.metrics import silhouette_score 6 | 7 | 8 | # Do the clustering in higher dims, plot in lower dims 9 | def WithProbabilities(a: type): 10 | class WithProbabilities(a): 11 | def fit(self, X): 12 | super().fit() 13 | try: 14 | self.probabilities_ = self.predict_proba(X) 15 | except: 16 | labels = np.unique(self.labels_) 17 | self.probabilities_ = np.zeros((len(X), len(labels))) 18 | inds = np.array(enumerate(labels)) 19 | self.probabilities_[inds] = 1.0 20 | 21 | 22 | def agglomerative_clustering(X, n_clusters=10, **kwargs): 23 | clustering = AgglomerativeClustering(linkage="ward", n_clusters=n_clusters, **kwargs) 24 | clustering.fit(X) 25 | return clustering 26 | 27 | 28 | def kmeans_clustering(dataset, max_clusters=10, n_clusters=None, **kwargs): 29 | best_score = -1 30 | best_clusters = None 31 | 32 | if n_clusters is None: 33 | for n_clusters in range(3, max_clusters + 1): 34 | # Perform k-means clustering 35 | kmeans = KMeans(n_clusters=n_clusters, **kwargs) 36 | labels = kmeans.fit_predict(dataset) 37 | 38 | # Calculate silhouette score 39 | score = silhouette_score(dataset, labels) 40 | print(n_clusters, score) 41 | 42 | # Update best score and clusters if current score is better 43 | if score > best_score: 44 | best_score = score 45 | best_clusters = kmeans 46 | return best_clusters 47 | else: 48 | kmeans = KMeans(n_clusters=n_clusters, **kwargs) 49 | kmeans.fit(dataset) 50 | return kmeans 51 | 52 | 53 | def bayesian_gaussian_mixture(X, prior: float = 0.01, **kwargs): 54 | bgmm = mixture.BayesianGaussianMixture( 55 | n_components=10, 56 | covariance_type="full", 57 | weight_concentration_prior=prior, 58 | weight_concentration_prior_type="dirichlet_process", 59 | mean_precision_prior=1e-2, 60 | # covariance_prior=1e0 * np.eye(x.shape), 61 | init_params="kmeans", 62 | max_iter=100, 63 | random_state=2, 64 | **kwargs 65 | ).fit(X) 66 | bgmm.labels_ = bgmm.predict(X) 67 | return bgmm 68 | 69 | 70 | def hdbscan_clustering(X, min_topic_size=10, **kwargs): 71 | hdb = hdbscan.HDBSCAN( 72 | min_cluster_size=min_topic_size, 73 | metric="euclidean", 74 | cluster_selection_method="eom", 75 | prediction_data=True, 76 | **kwargs 77 | ) 78 | hdb.fit(X) 79 | return hdb 80 | -------------------------------------------------------------------------------- /wise_topic/topic/topic_extractor.py: -------------------------------------------------------------------------------- 1 | import json 2 | from json import JSONDecodeError 3 | from typing import Sequence, Optional 4 | from math import ceil 5 | 6 | from langchain_core.language_models import BaseChatModel 7 | from langchain.prompts import PromptTemplate 8 | 9 | 10 | def extract_topics( 11 | docs: Sequence[str], 12 | llm: BaseChatModel, 13 | extra_prompt: str | None = None, 14 | n_topics=3, 15 | max_words=5, 16 | with_count: bool = False, 17 | ): 18 | min_topics = ceil(n_topics / 2) 19 | docs_ = ";\n".join(docs) 20 | prompt_text = f"""I am about to give you a list of documents, separated by semicolons 21 | and line breaks, and jointly delimited by triple back quotes. 22 | 23 | After the final document, there will be the word 24 | 25 | Generate a list of at most {n_topics} and at least {min_topics} distinct, non-overlapping topics 26 | that between them best describe the content of the documents. 27 | {extra_prompt if extra_prompt is not None else ""} 28 | Each topic should have at least 3 and at most {max_words} words. 29 | Avoid composite topics (`something and something else`) 30 | """ 31 | if with_count: 32 | prompt_text += """ 33 | 34 | Each topic should be followed by a count of how many documents it applies to 35 | 36 | Output the topics in the following format: 37 | ``` 38 | {{"topic1":count_of_documents_with_topic1, ...}} 39 | ``` 40 | """ 41 | else: 42 | prompt_text += """ 43 | Output the topics in the following format: 44 | ``` 45 | ["topic1", "topic2", ...] 46 | ``` 47 | """ 48 | prompt_text += """Don't exceed the {max_words} words limit, including words such as 'and' and 'of'. 49 | Make sure each following topic does not overlap with or duplicate any of the previous topics. 50 | After seeing , output a well-formed json containing the list of topics as described above. 51 | Make sure to return between {min_topics} and {n_topics} distinct, non-overlapping topics. 52 | ```{docs} 53 | ``` 54 | """ 55 | if hasattr(llm, "max_tokens"): 56 | llm.max_tokens = 5 * (max_words + 1) * n_topics 57 | 58 | chain = PromptTemplate.from_template(prompt_text) | llm 59 | for _ in range(3): 60 | try: 61 | out = chain.invoke( 62 | { 63 | "docs": docs_, 64 | "n_topics": n_topics, 65 | "max_words": max_words, 66 | "min_topics": min_topics, 67 | } 68 | ) 69 | return json.loads(out.content.replace("```json", "").replace("```", "")) 70 | except JSONDecodeError: 71 | pass 72 | -------------------------------------------------------------------------------- /wise_topic/topic/classifier.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Sequence 3 | 4 | from langchain.prompts import PromptTemplate 5 | from langchain_core.language_models import BaseChatModel 6 | 7 | 8 | def topic_classifier(topics: Sequence[str], message: str, llm: BaseChatModel): 9 | re_topics = "\n".join([f"{i+1}. {t}" for i, t in enumerate(topics)]) 10 | prompt_text = """I am about to give you a numbered list of topics. 11 | Then I will pass to you a message, after the word MESSAGE. 12 | Return an integer that is the number of the topic that best fits that message; 13 | if no topic fits the message, return 0. 14 | 15 | Return only the number, without additional text. 16 | {categories} 17 | MESSAGE: 18 | {message} 19 | TOPIC NUMBER:""" 20 | 21 | if hasattr(llm, "max_tokens"): 22 | llm.max_tokens = 5 23 | 24 | chain = PromptTemplate.from_template(prompt_text) | llm 25 | out = chain.invoke({"categories": re_topics, "message": message}) 26 | try: 27 | return int(out.content) 28 | except Exception as e: 29 | logging.warning(f"Error in classifier: {e}") 30 | return 0 31 | 32 | 33 | async def topic_classifier_async( 34 | topics: Sequence[str], message: str, llm: BaseChatModel 35 | ) -> int: 36 | """ 37 | Asynchronously classify the message into one of the provided topics using a language model. 38 | Return an integer corresponding to the number of the best fitting topic. 39 | If no topic fits, return 0. 40 | 41 | Parameters: 42 | - topics (Sequence[str]): List of possible topics. 43 | - message (str): The message to classify. 44 | - llm (BaseChatModel): The language model used for classification. 45 | 46 | Returns: 47 | - int: The number corresponding to the best-fitting topic, or 0 if no fit. 48 | """ 49 | re_topics = "\n".join([f"{i+1}. {t}" for i, t in enumerate(topics)]) 50 | prompt_text = """I am about to give you a numbered list of topics. 51 | Then I will pass to you a message, after the word MESSAGE. 52 | Return an integer that is the number of the topic that best fits that message; 53 | if no topic fits the message, return 0. 54 | 55 | Return only the number, without additional text. 56 | {categories} 57 | MESSAGE: 58 | {message} 59 | TOPIC NUMBER:""" 60 | 61 | if hasattr(llm, "max_tokens"): 62 | llm.max_tokens = 5 63 | 64 | # Create the prompt chain using the template and language model 65 | chain = PromptTemplate.from_template(prompt_text) | llm 66 | 67 | try: 68 | # Invoke the language model asynchronously 69 | out = await chain.ainvoke({"categories": re_topics, "message": message}) 70 | return int(out.content) 71 | except Exception as e: 72 | logging.warning(f"Error in classifier: {e}") 73 | return 0 74 | 75 | 76 | # Run the example 77 | if __name__ == "__main__": 78 | import asyncio 79 | from langchain_openai import ChatOpenAI 80 | 81 | async def main(llm): 82 | 83 | # Simulated topics and message 84 | topics = ["Sports", "Technology", "Music", "Politics"] 85 | message = "Artificial intelligence is transforming industries." 86 | # Call the topic_classifier with the mock LLM 87 | result = await topic_classifier_async(topics, message, llm) 88 | 89 | # Return the resulting topic number 90 | return result 91 | 92 | classifier_llm = ChatOpenAI(model="gpt-4o-mini", temperature=0) 93 | result = asyncio.run(main(classifier_llm)) 94 | print(f"Topic number: {result}") 95 | -------------------------------------------------------------------------------- /wise_topic/classifier/classifier.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | 5 | from langchain_core.messages import HumanMessage 6 | from langchain_core.prompts import PromptTemplate 7 | from langchain_core.language_models import BaseChatModel 8 | 9 | from wise_topic.classifier.prompts import binary_prompt, multi_choice_prompt 10 | 11 | 12 | def llm_classifier_binary(llm: BaseChatModel, question: str): 13 | prompt_value = ( 14 | PromptTemplate.from_template(binary_prompt) 15 | .invoke({"question": question}) 16 | .to_messages() 17 | ) 18 | out = llm_classifier(llm, [prompt_value], ["0", "1"]) 19 | return {False: out[0], True: out[1]} 20 | 21 | 22 | def llm_classifier_multiple( 23 | llm: BaseChatModel, 24 | question: str, 25 | answer_options: List[str], 26 | include_other: bool = False, 27 | ): 28 | 29 | assert ( 30 | len(answer_options) <= 9 31 | ), "Only up to 9 answer options are supported at the moment" 32 | categories = "\n".join([f"{i + 1}. {t}" for i, t in enumerate(answer_options)]) 33 | prompt_value = ( 34 | PromptTemplate.from_template(multi_choice_prompt(include_other)) 35 | .invoke({"question": question, "categories": categories}) 36 | .to_messages() 37 | ) 38 | 39 | valid_outputs = [str(i) for i in range(len(answer_options) + 1)] 40 | scores = llm_classifier( 41 | llm, 42 | [prompt_value], 43 | valid_outputs if include_other else valid_outputs[1:], 44 | top_logprobs=15, 45 | ) 46 | if include_other: 47 | used_options = ["Other"] + list(answer_options) 48 | else: 49 | used_options = list(answer_options) 50 | 51 | out = {k: v for k, v in zip(used_options, scores)} 52 | return out 53 | 54 | 55 | def llm_classifier( 56 | llm: BaseChatModel, messages, valid_options, top_logprobs=5, max_tokens=1 57 | ) -> np.ndarray: 58 | result = llm.generate( 59 | messages, 60 | logprobs=True, 61 | top_logprobs=top_logprobs, 62 | max_tokens=max_tokens, 63 | ) 64 | info = result.generations[0][0].generation_info["logprobs"]["content"][0][ 65 | "top_logprobs" 66 | ] 67 | 68 | scores = logprobs_to_scores(info, valid_options) 69 | return scores 70 | 71 | 72 | def logprobs_to_scores(logprobs, valid_options: List[str]) -> np.ndarray: 73 | scores = np.array(len(valid_options) * [float("-inf")]) 74 | matches = False 75 | for i, c in enumerate(valid_options): 76 | for p in logprobs: 77 | if isinstance(p, dict): # Langchain interface 78 | token = p["token"] 79 | logprob = p["logprob"] 80 | else: # OpenAI interface 81 | token = p.token 82 | logprob = p.logprob 83 | if token == c: 84 | matches = True 85 | scores[i] = logprob 86 | if matches: 87 | scores = scores - np.max(scores) 88 | scores = np.exp(scores) 89 | scores = scores / np.sum(scores) 90 | else: # If no matches, return uniform distribution - is that optimal? 91 | scores = np.ones(len(valid_options)) / len(valid_options) 92 | 93 | return scores 94 | 95 | # And this is how to do this with openai direct 96 | # response = raw_llm.chat.completions.create( 97 | # model="gpt-4-1106-preview", 98 | # messages=[ 99 | # {"role": "user", "content": reasoning_prompt.to_string()}, 100 | # ], 101 | # logprobs=True, 102 | # top_logprobs=5, 103 | # max_tokens=1, 104 | # ) 105 | # info = response.choices[0].logprobs.content[0].top_logprobs 106 | 107 | 108 | def classifier_with_reasoning( 109 | llm, reasoning_prompt: str, binary_prompt: str, reasoning_args: dict 110 | ): 111 | 112 | reasoning_template = PromptTemplate.from_template(reasoning_prompt) 113 | reasoning_prompt = reasoning_template.invoke(reasoning_args).to_messages() 114 | chain = reasoning_template | llm 115 | reasoning = chain.invoke(reasoning_args) 116 | messages = [reasoning_prompt + [reasoning, HumanMessage(content=binary_prompt)]] 117 | scores = llm_classifier(llm, messages, ["0", "1"]) 118 | out = { 119 | "reasoning": reasoning.content, 120 | "prob(1)": scores[1], 121 | } 122 | 123 | return out 124 | -------------------------------------------------------------------------------- /examples/Getting probability scores out of LLM classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "45ebdda2", 6 | "metadata": {}, 7 | "source": [ 8 | "# Getting probability scores out of LLM classification\n", 9 | "\n", 10 | "When comparing traditional ML classifiers to LLM-based ones, a common problem is that most classifier performance metrics require a vector of confidence/probability scores across the available options, not just the most likely answer. \n", 11 | "\n", 12 | "Fortunately, eg the OpenAI API allows to query token logprobs for up to 20 most likely tokens in each position of its response. \n", 13 | "These still need to be masked (discarding irrelevant options), converted to probabilities, and normalized to sum to one. \n", 14 | "\n", 15 | "To spare you the hassle of doing this, we provide two functions, a binary classifier (which expects a yes/no question), and a multiple-choice classifier that expects a multiple-choice question and a list of valid options. It also has an optional boolean argument `include_other`, which if true makes the classifier also include an \"Other\" option in its output, for when none of the valid options fit. \n", 16 | "\n", 17 | "To keep it simple, the multiple chocice classifier only supports up to 9 choice options, so the LLM output can be a single digit (for speed and parsing simplicity). Feel free to contribute a version that supports a larger number of choices! ;)" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 1, 23 | "id": "f32f114d", 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "from pprint import pprint\n", 28 | "\n", 29 | "from langchain_openai import ChatOpenAI\n", 30 | "\n", 31 | "try:\n", 32 | " import wise_topic\n", 33 | "except ImportError:\n", 34 | " import os, sys\n", 35 | " sys.path.append(os.path.realpath(\"..\"))\n", 36 | "\n", 37 | "\n", 38 | "from wise_topic import llm_classifier_binary, llm_classifier_multiple\n", 39 | "llm = ChatOpenAI(model=\"gpt-4-turbo\", temperature=0)\n" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 2, 45 | "id": "d7288d4b", 46 | "metadata": {}, 47 | "outputs": [ 48 | { 49 | "data": { 50 | "text/plain": [ 51 | "{False: 0.03559647724243312, True: 0.9644035227575669}" 52 | ] 53 | }, 54 | "execution_count": 2, 55 | "metadata": {}, 56 | "output_type": "execute_result" 57 | } 58 | ], 59 | "source": [ 60 | "question1 = \"Consider a very friendly pet with a fluffy tail. You know it's a cat or a dog. Is it a cat?\"\n", 61 | "llm_classifier_binary(llm, question1)" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 3, 67 | "id": "c3081966", 68 | "metadata": {}, 69 | "outputs": [ 70 | { 71 | "data": { 72 | "text/plain": [ 73 | "{False: 0.9999912515146222, True: 8.748485377892584e-06}" 74 | ] 75 | }, 76 | "execution_count": 3, 77 | "metadata": {}, 78 | "output_type": "execute_result" 79 | } 80 | ], 81 | "source": [ 82 | "question2 = \"Consider a very friendly pet with a waggy tail. You know it's a cat or a dog. Is it a cat?\"\n", 83 | "llm_classifier_binary(llm, question2)" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 4, 89 | "id": "0689d004", 90 | "metadata": {}, 91 | "outputs": [ 92 | { 93 | "data": { 94 | "text/plain": [ 95 | "{'cat': 0.9372176977942116,\n", 96 | " 'dog': 0.062782248112413,\n", 97 | " 'dragon': 5.215838794110004e-09,\n", 98 | " 'duck': 4.887753666874768e-08}" 99 | ] 100 | }, 101 | "execution_count": 4, 102 | "metadata": {}, 103 | "output_type": "execute_result" 104 | } 105 | ], 106 | "source": [ 107 | "question3 = \"Consider a very friendly pet with a fluffy tail. You know it's a cat, a dog, or a dragon. Which is it?\"\n", 108 | "llm_classifier_multiple(llm, question3, [\"cat\", \"dog\", \"dragon\", \"duck\"], include_other=False)" 109 | ] 110 | } 111 | ], 112 | "metadata": { 113 | "kernelspec": { 114 | "display_name": "Python [conda env:llm3.11]", 115 | "language": "python", 116 | "name": "conda-env-llm3.11-py" 117 | }, 118 | "language_info": { 119 | "codemirror_mode": { 120 | "name": "ipython", 121 | "version": 3 122 | }, 123 | "file_extension": ".py", 124 | "mimetype": "text/x-python", 125 | "name": "python", 126 | "nbconvert_exporter": "python", 127 | "pygments_lexer": "ipython3", 128 | "version": "3.11.5" 129 | } 130 | }, 131 | "nbformat": 4, 132 | "nbformat_minor": 5 133 | } 134 | -------------------------------------------------------------------------------- /examples/Topic extraction using LLMs only.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "45ebdda2", 6 | "metadata": {}, 7 | "source": [ 8 | "# Topic extraction using LLMs only\n", 9 | "\n", 10 | "The traditional ML ways of topic extraction rely on converting each message into a vector in some vector space, and then clustering in that vector space. \"Topics\" are then really just regions in that vector space.\n", 11 | "\n", 12 | "This approach has several weaknesses: Even interpreting such clusters is not trivial; editing them after the fit, let alone specifying an initial list of human-defined topics, to be automatically expanded if necessary, is pretty much impossible.\n", 13 | "\n", 14 | "Here we show a different way, using only LLM calls. It works as follows: we feed one message at the time to the topic processor; it either assigns it to one of the existing topics, or if none are a good fit, puts it aside. Once the number of messages put aside reaches a threshold, these are used to extract a new topic, which is added to the list. There is also the option of generating topich hierarchies, by setting `max_depth` to a value bigger than 1.\n" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "id": "f32f114d", 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "from pprint import pprint\n", 25 | "\n", 26 | "from langchain_openai import ChatOpenAI\n", 27 | "\n", 28 | "try:\n", 29 | " import wise_topic\n", 30 | "except ImportError:\n", 31 | " import os, sys\n", 32 | " sys.path.append(os.path.realpath(\"..\"))\n", 33 | "\n", 34 | "\n", 35 | "from wise_topic import greedy_topic_tree, tree_summary\n", 36 | "\n", 37 | "docs = [\n", 38 | " \"The summer sun blazed high in the sky, bringing warmth to the sandy beaches.\",\n", 39 | " \"During summer, the days are long and the nights are warm and inviting.\",\n", 40 | " \"Ice cream sales soar as people seek relief from the summer heat.\",\n", 41 | " \"Families often choose summer for vacations to take advantage of the sunny weather.\",\n", 42 | " \"Many festivals and outdoor concerts are scheduled in the summer months.\",\n", 43 | " \"Winter brings the joy of snowfall and the excitement of skiing.\",\n", 44 | " \"The cold winter nights are perfect for sipping hot chocolate by the fire.\",\n", 45 | " \"Winter storms can transform the landscape into a snowy wonderland.\",\n", 46 | " \"Heating bills tend to rise as winter's chill sets in.\",\n", 47 | " \"Many animals hibernate or migrate to cope with the harsh winter conditions.\",\n", 48 | " \"Fish swim in schools to protect themselves from predators.\",\n", 49 | " \"Salmon migrate upstream during spawning season, a remarkable journey.\",\n", 50 | " \"Tropical fish add vibrant color and life to coral reefs.\",\n", 51 | " \"Overfishing threatens many species of fish with extinction.\",\n", 52 | " \"Fish have a diverse range of habitats from deep oceans to shallow streams.\",\n", 53 | "]\n" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 2, 59 | "id": "d7288d4b", 60 | "metadata": {}, 61 | "outputs": [ 62 | { 63 | "data": { 64 | "text/plain": [ 65 | "{'Summer': 6, 'Winter': 5, 'threats to diverse fish species': 4}" 66 | ] 67 | }, 68 | "execution_count": 2, 69 | "metadata": {}, 70 | "output_type": "execute_result" 71 | } 72 | ], 73 | "source": [ 74 | "topic_llm = ChatOpenAI(\n", 75 | " model=\"gpt-4-turbo\",\n", 76 | " temperature=0,\n", 77 | " model_kwargs={\"response_format\": {\"type\": \"json_object\"}},\n", 78 | ")\n", 79 | "\n", 80 | "classifier_llm = ChatOpenAI(model=\"gpt-4-turbo\", temperature=0)\n", 81 | "\n", 82 | "topic_tree = greedy_topic_tree(\n", 83 | " docs,\n", 84 | " initial_topics=[\"Winter\", \"Summer\"],\n", 85 | " topic_llm=topic_llm,\n", 86 | " classifier_llm=classifier_llm,\n", 87 | " max_depth=1,\n", 88 | " num_topics_per_update=1,\n", 89 | " max_unclassified_messages=2,\n", 90 | ")\n", 91 | "\n", 92 | "tree_summary(topic_tree)" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 3, 98 | "id": "a4757ef7", 99 | "metadata": {}, 100 | "outputs": [ 101 | { 102 | "data": { 103 | "text/plain": [ 104 | "{'Summer': {'messages': ['During summer, the days are long and the nights are warm and inviting.',\n", 105 | " 'Tropical fish add vibrant color and life to coral reefs.',\n", 106 | " 'The summer sun blazed high in the sky, bringing warmth to the sandy beaches.',\n", 107 | " 'Ice cream sales soar as people seek relief from the summer heat.',\n", 108 | " 'Many festivals and outdoor concerts are scheduled in the summer months.',\n", 109 | " 'Families often choose summer for vacations to take advantage of the sunny weather.']},\n", 110 | " 'Winter': {'messages': ['Winter storms can transform the landscape into a snowy wonderland.',\n", 111 | " 'Winter brings the joy of snowfall and the excitement of skiing.',\n", 112 | " 'The cold winter nights are perfect for sipping hot chocolate by the fire.',\n", 113 | " 'Many animals hibernate or migrate to cope with the harsh winter conditions.',\n", 114 | " \"Heating bills tend to rise as winter's chill sets in.\"]},\n", 115 | " 'threats to diverse fish species': {'messages': ['Fish swim in schools to protect themselves from predators.',\n", 116 | " 'Fish have a diverse range of habitats from deep oceans to shallow streams.',\n", 117 | " 'Overfishing threatens many species of fish with extinction.',\n", 118 | " 'Salmon migrate upstream during spawning season, a remarkable journey.']}}" 119 | ] 120 | }, 121 | "metadata": {}, 122 | "output_type": "display_data" 123 | } 124 | ], 125 | "source": [ 126 | "display(topic_tree)" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "id": "381d095e", 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [] 136 | } 137 | ], 138 | "metadata": { 139 | "kernelspec": { 140 | "display_name": "Python [conda env:llm3.11]", 141 | "language": "python", 142 | "name": "conda-env-llm3.11-py" 143 | }, 144 | "language_info": { 145 | "codemirror_mode": { 146 | "name": "ipython", 147 | "version": 3 148 | }, 149 | "file_extension": ".py", 150 | "mimetype": "text/x-python", 151 | "name": "python", 152 | "nbconvert_exporter": "python", 153 | "pygments_lexer": "ipython3", 154 | "version": "3.11.5" 155 | } 156 | }, 157 | "nbformat": 4, 158 | "nbformat_minor": 5 159 | } 160 | -------------------------------------------------------------------------------- /wise_topic/topic/greedy.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from collections import defaultdict 3 | from copy import deepcopy 4 | from typing import Optional, Sequence, Callable, Union 5 | import logging 6 | import copy 7 | 8 | import numpy as np 9 | from langchain_core.language_models import BaseLanguageModel 10 | from langchain_openai import ChatOpenAI 11 | 12 | from .classifier import topic_classifier, topic_classifier_async 13 | from .topic_extractor import extract_topics 14 | from ..parallel import process_batch_parallel 15 | 16 | 17 | class GreedyTopicExtractor: 18 | def __init__( 19 | self, 20 | max_unclassified_messages: int = 30, 21 | initial_unclassified_messages: Optional[int] = None, 22 | num_topics_per_update: int = 3, 23 | initial_num_topics: Optional[int] = None, 24 | max_words_in_topic: int = 6, 25 | initial_topics: Optional[Sequence[str]] = None, 26 | topic_extractor: Callable = extract_topics, 27 | topic_llm: Union[str, BaseLanguageModel] = "gpt-4o", 28 | classifier: Callable = topic_classifier, 29 | classifier_llm: Union[str, BaseLanguageModel] = "gpt-4o-mini", 30 | max_parallel_calls: int = 5, 31 | extra_prompt: str | None = None, 32 | verbose: bool = False, 33 | ): 34 | """ 35 | 36 | :param max_unclassified_messages: 37 | :param initial_unclassified_messages: 38 | :param num_topics_per_update: 39 | :param initial_num_topics: 40 | :param max_words_in_topic: 41 | :param initial_topics: 42 | :param topic_extractor: 43 | :param topic_llm: 44 | :param classifier: 45 | :param classifier_llm: 46 | :param verbose: 47 | """ 48 | self.max_unclassified_messages = max_unclassified_messages 49 | self.initial_unclassified_messages = ( 50 | self.max_unclassified_messages * 2 51 | if initial_unclassified_messages is None 52 | else initial_unclassified_messages 53 | ) 54 | self.num_topics_per_update = num_topics_per_update 55 | self.initial_num_topics = ( 56 | num_topics_per_update * 2 57 | if initial_num_topics is None 58 | else initial_num_topics 59 | ) 60 | self.max_words_in_topic = max_words_in_topic 61 | self.topic_extractor = topic_extractor 62 | if isinstance(topic_llm, str): 63 | topic_llm = ChatOpenAI(model=topic_llm, temperature=0) 64 | self.topic_llm = topic_llm 65 | self.classifier = classifier 66 | if isinstance(classifier_llm, str): 67 | classifier_llm = ChatOpenAI(model=classifier_llm, temperature=0) 68 | self.classifier_llm = classifier_llm 69 | self.extra_prompt = extra_prompt 70 | self.verbose = verbose 71 | self.max_parallel_calls = max_parallel_calls 72 | 73 | self.topics = [] if initial_topics is None else list(initial_topics) 74 | self.messages_to_topics = {} 75 | self.unclassified_messages = [] 76 | self.rename = True 77 | 78 | def __call__(self, messages: Sequence[str]): 79 | """ 80 | Processes a list of messages synchronously while allowing up to N simultaneous async classifier calls. 81 | Pauses all classify_message calls if topics are being updated or extracted. Handles both cases where 82 | an event loop is already running or needs to be created. 83 | 84 | Parameters: 85 | - messages (Sequence[str]): The list of messages to classify. 86 | - max_concurrent_tasks (int): Maximum number of simultaneous async classifier calls. 87 | """ 88 | np.random.shuffle(messages) 89 | 90 | # Process messages in batches, assigning them to existing topics 91 | # if too many messages that don't fit have accumulated, extract more topics from them 92 | # and add them to the list 93 | for ind in range(0, len(messages), self.max_unclassified_messages): 94 | msg = messages[ 95 | ind : min(ind + self.max_unclassified_messages, len(messages)) 96 | ] 97 | self.process_batch(msg, self.max_parallel_calls) 98 | 99 | # Check if topics need to be updated 100 | update_topics = False 101 | if len(self.topics): 102 | if len(self.unclassified_messages) > self.max_unclassified_messages: 103 | update_topics = True 104 | else: 105 | if len(self.unclassified_messages) > self.initial_unclassified_messages: 106 | update_topics = True 107 | elif len(messages) - ind < self.max_unclassified_messages: 108 | 109 | # we're done and there were not sufficient messages in total to trigger the above 110 | update_topics = True 111 | 112 | # Pause tasks and update topics synchronously if necessary 113 | if update_topics: 114 | logging.info("Updating topics...") 115 | new_topics = self.extract_topics( 116 | self.unclassified_messages, extra_prompt=self.extra_prompt 117 | ) 118 | self.topics += new_topics 119 | print("new topics:", new_topics) 120 | unclassified = self.unclassified_messages 121 | self.unclassified_messages = [] 122 | self.process_batch(unclassified, self.max_parallel_calls) 123 | logging.info(self.topic_counts) 124 | 125 | # after we've finished classifying, rename the topics to reflect all the messages under them 126 | t2m = self.topics_to_messages 127 | for i, t in enumerate(copy.copy(self.topics)): 128 | if t != "Other": 129 | new_name = self.extract_topics(t2m[t])[0] 130 | self.topics[i] = new_name 131 | 132 | # Log topic counts after processing all messages 133 | logging.info(self.topic_counts) 134 | 135 | def process_batch(self, msg, n_jobs: int = 5): 136 | def classify_message(message: str): 137 | if len(self.topics): 138 | topic_num = topic_classifier( 139 | self.topics, message, llm=self.classifier_llm 140 | ) 141 | assert topic_num <= len(self.topics) 142 | else: 143 | topic_num = 0 144 | 145 | return topic_num 146 | 147 | results = process_batch_parallel( 148 | classify_message, [(m,) for m in msg], max_workers=n_jobs 149 | ) 150 | 151 | for topic_num, (message,) in results: 152 | if topic_num == 0: 153 | self.unclassified_messages.append(message) 154 | else: 155 | self.messages_to_topics[message] = topic_num - 1 156 | if self.verbose: 157 | print(message, "\n", self.topics[topic_num - 1], "\n***********") 158 | 159 | def update_topics(self): 160 | messages = self.unclassified_messages 161 | self.unclassified_messages = [] 162 | 163 | new_topics = self.extract_topics(messages) 164 | self.topics += new_topics 165 | 166 | for m in messages: 167 | self.step(m) 168 | 169 | def extract_topics(self, messages: Sequence[str], extra_prompt: str | None = None): 170 | logging.info("updating topics...") 171 | 172 | num_topics = ( 173 | self.num_topics_per_update if len(self.topics) else self.initial_num_topics 174 | ) 175 | 176 | topics = self.topic_extractor( 177 | messages, 178 | extra_prompt=extra_prompt, 179 | n_topics=num_topics, 180 | with_count=True, 181 | max_words=self.max_words_in_topic, 182 | llm=self.topic_llm, 183 | ) 184 | if self.verbose: 185 | print(topics) 186 | return list(topics.keys()) 187 | 188 | @property 189 | def topics_to_messages(self): 190 | out = defaultdict(list) 191 | for m, t in self.messages_to_topics.items(): 192 | out[self.topics[t]].append(m) 193 | 194 | if len(self.unclassified_messages): 195 | out["Other"] = self.unclassified_messages 196 | return out 197 | 198 | @property 199 | def topic_counts(self): 200 | return {k: len(v) for k, v in self.topics_to_messages.items()} 201 | 202 | 203 | def greedy_topic_tree(messages, max_depth=0, **kwargs): 204 | 205 | initial_topics = kwargs.pop("initial_topics", None) 206 | assert initial_topics is None or isinstance(initial_topics, (list, tuple, dict)) 207 | 208 | for key in ["max_unclassified_messages", "initial_unclassified_messages"]: 209 | if key in kwargs and kwargs[key] < 1: 210 | kwargs[key] = int(len(messages) * kwargs[key]) 211 | 212 | gte = GreedyTopicExtractor( 213 | **kwargs, 214 | initial_topics=( 215 | list(initial_topics.keys()) 216 | if isinstance(initial_topics, dict) 217 | else initial_topics 218 | ), 219 | ) 220 | gte(messages) 221 | topic_tree = {} 222 | t2m = gte.topics_to_messages 223 | for t, m in t2m.items(): 224 | topic_tree[t] = {"messages": m} 225 | 226 | if max_depth > 1 and len(topic_tree) > 1: # to prevent endless recursion 227 | for t, m in t2m.items(): 228 | if len(m) > gte.max_unclassified_messages: 229 | logging.info(f"Expanding topic {t} with {len(m)} messages") 230 | candidate_tree = greedy_topic_tree(m, max_depth=max_depth - 1, **kwargs) 231 | if not (len(candidate_tree) == 1 and "Other" in candidate_tree): 232 | topic_tree[t]["sub-topics"] = candidate_tree 233 | return topic_tree 234 | 235 | 236 | def cleanup(x: dict, min_topic_size: int = 5): 237 | x = deepcopy(x) 238 | # Collapse all the tiny topics into "Other" 239 | if "Other" not in x: 240 | x["Other"] = {"messages": []} 241 | 242 | keys = list(x.keys()) 243 | for k in keys: 244 | if k == "Other": 245 | continue 246 | if len(x[k]["messages"]) < min_topic_size: 247 | tmp = x.pop(k) 248 | x["Other"]["messages"] += tmp["messages"] 249 | 250 | if not len(x["Other"]["messages"]): 251 | x.pop("Other") 252 | 253 | # TODO: Move all subtopics with just one topic one level up 254 | 255 | return x 256 | 257 | 258 | def tree_summary(x: dict, include_messages: bool = True, min_topic_size: int = 1): 259 | x = cleanup(x, min_topic_size) 260 | out = {} 261 | for k, v in sorted( 262 | list(x.items()), key=lambda x: len(x[1]["messages"]), reverse=True 263 | ): 264 | m = v["messages"] 265 | new_k = f"{k} : {len(m)}" 266 | if "sub-topics" in v: 267 | out[new_k] = tree_summary(v["sub-topics"]) 268 | else: 269 | out[new_k] = m if include_messages else "" 270 | 271 | if "Other" in out: 272 | # Move "other" to the end 273 | tmp = out.pop("Other") 274 | out["Other"] = tmp 275 | print("yay!") 276 | return out 277 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------