├── .gitignore ├── README.md ├── main.py ├── requirements.txt └── zero-shot-intent-classifer.chain.json /.gitignore: -------------------------------------------------------------------------------- 1 | _venv/ 2 | .env 3 | sandbox.ipynb 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Zero-Shot Intent Classifier 2 | 3 | This sort of thing used to be non-trivial. I hacked this together probably in like an hour. 4 | 5 | Ho boy, the times: they are a changin. 6 | 7 | ## What is this? 8 | 9 | This is probably going to sound archaic in a few months, but a lot of "home assistant" type devices right now use a technique called "slot filling" under the hood. An "intent" classifier is the component that figures out what the relevant slots are and 'fills' them with values, resulting in a command being emitted and arguments passed. Instead of training one bespoke: you can probably just use this directly with no or very little modification. 10 | 11 | ## Setup 12 | 13 | 1. `git clone ; cd ` 14 | 2. `pip install -r requirements.txt` 15 | 3. Create a file named `.env` containing one line: `OPENAI_API_KEY=...`, replacing `...` with your key. 16 | 17 | ## Use 18 | 19 | $ python main.py "becca, how would I drive from my home to SeaTac airport?" 20 | ## {'intent': 'get_directions', 'arguments': {'start_location': 'home', 'end_location': 'SeaTac airport'}} 21 | 22 | ## Compiled Prompt 23 | 24 | > Act as the intent classification component of a home assistant, similar to Amazon Alexa (except your name is 'Becca', not 'Alexa'). 25 | > Common intents include: play_internet_radio, play_song_by_artist, get_weather, current_time, set_timer, remind_me 26 | > You receive input in json format: `{"input":...}` 27 | > You respond in json format: `{"intent":..., "arguments":{ ... }, }}` 28 | > {"input":`{spoken_request}`} 29 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | from langchain.llms import OpenAI 3 | from langchain.prompts import PromptTemplate 4 | from langchain.chains import LLMChain 5 | import json 6 | 7 | load_dotenv() 8 | 9 | INTENTS = [ 10 | "play_internet_radio", 11 | "play_song_by_artist", 12 | "get_weather", 13 | "current_time", 14 | "set_timer", 15 | "remind_me", 16 | ] 17 | 18 | def build_intent_classifier(intents=INTENTS): 19 | template = ( 20 | "Act as the intent classification component of a home assistant, similar to Amazon Alexa " 21 | "(except your name is 'Becca', not 'Alexa').\n" 22 | f"Common intents include: {', '.join(intents)}, ...\n" 23 | 'You receive input in json format: `{{"input":...}}`\n' 24 | 'You respond in json format: `{{"intent":..., "arguments":{{ ... }}, }}}}`\n\n' 25 | '{{"input":{spoken_request}}}' 26 | ) 27 | 28 | llm = OpenAI(temperature=0.1) 29 | prompt = PromptTemplate( 30 | input_variables=["spoken_request"], 31 | template=template, 32 | ) 33 | return LLMChain(llm=llm, prompt=prompt) 34 | 35 | def evaluate(chain, text): 36 | response = chain.run(text) 37 | return json.loads(response.strip()) 38 | 39 | if __name__ == '__main__': 40 | import sys 41 | 42 | # e2e test 43 | chain = build_intent_classifier() 44 | if len(sys.argv) > 1: 45 | text = sys.argv[1] 46 | response = evaluate(chain, text) 47 | print(response) 48 | else: 49 | text = "becca play kexp" 50 | response = evaluate(chain, text) 51 | print(response) 52 | assert response == {'intent':'play_internet_radio', 'arguments':{'station':'KEXP'}} -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | langchain 2 | openai 3 | google-api-python-client 4 | python-dotenv 5 | omegaconf 6 | transformers 7 | torch 8 | -------------------------------------------------------------------------------- /zero-shot-intent-classifer.chain.json: -------------------------------------------------------------------------------- 1 | { 2 | "memory": null, 3 | "verbose": false, 4 | "prompt": { 5 | "input_variables": [ 6 | "spoken_request" 7 | ], 8 | "output_parser": null, 9 | "partial_variables": {}, 10 | "template": "Act as the intent classification component of a home assistant, similar to Amazon Alexa (except your name is 'Becca', not 'Alexa').\nCommon intents include: play_internet_radio, play_song_by_artist, get_weather, current_time, set_timer, remind_me, ...\nYou receive input in json format: `{{\"input\":...}}`\nYou respond in json format: `{{\"intent\":..., \"arguments\":{{ ... }}, }}}}`\n\n{{\"input\":{spoken_request}}}", 11 | "template_format": "f-string", 12 | "validate_template": true, 13 | "_type": "prompt" 14 | }, 15 | "llm": { 16 | "model_name": "text-davinci-003", 17 | "temperature": 0.1, 18 | "max_tokens": 256, 19 | "top_p": 1, 20 | "frequency_penalty": 0, 21 | "presence_penalty": 0, 22 | "n": 1, 23 | "best_of": 1, 24 | "request_timeout": null, 25 | "logit_bias": {}, 26 | "_type": "openai" 27 | }, 28 | "output_key": "text", 29 | "_type": "llm_chain" 30 | } --------------------------------------------------------------------------------