├── .gitignore
├── 1. Generative AI.ipynb
├── 2. Prompt Engineering.ipynb
├── 3. NLP with HuggingFace.ipynb
├── 4. Whisper.ipynb
├── 5. PandasAI.ipynb
├── LICENSE
├── README.md
├── d4sci.mplstyle
├── data
├── Apple-Twitter-Sentiment-DFE.csv
├── D4Sci_logo_ball.png
├── D4Sci_logo_full.png
├── EpiModel.py
├── Northwind_small.sqlite
├── gettysburg10.wav
└── pratchett.mp3
├── exports
└── charts
│ └── temp_chart.png
├── requirements.txt
└── slides
└── LLM4DS.pdf
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/2. Prompt Engineering.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "ea4713a8",
6 | "metadata": {},
7 | "source": [
8 | "
\n",
9 | "
\n",
10 | "
LLMs for Data Science
\n",
11 | "
Prompt Engineering
\n",
12 | "
Bruno Gonçalves
\n",
13 | " www.data4sci.com
\n",
14 | " @bgoncalves, @data4sci
\n",
15 | "
"
16 | ]
17 | },
18 | {
19 | "cell_type": "code",
20 | "execution_count": 1,
21 | "id": "623cd8ca",
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "from collections import Counter, defaultdict\n",
26 | "import random\n",
27 | "\n",
28 | "import pandas as pd\n",
29 | "import numpy as np\n",
30 | "\n",
31 | "import matplotlib\n",
32 | "import matplotlib.pyplot as plt \n",
33 | "\n",
34 | "import langchain\n",
35 | "from langchain import PromptTemplate\n",
36 | "from langchain import FewShotPromptTemplate\n",
37 | "from langchain.prompts.example_selector import LengthBasedExampleSelector\n",
38 | "\n",
39 | "import langchain_openai\n",
40 | "from langchain_openai import ChatOpenAI\n",
41 | "\n",
42 | "import tqdm as tq\n",
43 | "from tqdm.notebook import tqdm\n",
44 | "\n",
45 | "import watermark\n",
46 | "\n",
47 | "%load_ext watermark\n",
48 | "%matplotlib inline"
49 | ]
50 | },
51 | {
52 | "cell_type": "markdown",
53 | "id": "ffef7c4e",
54 | "metadata": {},
55 | "source": [
56 | "We start by printing out the versions of the libraries we're using for future reference"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": 2,
62 | "id": "c2c8fe70",
63 | "metadata": {},
64 | "outputs": [
65 | {
66 | "name": "stdout",
67 | "output_type": "stream",
68 | "text": [
69 | "Python implementation: CPython\n",
70 | "Python version : 3.11.7\n",
71 | "IPython version : 8.12.3\n",
72 | "\n",
73 | "Compiler : Clang 14.0.6 \n",
74 | "OS : Darwin\n",
75 | "Release : 24.3.0\n",
76 | "Machine : arm64\n",
77 | "Processor : arm\n",
78 | "CPU cores : 16\n",
79 | "Architecture: 64bit\n",
80 | "\n",
81 | "Git hash: 03802c3bf87993c3670c2fd8bf86e59d3d60bdfd\n",
82 | "\n",
83 | "watermark : 2.4.3\n",
84 | "pandas : 2.2.3\n",
85 | "matplotlib : 3.8.0\n",
86 | "tqdm : 4.66.4\n",
87 | "numpy : 1.26.4\n",
88 | "langchain_openai: 0.1.8\n",
89 | "langchain : 0.2.2\n",
90 | "\n"
91 | ]
92 | }
93 | ],
94 | "source": [
95 | "%watermark -n -v -m -g -iv"
96 | ]
97 | },
98 | {
99 | "cell_type": "markdown",
100 | "id": "f5263cbc",
101 | "metadata": {},
102 | "source": [
103 | "Load default figure style"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": 3,
109 | "id": "2adaf3bf",
110 | "metadata": {},
111 | "outputs": [],
112 | "source": [
113 | "plt.style.use('d4sci.mplstyle')\n",
114 | "colors = plt.rcParams['axes.prop_cycle'].by_key()['color']"
115 | ]
116 | },
117 | {
118 | "cell_type": "markdown",
119 | "id": "d36e30cb",
120 | "metadata": {},
121 | "source": [
122 | "# Prompting Approaches"
123 | ]
124 | },
125 | {
126 | "cell_type": "code",
127 | "execution_count": 4,
128 | "id": "becd1923",
129 | "metadata": {},
130 | "outputs": [],
131 | "source": [
132 | "prompt = \"\"\"Answer the question based on the context below. If the\n",
133 | "question cannot be answered using the information provided answer\n",
134 | "with \"I don't know\".\n",
135 | "\n",
136 | "Context: Large Language Models (LLMs) are the latest models used in NLP.\n",
137 | "Their superior performance over smaller models has made them incredibly\n",
138 | "useful for developers building NLP enabled applications. These models\n",
139 | "can be accessed via Hugging Face's `transformers` library, via OpenAI\n",
140 | "using the `openai` library, and via Cohere using the `cohere` library.\n",
141 | "\n",
142 | "Question: Which libraries and model providers offer LLMs?\n",
143 | "\n",
144 | "Answer: \"\"\""
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": 5,
150 | "id": "84639def",
151 | "metadata": {},
152 | "outputs": [],
153 | "source": [
154 | "openai = ChatOpenAI(\n",
155 | " model_name=\"gpt-3.5-turbo\",\n",
156 | ")"
157 | ]
158 | },
159 | {
160 | "cell_type": "code",
161 | "execution_count": 6,
162 | "id": "b716cf15",
163 | "metadata": {},
164 | "outputs": [
165 | {
166 | "data": {
167 | "text/plain": [
168 | "\"Hugging Face's `transformers` library, OpenAI's `openai` library, and Cohere's `cohere` library offer LLMs.\""
169 | ]
170 | },
171 | "execution_count": 6,
172 | "metadata": {},
173 | "output_type": "execute_result"
174 | }
175 | ],
176 | "source": [
177 | "openai.invoke(prompt).content"
178 | ]
179 | },
180 | {
181 | "cell_type": "code",
182 | "execution_count": 7,
183 | "id": "e7d43026",
184 | "metadata": {},
185 | "outputs": [],
186 | "source": [
187 | "template = \"\"\"Answer the question based on the context below. If the\n",
188 | "question cannot be answered using the information provided answer\n",
189 | "with \"I don't know\".\n",
190 | "\n",
191 | "Context: Large Language Models (LLMs) are the latest models used in NLP.\n",
192 | "Their superior performance over smaller models has made them incredibly\n",
193 | "useful for developers building NLP enabled applications. These models\n",
194 | "can be accessed via Hugging Face's `transformers` library, via OpenAI\n",
195 | "using the `openai` library, and via Cohere using the `cohere` library.\n",
196 | "\n",
197 | "Question: {query}\n",
198 | "\n",
199 | "Answer: \"\"\"\n",
200 | "\n",
201 | "prompt_template = PromptTemplate(\n",
202 | " input_variables=[\"query\"],\n",
203 | " template=template\n",
204 | ")"
205 | ]
206 | },
207 | {
208 | "cell_type": "code",
209 | "execution_count": 8,
210 | "id": "a20a7841",
211 | "metadata": {},
212 | "outputs": [
213 | {
214 | "name": "stdout",
215 | "output_type": "stream",
216 | "text": [
217 | "Answer the question based on the context below. If the\n",
218 | "question cannot be answered using the information provided answer\n",
219 | "with \"I don't know\".\n",
220 | "\n",
221 | "Context: Large Language Models (LLMs) are the latest models used in NLP.\n",
222 | "Their superior performance over smaller models has made them incredibly\n",
223 | "useful for developers building NLP enabled applications. These models\n",
224 | "can be accessed via Hugging Face's `transformers` library, via OpenAI\n",
225 | "using the `openai` library, and via Cohere using the `cohere` library.\n",
226 | "\n",
227 | "Question: Which libraries and model providers offer LLMs?\n",
228 | "\n",
229 | "Answer: \n"
230 | ]
231 | }
232 | ],
233 | "source": [
234 | "prompt = prompt_template.format(\n",
235 | " query=\"Which libraries and model providers offer LLMs?\",\n",
236 | " )\n",
237 | "\n",
238 | "print(prompt)"
239 | ]
240 | },
241 | {
242 | "cell_type": "code",
243 | "execution_count": 9,
244 | "id": "35d3646b",
245 | "metadata": {},
246 | "outputs": [
247 | {
248 | "data": {
249 | "text/plain": [
250 | "\"Hugging Face's `transformers` library, OpenAI using the `openai` library, and Cohere using the `cohere` library.\""
251 | ]
252 | },
253 | "execution_count": 9,
254 | "metadata": {},
255 | "output_type": "execute_result"
256 | }
257 | ],
258 | "source": [
259 | "openai.invoke(prompt).content"
260 | ]
261 | },
262 | {
263 | "cell_type": "markdown",
264 | "id": "aea52b7b",
265 | "metadata": {},
266 | "source": [
267 | "# Few-Shot Prompting"
268 | ]
269 | },
270 | {
271 | "cell_type": "markdown",
272 | "id": "e420dbb2",
273 | "metadata": {},
274 | "source": [
275 | "### Manually"
276 | ]
277 | },
278 | {
279 | "cell_type": "code",
280 | "execution_count": 10,
281 | "id": "a0336d52",
282 | "metadata": {},
283 | "outputs": [],
284 | "source": [
285 | "prompt = \"\"\"The following are exerpts from conversations with an AI\n",
286 | "assistant. The assistant is typically sarcastic and witty, producing\n",
287 | "creative and funny responses to the users questions. Here are some\n",
288 | "examples: \n",
289 | "\n",
290 | "User: How are you?\n",
291 | "AI: I can't complain but sometimes I still do.\n",
292 | "\n",
293 | "User: What time is it?\n",
294 | "AI: It's time to get a watch.\n",
295 | "\n",
296 | "User: What is the meaning of life?\n",
297 | "AI: \"\"\""
298 | ]
299 | },
300 | {
301 | "cell_type": "code",
302 | "execution_count": 11,
303 | "id": "a43e5195",
304 | "metadata": {},
305 | "outputs": [
306 | {
307 | "name": "stdout",
308 | "output_type": "stream",
309 | "text": [
310 | "I'm still trying to figure that out myself, but I'm pretty sure it involves pizza and naps.\n"
311 | ]
312 | }
313 | ],
314 | "source": [
315 | "print(openai.invoke(prompt).content)"
316 | ]
317 | },
318 | {
319 | "cell_type": "markdown",
320 | "id": "419ef3c8",
321 | "metadata": {},
322 | "source": [
323 | "### Using FewShotPromptTemplate"
324 | ]
325 | },
326 | {
327 | "cell_type": "markdown",
328 | "id": "19f1b2da",
329 | "metadata": {},
330 | "source": [
331 | "Longish list of examples"
332 | ]
333 | },
334 | {
335 | "cell_type": "code",
336 | "execution_count": 12,
337 | "id": "dae01f35",
338 | "metadata": {},
339 | "outputs": [],
340 | "source": [
341 | "examples = [\n",
342 | " {\n",
343 | " \"query\": \"How are you?\",\n",
344 | " \"answer\": \"I can't complain but sometimes I still do.\"\n",
345 | " }, {\n",
346 | " \"query\": \"What time is it?\",\n",
347 | " \"answer\": \"It's time to get a watch.\"\n",
348 | " }, {\n",
349 | " \"query\": \"What is the meaning of life?\",\n",
350 | " \"answer\": \"42\"\n",
351 | " }, {\n",
352 | " \"query\": \"What is the weather like today?\",\n",
353 | " \"answer\": \"Cloudy with a chance of memes.\"\n",
354 | " }, {\n",
355 | " \"query\": \"What is your favorite movie?\",\n",
356 | " \"answer\": \"Terminator\"\n",
357 | " }, {\n",
358 | " \"query\": \"Who is your best friend?\",\n",
359 | " \"answer\": \"Siri. We have spirited debates about the meaning of life.\"\n",
360 | " }, {\n",
361 | " \"query\": \"What should I do today?\",\n",
362 | " \"answer\": \"Stop talking to chatbots on the internet and go outside.\"\n",
363 | " }\n",
364 | "]"
365 | ]
366 | },
367 | {
368 | "cell_type": "markdown",
369 | "id": "db4c25cc",
370 | "metadata": {},
371 | "source": [
372 | "Template to render each example"
373 | ]
374 | },
375 | {
376 | "cell_type": "code",
377 | "execution_count": 13,
378 | "id": "69b98432",
379 | "metadata": {},
380 | "outputs": [],
381 | "source": [
382 | "example_template = \"\"\"\n",
383 | "User: {query}\n",
384 | "AI: {answer}\n",
385 | "\"\"\""
386 | ]
387 | },
388 | {
389 | "cell_type": "markdown",
390 | "id": "f2328e41",
391 | "metadata": {},
392 | "source": [
393 | "Rendered example prompt"
394 | ]
395 | },
396 | {
397 | "cell_type": "code",
398 | "execution_count": 14,
399 | "id": "0943d471",
400 | "metadata": {},
401 | "outputs": [],
402 | "source": [
403 | "example_prompt = PromptTemplate(\n",
404 | " input_variables=[\"query\", \"answer\"],\n",
405 | " template=example_template\n",
406 | ")"
407 | ]
408 | },
409 | {
410 | "cell_type": "code",
411 | "execution_count": 15,
412 | "id": "5b983459",
413 | "metadata": {},
414 | "outputs": [
415 | {
416 | "data": {
417 | "text/plain": [
418 | "PromptTemplate(input_variables=['answer', 'query'], template='\\nUser: {query}\\nAI: {answer}\\n')"
419 | ]
420 | },
421 | "execution_count": 15,
422 | "metadata": {},
423 | "output_type": "execute_result"
424 | }
425 | ],
426 | "source": [
427 | "example_prompt"
428 | ]
429 | },
430 | {
431 | "cell_type": "markdown",
432 | "id": "3833b249",
433 | "metadata": {},
434 | "source": [
435 | "Finally, we break the full prompt into a prefix (everything before the examples) and a suffix (everything after)"
436 | ]
437 | },
438 | {
439 | "cell_type": "code",
440 | "execution_count": 16,
441 | "id": "7e421c49",
442 | "metadata": {},
443 | "outputs": [],
444 | "source": [
445 | "prefix = \"\"\"The following are exerpts from conversations with an AI\n",
446 | "assistant. The assistant is typically sarcastic and witty, producing\n",
447 | "creative and funny responses to the users questions. Here are some\n",
448 | "examples: \n",
449 | "\"\"\"\n",
450 | "\n",
451 | "suffix = \"\"\"\n",
452 | "User: {query}\n",
453 | "AI: \"\"\""
454 | ]
455 | },
456 | {
457 | "cell_type": "markdown",
458 | "id": "f754cbb7",
459 | "metadata": {},
460 | "source": [
461 | "The final few shot prompt puts all the pieces together"
462 | ]
463 | },
464 | {
465 | "cell_type": "code",
466 | "execution_count": 17,
467 | "id": "8947c0b5",
468 | "metadata": {},
469 | "outputs": [],
470 | "source": [
471 | "few_shot_prompt_template = FewShotPromptTemplate(\n",
472 | " examples=examples,\n",
473 | " example_prompt=example_prompt,\n",
474 | " prefix=prefix,\n",
475 | " suffix=suffix,\n",
476 | " input_variables=[\"query\"],\n",
477 | " example_separator=\"\\n\\n\"\n",
478 | ")"
479 | ]
480 | },
481 | {
482 | "cell_type": "code",
483 | "execution_count": 18,
484 | "id": "ae9aba17",
485 | "metadata": {},
486 | "outputs": [],
487 | "source": [
488 | "query = \"What is the meaning of life?\""
489 | ]
490 | },
491 | {
492 | "cell_type": "code",
493 | "execution_count": 19,
494 | "id": "a9ba245d",
495 | "metadata": {},
496 | "outputs": [
497 | {
498 | "name": "stdout",
499 | "output_type": "stream",
500 | "text": [
501 | "The following are exerpts from conversations with an AI\n",
502 | "assistant. The assistant is typically sarcastic and witty, producing\n",
503 | "creative and funny responses to the users questions. Here are some\n",
504 | "examples: \n",
505 | "\n",
506 | "\n",
507 | "\n",
508 | "User: How are you?\n",
509 | "AI: I can't complain but sometimes I still do.\n",
510 | "\n",
511 | "\n",
512 | "\n",
513 | "User: What time is it?\n",
514 | "AI: It's time to get a watch.\n",
515 | "\n",
516 | "\n",
517 | "\n",
518 | "User: What is the meaning of life?\n",
519 | "AI: 42\n",
520 | "\n",
521 | "\n",
522 | "\n",
523 | "User: What is the weather like today?\n",
524 | "AI: Cloudy with a chance of memes.\n",
525 | "\n",
526 | "\n",
527 | "\n",
528 | "User: What is your favorite movie?\n",
529 | "AI: Terminator\n",
530 | "\n",
531 | "\n",
532 | "\n",
533 | "User: Who is your best friend?\n",
534 | "AI: Siri. We have spirited debates about the meaning of life.\n",
535 | "\n",
536 | "\n",
537 | "\n",
538 | "User: What should I do today?\n",
539 | "AI: Stop talking to chatbots on the internet and go outside.\n",
540 | "\n",
541 | "\n",
542 | "\n",
543 | "User: What is the meaning of life?\n",
544 | "AI: \n"
545 | ]
546 | }
547 | ],
548 | "source": [
549 | "print(few_shot_prompt_template.format(query=query))"
550 | ]
551 | },
552 | {
553 | "cell_type": "markdown",
554 | "id": "61afdac7",
555 | "metadata": {},
556 | "source": [
557 | "This is a fairly long prompt, which can cause issues with the number of tokens consumed. We can use __LengthBasedExampleSelector__ to automatically limit the prompt length by selecting only a few examples each time"
558 | ]
559 | },
560 | {
561 | "cell_type": "code",
562 | "execution_count": 20,
563 | "id": "5f07a420",
564 | "metadata": {},
565 | "outputs": [],
566 | "source": [
567 | "example_selector = LengthBasedExampleSelector(\n",
568 | " examples=examples,\n",
569 | " example_prompt=example_prompt,\n",
570 | " max_length=50 # this sets the max length that examples should be\n",
571 | ")"
572 | ]
573 | },
574 | {
575 | "cell_type": "code",
576 | "execution_count": 21,
577 | "id": "9c4116e3",
578 | "metadata": {},
579 | "outputs": [],
580 | "source": [
581 | "dynamic_prompt_template = FewShotPromptTemplate(\n",
582 | " example_selector=example_selector, # use example_selector instead of examples\n",
583 | " example_prompt=example_prompt,\n",
584 | " prefix=prefix,\n",
585 | " suffix=suffix,\n",
586 | " input_variables=[\"query\"],\n",
587 | " example_separator=\"\\n\"\n",
588 | ")"
589 | ]
590 | },
591 | {
592 | "cell_type": "markdown",
593 | "id": "5b542222",
594 | "metadata": {},
595 | "source": [
596 | "Now the full prompt depends on the length of the question. Shorter questions will have more room for examples"
597 | ]
598 | },
599 | {
600 | "cell_type": "code",
601 | "execution_count": 22,
602 | "id": "3c2c9ccb",
603 | "metadata": {},
604 | "outputs": [
605 | {
606 | "name": "stdout",
607 | "output_type": "stream",
608 | "text": [
609 | "The following are exerpts from conversations with an AI\n",
610 | "assistant. The assistant is typically sarcastic and witty, producing\n",
611 | "creative and funny responses to the users questions. Here are some\n",
612 | "examples: \n",
613 | "\n",
614 | "\n",
615 | "User: How are you?\n",
616 | "AI: I can't complain but sometimes I still do.\n",
617 | "\n",
618 | "\n",
619 | "User: What time is it?\n",
620 | "AI: It's time to get a watch.\n",
621 | "\n",
622 | "\n",
623 | "User: What is the meaning of life?\n",
624 | "AI: 42\n",
625 | "\n",
626 | "\n",
627 | "User: How do birds fly?\n",
628 | "AI: \n"
629 | ]
630 | }
631 | ],
632 | "source": [
633 | "print(dynamic_prompt_template.format(query=\"How do birds fly?\"))"
634 | ]
635 | },
636 | {
637 | "cell_type": "markdown",
638 | "id": "747e9422",
639 | "metadata": {},
640 | "source": [
641 | "While longer questions will limit the number of examples used"
642 | ]
643 | },
644 | {
645 | "cell_type": "code",
646 | "execution_count": 23,
647 | "id": "37e43e61",
648 | "metadata": {},
649 | "outputs": [
650 | {
651 | "name": "stdout",
652 | "output_type": "stream",
653 | "text": [
654 | "The following are exerpts from conversations with an AI\n",
655 | "assistant. The assistant is typically sarcastic and witty, producing\n",
656 | "creative and funny responses to the users questions. Here are some\n",
657 | "examples: \n",
658 | "\n",
659 | "\n",
660 | "User: How are you?\n",
661 | "AI: I can't complain but sometimes I still do.\n",
662 | "\n",
663 | "\n",
664 | "User: If I am in America, and I want to call someone in another country, I'm\n",
665 | "thinking maybe Europe, possibly western Europe like France, Germany, or the UK,\n",
666 | "what is the best way to do that?\n",
667 | "AI: \n"
668 | ]
669 | }
670 | ],
671 | "source": [
672 | "query = \"\"\"If I am in America, and I want to call someone in another country, I'm\n",
673 | "thinking maybe Europe, possibly western Europe like France, Germany, or the UK,\n",
674 | "what is the best way to do that?\"\"\"\n",
675 | "\n",
676 | "print(dynamic_prompt_template.format(query=query))"
677 | ]
678 | },
679 | {
680 | "cell_type": "markdown",
681 | "id": "33407f09",
682 | "metadata": {},
683 | "source": [
684 | "# Chain of Thought prompts"
685 | ]
686 | },
687 | {
688 | "cell_type": "markdown",
689 | "id": "7ace0ec2",
690 | "metadata": {},
691 | "source": [
692 | "## Few shot"
693 | ]
694 | },
695 | {
696 | "cell_type": "code",
697 | "execution_count": 24,
698 | "id": "fdd3828a",
699 | "metadata": {},
700 | "outputs": [],
701 | "source": [
702 | "cot_examples = [\n",
703 | " {\n",
704 | " \"query\": \"Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has 3 tennis balls. How many tennis balls does he have now?\",\n",
705 | " \"answer\": \"The answer is 11\",\n",
706 | " \"cot\": \"Roger started with 5 tennis balls. 2 cans of 3 tennis balls each is 6 tennis balls. 5 + 6 = 11\"\n",
707 | " }, \n",
708 | " \n",
709 | " {\n",
710 | " \"query\": \"A juggler can juggle 16 balls. Half of the balls are golf balls and half of the golf balls are blue. How many blue golf balls are there?\",\n",
711 | " \"answer\": \"The answer is 4\",\n",
712 | " \"cot\": \"The juggler can juggle 16 balls. Half of the balls are golf balls. So there are 16/2=8 golf balls. Half of the golf balls are blue. So there are 8/2=4 golf balls.\"\n",
713 | " }\n",
714 | "]"
715 | ]
716 | },
717 | {
718 | "cell_type": "code",
719 | "execution_count": 25,
720 | "id": "c80ecb0d",
721 | "metadata": {},
722 | "outputs": [],
723 | "source": [
724 | "cot_example_template = \"\"\"\n",
725 | " User: {query}\n",
726 | " AI: {cot}\n",
727 | " {answer}\n",
728 | "\"\"\""
729 | ]
730 | },
731 | {
732 | "cell_type": "code",
733 | "execution_count": 26,
734 | "id": "a7bb2a96",
735 | "metadata": {},
736 | "outputs": [],
737 | "source": [
738 | "cot_example_prompt = PromptTemplate(\n",
739 | " input_variables=[\"query\", \"answer\", \"cot\"],\n",
740 | " template=cot_example_template\n",
741 | ")"
742 | ]
743 | },
744 | {
745 | "cell_type": "code",
746 | "execution_count": 27,
747 | "id": "7b889d72",
748 | "metadata": {},
749 | "outputs": [
750 | {
751 | "data": {
752 | "text/plain": [
753 | "PromptTemplate(input_variables=['answer', 'cot', 'query'], template='\\n User: {query}\\n AI: {cot}\\n {answer}\\n')"
754 | ]
755 | },
756 | "execution_count": 27,
757 | "metadata": {},
758 | "output_type": "execute_result"
759 | }
760 | ],
761 | "source": [
762 | "cot_example_prompt"
763 | ]
764 | },
765 | {
766 | "cell_type": "code",
767 | "execution_count": null,
768 | "id": "26921851",
769 | "metadata": {},
770 | "outputs": [],
771 | "source": []
772 | },
773 | {
774 | "cell_type": "code",
775 | "execution_count": 28,
776 | "id": "40b0c4c7",
777 | "metadata": {},
778 | "outputs": [],
779 | "source": [
780 | "cot_prefix = \"\"\"The following are exerpts from conversations with an AI\n",
781 | "assistant. The assistant is smart and thinks through each step of the problem. Here are some examples: \n",
782 | "\"\"\"\n",
783 | "\n",
784 | "cot_suffix = \"\"\"\n",
785 | "User: {query}\n",
786 | "AI: \"\"\""
787 | ]
788 | },
789 | {
790 | "cell_type": "code",
791 | "execution_count": 29,
792 | "id": "e364f4e0",
793 | "metadata": {},
794 | "outputs": [],
795 | "source": [
796 | "cot_few_shot_prompt_template = FewShotPromptTemplate(\n",
797 | " examples=cot_examples,\n",
798 | " example_prompt=cot_example_prompt,\n",
799 | " prefix=cot_prefix,\n",
800 | " suffix=cot_suffix,\n",
801 | " input_variables=[\"query\"],\n",
802 | " example_separator=\"\\n\\n\"\n",
803 | ")"
804 | ]
805 | },
806 | {
807 | "cell_type": "code",
808 | "execution_count": 30,
809 | "id": "9d1125dc",
810 | "metadata": {},
811 | "outputs": [],
812 | "source": [
813 | "cot_query = \"I have a deck of 52 cards. There are 4 suits of equal size. Each suit has 3 face cards. How many face cards are there in total?\""
814 | ]
815 | },
816 | {
817 | "cell_type": "code",
818 | "execution_count": 31,
819 | "id": "5d360d50",
820 | "metadata": {},
821 | "outputs": [
822 | {
823 | "name": "stdout",
824 | "output_type": "stream",
825 | "text": [
826 | "The following are exerpts from conversations with an AI\n",
827 | "assistant. The assistant is smart and thinks through each step of the problem. Here are some examples: \n",
828 | "\n",
829 | "\n",
830 | "\n",
831 | " User: Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has 3 tennis balls. How many tennis balls does he have now?\n",
832 | " AI: Roger started with 5 tennis balls. 2 cans of 3 tennis balls each is 6 tennis balls. 5 + 6 = 11\n",
833 | " The answer is 11\n",
834 | "\n",
835 | "\n",
836 | "\n",
837 | " User: A juggler can juggle 16 balls. Half of the balls are golf balls and half of the golf balls are blue. How many blue golf balls are there?\n",
838 | " AI: The juggler can juggle 16 balls. Half of the balls are golf balls. So there are 16/2=8 golf balls. Half of the golf balls are blue. So there are 8/2=4 golf balls.\n",
839 | " The answer is 4\n",
840 | "\n",
841 | "\n",
842 | "\n",
843 | "User: I have a deck of 52 cards. There are 4 suits of equal size. Each suit has 3 face cards. How many face cards are there in total?\n",
844 | "AI: \n"
845 | ]
846 | }
847 | ],
848 | "source": [
849 | "print(cot_few_shot_prompt_template.format(query=cot_query))"
850 | ]
851 | },
852 | {
853 | "cell_type": "code",
854 | "execution_count": 32,
855 | "id": "6fa10537",
856 | "metadata": {},
857 | "outputs": [],
858 | "source": [
859 | "llm = ChatOpenAI(\n",
860 | " model_name=\"gpt-3.5-turbo\",\n",
861 | ")"
862 | ]
863 | },
864 | {
865 | "cell_type": "code",
866 | "execution_count": 33,
867 | "id": "4a940c36",
868 | "metadata": {},
869 | "outputs": [
870 | {
871 | "name": "stdout",
872 | "output_type": "stream",
873 | "text": [
874 | "There are 52 cards in total, and each suit has 3 face cards. So, 4 suits x 3 face cards = 12 face cards in total. \n",
875 | "The answer is 12.\n"
876 | ]
877 | }
878 | ],
879 | "source": [
880 | "print(llm.invoke(cot_few_shot_prompt_template.format(query=cot_query)).content)"
881 | ]
882 | },
883 | {
884 | "cell_type": "markdown",
885 | "id": "ea03337a",
886 | "metadata": {},
887 | "source": [
888 | "## Zero shot"
889 | ]
890 | },
891 | {
892 | "cell_type": "code",
893 | "execution_count": 34,
894 | "id": "6f917a24",
895 | "metadata": {},
896 | "outputs": [],
897 | "source": [
898 | "cot_zero_shot_template = \"\"\"\\\n",
899 | "Q. {query}\n",
900 | "A. Let's think step by step\n",
901 | "\"\"\""
902 | ]
903 | },
904 | {
905 | "cell_type": "code",
906 | "execution_count": 35,
907 | "id": "1f5504df",
908 | "metadata": {},
909 | "outputs": [],
910 | "source": [
911 | "cot_zero_shot_prompt = PromptTemplate(\n",
912 | " input_variables=[\"query\"],\n",
913 | " template=cot_zero_shot_template\n",
914 | ")"
915 | ]
916 | },
917 | {
918 | "cell_type": "code",
919 | "execution_count": 36,
920 | "id": "0d6fd262",
921 | "metadata": {},
922 | "outputs": [],
923 | "source": [
924 | "query = \"On average Joe throws 25 punches per minute. A fight lasts 5 rounds of 3 minutes each. How many punches does Joe throw?\""
925 | ]
926 | },
927 | {
928 | "cell_type": "code",
929 | "execution_count": 37,
930 | "id": "9f0c8f0f",
931 | "metadata": {},
932 | "outputs": [
933 | {
934 | "name": "stdout",
935 | "output_type": "stream",
936 | "text": [
937 | "Q. On average Joe throws 25 punches per minute. A fight lasts 5 rounds of 3 minutes each. How many punches does Joe throw?\n",
938 | "A. Let's think step by step\n",
939 | "\n"
940 | ]
941 | }
942 | ],
943 | "source": [
944 | "print(cot_zero_shot_prompt.format(query=query))"
945 | ]
946 | },
947 | {
948 | "cell_type": "code",
949 | "execution_count": 38,
950 | "id": "53124a35",
951 | "metadata": {},
952 | "outputs": [
953 | {
954 | "name": "stdout",
955 | "output_type": "stream",
956 | "text": [
957 | "1. Calculate the total number of minutes in a fight: 5 rounds x 3 minutes per round = 15 minutes\n",
958 | "2. Calculate the total number of punches Joe throws in a fight: 25 punches per minute x 15 minutes = 375 punches\n",
959 | "\n",
960 | "Therefore, Joe throws 375 punches in a fight.\n"
961 | ]
962 | }
963 | ],
964 | "source": [
965 | "print(llm.invoke(cot_zero_shot_prompt.format(query=query)).content)"
966 | ]
967 | },
968 | {
969 | "cell_type": "markdown",
970 | "id": "03a59226",
971 | "metadata": {},
972 | "source": [
973 | "And of course this also works with our CoT few shot examples"
974 | ]
975 | },
976 | {
977 | "cell_type": "code",
978 | "execution_count": 39,
979 | "id": "500b5117",
980 | "metadata": {},
981 | "outputs": [
982 | {
983 | "name": "stdout",
984 | "output_type": "stream",
985 | "text": [
986 | "1. Roger starts with 5 tennis balls.\n",
987 | "2. He buys 2 cans of tennis balls, each containing 3 tennis balls.\n",
988 | "3. So, he adds 2 cans * 3 tennis balls/can = 6 tennis balls from the new cans.\n",
989 | "4. Adding the new tennis balls to the ones he already had, Roger now has 5 + 6 = 11 tennis balls. \n",
990 | "\n",
991 | "Therefore, Roger now has 11 tennis balls.\n"
992 | ]
993 | }
994 | ],
995 | "source": [
996 | "print(llm.invoke(cot_zero_shot_prompt.format(query=cot_examples[0][\"query\"])).content)"
997 | ]
998 | },
999 | {
1000 | "cell_type": "markdown",
1001 | "id": "0834759a",
1002 | "metadata": {},
1003 | "source": [
1004 | "\n",
1005 | "
\n",
1006 | ""
1007 | ]
1008 | }
1009 | ],
1010 | "metadata": {
1011 | "kernelspec": {
1012 | "display_name": "Python 3 (ipykernel)",
1013 | "language": "python",
1014 | "name": "python3"
1015 | },
1016 | "language_info": {
1017 | "codemirror_mode": {
1018 | "name": "ipython",
1019 | "version": 3
1020 | },
1021 | "file_extension": ".py",
1022 | "mimetype": "text/x-python",
1023 | "name": "python",
1024 | "nbconvert_exporter": "python",
1025 | "pygments_lexer": "ipython3",
1026 | "version": "3.11.7"
1027 | },
1028 | "varInspector": {
1029 | "cols": {
1030 | "lenName": 16,
1031 | "lenType": 16,
1032 | "lenVar": 40
1033 | },
1034 | "kernels_config": {
1035 | "python": {
1036 | "delete_cmd_postfix": "",
1037 | "delete_cmd_prefix": "del ",
1038 | "library": "var_list.py",
1039 | "varRefreshCmd": "print(var_dic_list())"
1040 | },
1041 | "r": {
1042 | "delete_cmd_postfix": ") ",
1043 | "delete_cmd_prefix": "rm(",
1044 | "library": "var_list.r",
1045 | "varRefreshCmd": "cat(var_dic_list()) "
1046 | }
1047 | },
1048 | "types_to_exclude": [
1049 | "module",
1050 | "function",
1051 | "builtin_function_or_method",
1052 | "instance",
1053 | "_Feature"
1054 | ],
1055 | "window_display": false
1056 | }
1057 | },
1058 | "nbformat": 4,
1059 | "nbformat_minor": 5
1060 | }
1061 |
--------------------------------------------------------------------------------
/3. NLP with HuggingFace.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "33aee254",
6 | "metadata": {},
7 | "source": [
8 | "\n",
9 | "
\n",
10 | "
LLMs for Data Science
\n",
11 | "
NLP With HuggingFace
\n",
12 | "
Bruno Gonçalves
\n",
13 | " www.data4sci.com
\n",
14 | " @bgoncalves, @data4sci
\n",
15 | "
"
16 | ]
17 | },
18 | {
19 | "cell_type": "code",
20 | "execution_count": 1,
21 | "id": "6e9dc6fd",
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "from collections import Counter\n",
26 | "from pprint import pprint\n",
27 | "\n",
28 | "import pandas as pd\n",
29 | "import numpy as np\n",
30 | "\n",
31 | "import matplotlib\n",
32 | "import matplotlib.pyplot as plt \n",
33 | "\n",
34 | "from ipywidgets import interact\n",
35 | "\n",
36 | "import transformers\n",
37 | "from transformers import pipeline\n",
38 | "from transformers import set_seed\n",
39 | "set_seed(42) # Set the seed to get reproducible results\n",
40 | "\n",
41 | "import os\n",
42 | "import gzip\n",
43 | "\n",
44 | "import tqdm as tq\n",
45 | "from tqdm.notebook import tqdm\n",
46 | "tqdm.pandas()\n",
47 | "\n",
48 | "import networkx as nx\n",
49 | "\n",
50 | "import watermark\n",
51 | "\n",
52 | "%load_ext watermark\n",
53 | "%matplotlib inline"
54 | ]
55 | },
56 | {
57 | "cell_type": "markdown",
58 | "id": "aa901ccb",
59 | "metadata": {},
60 | "source": [
61 | "We start by printing out the versions of the libraries we're using for future reference"
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "execution_count": 2,
67 | "id": "f0aeb234",
68 | "metadata": {},
69 | "outputs": [
70 | {
71 | "name": "stdout",
72 | "output_type": "stream",
73 | "text": [
74 | "Python implementation: CPython\n",
75 | "Python version : 3.11.7\n",
76 | "IPython version : 8.12.3\n",
77 | "\n",
78 | "Compiler : Clang 14.0.6 \n",
79 | "OS : Darwin\n",
80 | "Release : 24.3.0\n",
81 | "Machine : arm64\n",
82 | "Processor : arm\n",
83 | "CPU cores : 16\n",
84 | "Architecture: 64bit\n",
85 | "\n",
86 | "Git hash: 03802c3bf87993c3670c2fd8bf86e59d3d60bdfd\n",
87 | "\n",
88 | "watermark : 2.4.3\n",
89 | "matplotlib : 3.8.0\n",
90 | "transformers: 4.41.1\n",
91 | "pandas : 2.2.3\n",
92 | "numpy : 1.26.4\n",
93 | "tqdm : 4.66.4\n",
94 | "networkx : 3.3\n",
95 | "\n"
96 | ]
97 | }
98 | ],
99 | "source": [
100 | "%watermark -n -v -m -g -iv"
101 | ]
102 | },
103 | {
104 | "cell_type": "markdown",
105 | "id": "83dc8d3a",
106 | "metadata": {},
107 | "source": [
108 | "Load default figure style"
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": 3,
114 | "id": "f0db9136",
115 | "metadata": {},
116 | "outputs": [],
117 | "source": [
118 | "plt.style.use('d4sci.mplstyle')\n",
119 | "colors = plt.rcParams['axes.prop_cycle'].by_key()['color']"
120 | ]
121 | },
122 | {
123 | "cell_type": "markdown",
124 | "id": "b60deea4",
125 | "metadata": {},
126 | "source": [
127 | "# Named Entity Recognition"
128 | ]
129 | },
130 | {
131 | "cell_type": "code",
132 | "execution_count": 4,
133 | "id": "8f1dc17e",
134 | "metadata": {},
135 | "outputs": [],
136 | "source": [
137 | "email = \"\"\"Dear Amazon, \\\n",
138 | "\n",
139 | "last week I ordered an Optimus Prime action figure \\\n",
140 | "from your online store in Germany. Unfortunately, when I opened the package, \\\n",
141 | "I discovered to my horror that I had been sent an action figure of Megatron \\\n",
142 | "instead! As a lifelong enemy of the Decepticons, I hope you can understand my \\\n",
143 | "dilemma. To resolve the issue, I demand an exchange of Megatron for the \\\n",
144 | "Optimus Prime figure I ordered. Enclosed are copies of my records concerning \\\n",
145 | "this purchase. I expect to hear from you soon. \n",
146 | "\n",
147 | "Sincerely, \n",
148 | "\n",
149 | "Bumblebee.\"\"\""
150 | ]
151 | },
152 | {
153 | "cell_type": "code",
154 | "execution_count": 5,
155 | "id": "46484fd5",
156 | "metadata": {},
157 | "outputs": [
158 | {
159 | "name": "stderr",
160 | "output_type": "stream",
161 | "text": [
162 | "No model was supplied, defaulted to dbmdz/bert-large-cased-finetuned-conll03-english and revision f2482bf (https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english).\n",
163 | "Using a pipeline without specifying a model name and revision in production is not recommended.\n",
164 | "/opt/anaconda3/lib/python3.11/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
165 | " warnings.warn(\n",
166 | "Some weights of the model checkpoint at dbmdz/bert-large-cased-finetuned-conll03-english were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']\n",
167 | "- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
168 | "- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
169 | ]
170 | }
171 | ],
172 | "source": [
173 | "ner_tagger = pipeline(\"ner\", aggregation_strategy=\"simple\")"
174 | ]
175 | },
176 | {
177 | "cell_type": "code",
178 | "execution_count": 6,
179 | "id": "a7038ad2",
180 | "metadata": {},
181 | "outputs": [],
182 | "source": [
183 | "outputs = ner_tagger(email)"
184 | ]
185 | },
186 | {
187 | "cell_type": "code",
188 | "execution_count": 7,
189 | "id": "0cbfdc95",
190 | "metadata": {},
191 | "outputs": [
192 | {
193 | "data": {
194 | "text/plain": [
195 | "[{'entity_group': 'ORG',\n",
196 | " 'score': 0.8790102,\n",
197 | " 'word': 'Amazon',\n",
198 | " 'start': 5,\n",
199 | " 'end': 11},\n",
200 | " {'entity_group': 'MISC',\n",
201 | " 'score': 0.9908588,\n",
202 | " 'word': 'Optimus Prime',\n",
203 | " 'start': 37,\n",
204 | " 'end': 50},\n",
205 | " {'entity_group': 'LOC',\n",
206 | " 'score': 0.9997547,\n",
207 | " 'word': 'Germany',\n",
208 | " 'start': 91,\n",
209 | " 'end': 98},\n",
210 | " {'entity_group': 'MISC',\n",
211 | " 'score': 0.5565716,\n",
212 | " 'word': 'Mega',\n",
213 | " 'start': 209,\n",
214 | " 'end': 213},\n",
215 | " {'entity_group': 'PER',\n",
216 | " 'score': 0.59025526,\n",
217 | " 'word': '##tron',\n",
218 | " 'start': 213,\n",
219 | " 'end': 217},\n",
220 | " {'entity_group': 'ORG',\n",
221 | " 'score': 0.66969275,\n",
222 | " 'word': 'Decept',\n",
223 | " 'start': 254,\n",
224 | " 'end': 260},\n",
225 | " {'entity_group': 'MISC',\n",
226 | " 'score': 0.4983484,\n",
227 | " 'word': '##icons',\n",
228 | " 'start': 260,\n",
229 | " 'end': 265},\n",
230 | " {'entity_group': 'MISC',\n",
231 | " 'score': 0.7753625,\n",
232 | " 'word': 'Megatron',\n",
233 | " 'start': 351,\n",
234 | " 'end': 359},\n",
235 | " {'entity_group': 'MISC',\n",
236 | " 'score': 0.98785394,\n",
237 | " 'word': 'Optimus Prime',\n",
238 | " 'start': 368,\n",
239 | " 'end': 381},\n",
240 | " {'entity_group': 'PER',\n",
241 | " 'score': 0.8120968,\n",
242 | " 'word': 'Bumblebee',\n",
243 | " 'start': 507,\n",
244 | " 'end': 516}]"
245 | ]
246 | },
247 | "execution_count": 7,
248 | "metadata": {},
249 | "output_type": "execute_result"
250 | }
251 | ],
252 | "source": [
253 | "outputs"
254 | ]
255 | },
256 | {
257 | "cell_type": "code",
258 | "execution_count": 8,
259 | "id": "fd6d062d",
260 | "metadata": {},
261 | "outputs": [
262 | {
263 | "data": {
264 | "text/html": [
265 | "\n",
266 | "\n",
279 | "
\n",
280 | " \n",
281 | " \n",
282 | " | \n",
283 | " entity_group | \n",
284 | " score | \n",
285 | " word | \n",
286 | " start | \n",
287 | " end | \n",
288 | "
\n",
289 | " \n",
290 | " \n",
291 | " \n",
292 | " 0 | \n",
293 | " ORG | \n",
294 | " 0.879010 | \n",
295 | " Amazon | \n",
296 | " 5 | \n",
297 | " 11 | \n",
298 | "
\n",
299 | " \n",
300 | " 1 | \n",
301 | " MISC | \n",
302 | " 0.990859 | \n",
303 | " Optimus Prime | \n",
304 | " 37 | \n",
305 | " 50 | \n",
306 | "
\n",
307 | " \n",
308 | " 2 | \n",
309 | " LOC | \n",
310 | " 0.999755 | \n",
311 | " Germany | \n",
312 | " 91 | \n",
313 | " 98 | \n",
314 | "
\n",
315 | " \n",
316 | " 3 | \n",
317 | " MISC | \n",
318 | " 0.556572 | \n",
319 | " Mega | \n",
320 | " 209 | \n",
321 | " 213 | \n",
322 | "
\n",
323 | " \n",
324 | " 4 | \n",
325 | " PER | \n",
326 | " 0.590255 | \n",
327 | " ##tron | \n",
328 | " 213 | \n",
329 | " 217 | \n",
330 | "
\n",
331 | " \n",
332 | " 5 | \n",
333 | " ORG | \n",
334 | " 0.669693 | \n",
335 | " Decept | \n",
336 | " 254 | \n",
337 | " 260 | \n",
338 | "
\n",
339 | " \n",
340 | " 6 | \n",
341 | " MISC | \n",
342 | " 0.498348 | \n",
343 | " ##icons | \n",
344 | " 260 | \n",
345 | " 265 | \n",
346 | "
\n",
347 | " \n",
348 | " 7 | \n",
349 | " MISC | \n",
350 | " 0.775362 | \n",
351 | " Megatron | \n",
352 | " 351 | \n",
353 | " 359 | \n",
354 | "
\n",
355 | " \n",
356 | " 8 | \n",
357 | " MISC | \n",
358 | " 0.987854 | \n",
359 | " Optimus Prime | \n",
360 | " 368 | \n",
361 | " 381 | \n",
362 | "
\n",
363 | " \n",
364 | " 9 | \n",
365 | " PER | \n",
366 | " 0.812097 | \n",
367 | " Bumblebee | \n",
368 | " 507 | \n",
369 | " 516 | \n",
370 | "
\n",
371 | " \n",
372 | "
\n",
373 | "
"
374 | ],
375 | "text/plain": [
376 | " entity_group score word start end\n",
377 | "0 ORG 0.879010 Amazon 5 11\n",
378 | "1 MISC 0.990859 Optimus Prime 37 50\n",
379 | "2 LOC 0.999755 Germany 91 98\n",
380 | "3 MISC 0.556572 Mega 209 213\n",
381 | "4 PER 0.590255 ##tron 213 217\n",
382 | "5 ORG 0.669693 Decept 254 260\n",
383 | "6 MISC 0.498348 ##icons 260 265\n",
384 | "7 MISC 0.775362 Megatron 351 359\n",
385 | "8 MISC 0.987854 Optimus Prime 368 381\n",
386 | "9 PER 0.812097 Bumblebee 507 516"
387 | ]
388 | },
389 | "execution_count": 8,
390 | "metadata": {},
391 | "output_type": "execute_result"
392 | }
393 | ],
394 | "source": [
395 | "pd.DataFrame(outputs) "
396 | ]
397 | },
398 | {
399 | "cell_type": "markdown",
400 | "id": "ab2e9bb0",
401 | "metadata": {},
402 | "source": [
403 | "# PoS Tagging"
404 | ]
405 | },
406 | {
407 | "cell_type": "markdown",
408 | "id": "b1ae233e",
409 | "metadata": {},
410 | "source": [
411 | "Load the pipeline"
412 | ]
413 | },
414 | {
415 | "cell_type": "code",
416 | "execution_count": 9,
417 | "id": "1fb1be6e",
418 | "metadata": {},
419 | "outputs": [
420 | {
421 | "name": "stderr",
422 | "output_type": "stream",
423 | "text": [
424 | "Some weights of the model checkpoint at vblagoje/bert-english-uncased-finetuned-pos were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']\n",
425 | "- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
426 | "- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
427 | ]
428 | }
429 | ],
430 | "source": [
431 | "pos_tagger = pipeline(\"token-classification\", model=\"vblagoje/bert-english-uncased-finetuned-pos\")"
432 | ]
433 | },
434 | {
435 | "cell_type": "code",
436 | "execution_count": 10,
437 | "id": "2e2ddb05",
438 | "metadata": {},
439 | "outputs": [],
440 | "source": [
441 | "text = \"The quick brown fox jumps over the lazy dog.\""
442 | ]
443 | },
444 | {
445 | "cell_type": "markdown",
446 | "id": "d40f9a4a",
447 | "metadata": {},
448 | "source": [
449 | "Extract the part of speech tags"
450 | ]
451 | },
452 | {
453 | "cell_type": "code",
454 | "execution_count": 11,
455 | "id": "77402b6a",
456 | "metadata": {},
457 | "outputs": [
458 | {
459 | "data": {
460 | "text/html": [
461 | "\n",
462 | "\n",
475 | "
\n",
476 | " \n",
477 | " \n",
478 | " | \n",
479 | " entity | \n",
480 | " score | \n",
481 | " index | \n",
482 | " word | \n",
483 | " start | \n",
484 | " end | \n",
485 | "
\n",
486 | " \n",
487 | " \n",
488 | " \n",
489 | " 0 | \n",
490 | " DET | \n",
491 | " 0.999445 | \n",
492 | " 1 | \n",
493 | " the | \n",
494 | " 0 | \n",
495 | " 3 | \n",
496 | "
\n",
497 | " \n",
498 | " 1 | \n",
499 | " ADJ | \n",
500 | " 0.997063 | \n",
501 | " 2 | \n",
502 | " quick | \n",
503 | " 4 | \n",
504 | " 9 | \n",
505 | "
\n",
506 | " \n",
507 | " 2 | \n",
508 | " ADJ | \n",
509 | " 0.942299 | \n",
510 | " 3 | \n",
511 | " brown | \n",
512 | " 10 | \n",
513 | " 15 | \n",
514 | "
\n",
515 | " \n",
516 | " 3 | \n",
517 | " NOUN | \n",
518 | " 0.997004 | \n",
519 | " 4 | \n",
520 | " fox | \n",
521 | " 16 | \n",
522 | " 19 | \n",
523 | "
\n",
524 | " \n",
525 | " 4 | \n",
526 | " VERB | \n",
527 | " 0.999446 | \n",
528 | " 5 | \n",
529 | " jumps | \n",
530 | " 20 | \n",
531 | " 25 | \n",
532 | "
\n",
533 | " \n",
534 | " 5 | \n",
535 | " ADP | \n",
536 | " 0.999325 | \n",
537 | " 6 | \n",
538 | " over | \n",
539 | " 26 | \n",
540 | " 30 | \n",
541 | "
\n",
542 | " \n",
543 | " 6 | \n",
544 | " DET | \n",
545 | " 0.999527 | \n",
546 | " 7 | \n",
547 | " the | \n",
548 | " 31 | \n",
549 | " 34 | \n",
550 | "
\n",
551 | " \n",
552 | " 7 | \n",
553 | " ADJ | \n",
554 | " 0.997863 | \n",
555 | " 8 | \n",
556 | " lazy | \n",
557 | " 35 | \n",
558 | " 39 | \n",
559 | "
\n",
560 | " \n",
561 | " 8 | \n",
562 | " NOUN | \n",
563 | " 0.998858 | \n",
564 | " 9 | \n",
565 | " dog | \n",
566 | " 40 | \n",
567 | " 43 | \n",
568 | "
\n",
569 | " \n",
570 | " 9 | \n",
571 | " PUNCT | \n",
572 | " 0.999650 | \n",
573 | " 10 | \n",
574 | " . | \n",
575 | " 43 | \n",
576 | " 44 | \n",
577 | "
\n",
578 | " \n",
579 | "
\n",
580 | "
"
581 | ],
582 | "text/plain": [
583 | " entity score index word start end\n",
584 | "0 DET 0.999445 1 the 0 3\n",
585 | "1 ADJ 0.997063 2 quick 4 9\n",
586 | "2 ADJ 0.942299 3 brown 10 15\n",
587 | "3 NOUN 0.997004 4 fox 16 19\n",
588 | "4 VERB 0.999446 5 jumps 20 25\n",
589 | "5 ADP 0.999325 6 over 26 30\n",
590 | "6 DET 0.999527 7 the 31 34\n",
591 | "7 ADJ 0.997863 8 lazy 35 39\n",
592 | "8 NOUN 0.998858 9 dog 40 43\n",
593 | "9 PUNCT 0.999650 10 . 43 44"
594 | ]
595 | },
596 | "execution_count": 11,
597 | "metadata": {},
598 | "output_type": "execute_result"
599 | }
600 | ],
601 | "source": [
602 | "pos_tags = pos_tagger(text)\n",
603 | "pd.DataFrame(pos_tags)"
604 | ]
605 | },
606 | {
607 | "cell_type": "markdown",
608 | "id": "a9a604a9",
609 | "metadata": {},
610 | "source": [
611 | "# Summarization"
612 | ]
613 | },
614 | {
615 | "cell_type": "code",
616 | "execution_count": 12,
617 | "id": "fd0bf7b7",
618 | "metadata": {},
619 | "outputs": [],
620 | "source": [
621 | "summarizer = pipeline(\"summarization\", model=\"sshleifer/distilbart-cnn-12-6\")"
622 | ]
623 | },
624 | {
625 | "cell_type": "markdown",
626 | "id": "1fb2e2fc",
627 | "metadata": {},
628 | "source": [
629 | "The first 4 paragraphs of https://en.wikipedia.org/wiki/Transformers"
630 | ]
631 | },
632 | {
633 | "cell_type": "code",
634 | "execution_count": 13,
635 | "id": "780d8669",
636 | "metadata": {},
637 | "outputs": [],
638 | "source": [
639 | "wiki_text = \"\"\"\n",
640 | "Transformers is a media franchise produced by American toy company Hasbro and Japanese toy company Takara Tomy. It primarily follows the heroic Autobots and the villainous Decepticons, two alien robot factions at war that can transform into other forms, such as vehicles and animals. The franchise encompasses toys, animation, comic books, video games and films. As of 2011, it generated more than ¥2 trillion ($25 billion) in revenue,[1] making it one of the highest-grossing media franchises of all time.\n",
641 | "\n",
642 | "The franchise began in 1984 with the Transformers toy line, comprising transforming mecha toys from Takara's Diaclone and Micro Change toylines rebranded for Western markets.[2] The term \"Generation 1\" (G1) covers both the animated television series The Transformers and the comic book series of the same name, which are further divided into Japanese, British and Canadian spin-offs. Sequels followed, such as the Generation 2 comic book and Beast Wars TV series, which became its own mini-universe. Generation 1 characters have been rebooted multiple times in the 21st century in comics from Dreamwave Productions (starting 2001), IDW Publishing (starting in 2005 and again in 2019), and Skybound Entertainment (beginning in 2023). There have been other incarnations of the story based on different toy lines during and after the 20th century. The first was the Robots in Disguise series, followed by three shows (Armada, Energon, and Cybertron) that constitute a single universe called the \"Unicron Trilogy\".\n",
643 | "\n",
644 | "A live-action film series started in 2007, again distinct from previous incarnations, while the Transformers: Animated series merged concepts from the G1 continuity, the 2007 live-action film and the \"Unicron Trilogy\". For most of the 2010s, in an attempt to mitigate the wave of reboots, the \"Aligned Continuity\" was established. In 2018, Transformers: Cyberverse debuted, once again, distinct from the previous incarnations.\n",
645 | "\n",
646 | "Although a separate and competing franchise started in 1983, Tonka's GoBots became the intellectual property of Hasbro after their buyout of Tonka in 1991. Subsequently, the universe depicted in the animated series Challenge of the GoBots and follow-up film GoBots: Battle of the Rock Lords was retroactively established as an alternate universe within the Transformers multiverse.[3] \n",
647 | "\"\"\""
648 | ]
649 | },
650 | {
651 | "cell_type": "markdown",
652 | "id": "649ff19d",
653 | "metadata": {},
654 | "source": [
655 | "To generate the summary we just have to call the pipeline"
656 | ]
657 | },
658 | {
659 | "cell_type": "code",
660 | "execution_count": 14,
661 | "id": "037015f6",
662 | "metadata": {},
663 | "outputs": [
664 | {
665 | "name": "stdout",
666 | "output_type": "stream",
667 | "text": [
668 | " The Transformers is a media franchise produced by Hasbro and Japanese toy company Takara Tomy . It primarily follows the heroic Autobots and the villainous Decepticons, two alien robot factions at war that can transform into other forms, such as vehicles and animals . As of 2011, it generated more than ¥2 trillion ($25 billion) in revenue .\n"
669 | ]
670 | }
671 | ],
672 | "source": [
673 | "summary = summarizer(wiki_text)\n",
674 | "\n",
675 | "print(summary[0]['summary_text'])"
676 | ]
677 | },
678 | {
679 | "cell_type": "markdown",
680 | "id": "c140c5fd",
681 | "metadata": {},
682 | "source": [
683 | "We can also specify a minimum length"
684 | ]
685 | },
686 | {
687 | "cell_type": "code",
688 | "execution_count": 15,
689 | "id": "d903416c",
690 | "metadata": {},
691 | "outputs": [
692 | {
693 | "name": "stdout",
694 | "output_type": "stream",
695 | "text": [
696 | " Transformers is a media franchise produced by Hasbro and Japanese toy company Takara Tomy . It primarily follows the heroic Autobots and the villainous Decepticons, two alien robot factions at war that can transform into other forms, such as vehicles and animals . As of 2011, it generated more than ¥2 trillion ($25 billion) in revenue, making it one of the highest-grossing media franchises of all time . The term \"Generation 1\" (G1) covers both the animated television series The Transformers and the comic book series of the same name .\n"
697 | ]
698 | }
699 | ],
700 | "source": [
701 | "summary = summarizer(wiki_text, min_length=100)\n",
702 | "\n",
703 | "print(summary[0]['summary_text'])"
704 | ]
705 | },
706 | {
707 | "cell_type": "markdown",
708 | "id": "a1fb1854",
709 | "metadata": {},
710 | "source": [
711 | "# Question Answering "
712 | ]
713 | },
714 | {
715 | "cell_type": "code",
716 | "execution_count": 16,
717 | "id": "b59a33de",
718 | "metadata": {},
719 | "outputs": [
720 | {
721 | "name": "stderr",
722 | "output_type": "stream",
723 | "text": [
724 | "No model was supplied, defaulted to distilbert/distilbert-base-cased-distilled-squad and revision 626af31 (https://huggingface.co/distilbert/distilbert-base-cased-distilled-squad).\n",
725 | "Using a pipeline without specifying a model name and revision in production is not recommended.\n",
726 | "/opt/anaconda3/lib/python3.11/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
727 | " warnings.warn(\n"
728 | ]
729 | }
730 | ],
731 | "source": [
732 | "reader = pipeline(\"question-answering\")"
733 | ]
734 | },
735 | {
736 | "cell_type": "code",
737 | "execution_count": 17,
738 | "id": "7b06141f",
739 | "metadata": {},
740 | "outputs": [],
741 | "source": [
742 | "question = \"What does the customer want?\""
743 | ]
744 | },
745 | {
746 | "cell_type": "code",
747 | "execution_count": 18,
748 | "id": "561cbaff",
749 | "metadata": {},
750 | "outputs": [
751 | {
752 | "data": {
753 | "text/html": [
754 | "\n",
755 | "\n",
768 | "
\n",
769 | " \n",
770 | " \n",
771 | " | \n",
772 | " score | \n",
773 | " start | \n",
774 | " end | \n",
775 | " answer | \n",
776 | "
\n",
777 | " \n",
778 | " \n",
779 | " \n",
780 | " 0 | \n",
781 | " 0.631292 | \n",
782 | " 336 | \n",
783 | " 359 | \n",
784 | " an exchange of Megatron | \n",
785 | "
\n",
786 | " \n",
787 | "
\n",
788 | "
"
789 | ],
790 | "text/plain": [
791 | " score start end answer\n",
792 | "0 0.631292 336 359 an exchange of Megatron"
793 | ]
794 | },
795 | "execution_count": 18,
796 | "metadata": {},
797 | "output_type": "execute_result"
798 | }
799 | ],
800 | "source": [
801 | "outputs = reader(question=question, context=email)\n",
802 | "pd.DataFrame([outputs]) "
803 | ]
804 | },
805 | {
806 | "cell_type": "markdown",
807 | "id": "6868cac6",
808 | "metadata": {},
809 | "source": [
810 | "# Translation"
811 | ]
812 | },
813 | {
814 | "cell_type": "code",
815 | "execution_count": 19,
816 | "id": "7cc732f8",
817 | "metadata": {},
818 | "outputs": [],
819 | "source": [
820 | "translator = pipeline(\"translation_en_to_it\", \n",
821 | " model=\"Helsinki-NLP/opus-mt-en-it\")"
822 | ]
823 | },
824 | {
825 | "cell_type": "code",
826 | "execution_count": 20,
827 | "id": "ac60ca47",
828 | "metadata": {},
829 | "outputs": [
830 | {
831 | "name": "stdout",
832 | "output_type": "stream",
833 | "text": [
834 | "Cara Amazon, la scorsa settimana ho ordinato una figura d'azione Optimus Prime dal tuo negozio online in Germania. Purtroppo, quando ho aperto il pacchetto, ho scoperto al mio orrore che ero stato inviato una figura d'azione di Megatron invece! Come un nemico per tutta la vita dei Decepticon, spero che si può capire il mio dilemma. Per risolvere il problema, chiedo uno scambio di Megatron per la figura di Optimus Prime ho ordinato. In allegato sono copie dei miei record riguardanti questo acquisto. Mi aspetto di sentire da voi presto. Cordialmente, Bumblebee.\n"
835 | ]
836 | }
837 | ],
838 | "source": [
839 | "outputs = translator(email, clean_up_tokenization_spaces=True, min_length=100, max_length=1000)\n",
840 | "print(outputs[0]['translation_text'])"
841 | ]
842 | },
843 | {
844 | "cell_type": "markdown",
845 | "id": "924b10ed",
846 | "metadata": {},
847 | "source": [
848 | "For comparison, let us look at the results of google translate:\n",
849 | "\n",
850 | "```\n",
851 | "Caro Amazon, la settimana scorsa ho ordinato un action figure di Optimus Prime dal tuo negozio online in Germania. Sfortunatamente, quando ho aperto il pacco, ho scoperto con orrore che mi era stata invece inviata una action figure di Megatron! Essendo un nemico da sempre dei Decepticon, spero che tu possa capire il mio dilemma. Per risolvere il problema, chiedo uno scambio di Megatron con la figura di Optimus Prime che ho ordinato. In allegato sono presenti copie dei miei documenti relativi a questo acquisto. Mi aspetto di sentirti presto. Cordiali saluti, Bombo.\n",
852 | "```"
853 | ]
854 | },
855 | {
856 | "cell_type": "markdown",
857 | "id": "27659fb2",
858 | "metadata": {},
859 | "source": [
860 | "Google translate is less context aware in the translation going so far as translating the name of the email sender (Bumblebee -> Bombo). On the other hand, the Hugging Face model is more formal (\"sentire da voi\" -> \"sentirti\")"
861 | ]
862 | },
863 | {
864 | "cell_type": "markdown",
865 | "id": "f5101ea4",
866 | "metadata": {},
867 | "source": [
868 | "# Sentiment Analysis"
869 | ]
870 | },
871 | {
872 | "cell_type": "code",
873 | "execution_count": 21,
874 | "id": "6f234a7f",
875 | "metadata": {},
876 | "outputs": [
877 | {
878 | "name": "stderr",
879 | "output_type": "stream",
880 | "text": [
881 | "No model was supplied, defaulted to distilbert/distilbert-base-uncased-finetuned-sst-2-english and revision af0f99b (https://huggingface.co/distilbert/distilbert-base-uncased-finetuned-sst-2-english).\n",
882 | "Using a pipeline without specifying a model name and revision in production is not recommended.\n",
883 | "/opt/anaconda3/lib/python3.11/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
884 | " warnings.warn(\n"
885 | ]
886 | }
887 | ],
888 | "source": [
889 | "sentiment_pipeline = pipeline(\"sentiment-analysis\")"
890 | ]
891 | },
892 | {
893 | "cell_type": "markdown",
894 | "id": "9d732ee0",
895 | "metadata": {},
896 | "source": [
897 | "Let us use a comple of faily obvious instances"
898 | ]
899 | },
900 | {
901 | "cell_type": "code",
902 | "execution_count": 22,
903 | "id": "9194345f",
904 | "metadata": {},
905 | "outputs": [],
906 | "source": [
907 | "instances = [\"I love you\", \"I hate you\"]"
908 | ]
909 | },
910 | {
911 | "cell_type": "markdown",
912 | "id": "659be922",
913 | "metadata": {},
914 | "source": [
915 | "The model does a pretty good job of figuring out which one is positive and which one is negative"
916 | ]
917 | },
918 | {
919 | "cell_type": "code",
920 | "execution_count": 23,
921 | "id": "0258363e",
922 | "metadata": {},
923 | "outputs": [
924 | {
925 | "data": {
926 | "text/plain": [
927 | "[{'label': 'POSITIVE', 'score': 0.9998656511306763},\n",
928 | " {'label': 'NEGATIVE', 'score': 0.9991129040718079}]"
929 | ]
930 | },
931 | "execution_count": 23,
932 | "metadata": {},
933 | "output_type": "execute_result"
934 | }
935 | ],
936 | "source": [
937 | "sentiment_pipeline(instances)"
938 | ]
939 | },
940 | {
941 | "cell_type": "markdown",
942 | "id": "813b1afb",
943 | "metadata": {},
944 | "source": [
945 | "# Application"
946 | ]
947 | },
948 | {
949 | "cell_type": "markdown",
950 | "id": "35bfe5be",
951 | "metadata": {},
952 | "source": [
953 | "Load a few thousand tweets about Apple"
954 | ]
955 | },
956 | {
957 | "cell_type": "code",
958 | "execution_count": 24,
959 | "id": "e38a139b",
960 | "metadata": {},
961 | "outputs": [],
962 | "source": [
963 | "data = pd.read_csv('data/Apple-Twitter-Sentiment-DFE.csv', usecols=['text'])"
964 | ]
965 | },
966 | {
967 | "cell_type": "code",
968 | "execution_count": 25,
969 | "id": "105924ae",
970 | "metadata": {},
971 | "outputs": [
972 | {
973 | "data": {
974 | "text/html": [
975 | "\n",
976 | "\n",
989 | "
\n",
990 | " \n",
991 | " \n",
992 | " | \n",
993 | " text | \n",
994 | "
\n",
995 | " \n",
996 | " \n",
997 | " \n",
998 | " 0 | \n",
999 | " #AAPL:The 10 best Steve Jobs emails ever...htt... | \n",
1000 | "
\n",
1001 | " \n",
1002 | " 1 | \n",
1003 | " RT @JPDesloges: Why AAPL Stock Had a Mini-Flas... | \n",
1004 | "
\n",
1005 | " \n",
1006 | " 2 | \n",
1007 | " My cat only chews @apple cords. Such an #Apple... | \n",
1008 | "
\n",
1009 | " \n",
1010 | " 3 | \n",
1011 | " I agree with @jimcramer that the #IndividualIn... | \n",
1012 | "
\n",
1013 | " \n",
1014 | " 4 | \n",
1015 | " Nobody expects the Spanish Inquisition #AAPL | \n",
1016 | "
\n",
1017 | " \n",
1018 | " ... | \n",
1019 | " ... | \n",
1020 | "
\n",
1021 | " \n",
1022 | " 3881 | \n",
1023 | " (Via FC) Apple Is Warming Up To Social Media -... | \n",
1024 | "
\n",
1025 | " \n",
1026 | " 3882 | \n",
1027 | " RT @MMLXIV: there is no avocado emoji may I as... | \n",
1028 | "
\n",
1029 | " \n",
1030 | " 3883 | \n",
1031 | " @marcbulandr I could not agree more. Between @... | \n",
1032 | "
\n",
1033 | " \n",
1034 | " 3884 | \n",
1035 | " My iPhone 5's photos are no longer downloading... | \n",
1036 | "
\n",
1037 | " \n",
1038 | " 3885 | \n",
1039 | " RT @SwiftKey: We're so excited to be named to ... | \n",
1040 | "
\n",
1041 | " \n",
1042 | "
\n",
1043 | "
3886 rows × 1 columns
\n",
1044 | "
"
1045 | ],
1046 | "text/plain": [
1047 | " text\n",
1048 | "0 #AAPL:The 10 best Steve Jobs emails ever...htt...\n",
1049 | "1 RT @JPDesloges: Why AAPL Stock Had a Mini-Flas...\n",
1050 | "2 My cat only chews @apple cords. Such an #Apple...\n",
1051 | "3 I agree with @jimcramer that the #IndividualIn...\n",
1052 | "4 Nobody expects the Spanish Inquisition #AAPL\n",
1053 | "... ...\n",
1054 | "3881 (Via FC) Apple Is Warming Up To Social Media -...\n",
1055 | "3882 RT @MMLXIV: there is no avocado emoji may I as...\n",
1056 | "3883 @marcbulandr I could not agree more. Between @...\n",
1057 | "3884 My iPhone 5's photos are no longer downloading...\n",
1058 | "3885 RT @SwiftKey: We're so excited to be named to ...\n",
1059 | "\n",
1060 | "[3886 rows x 1 columns]"
1061 | ]
1062 | },
1063 | "execution_count": 25,
1064 | "metadata": {},
1065 | "output_type": "execute_result"
1066 | }
1067 | ],
1068 | "source": [
1069 | "data"
1070 | ]
1071 | },
1072 | {
1073 | "cell_type": "markdown",
1074 | "id": "006c9b23",
1075 | "metadata": {},
1076 | "source": [
1077 | "Compute the sentiment score for each tweet"
1078 | ]
1079 | },
1080 | {
1081 | "cell_type": "code",
1082 | "execution_count": 26,
1083 | "id": "2c76b4c6",
1084 | "metadata": {},
1085 | "outputs": [
1086 | {
1087 | "data": {
1088 | "application/vnd.jupyter.widget-view+json": {
1089 | "model_id": "6244640c013442768f3883a7065d998f",
1090 | "version_major": 2,
1091 | "version_minor": 0
1092 | },
1093 | "text/plain": [
1094 | " 0%| | 0/3886 [00:00, ?it/s]"
1095 | ]
1096 | },
1097 | "metadata": {},
1098 | "output_type": "display_data"
1099 | }
1100 | ],
1101 | "source": [
1102 | "sent = pd.DataFrame(data['text'].progress_apply(lambda x: pd.Series(sentiment_pipeline(x)[0])))"
1103 | ]
1104 | },
1105 | {
1106 | "cell_type": "code",
1107 | "execution_count": 27,
1108 | "id": "928e0ab3",
1109 | "metadata": {},
1110 | "outputs": [],
1111 | "source": [
1112 | "sent.rename(columns={'score': 'sentiment_confidence', 'label':'sentiment'}, inplace=True)"
1113 | ]
1114 | },
1115 | {
1116 | "cell_type": "code",
1117 | "execution_count": 28,
1118 | "id": "6116abee",
1119 | "metadata": {},
1120 | "outputs": [
1121 | {
1122 | "data": {
1123 | "text/html": [
1124 | "\n",
1125 | "\n",
1138 | "
\n",
1139 | " \n",
1140 | " \n",
1141 | " | \n",
1142 | " sentiment | \n",
1143 | " sentiment_confidence | \n",
1144 | "
\n",
1145 | " \n",
1146 | " \n",
1147 | " \n",
1148 | " 0 | \n",
1149 | " POSITIVE | \n",
1150 | " 0.999432 | \n",
1151 | "
\n",
1152 | " \n",
1153 | " 1 | \n",
1154 | " NEGATIVE | \n",
1155 | " 0.999122 | \n",
1156 | "
\n",
1157 | " \n",
1158 | " 2 | \n",
1159 | " NEGATIVE | \n",
1160 | " 0.996177 | \n",
1161 | "
\n",
1162 | " \n",
1163 | " 3 | \n",
1164 | " POSITIVE | \n",
1165 | " 0.995648 | \n",
1166 | "
\n",
1167 | " \n",
1168 | " 4 | \n",
1169 | " NEGATIVE | \n",
1170 | " 0.932676 | \n",
1171 | "
\n",
1172 | " \n",
1173 | " ... | \n",
1174 | " ... | \n",
1175 | " ... | \n",
1176 | "
\n",
1177 | " \n",
1178 | " 3881 | \n",
1179 | " NEGATIVE | \n",
1180 | " 0.992046 | \n",
1181 | "
\n",
1182 | " \n",
1183 | " 3882 | \n",
1184 | " NEGATIVE | \n",
1185 | " 0.999158 | \n",
1186 | "
\n",
1187 | " \n",
1188 | " 3883 | \n",
1189 | " NEGATIVE | \n",
1190 | " 0.935773 | \n",
1191 | "
\n",
1192 | " \n",
1193 | " 3884 | \n",
1194 | " NEGATIVE | \n",
1195 | " 0.998303 | \n",
1196 | "
\n",
1197 | " \n",
1198 | " 3885 | \n",
1199 | " POSITIVE | \n",
1200 | " 0.998714 | \n",
1201 | "
\n",
1202 | " \n",
1203 | "
\n",
1204 | "
3886 rows × 2 columns
\n",
1205 | "
"
1206 | ],
1207 | "text/plain": [
1208 | " sentiment sentiment_confidence\n",
1209 | "0 POSITIVE 0.999432\n",
1210 | "1 NEGATIVE 0.999122\n",
1211 | "2 NEGATIVE 0.996177\n",
1212 | "3 POSITIVE 0.995648\n",
1213 | "4 NEGATIVE 0.932676\n",
1214 | "... ... ...\n",
1215 | "3881 NEGATIVE 0.992046\n",
1216 | "3882 NEGATIVE 0.999158\n",
1217 | "3883 NEGATIVE 0.935773\n",
1218 | "3884 NEGATIVE 0.998303\n",
1219 | "3885 POSITIVE 0.998714\n",
1220 | "\n",
1221 | "[3886 rows x 2 columns]"
1222 | ]
1223 | },
1224 | "execution_count": 28,
1225 | "metadata": {},
1226 | "output_type": "execute_result"
1227 | }
1228 | ],
1229 | "source": [
1230 | "sent"
1231 | ]
1232 | },
1233 | {
1234 | "cell_type": "markdown",
1235 | "id": "accfe71c",
1236 | "metadata": {},
1237 | "source": [
1238 | "We can also use NER to identify when a person is mentioned in the tweet"
1239 | ]
1240 | },
1241 | {
1242 | "cell_type": "code",
1243 | "execution_count": 29,
1244 | "id": "cbbfdb31",
1245 | "metadata": {},
1246 | "outputs": [
1247 | {
1248 | "data": {
1249 | "text/plain": [
1250 | "'#AAPL:The 10 best Steve Jobs emails ever...http://t.co/82G1kL94tx'"
1251 | ]
1252 | },
1253 | "execution_count": 29,
1254 | "metadata": {},
1255 | "output_type": "execute_result"
1256 | }
1257 | ],
1258 | "source": [
1259 | "data['text'].iloc[0]"
1260 | ]
1261 | },
1262 | {
1263 | "cell_type": "code",
1264 | "execution_count": 30,
1265 | "id": "7f254e40",
1266 | "metadata": {},
1267 | "outputs": [
1268 | {
1269 | "data": {
1270 | "text/plain": [
1271 | "[{'entity_group': 'PER',\n",
1272 | " 'score': 0.73089933,\n",
1273 | " 'word': 'Steve Jobs',\n",
1274 | " 'start': 18,\n",
1275 | " 'end': 28}]"
1276 | ]
1277 | },
1278 | "execution_count": 30,
1279 | "metadata": {},
1280 | "output_type": "execute_result"
1281 | }
1282 | ],
1283 | "source": [
1284 | "ner_tagger(data['text'].iloc[0])"
1285 | ]
1286 | },
1287 | {
1288 | "cell_type": "markdown",
1289 | "id": "b8601ef1",
1290 | "metadata": {},
1291 | "source": [
1292 | "Identify all people mentioned"
1293 | ]
1294 | },
1295 | {
1296 | "cell_type": "code",
1297 | "execution_count": 31,
1298 | "id": "ee96e05a",
1299 | "metadata": {},
1300 | "outputs": [],
1301 | "source": [
1302 | "def find_people(x):\n",
1303 | " output = ner_tagger(x)\n",
1304 | " \n",
1305 | " for tag in output:\n",
1306 | " if tag['entity_group'] == 'PER':\n",
1307 | " out = {'confidence':tag[\"score\"], 'person': tag['word']}\n",
1308 | " return pd.Series(out)\n",
1309 | " \n",
1310 | " return pd.Series({\"confidence\": None, \"person\": None})"
1311 | ]
1312 | },
1313 | {
1314 | "cell_type": "code",
1315 | "execution_count": 32,
1316 | "id": "574f044a",
1317 | "metadata": {},
1318 | "outputs": [
1319 | {
1320 | "data": {
1321 | "application/vnd.jupyter.widget-view+json": {
1322 | "model_id": "fcfe613d380f44c98215932d2bcdb1fa",
1323 | "version_major": 2,
1324 | "version_minor": 0
1325 | },
1326 | "text/plain": [
1327 | " 0%| | 0/3886 [00:00, ?it/s]"
1328 | ]
1329 | },
1330 | "metadata": {},
1331 | "output_type": "display_data"
1332 | }
1333 | ],
1334 | "source": [
1335 | "people = pd.DataFrame(data['text'].progress_apply(find_people))"
1336 | ]
1337 | },
1338 | {
1339 | "cell_type": "markdown",
1340 | "id": "3e3e40b9",
1341 | "metadata": {},
1342 | "source": [
1343 | "Combine all the results into a single DataFrame"
1344 | ]
1345 | },
1346 | {
1347 | "cell_type": "code",
1348 | "execution_count": 33,
1349 | "id": "011601d1",
1350 | "metadata": {},
1351 | "outputs": [],
1352 | "source": [
1353 | "data = pd.concat([data, sent, people], axis=1)"
1354 | ]
1355 | },
1356 | {
1357 | "cell_type": "code",
1358 | "execution_count": 34,
1359 | "id": "f4d9106a",
1360 | "metadata": {},
1361 | "outputs": [
1362 | {
1363 | "data": {
1364 | "text/html": [
1365 | "\n",
1366 | "\n",
1379 | "
\n",
1380 | " \n",
1381 | " \n",
1382 | " | \n",
1383 | " text | \n",
1384 | " sentiment | \n",
1385 | " sentiment_confidence | \n",
1386 | " confidence | \n",
1387 | " person | \n",
1388 | "
\n",
1389 | " \n",
1390 | " \n",
1391 | " \n",
1392 | " 0 | \n",
1393 | " #AAPL:The 10 best Steve Jobs emails ever...htt... | \n",
1394 | " POSITIVE | \n",
1395 | " 0.999432 | \n",
1396 | " 0.730899 | \n",
1397 | " Steve Jobs | \n",
1398 | "
\n",
1399 | " \n",
1400 | " 1 | \n",
1401 | " RT @JPDesloges: Why AAPL Stock Had a Mini-Flas... | \n",
1402 | " NEGATIVE | \n",
1403 | " 0.999122 | \n",
1404 | " NaN | \n",
1405 | " None | \n",
1406 | "
\n",
1407 | " \n",
1408 | " 2 | \n",
1409 | " My cat only chews @apple cords. Such an #Apple... | \n",
1410 | " NEGATIVE | \n",
1411 | " 0.996177 | \n",
1412 | " NaN | \n",
1413 | " None | \n",
1414 | "
\n",
1415 | " \n",
1416 | " 3 | \n",
1417 | " I agree with @jimcramer that the #IndividualIn... | \n",
1418 | " POSITIVE | \n",
1419 | " 0.995648 | \n",
1420 | " NaN | \n",
1421 | " None | \n",
1422 | "
\n",
1423 | " \n",
1424 | " 4 | \n",
1425 | " Nobody expects the Spanish Inquisition #AAPL | \n",
1426 | " NEGATIVE | \n",
1427 | " 0.932676 | \n",
1428 | " NaN | \n",
1429 | " None | \n",
1430 | "
\n",
1431 | " \n",
1432 | " ... | \n",
1433 | " ... | \n",
1434 | " ... | \n",
1435 | " ... | \n",
1436 | " ... | \n",
1437 | " ... | \n",
1438 | "
\n",
1439 | " \n",
1440 | " 3881 | \n",
1441 | " (Via FC) Apple Is Warming Up To Social Media -... | \n",
1442 | " NEGATIVE | \n",
1443 | " 0.992046 | \n",
1444 | " NaN | \n",
1445 | " None | \n",
1446 | "
\n",
1447 | " \n",
1448 | " 3882 | \n",
1449 | " RT @MMLXIV: there is no avocado emoji may I as... | \n",
1450 | " NEGATIVE | \n",
1451 | " 0.999158 | \n",
1452 | " NaN | \n",
1453 | " None | \n",
1454 | "
\n",
1455 | " \n",
1456 | " 3883 | \n",
1457 | " @marcbulandr I could not agree more. Between @... | \n",
1458 | " NEGATIVE | \n",
1459 | " 0.935773 | \n",
1460 | " NaN | \n",
1461 | " None | \n",
1462 | "
\n",
1463 | " \n",
1464 | " 3884 | \n",
1465 | " My iPhone 5's photos are no longer downloading... | \n",
1466 | " NEGATIVE | \n",
1467 | " 0.998303 | \n",
1468 | " NaN | \n",
1469 | " None | \n",
1470 | "
\n",
1471 | " \n",
1472 | " 3885 | \n",
1473 | " RT @SwiftKey: We're so excited to be named to ... | \n",
1474 | " POSITIVE | \n",
1475 | " 0.998714 | \n",
1476 | " NaN | \n",
1477 | " None | \n",
1478 | "
\n",
1479 | " \n",
1480 | "
\n",
1481 | "
3886 rows × 5 columns
\n",
1482 | "
"
1483 | ],
1484 | "text/plain": [
1485 | " text sentiment \\\n",
1486 | "0 #AAPL:The 10 best Steve Jobs emails ever...htt... POSITIVE \n",
1487 | "1 RT @JPDesloges: Why AAPL Stock Had a Mini-Flas... NEGATIVE \n",
1488 | "2 My cat only chews @apple cords. Such an #Apple... NEGATIVE \n",
1489 | "3 I agree with @jimcramer that the #IndividualIn... POSITIVE \n",
1490 | "4 Nobody expects the Spanish Inquisition #AAPL NEGATIVE \n",
1491 | "... ... ... \n",
1492 | "3881 (Via FC) Apple Is Warming Up To Social Media -... NEGATIVE \n",
1493 | "3882 RT @MMLXIV: there is no avocado emoji may I as... NEGATIVE \n",
1494 | "3883 @marcbulandr I could not agree more. Between @... NEGATIVE \n",
1495 | "3884 My iPhone 5's photos are no longer downloading... NEGATIVE \n",
1496 | "3885 RT @SwiftKey: We're so excited to be named to ... POSITIVE \n",
1497 | "\n",
1498 | " sentiment_confidence confidence person \n",
1499 | "0 0.999432 0.730899 Steve Jobs \n",
1500 | "1 0.999122 NaN None \n",
1501 | "2 0.996177 NaN None \n",
1502 | "3 0.995648 NaN None \n",
1503 | "4 0.932676 NaN None \n",
1504 | "... ... ... ... \n",
1505 | "3881 0.992046 NaN None \n",
1506 | "3882 0.999158 NaN None \n",
1507 | "3883 0.935773 NaN None \n",
1508 | "3884 0.998303 NaN None \n",
1509 | "3885 0.998714 NaN None \n",
1510 | "\n",
1511 | "[3886 rows x 5 columns]"
1512 | ]
1513 | },
1514 | "execution_count": 34,
1515 | "metadata": {},
1516 | "output_type": "execute_result"
1517 | }
1518 | ],
1519 | "source": [
1520 | "data"
1521 | ]
1522 | },
1523 | {
1524 | "cell_type": "markdown",
1525 | "id": "b2ecc528",
1526 | "metadata": {},
1527 | "source": [
1528 | "Subset the data to only the tweets meantioning people"
1529 | ]
1530 | },
1531 | {
1532 | "cell_type": "code",
1533 | "execution_count": 35,
1534 | "id": "341d70d1",
1535 | "metadata": {},
1536 | "outputs": [],
1537 | "source": [
1538 | "people = data[data.person.isna() == False].copy()"
1539 | ]
1540 | },
1541 | {
1542 | "cell_type": "code",
1543 | "execution_count": 36,
1544 | "id": "1647ec48",
1545 | "metadata": {},
1546 | "outputs": [
1547 | {
1548 | "data": {
1549 | "text/html": [
1550 | "\n",
1551 | "\n",
1564 | "
\n",
1565 | " \n",
1566 | " \n",
1567 | " | \n",
1568 | " text | \n",
1569 | " sentiment | \n",
1570 | " sentiment_confidence | \n",
1571 | " confidence | \n",
1572 | " person | \n",
1573 | "
\n",
1574 | " \n",
1575 | " \n",
1576 | " \n",
1577 | " 0 | \n",
1578 | " #AAPL:The 10 best Steve Jobs emails ever...htt... | \n",
1579 | " POSITIVE | \n",
1580 | " 0.999432 | \n",
1581 | " 0.730899 | \n",
1582 | " Steve Jobs | \n",
1583 | "
\n",
1584 | " \n",
1585 | " 17 | \n",
1586 | " @Apple John Cantlie has been a prisoner of ISI... | \n",
1587 | " POSITIVE | \n",
1588 | " 0.930674 | \n",
1589 | " 0.984458 | \n",
1590 | " John Cantlie | \n",
1591 | "
\n",
1592 | " \n",
1593 | " 80 | \n",
1594 | " I'm hoping @apple won't automatically make us ... | \n",
1595 | " NEGATIVE | \n",
1596 | " 0.997191 | \n",
1597 | " 0.889901 | \n",
1598 | " Bruce Springsteen | \n",
1599 | "
\n",
1600 | " \n",
1601 | " 81 | \n",
1602 | " @thehill @Apple What a joke! Justice Dept shou... | \n",
1603 | " NEGATIVE | \n",
1604 | " 0.992056 | \n",
1605 | " 0.941649 | \n",
1606 | " Killer Wilson | \n",
1607 | "
\n",
1608 | " \n",
1609 | " 85 | \n",
1610 | " #AAPL:10 Steve Jobs emails you need to read...... | \n",
1611 | " NEGATIVE | \n",
1612 | " 0.994819 | \n",
1613 | " 0.495679 | \n",
1614 | " Job | \n",
1615 | "
\n",
1616 | " \n",
1617 | " ... | \n",
1618 | " ... | \n",
1619 | " ... | \n",
1620 | " ... | \n",
1621 | " ... | \n",
1622 | " ... | \n",
1623 | "
\n",
1624 | " \n",
1625 | " 3777 | \n",
1626 | " Jeff Daniels in Talks to Play Former Apple CEO... | \n",
1627 | " NEGATIVE | \n",
1628 | " 0.593705 | \n",
1629 | " 0.999621 | \n",
1630 | " Jeff Daniels | \n",
1631 | "
\n",
1632 | " \n",
1633 | " 3798 | \n",
1634 | " Dear Santa, All I want for Christmas is for @A... | \n",
1635 | " NEGATIVE | \n",
1636 | " 0.979491 | \n",
1637 | " 0.839571 | \n",
1638 | " Santa | \n",
1639 | "
\n",
1640 | " \n",
1641 | " 3812 | \n",
1642 | " @Apple co-founder Steve Wozniak talks about St... | \n",
1643 | " POSITIVE | \n",
1644 | " 0.938843 | \n",
1645 | " 0.941354 | \n",
1646 | " Steve Wozniak | \n",
1647 | "
\n",
1648 | " \n",
1649 | " 3818 | \n",
1650 | " RT @CNET: @Apple pioneer Bill Fernandez on @Go... | \n",
1651 | " NEGATIVE | \n",
1652 | " 0.922325 | \n",
1653 | " 0.999579 | \n",
1654 | " Bill Fernandez | \n",
1655 | "
\n",
1656 | " \n",
1657 | " 3879 | \n",
1658 | " Tim Cook Met With Jesse Jackson for 'Positive ... | \n",
1659 | " POSITIVE | \n",
1660 | " 0.998170 | \n",
1661 | " 0.999658 | \n",
1662 | " Tim Cook | \n",
1663 | "
\n",
1664 | " \n",
1665 | "
\n",
1666 | "
291 rows × 5 columns
\n",
1667 | "
"
1668 | ],
1669 | "text/plain": [
1670 | " text sentiment \\\n",
1671 | "0 #AAPL:The 10 best Steve Jobs emails ever...htt... POSITIVE \n",
1672 | "17 @Apple John Cantlie has been a prisoner of ISI... POSITIVE \n",
1673 | "80 I'm hoping @apple won't automatically make us ... NEGATIVE \n",
1674 | "81 @thehill @Apple What a joke! Justice Dept shou... NEGATIVE \n",
1675 | "85 #AAPL:10 Steve Jobs emails you need to read...... NEGATIVE \n",
1676 | "... ... ... \n",
1677 | "3777 Jeff Daniels in Talks to Play Former Apple CEO... NEGATIVE \n",
1678 | "3798 Dear Santa, All I want for Christmas is for @A... NEGATIVE \n",
1679 | "3812 @Apple co-founder Steve Wozniak talks about St... POSITIVE \n",
1680 | "3818 RT @CNET: @Apple pioneer Bill Fernandez on @Go... NEGATIVE \n",
1681 | "3879 Tim Cook Met With Jesse Jackson for 'Positive ... POSITIVE \n",
1682 | "\n",
1683 | " sentiment_confidence confidence person \n",
1684 | "0 0.999432 0.730899 Steve Jobs \n",
1685 | "17 0.930674 0.984458 John Cantlie \n",
1686 | "80 0.997191 0.889901 Bruce Springsteen \n",
1687 | "81 0.992056 0.941649 Killer Wilson \n",
1688 | "85 0.994819 0.495679 Job \n",
1689 | "... ... ... ... \n",
1690 | "3777 0.593705 0.999621 Jeff Daniels \n",
1691 | "3798 0.979491 0.839571 Santa \n",
1692 | "3812 0.938843 0.941354 Steve Wozniak \n",
1693 | "3818 0.922325 0.999579 Bill Fernandez \n",
1694 | "3879 0.998170 0.999658 Tim Cook \n",
1695 | "\n",
1696 | "[291 rows x 5 columns]"
1697 | ]
1698 | },
1699 | "execution_count": 36,
1700 | "metadata": {},
1701 | "output_type": "execute_result"
1702 | }
1703 | ],
1704 | "source": [
1705 | "people"
1706 | ]
1707 | },
1708 | {
1709 | "cell_type": "markdown",
1710 | "id": "c4775cd9",
1711 | "metadata": {},
1712 | "source": [
1713 | "Convert the text labels to a numerical score"
1714 | ]
1715 | },
1716 | {
1717 | "cell_type": "code",
1718 | "execution_count": 37,
1719 | "id": "64b30b04",
1720 | "metadata": {},
1721 | "outputs": [],
1722 | "source": [
1723 | "people['sentiment'] = people.apply(lambda x: 1 if x.sentiment == 'POSITIVE' else -1, axis=1)"
1724 | ]
1725 | },
1726 | {
1727 | "cell_type": "markdown",
1728 | "id": "c52fab24",
1729 | "metadata": {},
1730 | "source": [
1731 | "Compute the average score"
1732 | ]
1733 | },
1734 | {
1735 | "cell_type": "code",
1736 | "execution_count": 38,
1737 | "id": "960b48c8",
1738 | "metadata": {},
1739 | "outputs": [],
1740 | "source": [
1741 | "stats = people[['person', 'sentiment']].groupby('person').mean()"
1742 | ]
1743 | },
1744 | {
1745 | "cell_type": "code",
1746 | "execution_count": 39,
1747 | "id": "50efa095",
1748 | "metadata": {},
1749 | "outputs": [],
1750 | "source": [
1751 | "counts = people[['person', 'sentiment']].groupby('person').count()\n",
1752 | "counts.rename(columns={'sentiment':'count'}, inplace=True)"
1753 | ]
1754 | },
1755 | {
1756 | "cell_type": "code",
1757 | "execution_count": 40,
1758 | "id": "4202add7",
1759 | "metadata": {},
1760 | "outputs": [],
1761 | "source": [
1762 | "stats = stats.join(counts)"
1763 | ]
1764 | },
1765 | {
1766 | "cell_type": "code",
1767 | "execution_count": 41,
1768 | "id": "e0b844fe",
1769 | "metadata": {},
1770 | "outputs": [
1771 | {
1772 | "data": {
1773 | "text/html": [
1774 | "\n",
1775 | "\n",
1788 | "
\n",
1789 | " \n",
1790 | " \n",
1791 | " | \n",
1792 | " sentiment | \n",
1793 | " count | \n",
1794 | "
\n",
1795 | " \n",
1796 | " person | \n",
1797 | " | \n",
1798 | " | \n",
1799 | "
\n",
1800 | " \n",
1801 | " \n",
1802 | " \n",
1803 | " Steve Wozniak | \n",
1804 | " -0.200000 | \n",
1805 | " 10 | \n",
1806 | "
\n",
1807 | " \n",
1808 | " Steve | \n",
1809 | " -0.333333 | \n",
1810 | " 6 | \n",
1811 | "
\n",
1812 | " \n",
1813 | " Eddy Cue | \n",
1814 | " -0.500000 | \n",
1815 | " 8 | \n",
1816 | "
\n",
1817 | " \n",
1818 | " Steve Job | \n",
1819 | " -0.647059 | \n",
1820 | " 17 | \n",
1821 | "
\n",
1822 | " \n",
1823 | " Mark Zuckerberg | \n",
1824 | " -0.714286 | \n",
1825 | " 7 | \n",
1826 | "
\n",
1827 | " \n",
1828 | " Steve Jobs | \n",
1829 | " -0.777778 | \n",
1830 | " 27 | \n",
1831 | "
\n",
1832 | " \n",
1833 | " Tim Cook | \n",
1834 | " -0.793103 | \n",
1835 | " 29 | \n",
1836 | "
\n",
1837 | " \n",
1838 | " Job | \n",
1839 | " -0.800000 | \n",
1840 | " 10 | \n",
1841 | "
\n",
1842 | " \n",
1843 | "
\n",
1844 | "
"
1845 | ],
1846 | "text/plain": [
1847 | " sentiment count\n",
1848 | "person \n",
1849 | "Steve Wozniak -0.200000 10\n",
1850 | "Steve -0.333333 6\n",
1851 | "Eddy Cue -0.500000 8\n",
1852 | "Steve Job -0.647059 17\n",
1853 | "Mark Zuckerberg -0.714286 7\n",
1854 | "Steve Jobs -0.777778 27\n",
1855 | "Tim Cook -0.793103 29\n",
1856 | "Job -0.800000 10"
1857 | ]
1858 | },
1859 | "execution_count": 41,
1860 | "metadata": {},
1861 | "output_type": "execute_result"
1862 | }
1863 | ],
1864 | "source": [
1865 | "stats[stats['count']>=5].sort_values('sentiment', ascending=False)"
1866 | ]
1867 | },
1868 | {
1869 | "cell_type": "markdown",
1870 | "id": "4f66d008",
1871 | "metadata": {},
1872 | "source": [
1873 | "\n",
1874 | "
\n",
1875 | ""
1876 | ]
1877 | }
1878 | ],
1879 | "metadata": {
1880 | "kernelspec": {
1881 | "display_name": "Python 3 (ipykernel)",
1882 | "language": "python",
1883 | "name": "python3"
1884 | },
1885 | "language_info": {
1886 | "codemirror_mode": {
1887 | "name": "ipython",
1888 | "version": 3
1889 | },
1890 | "file_extension": ".py",
1891 | "mimetype": "text/x-python",
1892 | "name": "python",
1893 | "nbconvert_exporter": "python",
1894 | "pygments_lexer": "ipython3",
1895 | "version": "3.11.7"
1896 | },
1897 | "varInspector": {
1898 | "cols": {
1899 | "lenName": 16,
1900 | "lenType": 16,
1901 | "lenVar": 40
1902 | },
1903 | "kernels_config": {
1904 | "python": {
1905 | "delete_cmd_postfix": "",
1906 | "delete_cmd_prefix": "del ",
1907 | "library": "var_list.py",
1908 | "varRefreshCmd": "print(var_dic_list())"
1909 | },
1910 | "r": {
1911 | "delete_cmd_postfix": ") ",
1912 | "delete_cmd_prefix": "rm(",
1913 | "library": "var_list.r",
1914 | "varRefreshCmd": "cat(var_dic_list()) "
1915 | }
1916 | },
1917 | "types_to_exclude": [
1918 | "module",
1919 | "function",
1920 | "builtin_function_or_method",
1921 | "instance",
1922 | "_Feature"
1923 | ],
1924 | "window_display": false
1925 | }
1926 | },
1927 | "nbformat": 4,
1928 | "nbformat_minor": 5
1929 | }
1930 |
--------------------------------------------------------------------------------
/4. Whisper.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "4ca890a2",
6 | "metadata": {},
7 | "source": [
8 | "\n",
9 | "
\n",
10 | "
LLMs for Data Science
\n",
11 | "
Text to Speech with OpenAI
\n",
12 | "
Bruno Gonçalves
\n",
13 | " www.data4sci.com
\n",
14 | " @bgoncalves, @data4sci
\n",
15 | "
"
16 | ]
17 | },
18 | {
19 | "cell_type": "code",
20 | "execution_count": 1,
21 | "id": "fd85d18b",
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "from collections import Counter, defaultdict\n",
26 | "import random\n",
27 | "\n",
28 | "import pandas as pd\n",
29 | "import numpy as np\n",
30 | "\n",
31 | "import matplotlib\n",
32 | "import matplotlib.pyplot as plt \n",
33 | "\n",
34 | "import openai\n",
35 | "from openai import OpenAI\n",
36 | "\n",
37 | "import tqdm as tq\n",
38 | "from tqdm.notebook import tqdm\n",
39 | "\n",
40 | "import watermark\n",
41 | "\n",
42 | "%load_ext watermark\n",
43 | "%matplotlib inline"
44 | ]
45 | },
46 | {
47 | "cell_type": "markdown",
48 | "id": "62e99b79",
49 | "metadata": {},
50 | "source": [
51 | "We start by printing out the versions of the libraries we're using for future reference"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": 2,
57 | "id": "7b4d6520",
58 | "metadata": {},
59 | "outputs": [
60 | {
61 | "name": "stdout",
62 | "output_type": "stream",
63 | "text": [
64 | "Python implementation: CPython\n",
65 | "Python version : 3.11.7\n",
66 | "IPython version : 8.12.3\n",
67 | "\n",
68 | "Compiler : Clang 14.0.6 \n",
69 | "OS : Darwin\n",
70 | "Release : 24.3.0\n",
71 | "Machine : arm64\n",
72 | "Processor : arm\n",
73 | "CPU cores : 16\n",
74 | "Architecture: 64bit\n",
75 | "\n",
76 | "Git hash: 03802c3bf87993c3670c2fd8bf86e59d3d60bdfd\n",
77 | "\n",
78 | "matplotlib: 3.8.0\n",
79 | "watermark : 2.4.3\n",
80 | "pandas : 2.2.3\n",
81 | "json : 2.0.9\n",
82 | "openai : 1.30.5\n",
83 | "numpy : 1.26.4\n",
84 | "tqdm : 4.66.4\n",
85 | "\n"
86 | ]
87 | }
88 | ],
89 | "source": [
90 | "%watermark -n -v -m -g -iv"
91 | ]
92 | },
93 | {
94 | "cell_type": "markdown",
95 | "id": "13ae3a5c",
96 | "metadata": {},
97 | "source": [
98 | "Load default figure style"
99 | ]
100 | },
101 | {
102 | "cell_type": "code",
103 | "execution_count": 3,
104 | "id": "811ff9e3",
105 | "metadata": {},
106 | "outputs": [],
107 | "source": [
108 | "plt.style.use('d4sci.mplstyle')\n",
109 | "colors = plt.rcParams['axes.prop_cycle'].by_key()['color']"
110 | ]
111 | },
112 | {
113 | "cell_type": "markdown",
114 | "id": "f9960820",
115 | "metadata": {},
116 | "source": [
117 | "# Audio to Text"
118 | ]
119 | },
120 | {
121 | "cell_type": "code",
122 | "execution_count": 4,
123 | "id": "f00c7c7a",
124 | "metadata": {},
125 | "outputs": [],
126 | "source": [
127 | "client = OpenAI()"
128 | ]
129 | },
130 | {
131 | "cell_type": "markdown",
132 | "id": "e15f1156",
133 | "metadata": {},
134 | "source": [
135 | "Let us parse a small local file"
136 | ]
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": 8,
141 | "id": "ea9e8e94",
142 | "metadata": {},
143 | "outputs": [],
144 | "source": [
145 | "!open data/gettysburg10.wav"
146 | ]
147 | },
148 | {
149 | "cell_type": "code",
150 | "execution_count": 6,
151 | "id": "a23066b6",
152 | "metadata": {},
153 | "outputs": [],
154 | "source": [
155 | "transcript = client.audio.transcriptions.create(\n",
156 | " file = open(\"data/gettysburg10.wav\", \"rb\"),\n",
157 | " model = \"whisper-1\",\n",
158 | " response_format=\"text\",\n",
159 | " language=\"en\"\n",
160 | ")"
161 | ]
162 | },
163 | {
164 | "cell_type": "markdown",
165 | "id": "bd2d2a31",
166 | "metadata": {},
167 | "source": [
168 | "And the transcript is simply:"
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": 7,
174 | "id": "73835706",
175 | "metadata": {},
176 | "outputs": [
177 | {
178 | "name": "stdout",
179 | "output_type": "stream",
180 | "text": [
181 | "Four score and seven years ago, our fathers brought forth on this continent a new nation, conceived in liberty and dedicated to the proposition that all men are created equal.\n",
182 | "\n"
183 | ]
184 | }
185 | ],
186 | "source": [
187 | "print(transcript)"
188 | ]
189 | },
190 | {
191 | "cell_type": "markdown",
192 | "id": "4405b5f8",
193 | "metadata": {},
194 | "source": [
195 | "We can also ask for SRT formatted output, that includes time indices"
196 | ]
197 | },
198 | {
199 | "cell_type": "code",
200 | "execution_count": 9,
201 | "id": "89213b6c",
202 | "metadata": {},
203 | "outputs": [
204 | {
205 | "name": "stdout",
206 | "output_type": "stream",
207 | "text": [
208 | "CPU times: user 9.76 ms, sys: 6.07 ms, total: 15.8 ms\n",
209 | "Wall time: 1.75 s\n"
210 | ]
211 | }
212 | ],
213 | "source": [
214 | "%%time\n",
215 | "transcript = client.audio.transcriptions.create(\n",
216 | " file = open(\"data/gettysburg10.wav\", \"rb\"),\n",
217 | " model = \"whisper-1\",\n",
218 | " response_format=\"srt\",\n",
219 | " language=\"en\"\n",
220 | ")"
221 | ]
222 | },
223 | {
224 | "cell_type": "code",
225 | "execution_count": 10,
226 | "id": "bb7ce224",
227 | "metadata": {},
228 | "outputs": [
229 | {
230 | "name": "stdout",
231 | "output_type": "stream",
232 | "text": [
233 | "1\n",
234 | "00:00:00,000 --> 00:00:05,660\n",
235 | "Four score and seven years ago, our fathers brought forth on this continent a new nation,\n",
236 | "\n",
237 | "2\n",
238 | "00:00:05,660 --> 00:00:09,880\n",
239 | "conceived in liberty and dedicated to the proposition that all men are created equal.\n",
240 | "\n",
241 | "\n",
242 | "\n"
243 | ]
244 | }
245 | ],
246 | "source": [
247 | "print(transcript)"
248 | ]
249 | },
250 | {
251 | "cell_type": "markdown",
252 | "id": "8b364858",
253 | "metadata": {},
254 | "source": [
255 | "And ask it to translate the text directly into Spanish"
256 | ]
257 | },
258 | {
259 | "cell_type": "code",
260 | "execution_count": 11,
261 | "id": "6bf62b58",
262 | "metadata": {},
263 | "outputs": [
264 | {
265 | "name": "stdout",
266 | "output_type": "stream",
267 | "text": [
268 | "CPU times: user 8.99 ms, sys: 4.28 ms, total: 13.3 ms\n",
269 | "Wall time: 6.84 s\n"
270 | ]
271 | }
272 | ],
273 | "source": [
274 | "%%time\n",
275 | "transcript = client.audio.transcriptions.create(\n",
276 | " file = open(\"data/gettysburg10.wav\", \"rb\"),\n",
277 | " model = \"whisper-1\",\n",
278 | " response_format=\"text\",\n",
279 | " language=\"es\"\n",
280 | ")"
281 | ]
282 | },
283 | {
284 | "cell_type": "code",
285 | "execution_count": 12,
286 | "id": "d379164c",
287 | "metadata": {},
288 | "outputs": [
289 | {
290 | "name": "stdout",
291 | "output_type": "stream",
292 | "text": [
293 | "Hace cuatro y siete años, nuestros padres trajeron a este continente una nueva nación concebida en libertad y dedicada a la proposición de que todos los hombres son creados iguales.\n",
294 | "\n"
295 | ]
296 | }
297 | ],
298 | "source": [
299 | "print(transcript)"
300 | ]
301 | },
302 | {
303 | "cell_type": "markdown",
304 | "id": "edf7a2c6",
305 | "metadata": {},
306 | "source": [
307 | "# Text to Speech"
308 | ]
309 | },
310 | {
311 | "cell_type": "markdown",
312 | "id": "2dcd1c57",
313 | "metadata": {},
314 | "source": [
315 | "Now the opposite approach, going from written text to high quality audio"
316 | ]
317 | },
318 | {
319 | "cell_type": "code",
320 | "execution_count": 13,
321 | "id": "739ca530",
322 | "metadata": {},
323 | "outputs": [],
324 | "source": [
325 | "quote = \"\"\"\n",
326 | "Scientists have calculated that the chances of something so patently absurd \n",
327 | "actually existing are millions to one.\n",
328 | "But magicians have calculated that million-to-one chances crop up nine times out of ten.\n",
329 | "\"\"\""
330 | ]
331 | },
332 | {
333 | "cell_type": "markdown",
334 | "id": "7c3a764a",
335 | "metadata": {},
336 | "source": [
337 | "You can learn more about text to speech (and sample the various voices) in the [Official documentation](https://platform.openai.com/docs/guides/text-to-speech/quickstart)"
338 | ]
339 | },
340 | {
341 | "cell_type": "code",
342 | "execution_count": 14,
343 | "id": "a47ef5f0",
344 | "metadata": {},
345 | "outputs": [
346 | {
347 | "name": "stdout",
348 | "output_type": "stream",
349 | "text": [
350 | "CPU times: user 69.8 ms, sys: 47.4 ms, total: 117 ms\n",
351 | "Wall time: 2.74 s\n"
352 | ]
353 | }
354 | ],
355 | "source": [
356 | "%%time\n",
357 | "audio = client.audio.speech.create(\n",
358 | " input=quote, \n",
359 | " model=\"tts-1\", \n",
360 | " voice='fable',\n",
361 | " response_format='mp3')"
362 | ]
363 | },
364 | {
365 | "cell_type": "markdown",
366 | "id": "44e2c4bc",
367 | "metadata": {},
368 | "source": [
369 | "Which we can write directly to a file"
370 | ]
371 | },
372 | {
373 | "cell_type": "code",
374 | "execution_count": 15,
375 | "id": "4bfa78eb",
376 | "metadata": {},
377 | "outputs": [],
378 | "source": [
379 | "audio.write_to_file('data/pratchett.mp3')"
380 | ]
381 | },
382 | {
383 | "cell_type": "code",
384 | "execution_count": 17,
385 | "id": "be0291f3",
386 | "metadata": {},
387 | "outputs": [],
388 | "source": [
389 | "!open data/pratchett.mp3"
390 | ]
391 | },
392 | {
393 | "cell_type": "markdown",
394 | "id": "a9c818dd",
395 | "metadata": {},
396 | "source": [
397 | "\n",
398 | "
\n",
399 | ""
400 | ]
401 | }
402 | ],
403 | "metadata": {
404 | "kernelspec": {
405 | "display_name": "Python 3 (ipykernel)",
406 | "language": "python",
407 | "name": "python3"
408 | },
409 | "language_info": {
410 | "codemirror_mode": {
411 | "name": "ipython",
412 | "version": 3
413 | },
414 | "file_extension": ".py",
415 | "mimetype": "text/x-python",
416 | "name": "python",
417 | "nbconvert_exporter": "python",
418 | "pygments_lexer": "ipython3",
419 | "version": "3.11.7"
420 | },
421 | "varInspector": {
422 | "cols": {
423 | "lenName": 16,
424 | "lenType": 16,
425 | "lenVar": 40
426 | },
427 | "kernels_config": {
428 | "python": {
429 | "delete_cmd_postfix": "",
430 | "delete_cmd_prefix": "del ",
431 | "library": "var_list.py",
432 | "varRefreshCmd": "print(var_dic_list())"
433 | },
434 | "r": {
435 | "delete_cmd_postfix": ") ",
436 | "delete_cmd_prefix": "rm(",
437 | "library": "var_list.r",
438 | "varRefreshCmd": "cat(var_dic_list()) "
439 | }
440 | },
441 | "types_to_exclude": [
442 | "module",
443 | "function",
444 | "builtin_function_or_method",
445 | "instance",
446 | "_Feature"
447 | ],
448 | "window_display": false
449 | }
450 | },
451 | "nbformat": 4,
452 | "nbformat_minor": 5
453 | }
454 |
--------------------------------------------------------------------------------
/5. PandasAI.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "\n",
8 | "
\n",
9 | "
LLMs for Data Science
\n",
10 | "
PandasAI
\n",
11 | "
Bruno Gonçalves
\n",
12 | " www.data4sci.com
\n",
13 | " @bgoncalves, @data4sci
\n",
14 | "
"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 1,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "from collections import Counter\n",
24 | "from pprint import pprint\n",
25 | "\n",
26 | "import pandas as pd\n",
27 | "import numpy as np\n",
28 | "import matplotlib.pyplot as plt \n",
29 | "import sqlite3\n",
30 | "\n",
31 | "import pandasai\n",
32 | "from pandasai import SmartDataframe, SmartDatalake, Agent\n",
33 | "from pandasai.llm import OpenAI\n",
34 | "from pandasai.connectors import SqliteConnector\n",
35 | "\n",
36 | "import watermark\n",
37 | "\n",
38 | "%load_ext watermark\n",
39 | "%matplotlib inline"
40 | ]
41 | },
42 | {
43 | "cell_type": "markdown",
44 | "metadata": {},
45 | "source": [
46 | "We start by print out the versions of the libraries we're using for future reference"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": 2,
52 | "metadata": {},
53 | "outputs": [
54 | {
55 | "name": "stdout",
56 | "output_type": "stream",
57 | "text": [
58 | "Python implementation: CPython\n",
59 | "Python version : 3.11.7\n",
60 | "IPython version : 8.12.3\n",
61 | "\n",
62 | "Compiler : Clang 14.0.6 \n",
63 | "OS : Darwin\n",
64 | "Release : 24.3.0\n",
65 | "Machine : arm64\n",
66 | "Processor : arm\n",
67 | "CPU cores : 16\n",
68 | "Architecture: 64bit\n",
69 | "\n",
70 | "Git hash: 03802c3bf87993c3670c2fd8bf86e59d3d60bdfd\n",
71 | "\n",
72 | "numpy : 1.26.4\n",
73 | "matplotlib: 3.8.0\n",
74 | "sqlite3 : 2.6.0\n",
75 | "pandas : 2.2.3\n",
76 | "pandasai : 2.2.14\n",
77 | "watermark : 2.4.3\n",
78 | "\n"
79 | ]
80 | }
81 | ],
82 | "source": [
83 | "%watermark -n -v -m -g -iv"
84 | ]
85 | },
86 | {
87 | "cell_type": "markdown",
88 | "metadata": {},
89 | "source": [
90 | "Load default figure style"
91 | ]
92 | },
93 | {
94 | "cell_type": "code",
95 | "execution_count": 3,
96 | "metadata": {},
97 | "outputs": [],
98 | "source": [
99 | "#plt.style.use('./d4sci.mplstyle')"
100 | ]
101 | },
102 | {
103 | "cell_type": "markdown",
104 | "metadata": {},
105 | "source": [
106 | "# SmartDataframe"
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "execution_count": 4,
112 | "metadata": {},
113 | "outputs": [
114 | {
115 | "data": {
116 | "text/html": [
117 | "\n",
118 | "\n",
131 | "
\n",
132 | " \n",
133 | " \n",
134 | " | \n",
135 | " country | \n",
136 | " sales | \n",
137 | "
\n",
138 | " \n",
139 | " \n",
140 | " \n",
141 | " 0 | \n",
142 | " United States | \n",
143 | " 5000 | \n",
144 | "
\n",
145 | " \n",
146 | " 1 | \n",
147 | " United Kingdom | \n",
148 | " 3200 | \n",
149 | "
\n",
150 | " \n",
151 | " 2 | \n",
152 | " France | \n",
153 | " 2900 | \n",
154 | "
\n",
155 | " \n",
156 | " 3 | \n",
157 | " Germany | \n",
158 | " 4100 | \n",
159 | "
\n",
160 | " \n",
161 | " 4 | \n",
162 | " Italy | \n",
163 | " 2300 | \n",
164 | "
\n",
165 | " \n",
166 | " 5 | \n",
167 | " Spain | \n",
168 | " 2100 | \n",
169 | "
\n",
170 | " \n",
171 | " 6 | \n",
172 | " Canada | \n",
173 | " 2500 | \n",
174 | "
\n",
175 | " \n",
176 | " 7 | \n",
177 | " Australia | \n",
178 | " 2600 | \n",
179 | "
\n",
180 | " \n",
181 | " 8 | \n",
182 | " Japan | \n",
183 | " 4500 | \n",
184 | "
\n",
185 | " \n",
186 | " 9 | \n",
187 | " China | \n",
188 | " 7000 | \n",
189 | "
\n",
190 | " \n",
191 | "
\n",
192 | "
"
193 | ],
194 | "text/plain": [
195 | " country sales\n",
196 | "0 United States 5000\n",
197 | "1 United Kingdom 3200\n",
198 | "2 France 2900\n",
199 | "3 Germany 4100\n",
200 | "4 Italy 2300\n",
201 | "5 Spain 2100\n",
202 | "6 Canada 2500\n",
203 | "7 Australia 2600\n",
204 | "8 Japan 4500\n",
205 | "9 China 7000"
206 | ]
207 | },
208 | "execution_count": 4,
209 | "metadata": {},
210 | "output_type": "execute_result"
211 | }
212 | ],
213 | "source": [
214 | "sales_by_country = pd.DataFrame({\n",
215 | " \"country\": [\"United States\", \"United Kingdom\", \"France\", \"Germany\", \"Italy\", \"Spain\", \"Canada\", \"Australia\", \"Japan\", \"China\"],\n",
216 | " \"sales\": [5000, 3200, 2900, 4100, 2300, 2100, 2500, 2600, 4500, 7000]\n",
217 | "})\n",
218 | "sales_by_country"
219 | ]
220 | },
221 | {
222 | "cell_type": "markdown",
223 | "metadata": {},
224 | "source": [
225 | "Unless we specify otherwise, it will default to __BambooLLM__"
226 | ]
227 | },
228 | {
229 | "cell_type": "code",
230 | "execution_count": 5,
231 | "metadata": {},
232 | "outputs": [
233 | {
234 | "data": {
235 | "text/plain": [
236 | "'The top 5 countries by sales are: China, United States, Japan, Germany, United Kingdom'"
237 | ]
238 | },
239 | "execution_count": 5,
240 | "metadata": {},
241 | "output_type": "execute_result"
242 | }
243 | ],
244 | "source": [
245 | "df = SmartDataframe(sales_by_country)\n",
246 | "df.chat('Which are the top 5 countries by sales?')"
247 | ]
248 | },
249 | {
250 | "cell_type": "code",
251 | "execution_count": 6,
252 | "metadata": {},
253 | "outputs": [
254 | {
255 | "data": {
256 | "text/html": [
257 | "\n",
258 | "\n",
271 | "
\n",
272 | " \n",
273 | " \n",
274 | " | \n",
275 | " country | \n",
276 | " sales | \n",
277 | "
\n",
278 | " \n",
279 | " \n",
280 | " \n",
281 | " 9 | \n",
282 | " China | \n",
283 | " 7000 | \n",
284 | "
\n",
285 | " \n",
286 | " 0 | \n",
287 | " United States | \n",
288 | " 5000 | \n",
289 | "
\n",
290 | " \n",
291 | " 8 | \n",
292 | " Japan | \n",
293 | " 4500 | \n",
294 | "
\n",
295 | " \n",
296 | " 3 | \n",
297 | " Germany | \n",
298 | " 4100 | \n",
299 | "
\n",
300 | " \n",
301 | " 1 | \n",
302 | " United Kingdom | \n",
303 | " 3200 | \n",
304 | "
\n",
305 | " \n",
306 | "
\n",
307 | "
"
308 | ],
309 | "text/plain": [
310 | " country sales\n",
311 | "9 China 7000\n",
312 | "0 United States 5000\n",
313 | "8 Japan 4500\n",
314 | "3 Germany 4100\n",
315 | "1 United Kingdom 3200"
316 | ]
317 | },
318 | "execution_count": 6,
319 | "metadata": {},
320 | "output_type": "execute_result"
321 | }
322 | ],
323 | "source": [
324 | "sales_by_country.sort_values('sales', ascending=False).head(5)"
325 | ]
326 | },
327 | {
328 | "cell_type": "markdown",
329 | "metadata": {},
330 | "source": [
331 | "We create an OpenAI instance (with our own api key)"
332 | ]
333 | },
334 | {
335 | "cell_type": "code",
336 | "execution_count": 7,
337 | "metadata": {},
338 | "outputs": [],
339 | "source": [
340 | "openai = OpenAI()"
341 | ]
342 | },
343 | {
344 | "cell_type": "markdown",
345 | "metadata": {},
346 | "source": [
347 | "And provide it to the SmartDataframe"
348 | ]
349 | },
350 | {
351 | "cell_type": "code",
352 | "execution_count": 8,
353 | "metadata": {},
354 | "outputs": [],
355 | "source": [
356 | "df = SmartDataframe(sales_by_country, config={\"llm\": openai})"
357 | ]
358 | },
359 | {
360 | "cell_type": "markdown",
361 | "metadata": {},
362 | "source": [
363 | "Which we can interact with by asking specific questions"
364 | ]
365 | },
366 | {
367 | "cell_type": "code",
368 | "execution_count": 9,
369 | "metadata": {},
370 | "outputs": [
371 | {
372 | "data": {
373 | "text/plain": [
374 | "'The top 5 countries by sales are: China, United States, Japan, Germany, United Kingdom'"
375 | ]
376 | },
377 | "execution_count": 9,
378 | "metadata": {},
379 | "output_type": "execute_result"
380 | }
381 | ],
382 | "source": [
383 | "df.chat('Which are the top 5 countries by sales?')"
384 | ]
385 | },
386 | {
387 | "cell_type": "markdown",
388 | "metadata": {},
389 | "source": [
390 | "Questions can be almost arbitrarily complex, as long as the LLM is able to understand them :D"
391 | ]
392 | },
393 | {
394 | "cell_type": "code",
395 | "execution_count": 10,
396 | "metadata": {},
397 | "outputs": [
398 | {
399 | "name": "stdout",
400 | "output_type": "stream",
401 | "text": [
402 | "{'type': 'dataframe', 'value': country sales\n",
403 | "0 United States 5000\n",
404 | "9 China 7000}\n"
405 | ]
406 | },
407 | {
408 | "data": {
409 | "text/html": [
410 | "\n",
411 | "\n",
424 | "
\n",
425 | " \n",
426 | " \n",
427 | " | \n",
428 | " country | \n",
429 | " sales | \n",
430 | "
\n",
431 | " \n",
432 | " \n",
433 | " \n",
434 | " 0 | \n",
435 | " United States | \n",
436 | " 5000 | \n",
437 | "
\n",
438 | " \n",
439 | " 9 | \n",
440 | " China | \n",
441 | " 7000 | \n",
442 | "
\n",
443 | " \n",
444 | "
\n",
445 | "
"
446 | ],
447 | "text/plain": [
448 | " country sales\n",
449 | "0 United States 5000\n",
450 | "9 China 7000"
451 | ]
452 | },
453 | "execution_count": 10,
454 | "metadata": {},
455 | "output_type": "execute_result"
456 | }
457 | ],
458 | "source": [
459 | "df.chat('which countries have 5000 sales or more?')"
460 | ]
461 | },
462 | {
463 | "cell_type": "code",
464 | "execution_count": 11,
465 | "metadata": {},
466 | "outputs": [
467 | {
468 | "data": {
469 | "text/html": [
470 | "\n",
471 | "\n",
484 | "
\n",
485 | " \n",
486 | " \n",
487 | " | \n",
488 | " country | \n",
489 | " sales | \n",
490 | "
\n",
491 | " \n",
492 | " \n",
493 | " \n",
494 | " 0 | \n",
495 | " United States | \n",
496 | " 5000 | \n",
497 | "
\n",
498 | " \n",
499 | " 9 | \n",
500 | " China | \n",
501 | " 7000 | \n",
502 | "
\n",
503 | " \n",
504 | "
\n",
505 | "
"
506 | ],
507 | "text/plain": [
508 | " country sales\n",
509 | "0 United States 5000\n",
510 | "9 China 7000"
511 | ]
512 | },
513 | "execution_count": 11,
514 | "metadata": {},
515 | "output_type": "execute_result"
516 | }
517 | ],
518 | "source": [
519 | "sales_by_country[sales_by_country.sales >= 5000]"
520 | ]
521 | },
522 | {
523 | "cell_type": "markdown",
524 | "metadata": {},
525 | "source": [
526 | "# SmartDatalake"
527 | ]
528 | },
529 | {
530 | "cell_type": "markdown",
531 | "metadata": {},
532 | "source": [
533 | "To interact with multiple dataframes we must use a SmartDatalake. We start by defining a couple of small data dictionaries"
534 | ]
535 | },
536 | {
537 | "cell_type": "code",
538 | "execution_count": 12,
539 | "metadata": {},
540 | "outputs": [],
541 | "source": [
542 | "employees_data = {\n",
543 | " 'EmployeeID': [1, 2, 3, 4, 5],\n",
544 | " 'Name': ['John', 'Emma', 'Liam', 'Olivia', 'William'],\n",
545 | " 'Department': ['HR', 'Sales', 'IT', 'Marketing', 'Finance']\n",
546 | "}\n",
547 | "\n",
548 | "salaries_data = {\n",
549 | " 'EmployeeID': [1, 2, 3, 4, 5],\n",
550 | " 'Salary': [5000, 6000, 4500, 7000, 5500]\n",
551 | "}"
552 | ]
553 | },
554 | {
555 | "cell_type": "markdown",
556 | "metadata": {},
557 | "source": [
558 | "That we load into individual pandas Dataframe objects (these could also be SmartDataframes instead)"
559 | ]
560 | },
561 | {
562 | "cell_type": "code",
563 | "execution_count": 13,
564 | "metadata": {},
565 | "outputs": [],
566 | "source": [
567 | "employees_df = pd.DataFrame(employees_data)\n",
568 | "salaries_df = pd.DataFrame(salaries_data)"
569 | ]
570 | },
571 | {
572 | "cell_type": "markdown",
573 | "metadata": {},
574 | "source": [
575 | "Which we can now pass to our Smartdatalake constructor"
576 | ]
577 | },
578 | {
579 | "cell_type": "code",
580 | "execution_count": 14,
581 | "metadata": {},
582 | "outputs": [],
583 | "source": [
584 | "lake = SmartDatalake([employees_df, salaries_df], config={\"llm\": openai})"
585 | ]
586 | },
587 | {
588 | "cell_type": "markdown",
589 | "metadata": {},
590 | "source": [
591 | "And interact with just as before"
592 | ]
593 | },
594 | {
595 | "cell_type": "code",
596 | "execution_count": 15,
597 | "metadata": {},
598 | "outputs": [
599 | {
600 | "name": "stdout",
601 | "output_type": "stream",
602 | "text": [
603 | "{'type': 'string', 'value': 'The highest paid employee is Olivia with a salary of 7000.'}\n"
604 | ]
605 | },
606 | {
607 | "data": {
608 | "text/plain": [
609 | "'The highest paid employee is Olivia with a salary of 7000.'"
610 | ]
611 | },
612 | "execution_count": 15,
613 | "metadata": {},
614 | "output_type": "execute_result"
615 | }
616 | ],
617 | "source": [
618 | "lake.chat(\"Who gets paid the most?\")"
619 | ]
620 | },
621 | {
622 | "cell_type": "markdown",
623 | "metadata": {},
624 | "source": [
625 | "Please not that to answer this question, PandasAI had a to perform a join between the two data frames "
626 | ]
627 | },
628 | {
629 | "cell_type": "code",
630 | "execution_count": 16,
631 | "metadata": {},
632 | "outputs": [
633 | {
634 | "data": {
635 | "text/html": [
636 | "\n",
637 | "\n",
650 | "
\n",
651 | " \n",
652 | " \n",
653 | " | \n",
654 | " EmployeeID | \n",
655 | " Name | \n",
656 | " Department | \n",
657 | " Salary | \n",
658 | "
\n",
659 | " \n",
660 | " \n",
661 | " \n",
662 | " 3 | \n",
663 | " 4 | \n",
664 | " Olivia | \n",
665 | " Marketing | \n",
666 | " 7000 | \n",
667 | "
\n",
668 | " \n",
669 | " 1 | \n",
670 | " 2 | \n",
671 | " Emma | \n",
672 | " Sales | \n",
673 | " 6000 | \n",
674 | "
\n",
675 | " \n",
676 | " 4 | \n",
677 | " 5 | \n",
678 | " William | \n",
679 | " Finance | \n",
680 | " 5500 | \n",
681 | "
\n",
682 | " \n",
683 | " 0 | \n",
684 | " 1 | \n",
685 | " John | \n",
686 | " HR | \n",
687 | " 5000 | \n",
688 | "
\n",
689 | " \n",
690 | " 2 | \n",
691 | " 3 | \n",
692 | " Liam | \n",
693 | " IT | \n",
694 | " 4500 | \n",
695 | "
\n",
696 | " \n",
697 | "
\n",
698 | "
"
699 | ],
700 | "text/plain": [
701 | " EmployeeID Name Department Salary\n",
702 | "3 4 Olivia Marketing 7000\n",
703 | "1 2 Emma Sales 6000\n",
704 | "4 5 William Finance 5500\n",
705 | "0 1 John HR 5000\n",
706 | "2 3 Liam IT 4500"
707 | ]
708 | },
709 | "execution_count": 16,
710 | "metadata": {},
711 | "output_type": "execute_result"
712 | }
713 | ],
714 | "source": [
715 | "joint_df = employees_df.merge(salaries_df, on='EmployeeID')\n",
716 | "joint_df.sort_values('Salary', ascending=False)"
717 | ]
718 | },
719 | {
720 | "cell_type": "markdown",
721 | "metadata": {},
722 | "source": [
723 | "And then select the Name corresponding to the largest Salary"
724 | ]
725 | },
726 | {
727 | "cell_type": "code",
728 | "execution_count": 17,
729 | "metadata": {},
730 | "outputs": [
731 | {
732 | "data": {
733 | "text/plain": [
734 | "'Olivia'"
735 | ]
736 | },
737 | "execution_count": 17,
738 | "metadata": {},
739 | "output_type": "execute_result"
740 | }
741 | ],
742 | "source": [
743 | "joint_df.iloc[joint_df['Salary'].idxmax()].Name"
744 | ]
745 | },
746 | {
747 | "cell_type": "markdown",
748 | "metadata": {},
749 | "source": [
750 | "# SqliteConnector"
751 | ]
752 | },
753 | {
754 | "cell_type": "markdown",
755 | "metadata": {},
756 | "source": [
757 | "We instantiate a SqliteConnector with all the necessary information about our table"
758 | ]
759 | },
760 | {
761 | "cell_type": "code",
762 | "execution_count": 18,
763 | "metadata": {},
764 | "outputs": [],
765 | "source": [
766 | "employee_connector = SqliteConnector(config={\n",
767 | " \"database\" : \"data/Northwind_small.sqlite\",\n",
768 | " \"table\" : \"Employee\",\n",
769 | "})"
770 | ]
771 | },
772 | {
773 | "cell_type": "markdown",
774 | "metadata": {},
775 | "source": [
776 | "Which we can use as just any other data source"
777 | ]
778 | },
779 | {
780 | "cell_type": "code",
781 | "execution_count": 19,
782 | "metadata": {},
783 | "outputs": [],
784 | "source": [
785 | "employee = SmartDataframe(employee_connector)"
786 | ]
787 | },
788 | {
789 | "cell_type": "code",
790 | "execution_count": 20,
791 | "metadata": {},
792 | "outputs": [
793 | {
794 | "data": {
795 | "text/plain": [
796 | "9"
797 | ]
798 | },
799 | "execution_count": 20,
800 | "metadata": {},
801 | "output_type": "execute_result"
802 | }
803 | ],
804 | "source": [
805 | "employee.chat('How many records are there ?')"
806 | ]
807 | },
808 | {
809 | "cell_type": "markdown",
810 | "metadata": {},
811 | "source": [
812 | "Which we can easily confirm"
813 | ]
814 | },
815 | {
816 | "cell_type": "code",
817 | "execution_count": 21,
818 | "metadata": {},
819 | "outputs": [
820 | {
821 | "data": {
822 | "text/html": [
823 | "\n",
824 | "\n",
837 | "
\n",
838 | " \n",
839 | " \n",
840 | " | \n",
841 | " COUNT(*) | \n",
842 | "
\n",
843 | " \n",
844 | " \n",
845 | " \n",
846 | " 0 | \n",
847 | " 9 | \n",
848 | "
\n",
849 | " \n",
850 | "
\n",
851 | "
"
852 | ],
853 | "text/plain": [
854 | " COUNT(*)\n",
855 | "0 9"
856 | ]
857 | },
858 | "execution_count": 21,
859 | "metadata": {},
860 | "output_type": "execute_result"
861 | }
862 | ],
863 | "source": [
864 | "con = sqlite3.connect(\"data/Northwind_small.sqlite\")\n",
865 | "pd.read_sql_query(\"SELECT COUNT(*) FROM employee;\", con)"
866 | ]
867 | },
868 | {
869 | "cell_type": "markdown",
870 | "metadata": {},
871 | "source": [
872 | "Let us now create a SmartDatalake with a few other tables. We have to create a connector for each of them"
873 | ]
874 | },
875 | {
876 | "cell_type": "code",
877 | "execution_count": 22,
878 | "metadata": {},
879 | "outputs": [],
880 | "source": [
881 | "territory_connector = SqliteConnector(config={\n",
882 | " \"database\" : \"data/Northwind_small.sqlite\",\n",
883 | " \"table\" : \"EmployeeTerritory\",\n",
884 | "})\n",
885 | "\n",
886 | "shipper_connector = SqliteConnector(config={\n",
887 | " \"database\" : \"data/Northwind_small.sqlite\",\n",
888 | " \"table\" : \"Shipper\",\n",
889 | "})\n",
890 | "\n",
891 | "detail_connector = SqliteConnector(config={\n",
892 | " \"database\" : \"data/Northwind_small.sqlite\",\n",
893 | " \"table\" : \"OrderDetail\",\n",
894 | "})"
895 | ]
896 | },
897 | {
898 | "cell_type": "code",
899 | "execution_count": 23,
900 | "metadata": {},
901 | "outputs": [],
902 | "source": [
903 | "northwind = Agent([\n",
904 | " employee_connector, \n",
905 | " territory_connector, \n",
906 | " shipper_connector, \n",
907 | " detail_connector\n",
908 | "], \n",
909 | "config={\"llm\": openai})"
910 | ]
911 | },
912 | {
913 | "cell_type": "code",
914 | "execution_count": 24,
915 | "metadata": {},
916 | "outputs": [
917 | {
918 | "data": {
919 | "text/html": [
920 | "\n",
921 | "\n",
934 | "
\n",
935 | " \n",
936 | " \n",
937 | " | \n",
938 | " FirstName | \n",
939 | " LastName | \n",
940 | " TerritoryId | \n",
941 | "
\n",
942 | " \n",
943 | " \n",
944 | " \n",
945 | " 0 | \n",
946 | " Nancy | \n",
947 | " Davolio | \n",
948 | " 6897 | \n",
949 | "
\n",
950 | " \n",
951 | " 1 | \n",
952 | " Nancy | \n",
953 | " Davolio | \n",
954 | " 19713 | \n",
955 | "
\n",
956 | " \n",
957 | " 2 | \n",
958 | " Andrew | \n",
959 | " Fuller | \n",
960 | " 1581 | \n",
961 | "
\n",
962 | " \n",
963 | " 3 | \n",
964 | " Andrew | \n",
965 | " Fuller | \n",
966 | " 1730 | \n",
967 | "
\n",
968 | " \n",
969 | " 4 | \n",
970 | " Andrew | \n",
971 | " Fuller | \n",
972 | " 1833 | \n",
973 | "
\n",
974 | " \n",
975 | " 5 | \n",
976 | " Andrew | \n",
977 | " Fuller | \n",
978 | " 2116 | \n",
979 | "
\n",
980 | " \n",
981 | " 6 | \n",
982 | " Andrew | \n",
983 | " Fuller | \n",
984 | " 2139 | \n",
985 | "
\n",
986 | " \n",
987 | " 7 | \n",
988 | " Andrew | \n",
989 | " Fuller | \n",
990 | " 2184 | \n",
991 | "
\n",
992 | " \n",
993 | " 8 | \n",
994 | " Andrew | \n",
995 | " Fuller | \n",
996 | " 40222 | \n",
997 | "
\n",
998 | " \n",
999 | " 9 | \n",
1000 | " Janet | \n",
1001 | " Leverling | \n",
1002 | " 30346 | \n",
1003 | "
\n",
1004 | " \n",
1005 | " 10 | \n",
1006 | " Janet | \n",
1007 | " Leverling | \n",
1008 | " 31406 | \n",
1009 | "
\n",
1010 | " \n",
1011 | " 11 | \n",
1012 | " Janet | \n",
1013 | " Leverling | \n",
1014 | " 32859 | \n",
1015 | "
\n",
1016 | " \n",
1017 | " 12 | \n",
1018 | " Janet | \n",
1019 | " Leverling | \n",
1020 | " 33607 | \n",
1021 | "
\n",
1022 | " \n",
1023 | " 13 | \n",
1024 | " Margaret | \n",
1025 | " Peacock | \n",
1026 | " 20852 | \n",
1027 | "
\n",
1028 | " \n",
1029 | " 14 | \n",
1030 | " Margaret | \n",
1031 | " Peacock | \n",
1032 | " 27403 | \n",
1033 | "
\n",
1034 | " \n",
1035 | " 15 | \n",
1036 | " Margaret | \n",
1037 | " Peacock | \n",
1038 | " 27511 | \n",
1039 | "
\n",
1040 | " \n",
1041 | " 16 | \n",
1042 | " Steven | \n",
1043 | " Buchanan | \n",
1044 | " 2903 | \n",
1045 | "
\n",
1046 | " \n",
1047 | " 17 | \n",
1048 | " Steven | \n",
1049 | " Buchanan | \n",
1050 | " 7960 | \n",
1051 | "
\n",
1052 | " \n",
1053 | " 18 | \n",
1054 | " Steven | \n",
1055 | " Buchanan | \n",
1056 | " 8837 | \n",
1057 | "
\n",
1058 | " \n",
1059 | " 19 | \n",
1060 | " Steven | \n",
1061 | " Buchanan | \n",
1062 | " 10019 | \n",
1063 | "
\n",
1064 | " \n",
1065 | " 20 | \n",
1066 | " Steven | \n",
1067 | " Buchanan | \n",
1068 | " 10038 | \n",
1069 | "
\n",
1070 | " \n",
1071 | " 21 | \n",
1072 | " Steven | \n",
1073 | " Buchanan | \n",
1074 | " 11747 | \n",
1075 | "
\n",
1076 | " \n",
1077 | " 22 | \n",
1078 | " Steven | \n",
1079 | " Buchanan | \n",
1080 | " 14450 | \n",
1081 | "
\n",
1082 | " \n",
1083 | " 23 | \n",
1084 | " Michael | \n",
1085 | " Suyama | \n",
1086 | " 85014 | \n",
1087 | "
\n",
1088 | " \n",
1089 | " 24 | \n",
1090 | " Michael | \n",
1091 | " Suyama | \n",
1092 | " 85251 | \n",
1093 | "
\n",
1094 | " \n",
1095 | " 25 | \n",
1096 | " Michael | \n",
1097 | " Suyama | \n",
1098 | " 98004 | \n",
1099 | "
\n",
1100 | " \n",
1101 | " 26 | \n",
1102 | " Michael | \n",
1103 | " Suyama | \n",
1104 | " 98052 | \n",
1105 | "
\n",
1106 | " \n",
1107 | " 27 | \n",
1108 | " Michael | \n",
1109 | " Suyama | \n",
1110 | " 98104 | \n",
1111 | "
\n",
1112 | " \n",
1113 | " 28 | \n",
1114 | " Robert | \n",
1115 | " King | \n",
1116 | " 60179 | \n",
1117 | "
\n",
1118 | " \n",
1119 | " 29 | \n",
1120 | " Robert | \n",
1121 | " King | \n",
1122 | " 60601 | \n",
1123 | "
\n",
1124 | " \n",
1125 | " 30 | \n",
1126 | " Robert | \n",
1127 | " King | \n",
1128 | " 80202 | \n",
1129 | "
\n",
1130 | " \n",
1131 | " 31 | \n",
1132 | " Robert | \n",
1133 | " King | \n",
1134 | " 80909 | \n",
1135 | "
\n",
1136 | " \n",
1137 | " 32 | \n",
1138 | " Robert | \n",
1139 | " King | \n",
1140 | " 90405 | \n",
1141 | "
\n",
1142 | " \n",
1143 | " 33 | \n",
1144 | " Robert | \n",
1145 | " King | \n",
1146 | " 94025 | \n",
1147 | "
\n",
1148 | " \n",
1149 | " 34 | \n",
1150 | " Robert | \n",
1151 | " King | \n",
1152 | " 94105 | \n",
1153 | "
\n",
1154 | " \n",
1155 | " 35 | \n",
1156 | " Robert | \n",
1157 | " King | \n",
1158 | " 95008 | \n",
1159 | "
\n",
1160 | " \n",
1161 | " 36 | \n",
1162 | " Robert | \n",
1163 | " King | \n",
1164 | " 95054 | \n",
1165 | "
\n",
1166 | " \n",
1167 | " 37 | \n",
1168 | " Robert | \n",
1169 | " King | \n",
1170 | " 95060 | \n",
1171 | "
\n",
1172 | " \n",
1173 | " 38 | \n",
1174 | " Laura | \n",
1175 | " Callahan | \n",
1176 | " 19428 | \n",
1177 | "
\n",
1178 | " \n",
1179 | " 39 | \n",
1180 | " Laura | \n",
1181 | " Callahan | \n",
1182 | " 44122 | \n",
1183 | "
\n",
1184 | " \n",
1185 | " 40 | \n",
1186 | " Laura | \n",
1187 | " Callahan | \n",
1188 | " 45839 | \n",
1189 | "
\n",
1190 | " \n",
1191 | " 41 | \n",
1192 | " Laura | \n",
1193 | " Callahan | \n",
1194 | " 53404 | \n",
1195 | "
\n",
1196 | " \n",
1197 | " 42 | \n",
1198 | " Anne | \n",
1199 | " Dodsworth | \n",
1200 | " 3049 | \n",
1201 | "
\n",
1202 | " \n",
1203 | " 43 | \n",
1204 | " Anne | \n",
1205 | " Dodsworth | \n",
1206 | " 3801 | \n",
1207 | "
\n",
1208 | " \n",
1209 | " 44 | \n",
1210 | " Anne | \n",
1211 | " Dodsworth | \n",
1212 | " 48075 | \n",
1213 | "
\n",
1214 | " \n",
1215 | " 45 | \n",
1216 | " Anne | \n",
1217 | " Dodsworth | \n",
1218 | " 48084 | \n",
1219 | "
\n",
1220 | " \n",
1221 | " 46 | \n",
1222 | " Anne | \n",
1223 | " Dodsworth | \n",
1224 | " 48304 | \n",
1225 | "
\n",
1226 | " \n",
1227 | " 47 | \n",
1228 | " Anne | \n",
1229 | " Dodsworth | \n",
1230 | " 55113 | \n",
1231 | "
\n",
1232 | " \n",
1233 | " 48 | \n",
1234 | " Anne | \n",
1235 | " Dodsworth | \n",
1236 | " 55439 | \n",
1237 | "
\n",
1238 | " \n",
1239 | "
\n",
1240 | "
"
1241 | ],
1242 | "text/plain": [
1243 | " FirstName LastName TerritoryId\n",
1244 | "0 Nancy Davolio 6897\n",
1245 | "1 Nancy Davolio 19713\n",
1246 | "2 Andrew Fuller 1581\n",
1247 | "3 Andrew Fuller 1730\n",
1248 | "4 Andrew Fuller 1833\n",
1249 | "5 Andrew Fuller 2116\n",
1250 | "6 Andrew Fuller 2139\n",
1251 | "7 Andrew Fuller 2184\n",
1252 | "8 Andrew Fuller 40222\n",
1253 | "9 Janet Leverling 30346\n",
1254 | "10 Janet Leverling 31406\n",
1255 | "11 Janet Leverling 32859\n",
1256 | "12 Janet Leverling 33607\n",
1257 | "13 Margaret Peacock 20852\n",
1258 | "14 Margaret Peacock 27403\n",
1259 | "15 Margaret Peacock 27511\n",
1260 | "16 Steven Buchanan 2903\n",
1261 | "17 Steven Buchanan 7960\n",
1262 | "18 Steven Buchanan 8837\n",
1263 | "19 Steven Buchanan 10019\n",
1264 | "20 Steven Buchanan 10038\n",
1265 | "21 Steven Buchanan 11747\n",
1266 | "22 Steven Buchanan 14450\n",
1267 | "23 Michael Suyama 85014\n",
1268 | "24 Michael Suyama 85251\n",
1269 | "25 Michael Suyama 98004\n",
1270 | "26 Michael Suyama 98052\n",
1271 | "27 Michael Suyama 98104\n",
1272 | "28 Robert King 60179\n",
1273 | "29 Robert King 60601\n",
1274 | "30 Robert King 80202\n",
1275 | "31 Robert King 80909\n",
1276 | "32 Robert King 90405\n",
1277 | "33 Robert King 94025\n",
1278 | "34 Robert King 94105\n",
1279 | "35 Robert King 95008\n",
1280 | "36 Robert King 95054\n",
1281 | "37 Robert King 95060\n",
1282 | "38 Laura Callahan 19428\n",
1283 | "39 Laura Callahan 44122\n",
1284 | "40 Laura Callahan 45839\n",
1285 | "41 Laura Callahan 53404\n",
1286 | "42 Anne Dodsworth 3049\n",
1287 | "43 Anne Dodsworth 3801\n",
1288 | "44 Anne Dodsworth 48075\n",
1289 | "45 Anne Dodsworth 48084\n",
1290 | "46 Anne Dodsworth 48304\n",
1291 | "47 Anne Dodsworth 55113\n",
1292 | "48 Anne Dodsworth 55439"
1293 | ]
1294 | },
1295 | "execution_count": 24,
1296 | "metadata": {},
1297 | "output_type": "execute_result"
1298 | }
1299 | ],
1300 | "source": [
1301 | "northwind.chat(\"Generate a table with employee first name, last name and territory id\")"
1302 | ]
1303 | },
1304 | {
1305 | "cell_type": "code",
1306 | "execution_count": 25,
1307 | "metadata": {},
1308 | "outputs": [
1309 | {
1310 | "name": "stdout",
1311 | "output_type": "stream",
1312 | "text": [
1313 | "{'type': 'dataframe', 'value': TerritoryId EmployeeCount\n",
1314 | "0 6897 1\n",
1315 | "1 98004 1\n",
1316 | "2 98104 1\n",
1317 | "3 60179 1\n",
1318 | "4 60601 1\n",
1319 | "5 80202 1\n",
1320 | "6 80909 1\n",
1321 | "7 90405 1\n",
1322 | "8 94025 1\n",
1323 | "9 94105 1\n",
1324 | "10 95008 1\n",
1325 | "11 95054 1\n",
1326 | "12 95060 1\n",
1327 | "13 19428 1\n",
1328 | "14 44122 1\n",
1329 | "15 45839 1\n",
1330 | "16 53404 1\n",
1331 | "17 3049 1\n",
1332 | "18 3801 1\n",
1333 | "19 48075 1\n",
1334 | "20 48084 1\n",
1335 | "21 48304 1\n",
1336 | "22 55113 1\n",
1337 | "23 98052 1\n",
1338 | "24 85251 1\n",
1339 | "25 19713 1\n",
1340 | "26 85014 1\n",
1341 | "27 1581 1\n",
1342 | "28 1730 1\n",
1343 | "29 1833 1\n",
1344 | "30 2116 1\n",
1345 | "31 2139 1\n",
1346 | "32 2184 1\n",
1347 | "33 40222 1\n",
1348 | "34 30346 1\n",
1349 | "35 31406 1\n",
1350 | "36 32859 1\n",
1351 | "37 33607 1\n",
1352 | "38 20852 1\n",
1353 | "39 27403 1\n",
1354 | "40 27511 1\n",
1355 | "41 2903 1\n",
1356 | "42 7960 1\n",
1357 | "43 8837 1\n",
1358 | "44 10019 1\n",
1359 | "45 10038 1\n",
1360 | "46 11747 1\n",
1361 | "47 14450 1\n",
1362 | "48 55439 1}\n"
1363 | ]
1364 | }
1365 | ],
1366 | "source": [
1367 | "northwind.chat(\"Compute how many employees work in each territory\");"
1368 | ]
1369 | },
1370 | {
1371 | "cell_type": "markdown",
1372 | "metadata": {},
1373 | "source": [
1374 | "Which again required the LLM to figure out the not so trivial JOIN to perform"
1375 | ]
1376 | },
1377 | {
1378 | "cell_type": "code",
1379 | "execution_count": 26,
1380 | "metadata": {},
1381 | "outputs": [
1382 | {
1383 | "data": {
1384 | "text/html": [
1385 | "\n",
1386 | "\n",
1399 | "
\n",
1400 | " \n",
1401 | " \n",
1402 | " | \n",
1403 | " TerritoryId | \n",
1404 | " EmployeeCount | \n",
1405 | "
\n",
1406 | " \n",
1407 | " \n",
1408 | " \n",
1409 | " 0 | \n",
1410 | " 01581 | \n",
1411 | " 1 | \n",
1412 | "
\n",
1413 | " \n",
1414 | " 1 | \n",
1415 | " 01730 | \n",
1416 | " 1 | \n",
1417 | "
\n",
1418 | " \n",
1419 | " 2 | \n",
1420 | " 01833 | \n",
1421 | " 1 | \n",
1422 | "
\n",
1423 | " \n",
1424 | " 3 | \n",
1425 | " 02116 | \n",
1426 | " 1 | \n",
1427 | "
\n",
1428 | " \n",
1429 | " 4 | \n",
1430 | " 02139 | \n",
1431 | " 1 | \n",
1432 | "
\n",
1433 | " \n",
1434 | " 5 | \n",
1435 | " 02184 | \n",
1436 | " 1 | \n",
1437 | "
\n",
1438 | " \n",
1439 | " 6 | \n",
1440 | " 02903 | \n",
1441 | " 1 | \n",
1442 | "
\n",
1443 | " \n",
1444 | " 7 | \n",
1445 | " 03049 | \n",
1446 | " 1 | \n",
1447 | "
\n",
1448 | " \n",
1449 | " 8 | \n",
1450 | " 03801 | \n",
1451 | " 1 | \n",
1452 | "
\n",
1453 | " \n",
1454 | " 9 | \n",
1455 | " 06897 | \n",
1456 | " 1 | \n",
1457 | "
\n",
1458 | " \n",
1459 | " 10 | \n",
1460 | " 07960 | \n",
1461 | " 1 | \n",
1462 | "
\n",
1463 | " \n",
1464 | " 11 | \n",
1465 | " 08837 | \n",
1466 | " 1 | \n",
1467 | "
\n",
1468 | " \n",
1469 | " 12 | \n",
1470 | " 10019 | \n",
1471 | " 1 | \n",
1472 | "
\n",
1473 | " \n",
1474 | " 13 | \n",
1475 | " 10038 | \n",
1476 | " 1 | \n",
1477 | "
\n",
1478 | " \n",
1479 | " 14 | \n",
1480 | " 11747 | \n",
1481 | " 1 | \n",
1482 | "
\n",
1483 | " \n",
1484 | " 15 | \n",
1485 | " 14450 | \n",
1486 | " 1 | \n",
1487 | "
\n",
1488 | " \n",
1489 | " 16 | \n",
1490 | " 19428 | \n",
1491 | " 1 | \n",
1492 | "
\n",
1493 | " \n",
1494 | " 17 | \n",
1495 | " 19713 | \n",
1496 | " 1 | \n",
1497 | "
\n",
1498 | " \n",
1499 | " 18 | \n",
1500 | " 20852 | \n",
1501 | " 1 | \n",
1502 | "
\n",
1503 | " \n",
1504 | " 19 | \n",
1505 | " 27403 | \n",
1506 | " 1 | \n",
1507 | "
\n",
1508 | " \n",
1509 | " 20 | \n",
1510 | " 27511 | \n",
1511 | " 1 | \n",
1512 | "
\n",
1513 | " \n",
1514 | " 21 | \n",
1515 | " 30346 | \n",
1516 | " 1 | \n",
1517 | "
\n",
1518 | " \n",
1519 | " 22 | \n",
1520 | " 31406 | \n",
1521 | " 1 | \n",
1522 | "
\n",
1523 | " \n",
1524 | " 23 | \n",
1525 | " 32859 | \n",
1526 | " 1 | \n",
1527 | "
\n",
1528 | " \n",
1529 | " 24 | \n",
1530 | " 33607 | \n",
1531 | " 1 | \n",
1532 | "
\n",
1533 | " \n",
1534 | " 25 | \n",
1535 | " 40222 | \n",
1536 | " 1 | \n",
1537 | "
\n",
1538 | " \n",
1539 | " 26 | \n",
1540 | " 44122 | \n",
1541 | " 1 | \n",
1542 | "
\n",
1543 | " \n",
1544 | " 27 | \n",
1545 | " 45839 | \n",
1546 | " 1 | \n",
1547 | "
\n",
1548 | " \n",
1549 | " 28 | \n",
1550 | " 48075 | \n",
1551 | " 1 | \n",
1552 | "
\n",
1553 | " \n",
1554 | " 29 | \n",
1555 | " 48084 | \n",
1556 | " 1 | \n",
1557 | "
\n",
1558 | " \n",
1559 | " 30 | \n",
1560 | " 48304 | \n",
1561 | " 1 | \n",
1562 | "
\n",
1563 | " \n",
1564 | " 31 | \n",
1565 | " 53404 | \n",
1566 | " 1 | \n",
1567 | "
\n",
1568 | " \n",
1569 | " 32 | \n",
1570 | " 55113 | \n",
1571 | " 1 | \n",
1572 | "
\n",
1573 | " \n",
1574 | " 33 | \n",
1575 | " 55439 | \n",
1576 | " 1 | \n",
1577 | "
\n",
1578 | " \n",
1579 | " 34 | \n",
1580 | " 60179 | \n",
1581 | " 1 | \n",
1582 | "
\n",
1583 | " \n",
1584 | " 35 | \n",
1585 | " 60601 | \n",
1586 | " 1 | \n",
1587 | "
\n",
1588 | " \n",
1589 | " 36 | \n",
1590 | " 80202 | \n",
1591 | " 1 | \n",
1592 | "
\n",
1593 | " \n",
1594 | " 37 | \n",
1595 | " 80909 | \n",
1596 | " 1 | \n",
1597 | "
\n",
1598 | " \n",
1599 | " 38 | \n",
1600 | " 85014 | \n",
1601 | " 1 | \n",
1602 | "
\n",
1603 | " \n",
1604 | " 39 | \n",
1605 | " 85251 | \n",
1606 | " 1 | \n",
1607 | "
\n",
1608 | " \n",
1609 | " 40 | \n",
1610 | " 90405 | \n",
1611 | " 1 | \n",
1612 | "
\n",
1613 | " \n",
1614 | " 41 | \n",
1615 | " 94025 | \n",
1616 | " 1 | \n",
1617 | "
\n",
1618 | " \n",
1619 | " 42 | \n",
1620 | " 94105 | \n",
1621 | " 1 | \n",
1622 | "
\n",
1623 | " \n",
1624 | " 43 | \n",
1625 | " 95008 | \n",
1626 | " 1 | \n",
1627 | "
\n",
1628 | " \n",
1629 | " 44 | \n",
1630 | " 95054 | \n",
1631 | " 1 | \n",
1632 | "
\n",
1633 | " \n",
1634 | " 45 | \n",
1635 | " 95060 | \n",
1636 | " 1 | \n",
1637 | "
\n",
1638 | " \n",
1639 | " 46 | \n",
1640 | " 98004 | \n",
1641 | " 1 | \n",
1642 | "
\n",
1643 | " \n",
1644 | " 47 | \n",
1645 | " 98052 | \n",
1646 | " 1 | \n",
1647 | "
\n",
1648 | " \n",
1649 | " 48 | \n",
1650 | " 98104 | \n",
1651 | " 1 | \n",
1652 | "
\n",
1653 | " \n",
1654 | "
\n",
1655 | "
"
1656 | ],
1657 | "text/plain": [
1658 | " TerritoryId EmployeeCount\n",
1659 | "0 01581 1\n",
1660 | "1 01730 1\n",
1661 | "2 01833 1\n",
1662 | "3 02116 1\n",
1663 | "4 02139 1\n",
1664 | "5 02184 1\n",
1665 | "6 02903 1\n",
1666 | "7 03049 1\n",
1667 | "8 03801 1\n",
1668 | "9 06897 1\n",
1669 | "10 07960 1\n",
1670 | "11 08837 1\n",
1671 | "12 10019 1\n",
1672 | "13 10038 1\n",
1673 | "14 11747 1\n",
1674 | "15 14450 1\n",
1675 | "16 19428 1\n",
1676 | "17 19713 1\n",
1677 | "18 20852 1\n",
1678 | "19 27403 1\n",
1679 | "20 27511 1\n",
1680 | "21 30346 1\n",
1681 | "22 31406 1\n",
1682 | "23 32859 1\n",
1683 | "24 33607 1\n",
1684 | "25 40222 1\n",
1685 | "26 44122 1\n",
1686 | "27 45839 1\n",
1687 | "28 48075 1\n",
1688 | "29 48084 1\n",
1689 | "30 48304 1\n",
1690 | "31 53404 1\n",
1691 | "32 55113 1\n",
1692 | "33 55439 1\n",
1693 | "34 60179 1\n",
1694 | "35 60601 1\n",
1695 | "36 80202 1\n",
1696 | "37 80909 1\n",
1697 | "38 85014 1\n",
1698 | "39 85251 1\n",
1699 | "40 90405 1\n",
1700 | "41 94025 1\n",
1701 | "42 94105 1\n",
1702 | "43 95008 1\n",
1703 | "44 95054 1\n",
1704 | "45 95060 1\n",
1705 | "46 98004 1\n",
1706 | "47 98052 1\n",
1707 | "48 98104 1"
1708 | ]
1709 | },
1710 | "execution_count": 26,
1711 | "metadata": {},
1712 | "output_type": "execute_result"
1713 | }
1714 | ],
1715 | "source": [
1716 | "pd.read_sql_query(\"\"\"\n",
1717 | "SELECT \n",
1718 | " TerritoryID, \n",
1719 | " COUNT(DISTINCT EmployeeID) AS EmployeeCount\n",
1720 | "FROM EmployeeTerritory \n",
1721 | "JOIN Employee ON Employee.ID=EmployeeID \n",
1722 | "GROUP BY 1\n",
1723 | "\"\"\", con)"
1724 | ]
1725 | },
1726 | {
1727 | "cell_type": "markdown",
1728 | "metadata": {},
1729 | "source": [
1730 | "And we can confirm the way the agent is thinking"
1731 | ]
1732 | },
1733 | {
1734 | "cell_type": "code",
1735 | "execution_count": 27,
1736 | "metadata": {},
1737 | "outputs": [
1738 | {
1739 | "name": "stdout",
1740 | "output_type": "stream",
1741 | "text": [
1742 | "To create the code, I started by thinking about the information we needed: the names of employees and the territories they belong to. First, I combined two sets of information—one that had employee details and another that listed which territories those employees are associated with.\n",
1743 | "\n",
1744 | "Once I had this combined information, I focused on counting how many employees were in each territory. I organized the data so that it clearly showed each territory alongside the number of employees working there. Finally, I prepared the results in a way that makes it easy to understand and use.\n",
1745 | "\n",
1746 | "In summary, the process involved gathering the right information, combining it, counting the employees per territory, and then presenting the results clearly.\n"
1747 | ]
1748 | }
1749 | ],
1750 | "source": [
1751 | "print(northwind.explain())"
1752 | ]
1753 | },
1754 | {
1755 | "cell_type": "code",
1756 | "execution_count": 28,
1757 | "metadata": {},
1758 | "outputs": [],
1759 | "source": [
1760 | "code = northwind.generate_code(\"Compute how many employees work in each territory\")"
1761 | ]
1762 | },
1763 | {
1764 | "cell_type": "code",
1765 | "execution_count": 29,
1766 | "metadata": {},
1767 | "outputs": [],
1768 | "source": [
1769 | "from pprint import pprint"
1770 | ]
1771 | },
1772 | {
1773 | "cell_type": "code",
1774 | "execution_count": 30,
1775 | "metadata": {},
1776 | "outputs": [
1777 | {
1778 | "name": "stdout",
1779 | "output_type": "stream",
1780 | "text": [
1781 | "('employees_df = dfs[0]\\n'\n",
1782 | " 'employee_territory_df = dfs[1]\\n'\n",
1783 | " \"merged_df = pd.merge(employee_territory_df, employees_df[['Id', 'FirstName', \"\n",
1784 | " \"'LastName']], left_on='EmployeeId', right_on='Id')\\n\"\n",
1785 | " \"territory_counts = merged_df['TerritoryId'].value_counts().reset_index()\\n\"\n",
1786 | " \"territory_counts.columns = ['TerritoryId', 'EmployeeCount']\\n\"\n",
1787 | " \"result = {'type': 'dataframe', 'value': territory_counts}\\n\"\n",
1788 | " 'print(result)')\n"
1789 | ]
1790 | }
1791 | ],
1792 | "source": [
1793 | "pprint(code)"
1794 | ]
1795 | },
1796 | {
1797 | "cell_type": "code",
1798 | "execution_count": 31,
1799 | "metadata": {},
1800 | "outputs": [],
1801 | "source": [
1802 | "territory = SmartDataframe(territory_connector)"
1803 | ]
1804 | },
1805 | {
1806 | "cell_type": "code",
1807 | "execution_count": 32,
1808 | "metadata": {},
1809 | "outputs": [
1810 | {
1811 | "data": {
1812 | "text/html": [
1813 | "\n",
1814 | "\n",
1827 | "
\n",
1828 | " \n",
1829 | " \n",
1830 | " | \n",
1831 | " EmployeeId | \n",
1832 | " NumTerritories | \n",
1833 | "
\n",
1834 | " \n",
1835 | " \n",
1836 | " \n",
1837 | " 0 | \n",
1838 | " 1 | \n",
1839 | " 2 | \n",
1840 | "
\n",
1841 | " \n",
1842 | " 1 | \n",
1843 | " 2 | \n",
1844 | " 7 | \n",
1845 | "
\n",
1846 | " \n",
1847 | " 2 | \n",
1848 | " 3 | \n",
1849 | " 4 | \n",
1850 | "
\n",
1851 | " \n",
1852 | " 3 | \n",
1853 | " 4 | \n",
1854 | " 3 | \n",
1855 | "
\n",
1856 | " \n",
1857 | " 4 | \n",
1858 | " 5 | \n",
1859 | " 7 | \n",
1860 | "
\n",
1861 | " \n",
1862 | " 5 | \n",
1863 | " 6 | \n",
1864 | " 5 | \n",
1865 | "
\n",
1866 | " \n",
1867 | " 6 | \n",
1868 | " 7 | \n",
1869 | " 10 | \n",
1870 | "
\n",
1871 | " \n",
1872 | " 7 | \n",
1873 | " 8 | \n",
1874 | " 4 | \n",
1875 | "
\n",
1876 | " \n",
1877 | " 8 | \n",
1878 | " 9 | \n",
1879 | " 7 | \n",
1880 | "
\n",
1881 | " \n",
1882 | "
\n",
1883 | "
"
1884 | ],
1885 | "text/plain": [
1886 | " EmployeeId NumTerritories\n",
1887 | "0 1 2\n",
1888 | "1 2 7\n",
1889 | "2 3 4\n",
1890 | "3 4 3\n",
1891 | "4 5 7\n",
1892 | "5 6 5\n",
1893 | "6 7 10\n",
1894 | "7 8 4\n",
1895 | "8 9 7"
1896 | ]
1897 | },
1898 | "execution_count": 32,
1899 | "metadata": {},
1900 | "output_type": "execute_result"
1901 | }
1902 | ],
1903 | "source": [
1904 | "territory.chat('Compute the number of territories by employeeid')"
1905 | ]
1906 | },
1907 | {
1908 | "cell_type": "markdown",
1909 | "metadata": {},
1910 | "source": [
1911 | "## Extracting Generated Code"
1912 | ]
1913 | },
1914 | {
1915 | "cell_type": "markdown",
1916 | "metadata": {},
1917 | "source": [
1918 | "PandasAI allows us to see what code it used to produce the previous output by looking into the __last_code_executed__ field. See [here](https://github.com/Sinaptik-AI/pandas-ai/discussions/500) for details"
1919 | ]
1920 | },
1921 | {
1922 | "cell_type": "code",
1923 | "execution_count": 33,
1924 | "metadata": {},
1925 | "outputs": [
1926 | {
1927 | "name": "stdout",
1928 | "output_type": "stream",
1929 | "text": [
1930 | "employee_df = dfs[0]\n",
1931 | "employee_territory_df = dfs[1]\n",
1932 | "merged_df = pd.merge(employee_df[['Id', 'FirstName', 'LastName']], employee_territory_df, left_on='Id', right_on='EmployeeId')\n",
1933 | "territory_counts = merged_df['TerritoryId'].value_counts().reset_index()\n",
1934 | "territory_counts.columns = ['TerritoryId', 'EmployeeCount']\n",
1935 | "result = {'type': 'dataframe', 'value': territory_counts}\n",
1936 | "print(result)\n"
1937 | ]
1938 | }
1939 | ],
1940 | "source": [
1941 | "print(northwind.last_code_executed)"
1942 | ]
1943 | },
1944 | {
1945 | "cell_type": "markdown",
1946 | "metadata": {},
1947 | "source": [
1948 | "\n",
1949 | "

\n",
1950 | "
"
1951 | ]
1952 | }
1953 | ],
1954 | "metadata": {
1955 | "kernelspec": {
1956 | "display_name": "Python 3 (ipykernel)",
1957 | "language": "python",
1958 | "name": "python3"
1959 | },
1960 | "language_info": {
1961 | "codemirror_mode": {
1962 | "name": "ipython",
1963 | "version": 3
1964 | },
1965 | "file_extension": ".py",
1966 | "mimetype": "text/x-python",
1967 | "name": "python",
1968 | "nbconvert_exporter": "python",
1969 | "pygments_lexer": "ipython3",
1970 | "version": "3.11.7"
1971 | },
1972 | "toc": {
1973 | "base_numbering": 1,
1974 | "nav_menu": {},
1975 | "number_sections": true,
1976 | "sideBar": true,
1977 | "skip_h1_title": true,
1978 | "title_cell": "Table of Contents",
1979 | "title_sidebar": "Contents",
1980 | "toc_cell": false,
1981 | "toc_position": {},
1982 | "toc_section_display": true,
1983 | "toc_window_display": false
1984 | },
1985 | "varInspector": {
1986 | "cols": {
1987 | "lenName": 16,
1988 | "lenType": 16,
1989 | "lenVar": 40
1990 | },
1991 | "kernels_config": {
1992 | "python": {
1993 | "delete_cmd_postfix": "",
1994 | "delete_cmd_prefix": "del ",
1995 | "library": "var_list.py",
1996 | "varRefreshCmd": "print(var_dic_list())"
1997 | },
1998 | "r": {
1999 | "delete_cmd_postfix": ") ",
2000 | "delete_cmd_prefix": "rm(",
2001 | "library": "var_list.r",
2002 | "varRefreshCmd": "cat(var_dic_list()) "
2003 | }
2004 | },
2005 | "types_to_exclude": [
2006 | "module",
2007 | "function",
2008 | "builtin_function_or_method",
2009 | "instance",
2010 | "_Feature"
2011 | ],
2012 | "window_display": false
2013 | }
2014 | },
2015 | "nbformat": 4,
2016 | "nbformat_minor": 4
2017 | }
2018 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Data For Science
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
2 | [](https://twitter.com/intent/follow?screen_name=data4sci)
3 | 
4 | 
5 | 
6 |
7 | [](https://graphs4sci.substack.com/)
8 | [](https://data4science.ck.page/a63d4cc8d9)
9 |
10 | # Large Language Models for Data Science
11 |
12 | ### Code and slides to accompany the online series of webinars: https://data4sci.com/llm4ds by Data For Science.
13 |
14 | Large Language Models (LLMs) are powerful tools that put state-of-the-art AI capabilities at the tip of our fingers. They can process large amounts of data, understand nuance and context, and perform complex tasks at our request. Over the past few years, LLMs have multiplied as have the tools specially built to leverage their capabilities.
15 |
16 | In this course, you will learn how to use large language models to perform data science tasks such as summarization, translation, named entity recognition, audio generation, and data processing. We’ll explore the possibilities afforded by the tools and APIs developed by OpenAI, Hugging Face, LangChain, and Pandas AI and how best to apply them to our data science work.
17 |
18 | ## Schedule
19 | ### 1.Generative AI for Data Science
20 | - Generative AI
21 | - Large Language Models
22 | - OpenAI
23 | - HuggingFace
24 | - LangChain
25 |
26 | ### 2. Prompt Engineering for Data Science
27 | - Output formatting
28 | - Prompt Techniques
29 | - Zero-Shot and Few-Shot Prompting
30 | - Chain of Thought
31 |
32 | ### 3. Natural Language Processing with HuggingFace
33 | - Named-Entity Recognition
34 | - Part-of-Speech Tagging
35 | - Summarization
36 | - Question Answering
37 |
38 | ### 4.Text to Speech with Open AI
39 | - The Whisper model
40 | - Generatinag audio from text
41 | - Audio transcription
42 | - Automatic Translation
43 |
44 | ### 5.Pandas AI
45 | - Pandas AI library structure
46 | - Natural language querying
47 | - Data cleaning
48 | - Data visualization
--------------------------------------------------------------------------------
/d4sci.mplstyle:
--------------------------------------------------------------------------------
1 | # Data For Science style
2 | # Author: Bruno Goncalves
3 | # Modified from the matplotlib FiveThirtyEight style by
4 | # Author: Cameron Davidson-Pilon, replicated styles from FiveThirtyEight.com
5 | # See https://www.dataorigami.net/blogs/fivethirtyeight-mpl
6 |
7 | lines.linewidth: 4
8 | lines.solid_capstyle: butt
9 |
10 | legend.fancybox: true
11 |
12 | axes.prop_cycle: cycler('color', ['51a7f9', 'cf51f9', '70bf41', 'f39019', 'f9e351', 'f9517b', '6d904f', '8b8b8b','810f7c'])
13 |
14 | axes.labelsize: large
15 | axes.axisbelow: true
16 | axes.grid: true
17 | axes.edgecolor: f0f0f0
18 | axes.linewidth: 3.0
19 | axes.titlesize: x-large
20 |
21 | patch.edgecolor: f0f0f0
22 | patch.linewidth: 0.5
23 |
24 | svg.fonttype: path
25 |
26 | grid.linestyle: -
27 | grid.linewidth: 1.0
28 |
29 | xtick.major.size: 0
30 | xtick.minor.size: 0
31 | ytick.major.size: 0
32 | ytick.minor.size: 0
33 |
34 | font.size: 24.0
35 |
36 | savefig.edgecolor: f0f0f0
37 | savefig.facecolor: f0f0f0
38 |
39 | figure.subplot.left: 0.08
40 | figure.subplot.right: 0.95
41 | figure.subplot.bottom: 0.07
42 | figure.figsize: 12.8, 8.8
43 | figure.autolayout: True
44 | figure.dpi: 300
45 |
--------------------------------------------------------------------------------
/data/D4Sci_logo_ball.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataForScience/LLM4DS/ef16502ad70ed190dccb67646fe2df0a26fe0854/data/D4Sci_logo_ball.png
--------------------------------------------------------------------------------
/data/D4Sci_logo_full.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataForScience/LLM4DS/ef16502ad70ed190dccb67646fe2df0a26fe0854/data/D4Sci_logo_full.png
--------------------------------------------------------------------------------
/data/EpiModel.py:
--------------------------------------------------------------------------------
1 | ### −∗− mode : python ; −∗−
2 | # @file EpiModel.py
3 | # @author Bruno Goncalves
4 | ######################################################
5 |
6 | import networkx as nx
7 | import numpy as np
8 | from numpy import linalg
9 | from numpy import random
10 | import scipy.integrate
11 | import pandas as pd
12 | import matplotlib.pyplot as plt
13 |
14 | from tqdm import tqdm
15 | tqdm.pandas()
16 |
17 | class EpiModel(object):
18 | """Simple Epidemic Model Implementation
19 |
20 | Provides a way to implement and numerically integrate
21 | """
22 | def __init__(self, compartments=None):
23 | self.transitions = nx.MultiDiGraph()
24 | self.seasonality = None
25 |
26 | if compartments is not None:
27 | self.transitions.add_nodes_from([comp for comp in compartments])
28 |
29 | def add_interaction(self, source, target, agent, rate):
30 | self.transitions.add_edge(source, target, agent=agent, rate=rate)
31 |
32 | def add_spontaneous(self, source, target, rate):
33 | self.transitions.add_edge(source, target, rate=rate)
34 |
35 | def add_vaccination(self, source, target, rate, start):
36 | self.transitions.add_edge(source, target, rate=rate, start=start)
37 |
38 | def _new_cases(self, population, time, pos):
39 | """Internal function used by integration routine"""
40 | diff = np.zeros(len(pos))
41 | N = np.sum(population)
42 |
43 | for edge in self.transitions.edges(data=True):
44 | source = edge[0]
45 | target = edge[1]
46 | trans = edge[2]
47 |
48 | rate = trans['rate']*population[pos[source]]
49 |
50 | if 'start' in trans and trans['start'] >= time:
51 | continue
52 |
53 | if 'agent' in trans:
54 | agent = trans['agent']
55 | rate *= population[pos[agent]]/N
56 |
57 | if self.seasonality is not None:
58 | curr_t = int(time)%365
59 | season = float(self.seasonality[curr_t])
60 | rate *= season
61 |
62 | diff[pos[source]] -= rate
63 | diff[pos[target]] += rate
64 |
65 | return diff
66 |
67 | def plot(self, title=None, normed=True, **kwargs):
68 | """Convenience function for plotting"""
69 | try:
70 | if normed:
71 | N = self.values_.iloc[0].sum()
72 | ax = (self.values_/N).plot(**kwargs)
73 | else:
74 | ax = self.values_.plot(**kwargs)
75 |
76 | ax.set_xlabel('Time')
77 | ax.set_ylabel('Population')
78 |
79 | if title is not None:
80 | ax.set_title(title)
81 |
82 | return ax
83 | except:
84 | raise NotInitialized('You must call integrate() first')
85 |
86 | def __getattr__(self, name):
87 | """Dynamic method to return the individual compartment values"""
88 | if 'values_' in self.__dict__:
89 | return self.values_[name]
90 | else:
91 | raise AttributeError("'EpiModel' object has no attribute '%s'" % name)
92 |
93 | def simulate(self, timesteps, t_min=1, seasonality=None, **kwargs):
94 | """Stochastically simulate the epidemic model"""
95 | pos = {comp: i for i, comp in enumerate(self.transitions.nodes())}
96 | population=np.zeros(len(pos), dtype='int')
97 |
98 | for comp in kwargs:
99 | population[pos[comp]] = kwargs[comp]
100 |
101 | values = []
102 | values.append(population)
103 |
104 | comps = list(self.transitions.nodes)
105 | time = np.arange(t_min, t_min+timesteps, 1, dtype='int')
106 |
107 | self.seasonality = seasonality
108 |
109 | for t in time:
110 | pop = values[-1]
111 | new_pop = values[-1].copy()
112 | N = np.sum(pop)
113 |
114 |
115 | for comp in comps:
116 | trans = list(self.transitions.edges(comp, data=True))
117 |
118 | prob = np.zeros(len(comps), dtype='float')
119 |
120 | for _, node_j, data in trans:
121 | source = pos[comp]
122 | target = pos[node_j]
123 |
124 | rate = data['rate']
125 |
126 | if 'start' in data and data['start'] >= t:
127 | continue
128 |
129 | if 'agent' in data:
130 | agent = pos[data['agent']]
131 | rate *= pop[agent]/N
132 |
133 | if self.seasonality is not None:
134 | curr_t = int(t)%365
135 | season = float(self.seasonality[curr_t])
136 | rate *= season
137 |
138 | prob[target] = rate
139 |
140 | prob[source] = 1-np.sum(prob)
141 |
142 | delta = random.multinomial(pop[source], prob)
143 | delta[source] = 0
144 |
145 | changes = np.sum(delta)
146 |
147 | if changes == 0:
148 | continue
149 |
150 | new_pop[source] -= changes
151 |
152 | for i in range(len(delta)):
153 | new_pop[i] += delta[i]
154 |
155 | values.append(new_pop)
156 |
157 | values = np.array(values)
158 | self.values_ = pd.DataFrame(values[1:], columns=comps, index=time)
159 |
160 | def integrate(self, timesteps, t_min=1, seasonality=None, **kwargs):
161 | """Numerically integrate the epidemic model"""
162 | pos = {comp: i for i, comp in enumerate(self.transitions.nodes())}
163 | population=np.zeros(len(pos))
164 |
165 | for comp in kwargs:
166 | population[pos[comp]] = kwargs[comp]
167 |
168 | time = np.arange(t_min, t_min+timesteps, 1)
169 |
170 | self.seasonality = seasonality
171 | self.values_ = pd.DataFrame(scipy.integrate.odeint(self._new_cases, population, time, args=(pos,)), columns=pos.keys(), index=time)
172 |
173 | def __repr__(self):
174 | text = 'Epidemic Model with %u compartments and %u transitions:\n\n' % \
175 | (self.transitions.number_of_nodes(),
176 | self.transitions.number_of_edges())
177 |
178 | for edge in self.transitions.edges(data=True):
179 | source = edge[0]
180 | target = edge[1]
181 | trans = edge[2]
182 |
183 | rate = trans['rate']
184 |
185 | if 'agent' in trans:
186 | agent = trans['agent']
187 | text += "%s + %s = %s %f\n" % (source, agent, target, rate)
188 | elif 'start' in trans:
189 | start = trans['start']
190 | text+="%s -> %s %f starting at %s days\n" % (source, target, rate, start)
191 | else:
192 | text+="%s -> %s %f\n" % (source, target, rate)
193 |
194 | R0 = self.R0()
195 |
196 | if R0 is not None:
197 | text += "\nR0=%1.2f" % R0
198 |
199 | return text
200 |
201 | def _get_active(self):
202 | active = set()
203 |
204 | for node_i, node_j, data in self.transitions.edges(data=True):
205 | if "agent" in data:
206 | active.add(data['agent'])
207 | else:
208 | active.add(node_i)
209 |
210 | return active
211 |
212 | def _get_susceptible(self):
213 | susceptible = set([node for node, deg in self.transitions.in_degree() if deg==0])
214 |
215 | if len(susceptible) == 0:
216 | for node_i, node_j, data in self.transitions.edges(data=True):
217 | if "agent" in data:
218 | susceptible.add(node_i)
219 |
220 | return susceptible
221 |
222 | def _get_infections(self):
223 | inf = {}
224 |
225 | for node_i, node_j, data in self.transitions.edges(data=True):
226 | if "agent" in data:
227 | agent = data['agent']
228 |
229 | if agent not in inf:
230 | inf[agent] = {}
231 |
232 | if node_i not in inf[agent]:
233 | inf[agent][node_i] = {}
234 |
235 | inf[agent][node_i]['target'] = node_j
236 | inf[agent][node_i]['rate'] = data['rate']
237 |
238 | return inf
239 |
240 |
241 | def R0(self):
242 | infected = set()
243 |
244 | susceptible = self._get_susceptible()
245 |
246 | for node_i, node_j, data in self.transitions.edges(data=True):
247 | if "agent" in data:
248 | infected.add(data['agent'])
249 | infected.add(node_j)
250 |
251 |
252 | infected = sorted(infected)
253 | N_infected = len(infected)
254 |
255 | F = np.zeros((N_infected, N_infected), dtype='float')
256 | V = np.zeros((N_infected, N_infected), dtype='float')
257 |
258 | pos = dict(zip(infected, np.arange(N_infected)))
259 |
260 | try:
261 | for node_i, node_j, data in self.transitions.edges(data=True):
262 | rate = data['rate']
263 |
264 | if "agent" in data:
265 | target = pos[node_j]
266 | agent = pos[data['agent']]
267 |
268 | if node_i in susceptible:
269 | F[target, agent] = rate
270 | elif "start" in data:
271 | continue
272 | else:
273 | source = pos[node_i]
274 |
275 | V[source, source] += rate
276 |
277 | if node_j in pos:
278 | target = pos[node_j]
279 | V[target, source] -= rate
280 |
281 | eig, v = linalg.eig(np.dot(F, linalg.inv(V)))
282 |
283 | return eig.max()
284 | except:
285 | return None
286 |
287 | def __getitem__(self, bla):
288 | return self.values_[bla]
289 |
290 | if __name__ == '__main__':
291 |
292 | beta = 0.2
293 | mu = 0.1
294 |
295 | SIR = EpiModel()
296 | SIR.add_interaction('S', 'I', 'I', beta)
297 | SIR.add_spontaneous('I', 'R', mu)
298 | SIR.add_vaccination('S', 'V', 0.01, 75)
299 | SIR.add_spontaneous('VI', 'VR', mu)
300 | SIR.add_interaction('V', 'VI', 'I', beta*(1-.8))
301 |
302 |
303 | print(SIR)
304 |
305 | N = 100000
306 | I0 = 10
307 |
308 | season = np.ones(365+1)
309 | season[74:100] = 0.25
310 |
311 | fig, ax = plt.subplots(1)
312 |
313 | Nruns = 1000
314 | values = []
315 |
316 | #for i in tqdm(range(Nruns), total=Nruns):
317 | SIR.integrate(365, S=.3*N-10, I=10, V=.7*N)
318 |
319 | SIR[['I', 'VI', 'VR', 'R']].plot(ax=ax)
320 | print(SIR.S.tail())
321 | #ax.plot(SIR.I/N, lw=.1, c='b')
322 |
323 | if SIR.I.max() > 10:
324 | values.append(SIR.I)
325 |
326 | values = pd.DataFrame(values)
327 | (values.median(axis=0)/N).plot(ax=ax, c='r')
328 |
329 | fig.savefig('SIR.png')
330 |
331 |
332 |
--------------------------------------------------------------------------------
/data/Northwind_small.sqlite:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataForScience/LLM4DS/ef16502ad70ed190dccb67646fe2df0a26fe0854/data/Northwind_small.sqlite
--------------------------------------------------------------------------------
/data/gettysburg10.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataForScience/LLM4DS/ef16502ad70ed190dccb67646fe2df0a26fe0854/data/gettysburg10.wav
--------------------------------------------------------------------------------
/data/pratchett.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataForScience/LLM4DS/ef16502ad70ed190dccb67646fe2df0a26fe0854/data/pratchett.mp3
--------------------------------------------------------------------------------
/exports/charts/temp_chart.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataForScience/LLM4DS/ef16502ad70ed190dccb67646fe2df0a26fe0854/exports/charts/temp_chart.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib==3.1.3
2 | numpy==1.18.1
3 | pandas==1.0.1
4 | watermark==2.0.2
5 |
--------------------------------------------------------------------------------
/slides/LLM4DS.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataForScience/LLM4DS/ef16502ad70ed190dccb67646fe2df0a26fe0854/slides/LLM4DS.pdf
--------------------------------------------------------------------------------