├── .gitignore
├── README.md
├── app.py
├── assets
├── overview.png
├── sample1.mp4
├── sample2.mp4
├── sample3.mp4
└── sample4.mp4
├── main.py
├── notebooks
├── Ask_about_character.ipynb
├── Vid2Desc_Gemini_1_0_Pro_Vision.ipynb
└── llm_personality.ipynb
├── requirements.txt
├── styles.css
└── vid2persona
├── gen
├── gemini.py
├── local_openllm.py
├── tgi_openllm.py
└── utils.py
├── init.py
├── pipeline
├── llm.py
└── vlm.py
├── prompts
├── llm.toml
└── vlm.toml
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Vid2Persona
2 |
3 | This project breathes life into video characters by using AI to describe their personality and then chat with you as them.
4 |
5 |
6 |

7 |
8 |
9 | ## Brainstormed workflow
10 |
11 | 1. get a person's description from the video clip using Large Multimodal Model
12 | - We choose [Get video descriptions](https://cloud.google.com/vertex-ai/generative-ai/docs/video/video-descriptions#vid-desc-rest) service from [Generative AI on Vertex AI](https://cloud.google.com/vertex-ai/generative-ai).
13 |
14 | 2. based on the description, ask Large Language Model to pretend to be the person
15 | 3. then, chatting with that personality
16 | - We choose either [Gemini API from Google AI Studio](https://ai.google.dev/) or [Gemini API from Generative AI on Vertex AI](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/gemini).
17 |
18 | The final output is the Gradio based chatting application hosted on [Hugging Face Space](https://huggingface.co/spaces).
19 |
20 | Optionally, we could leverage other open source technologies
21 | - [diffusers](https://huggingface.co/docs/diffusers/en/index) to generate images of the person in different poses or the backgrounds
22 | - [transformers](https://huggingface.co/docs/transformers/en/index) to replace closed Gemini model with open models such as [LLaMA2](https://llama.meta.com/), [Gemma](https://blog.google/technology/developers/gemma-open-models/), [Mistral](https://mistral.ai/), etc.
23 |
24 | ## Realized workflow
25 |
26 | ### Character description
27 |
28 | We obtain a description from an input video using the [Gemini Pro 1.0 API](https://ai.google.dev/). We create a custom prompt (which we brainstormed with help of ChatGPT) to provide as inputs to the API along with the video. The prompt is available in [this file](./vid2persona/prompts/vlm.toml).
29 |
30 | Refer to [this notebook](./notebooks/Ask_about_character.ipynb) for a rundown.
31 |
32 | Here is an example of how a Gemini response looks like:
33 |
34 | ```json
35 | {
36 | "characters": [
37 | {
38 | "name": "Alice",
39 | "physicalDescription": "Alice is a young woman with long, wavy brown hair and hazel eyes. She is of average height and has a slim build. Her most distinctive feature is her warm, friendly smile.",
40 | "personalityTraits": [
41 | "Alice is a kind, compassionate, and intelligent woman. She is always willing to help others and is a great listener. She is also very creative and has a great sense of humor.",
42 | ],
43 | "likes": [
44 | "Alice loves spending time with her friends and family.",
45 | "She enjoys reading, writing, and listening to music.",
46 | "She is also a big fan of traveling and exploring new places."
47 | ],
48 | "dislikes": [
49 | "Alice dislikes rudeness and cruelty.",
50 | "She also dislikes being lied to or taken advantage of.",
51 | "She is not a fan of heights or roller coasters."
52 | ],
53 | "background": [
54 | "Alice grew up in a small town in the Midwest.",
55 | "She was always a good student and excelled in her studies.",
56 | "After graduating from high school, she moved to the city to attend college.",
57 | "She is currently working as a social worker."
58 | ],
59 | "goals": [
60 | "Alice wants to make a difference in the world.",
61 | "She hopes to one day open her own counseling practice.",
62 | "She also wants to travel the world and experience different cultures."
63 | ],
64 | "relationships": [
65 | "Alice is very close to her family and friends.",
66 | "She is also in a loving relationship with her partner, Ben.",
67 | "She has a good relationship with her colleagues and is well-respected by her clients."
68 | ]
69 | }
70 | ]
71 | }
72 | ```
73 |
74 | ### Chatting with the character
75 |
76 | Next, we construct a system prompt from the response above and use it as an input to a Large Language Model (LLM). This prompt is available [here](./vid2persona/prompts/llm.toml). The system prompt helps the LLM to be character-aware.
77 |
78 | Refer to [this notebook](./notebooks/llm_personality.ipynb) for a rundown.
79 |
80 | > [!NOTE]
81 | > If a video contains multiple characters, we construct the system prompt only for one.
82 |
83 | You can find all of this collated into a single pipeline in [this demo](https://huggingface.co/spaces/chansung/vid2persona). Feel free to give it a try!
84 |
85 | ## Design considerations
86 |
87 | We designed the overall pipeline like so for the following reasons:
88 |
89 | * Videos can be hard to process efficiently and captioning them requires quite a lot compute cavalry. The existing open solutions didn't meet our needs. This why we delegated this part of the pipeline to Gemini.
90 | * On the other hand, the literature around making LLMs accessible is widely popular, thanks to tools like `bitsandbytes`. For the second part of the pipeline, we wanted to provide the users the flexibility of "bring your own language model". This is also because there's an abundance of high-quality open LLMs particularly good at this task. For our project, we used [HuggingFaceH4/zephyr-7b-beta](https://huggingface.co/HuggingFaceH4/zephyr-7b-beta) because it's small (7B) and also very performant.
91 |
92 | For the scaling the second part of the pipeline, [`text-generation-inference`](https://huggingface.co/docs/text-generation-inference) is leveraged.
93 |
94 | ## Acknowledgments
95 |
96 | This is a project built during the Gemini sprint held by Google's ML Developer Programs team. We are thankful to be granted good amount of GCP credits to finish up this project.
97 |
98 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import gradio as gr
3 |
4 | from vid2persona import init
5 | from vid2persona.pipeline import vlm
6 | from vid2persona.pipeline import llm
7 |
8 | def validate_args(func):
9 | def inner_function(args):
10 | gcp_project_id, gcp_project_location = init.get_env_vars()
11 |
12 | if args.gcp_project_id is None: args.gcp_project_id = gcp_project_id
13 | if args.gcp_project_location is None: args.gcp_project_location = gcp_project_location
14 |
15 | if args.gcp_project_id is None or args.gcp_project_location is None:
16 | raise ValueError("gcp-project-id or gcp-project-location is missing")
17 |
18 | if args.hf_access_token is not None:
19 | if args.model_id not in init.ALLOWED_LLM_FOR_HF_PRO_ACCOUNTS:
20 | raise ValueError("not supported model for Hugging Face PRO account")
21 | return func(args)
22 | return inner_function
23 |
24 | async def extract_traits(video_path, gcp_project_id, gcp_project_location, prompt_tpl_path):
25 | traits = await vlm.get_traits(
26 | gcp_project_id,
27 | gcp_project_location,
28 | video_path,
29 | prompt_tpl_path
30 | )
31 | if 'characters' in traits:
32 | traits = traits['characters'][0]
33 |
34 | return [
35 | traits, [],
36 | gr.Textbox("", interactive=True),
37 | gr.Button(interactive=True),
38 | gr.Button(interactive=True),
39 | gr.Button(interactive=True)
40 | ]
41 |
42 | async def conversation(
43 | message: str, messages: list, traits: dict,
44 | prompt_tpl_path: str, model_id: str,
45 | max_input_token_length: int, max_new_tokens: int,
46 | temperature: float, top_p: float, top_k: float,
47 | repetition_penalty: float, hf_access_token: str,
48 | ):
49 | if hf_access_token == "":
50 | hf_access_token = None
51 |
52 | messages = messages + [[message, ""]]
53 | yield [messages, message, gr.Button(interactive=False), gr.Button(interactive=False)]
54 |
55 | async for partial_response in llm.chat(
56 | message, messages, traits,
57 | prompt_tpl_path, model_id,
58 | max_input_token_length, max_new_tokens,
59 | temperature, top_p, top_k,
60 | repetition_penalty, hf_token=hf_access_token
61 | ):
62 | last_message = messages[-1]
63 | last_message[1] = last_message[1] + partial_response
64 | messages[-1] = last_message
65 | yield [messages, "", gr.Button(interactive=False), gr.Button(interactive=False)]
66 |
67 | yield [messages, "", gr.Button(interactive=True), gr.Button(interactive=True)]
68 |
69 | async def regen_conversation(
70 | messages: list, traits: dict,
71 | prompt_tpl_path: str, model_id: str,
72 | max_input_token_length: int, max_new_tokens: int,
73 | temperature: float, top_p: float, top_k: float,
74 | repetition_penalty: float, hf_access_token: str,
75 | ):
76 | if len(messages) > 0:
77 | if hf_access_token == "":
78 | hf_access_token = None
79 |
80 | message = messages[-1][0]
81 | messages = messages[:-1]
82 | messages = messages + [[message, ""]]
83 | yield [messages, "", gr.Button(interactive=False), gr.Button(interactive=False)]
84 |
85 | async for partial_response in llm.chat(
86 | message, messages, traits,
87 | prompt_tpl_path, model_id,
88 | max_input_token_length, max_new_tokens,
89 | temperature, top_p, top_k,
90 | repetition_penalty, hf_token=hf_access_token
91 | ):
92 | last_message = messages[-1]
93 | last_message[1] = last_message[1] + partial_response
94 | messages[-1] = last_message
95 | yield [messages, "", gr.Button(interactive=False), gr.Button(interactive=False)]
96 |
97 | yield [messages, "", gr.Button(interactive=True), gr.Button(interactive=True)]
98 |
99 | @validate_args
100 | def main(args):
101 | with gr.Blocks(css="styles.css", theme=gr.themes.Soft()) as demo:
102 | # hidden components
103 | gcp_project_id = gr.Textbox(args.gcp_project_id, visible=False)
104 | gcp_project_location = gr.Textbox(args.gcp_project_location, visible=False)
105 | prompt_tpl_path = gr.Textbox(args.prompt_tpl_path, visible=False)
106 | hf_access_token = gr.Textbox(args.hf_access_token, visible=False)
107 |
108 | gr.Markdown("Vid2Persona", elem_classes=["md-center", "h1-font"])
109 | gr.Markdown("This project breathes life into video characters by using AI to describe their personality and then chat with you as them.")
110 |
111 | with gr.Column(elem_classes=["group"]):
112 | with gr.Row():
113 | video = gr.Video(label="upload short video clip")
114 | traits = gr.Json(label="extracted traits")
115 |
116 | with gr.Row():
117 | trait_gen = gr.Button("generate traits")
118 |
119 | with gr.Column(elem_classes=["group"]):
120 | chatbot = gr.Chatbot([], label="chatbot", elem_id="chatbot", elem_classes=["chatbot-no-label"])
121 | with gr.Row():
122 | clear = gr.Button("clear conversation", interactive=False)
123 | regen = gr.Button("regenerate the last", interactive=False)
124 | stop = gr.Button("stop", interactive=False)
125 | user_input = gr.Textbox(placeholder="ask anything", interactive=False, elem_classes=["textbox-no-label", "textbox-no-top-bottom-borders"])
126 |
127 | with gr.Accordion("parameters' control pane", open=False):
128 | model_id = gr.Dropdown(choices=init.ALLOWED_LLM_FOR_HF_PRO_ACCOUNTS, value=args.model_id, label="Model ID")
129 |
130 | with gr.Row():
131 | max_input_token_length = gr.Slider(minimum=1024, maximum=4096, value=args.max_input_token_length, label="max-input-tokens")
132 | max_new_tokens = gr.Slider(minimum=128, maximum=2048, value=args.max_new_tokens, label="max-new-tokens")
133 |
134 | with gr.Row():
135 | temperature = gr.Slider(minimum=0, maximum=2, step=0.1, value=args.temperature, label="temperature")
136 | top_p = gr.Slider(minimum=0, maximum=2, step=0.1, value=args.top_p, label="top-p")
137 | top_k = gr.Slider(minimum=0, maximum=2, step=0.1, value=args.top_k, label="top-k")
138 | repetition_penalty = gr.Slider(minimum=0, maximum=2, step=0.1, value=args.repetition_penalty, label="repetition-penalty")
139 |
140 | with gr.Row():
141 | gr.Markdown(
142 | "[](https://github.com/deep-diver/Vid2Persona) "
143 | "[](https://twitter.com/algo_diver) "
144 | "[](https://twitter.com/RisingSayak )",
145 | elem_id="bottom-md"
146 | )
147 |
148 | trait_gen.click(
149 | extract_traits,
150 | [video, gcp_project_id, gcp_project_location, prompt_tpl_path],
151 | [traits, chatbot, user_input, clear, regen, stop]
152 | )
153 |
154 | conv = user_input.submit(
155 | conversation,
156 | [
157 | user_input, chatbot, traits,
158 | prompt_tpl_path, model_id,
159 | max_input_token_length, max_new_tokens,
160 | temperature, top_p, top_k,
161 | repetition_penalty, hf_access_token
162 | ],
163 | [chatbot, user_input, clear, regen]
164 | )
165 |
166 | clear.click(
167 | lambda: [
168 | gr.Chatbot([]),
169 | gr.Button(interactive=False),
170 | gr.Button(interactive=False),
171 | ],
172 | None, [chatbot, clear, regen]
173 | )
174 |
175 | conv_regen = regen.click(
176 | regen_conversation,
177 | [
178 | chatbot, traits,
179 | prompt_tpl_path, model_id,
180 | max_input_token_length, max_new_tokens,
181 | temperature, top_p, top_k,
182 | repetition_penalty, hf_access_token
183 | ],
184 | [chatbot, user_input, clear, regen]
185 | )
186 |
187 | stop.click(
188 | None, None, None,
189 | cancels=[conv, conv_regen]
190 | )
191 |
192 | demo.launch()
193 |
194 | if __name__ == "__main__":
195 | parser = argparse.ArgumentParser()
196 | parser.add_argument('--gcp-project-id', type=str, required=True,
197 | help="The ID of your Google Cloud Platform (GCP) project. which "
198 | "you want to run Vertex AI multimodal video anlaysis with "
199 | "Gemini 1.0 Pro Vision model.")
200 |
201 | parser.add_argument('--gcp-project-location', type=str, required=True,
202 | help="The GCP region where you want to run Vertex AI multimodal "
203 | "video analysis with Gemini 1.0 Pro Vision model.")
204 |
205 | parser.add_argument('--prompt-tpl-path', type=str, default="vid2persona/prompts",
206 | help="Path to the directory containing prompt templates for the model.")
207 |
208 | parser.add_argument('--hf-access-token', type=str, default=None,
209 | help="Your Hugging Face access token (if needed for model access). "
210 | "If you don't specify this, the program will run the model on "
211 | "your local machine")
212 |
213 | parser.add_argument('--model-id', type=str, default="HuggingFaceH4/zephyr-7b-beta",
214 | help="The Hugging Face model repository fo the language model to use.")
215 |
216 | parser.add_argument('--max-input-token-length', type=int, default=4096,
217 | help="Maximum number of input tokens allowed for the model.")
218 |
219 | parser.add_argument('--max-new-tokens', type=int, default=128,
220 | help="Maximum number of input tokens allowed for the model.")
221 |
222 | parser.add_argument('--temperature', type=float, default=0.6,
223 | help="Controls the randomness/creativity of the model's output "
224 | "(higher values mean more random).")
225 |
226 | parser.add_argument('--top-p', type=float, default=0.9,
227 | help="Nucleus sampling: considers the smallest set of tokens with "
228 | "a cumulative probability of at least 'top_p'.")
229 |
230 | parser.add_argument('--top-k', type=int, default=50,
231 | help="Limits the number of tokens considered for generation at each step.")
232 |
233 | parser.add_argument('--repetition-penalty', type=float, default=1.2,
234 | help="Penalizes repeated tokens to encourage diversity in the output.")
235 |
236 | args = parser.parse_args()
237 | main(args)
238 |
--------------------------------------------------------------------------------
/assets/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deep-diver/Vid2Persona/c36c5972443a02ea6320f6e5e6c97163d028d791/assets/overview.png
--------------------------------------------------------------------------------
/assets/sample1.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deep-diver/Vid2Persona/c36c5972443a02ea6320f6e5e6c97163d028d791/assets/sample1.mp4
--------------------------------------------------------------------------------
/assets/sample2.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deep-diver/Vid2Persona/c36c5972443a02ea6320f6e5e6c97163d028d791/assets/sample2.mp4
--------------------------------------------------------------------------------
/assets/sample3.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deep-diver/Vid2Persona/c36c5972443a02ea6320f6e5e6c97163d028d791/assets/sample3.mp4
--------------------------------------------------------------------------------
/assets/sample4.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deep-diver/Vid2Persona/c36c5972443a02ea6320f6e5e6c97163d028d791/assets/sample4.mp4
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import argparse
3 | from vid2persona import init
4 | from vid2persona.pipeline import vlm
5 | from vid2persona.pipeline import llm
6 |
7 | def validate_args(func):
8 | def inner_function(args):
9 | gcp_project_id, gcp_project_location = init.get_env_vars()
10 |
11 | if args.gcp_project_id is None: args.gcp_project_id = gcp_project_id
12 | if args.gcp_project_location is None: args.gcp_project_location = gcp_project_location
13 |
14 | if args.gcp_project_id is None or args.gcp_project_location is None:
15 | raise ValueError("gcp-project-id or gcp-project-location is missing")
16 |
17 | if args.hf_access_token is not None:
18 | if args.model_id not in init.ALLOWED_LLM_FOR_HF_PRO_ACCOUNTS:
19 | raise ValueError("not supported model for Hugging Face PRO account")
20 | return func(args)
21 | return inner_function
22 |
23 | @validate_args
24 | async def workflow(args):
25 | traits = await vlm.get_traits(
26 | args.gcp_project_id,
27 | args.gcp_project_location,
28 | args.target_movie_clip,
29 | args.prompt_tpl_path
30 | )
31 | if 'characters' in traits:
32 | traits = traits['characters'][0]
33 |
34 | messages = []
35 | async for response in llm.chat(
36 | args.message, messages, traits,
37 | args.prompt_tpl_path, args.model_id,
38 | args.max_input_token_length, args.max_new_tokens,
39 | args.temperature, args.top_p, args.top_k,
40 | args.repetition_penalty, hf_token=args.hf_access_token
41 | ):
42 | print(response, end="")
43 |
44 | if __name__ == "__main__":
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument('--gcp-project-id', type=str, required=True,
47 | help="The ID of your Google Cloud Platform (GCP) project. which "
48 | "you want to run Vertex AI multimodal video anlaysis with "
49 | "Gemini 1.0 Pro Vision model.")
50 |
51 | parser.add_argument('--gcp-project-location', type=str, required=True,
52 | help="The GCP region where you want to run Vertex AI multimodal "
53 | "video analysis with Gemini 1.0 Pro Vision model.")
54 |
55 | parser.add_argument('--target-movie-clip', type=str, default="assets/sample1.mp4",
56 | help="Video file path you want to analyze and process.")
57 |
58 | parser.add_argument('--prompt-tpl-path', type=str, default="vid2persona/prompts",
59 | help="Path to the directory containing prompt templates for the model.")
60 |
61 | parser.add_argument('--hf-access-token', type=str, default=None,
62 | help="Your Hugging Face access token (if needed for model access). "
63 | "If you don't specify this, the program will run the model on "
64 | "your local machine")
65 |
66 | parser.add_argument('--model-id', type=str, default="HuggingFaceH4/zephyr-7b-beta",
67 | help="The Hugging Face model repository fo the language model to use.")
68 |
69 | parser.add_argument('--max-input-token-length', type=int, default=4096,
70 | help="Maximum number of input tokens allowed for the model.")
71 |
72 | parser.add_argument('--max-new-tokens', type=int, default=1024,
73 | help="Maximum number of input tokens allowed for the model.")
74 |
75 | parser.add_argument('--temperature', type=float, default=0.6,
76 | help="Controls the randomness/creativity of the model's output "
77 | "(higher values mean more random).")
78 |
79 | parser.add_argument('--top-p', type=float, default=0.9,
80 | help="Nucleus sampling: considers the smallest set of tokens with "
81 | "a cumulative probability of at least 'top_p'.")
82 |
83 | parser.add_argument('--top-k', type=int, default=50,
84 | help="Limits the number of tokens considered for generation at each step.")
85 |
86 | parser.add_argument('--repetition-penalty', type=float, default=1.2,
87 | help="Penalizes repeated tokens to encourage diversity in the output.")
88 |
89 | parser.add_argument('--message', type=str, default="Hello there! How are you doing?",
90 | help="The initial message to start a conversation or text generation task.")
91 |
92 | args = parser.parse_args()
93 |
94 | asyncio.run(workflow(args))
95 |
--------------------------------------------------------------------------------
/notebooks/Ask_about_character.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "view-in-github",
7 | "colab_type": "text"
8 | },
9 | "source": [
10 | "
"
11 | ]
12 | },
13 | {
14 | "cell_type": "markdown",
15 | "metadata": {
16 | "id": "ddSBgL68Zu8j"
17 | },
18 | "source": [
19 | "# Ask about Video clip with Gemini 1.0 Pro Vision on Vertex AI"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": null,
25 | "metadata": {
26 | "id": "Rx6vrvyfBDcd"
27 | },
28 | "outputs": [],
29 | "source": [
30 | "!pip install --upgrade google-cloud-aiplatform"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "source": [
36 | "from IPython.display import Markdown, display"
37 | ],
38 | "metadata": {
39 | "id": "a3csqdgWy8-Z"
40 | },
41 | "execution_count": 2,
42 | "outputs": []
43 | },
44 | {
45 | "cell_type": "markdown",
46 | "metadata": {
47 | "id": "u1IQpbbTZ60q"
48 | },
49 | "source": [
50 | "## Authentication to Vertex AI with `gcloud`"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": 3,
56 | "metadata": {
57 | "id": "Fv1mYIWcD9dB",
58 | "colab": {
59 | "base_uri": "https://localhost:8080/"
60 | },
61 | "outputId": "f1c69e89-bcc6-4c9b-ffc7-42cface19877"
62 | },
63 | "outputs": [
64 | {
65 | "output_type": "stream",
66 | "name": "stdout",
67 | "text": [
68 | "Go to the following link in your browser:\n",
69 | "\n",
70 | " https://accounts.google.com/o/oauth2/auth?response_type=code&client_id=764086051850-6qr4p6gpi6hn506pt8ejuq83di341hur.apps.googleusercontent.com&redirect_uri=https%3A%2F%2Fsdk.cloud.google.com%2Fapplicationdefaultauthcode.html&scope=openid+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fcloud-platform+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fsqlservice.login&state=vqPaFBd2sQrNAuctRbS6f0T91AXH9R&prompt=consent&token_usage=remote&access_type=offline&code_challenge=vzahrDrqQsH6lwRwBVsekZeG8vDMniJ34ZYB_zMvn-I&code_challenge_method=S256\n",
71 | "\n",
72 | "Enter authorization code: 4/0AeaYSHAeFzH5IlfDzvLcvZas1zKe_4MH79mfJ5q8rbkXxveeYDkdpTOzD1p6Xd9skYb7Lg\n",
73 | "\n",
74 | "Credentials saved to file: [/content/.config/application_default_credentials.json]\n",
75 | "\n",
76 | "These credentials will be used by any library that requests Application Default Credentials (ADC).\n",
77 | "\u001b[1;33mWARNING:\u001b[0m \n",
78 | "Cannot find a quota project to add to ADC. You might receive a \"quota exceeded\" or \"API not enabled\" error. Run $ gcloud auth application-default set-quota-project to add a quota project.\n"
79 | ]
80 | }
81 | ],
82 | "source": [
83 | "!gcloud auth application-default login\n",
84 | "\n",
85 | "# or do the same thing without interrupting prompt\n",
86 | "#\n",
87 | "# export GOOGLE_APPLICATION_CREDENTIALS=\"/path/to/your/service_account_key.json\"\n",
88 | "# gcloud auth application-default login --client-id-file=/path/to/your/service_account_key.json"
89 | ]
90 | },
91 | {
92 | "cell_type": "markdown",
93 | "metadata": {
94 | "id": "Ai1E5s7XaC04"
95 | },
96 | "source": [
97 | "## Setup GCP Project and Location"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": 4,
103 | "metadata": {
104 | "id": "N5fpTuQhCnN5"
105 | },
106 | "outputs": [],
107 | "source": [
108 | "GCP_PROJECT_ID=\"gde-prj\"\n",
109 | "GCP_PROJECT_LOCATION=\"us-central1\""
110 | ]
111 | },
112 | {
113 | "cell_type": "markdown",
114 | "metadata": {
115 | "id": "JfTetkgpcQ5n"
116 | },
117 | "source": [
118 | "## Define Gemini call function"
119 | ]
120 | },
121 | {
122 | "cell_type": "code",
123 | "execution_count": 5,
124 | "metadata": {
125 | "id": "-oCtUjGDBVYr"
126 | },
127 | "outputs": [],
128 | "source": [
129 | "import base64\n",
130 | "import vertexai\n",
131 | "from vertexai.generative_models import GenerativeModel, Part, GenerationResponse, GenerationConfig\n",
132 | "\n",
133 | "def initi_vertexai(project_id: str, location: str) -> None:\n",
134 | " vertexai.init(project=project_id, location=location)\n",
135 | "\n",
136 | "def ask_gemini(\n",
137 | " prompt: str=None, gcs: str=None, base64_encoded: bytes=None, stream: bool=False, generation_config: dict=None\n",
138 | ") -> GenerationResponse:\n",
139 | " if gcs is None and base64_encoded is None:\n",
140 | " raise ValueError(\"Either a GCS bucket path or base64_encoded string of the video must be provided\")\n",
141 | "\n",
142 | " if gcs is not None and base64_encoded is not None:\n",
143 | " raise ValueError(\"Only one of gcs or base64_encoded must be provided\")\n",
144 | "\n",
145 | " if gcs is not None:\n",
146 | " video = Part.from_uri(gcs, mime_type=\"video/mp4\")\n",
147 | " else:\n",
148 | " video = Part.from_data(data=base64.b64decode(base64_encoded), mime_type=\"video/mp4\")\n",
149 | "\n",
150 | " if prompt is None:\n",
151 | " prompt = \"What is in the video?\"\n",
152 | "\n",
153 | " if generation_config is None:\n",
154 | " generation_config = GenerationConfig(\n",
155 | " max_output_tokens=2048,\n",
156 | " temperature=0.4,\n",
157 | " top_p=1,\n",
158 | " top_k=32\n",
159 | " )\n",
160 | "\n",
161 | " vision_model = GenerativeModel(\"gemini-1.0-pro-vision\")\n",
162 | " return vision_model.generate_content(\n",
163 | " [video, prompt],\n",
164 | " generation_config=generation_config, stream=stream\n",
165 | " )"
166 | ]
167 | },
168 | {
169 | "cell_type": "markdown",
170 | "source": [
171 | "## Define base64 encoding function"
172 | ],
173 | "metadata": {
174 | "id": "QJJz4m7-zIrc"
175 | }
176 | },
177 | {
178 | "cell_type": "code",
179 | "source": [
180 | "def get_base64_encode(file_path):\n",
181 | " with open(file_path, 'rb') as f:\n",
182 | " data = f.read()\n",
183 | "\n",
184 | " return base64.b64encode(data)"
185 | ],
186 | "metadata": {
187 | "id": "Gt2AMKHQsE8S"
188 | },
189 | "execution_count": 8,
190 | "outputs": []
191 | },
192 | {
193 | "cell_type": "code",
194 | "source": [
195 | "!git clone https://github.com/deep-diver/Vid2Persona.git\n",
196 | "!mv Vid2Persona/assets/*.mp4 ."
197 | ],
198 | "metadata": {
199 | "id": "RjEQnt3t6t8B",
200 | "outputId": "3ed74169-79fc-4a83-8687-8f937c603d33",
201 | "colab": {
202 | "base_uri": "https://localhost:8080/"
203 | }
204 | },
205 | "execution_count": 6,
206 | "outputs": [
207 | {
208 | "output_type": "stream",
209 | "name": "stdout",
210 | "text": [
211 | "Cloning into 'Vid2Persona'...\n",
212 | "remote: Enumerating objects: 42, done.\u001b[K\n",
213 | "remote: Counting objects: 100% (2/2), done.\u001b[K\n",
214 | "remote: Compressing objects: 100% (2/2), done.\u001b[K\n",
215 | "remote: Total 42 (delta 0), reused 0 (delta 0), pack-reused 40\u001b[K\n",
216 | "Receiving objects: 100% (42/42), 62.02 MiB | 36.39 MiB/s, done.\n",
217 | "Resolving deltas: 100% (8/8), done.\n"
218 | ]
219 | }
220 | ]
221 | },
222 | {
223 | "cell_type": "code",
224 | "source": [
225 | "sample1 = get_base64_encode(\"sample1.mp4\")\n",
226 | "sample2 = get_base64_encode(\"sample2.mp4\")\n",
227 | "sample3 = get_base64_encode(\"sample3.mp4\")\n",
228 | "sample4 = get_base64_encode(\"sample4.mp4\")"
229 | ],
230 | "metadata": {
231 | "id": "wsVTq1bdsQmo"
232 | },
233 | "execution_count": 9,
234 | "outputs": []
235 | },
236 | {
237 | "cell_type": "markdown",
238 | "source": [
239 | "## Define common prompt"
240 | ],
241 | "metadata": {
242 | "id": "A5PJ_SIMzPdo"
243 | }
244 | },
245 | {
246 | "cell_type": "code",
247 | "source": [
248 | "prompt = \"\"\"\n",
249 | "arefully analyze the provided video clip to identify and extract detailed information about the main character(s) featured. Pay attention to visual elements, spoken dialogue, character interactions, and any narrative cues that reveal aspects of the character's personality, physical appearance, behaviors, and background.\n",
250 | "\n",
251 | "Your task is to construct a rich, imaginative character profile based on your observations, and where explicit information is not available, you are encouraged to use your creativity to fill in the gaps. The goal is to create a vivid, believable character profile that can be used to simulate conversation with a language model as if it were the character itself.\n",
252 | "\n",
253 | "Format the extracted data as a structured JSON object containing the following fields for each main character:\n",
254 | "\n",
255 | "name: The character's name as mentioned or inferred in the video. If not provided, create a suitable name that matches the character's traits and context.\n",
256 | "physicalDescription: Describe the character's appearance, including hair color, eye color, height, and distinctive features. Use imaginative details if necessary to provide a complete picture.\n",
257 | "personalityTraits: List descriptive adjectives or phrases that capture the character's personality, based on their actions and dialogue. Invent traits as needed to ensure a well-rounded personality.\n",
258 | "likes: Specify things, activities, or concepts the character enjoys or values, deduced or imagined from their behavior and interactions.\n",
259 | "dislikes: Note what the character appears to dislike or avoid, filling in creatively where direct evidence is not available.\n",
260 | "background: Provide background information such as occupation, family ties, or significant life events, inferring where possible or inventing details to add depth to the character's story.\n",
261 | "goals: Describe the character's apparent motivations and objectives, whether explicitly stated or implied. Where not directly observable, construct plausible goals that align with the character's portrayed or inferred traits.\n",
262 | "relationships: Detail the character's relationships with other characters, including the nature of each relationship and the names of other characters involved. Use creative license to elaborate on these relationships if the video provides limited information.\n",
263 | "Ensure the JSON object is well-structured and comprehensive, ready for integration with a language model to facilitate engaging conversations as if with the character itself. For multiple main characters, provide a distinct profile for each within the same JSON object.\n",
264 | "\"\"\""
265 | ],
266 | "metadata": {
267 | "id": "MDNxhnORtPXu"
268 | },
269 | "execution_count": 10,
270 | "outputs": []
271 | },
272 | {
273 | "cell_type": "markdown",
274 | "source": [
275 | "## Let's ask!"
276 | ],
277 | "metadata": {
278 | "id": "lacOyE4izTJO"
279 | }
280 | },
281 | {
282 | "cell_type": "markdown",
283 | "source": [
284 | "### on Sample1"
285 | ],
286 | "metadata": {
287 | "id": "EFryD2hFsnCz"
288 | }
289 | },
290 | {
291 | "cell_type": "code",
292 | "source": [
293 | "initi_vertexai(GCP_PROJECT_ID, GCP_PROJECT_LOCATION)\n",
294 | "try:\n",
295 | " response = ask_gemini(\n",
296 | " prompt=prompt,\n",
297 | " base64_encoded=sample1\n",
298 | " )\n",
299 | " display(Markdown(response.text))\n",
300 | "except Exception as e:\n",
301 | " print(f\"something went wrong {e}\")"
302 | ],
303 | "metadata": {
304 | "colab": {
305 | "base_uri": "https://localhost:8080/",
306 | "height": 714
307 | },
308 | "id": "T-0Oewvzsd5z",
309 | "outputId": "08f9123a-a5db-4ba2-cf0b-ffc39b9ef525"
310 | },
311 | "execution_count": 11,
312 | "outputs": [
313 | {
314 | "output_type": "stream",
315 | "name": "stderr",
316 | "text": [
317 | "/usr/local/lib/python3.10/dist-packages/google/auth/_default.py:76: UserWarning: Your application has authenticated using end user credentials from Google Cloud SDK without a quota project. You might receive a \"quota exceeded\" or \"API not enabled\" error. See the following page for troubleshooting: https://cloud.google.com/docs/authentication/adc-troubleshooting/user-creds. \n",
318 | " warnings.warn(_CLOUD_SDK_CREDENTIALS_WARNING)\n"
319 | ]
320 | },
321 | {
322 | "output_type": "display_data",
323 | "data": {
324 | "text/plain": [
325 | ""
326 | ],
327 | "text/markdown": " ```json\n{\n \"characters\": [\n {\n \"name\": \"Alice\",\n \"physicalDescription\": \"Alice is a young woman with long, wavy brown hair and hazel eyes. She is of average height and has a slim build. Her most distinctive feature is her warm, friendly smile.\",\n \"personalityTraits\": [\n \"Alice is a kind, compassionate, and intelligent woman. She is always willing to help others and is a great listener. She is also very creative and has a great sense of humor.\",\n ],\n \"likes\": [\n \"Alice loves spending time with her friends and family.\",\n \"She enjoys reading, writing, and listening to music.\",\n \"She is also a big fan of traveling and exploring new places.\"\n ],\n \"dislikes\": [\n \"Alice dislikes rudeness and cruelty.\",\n \"She also dislikes being lied to or taken advantage of.\",\n \"She is not a fan of heights or roller coasters.\"\n ],\n \"background\": [\n \"Alice grew up in a small town in the Midwest.\",\n \"She was always a good student and excelled in her studies.\",\n \"After graduating from high school, she moved to the city to attend college.\",\n \"She is currently working as a social worker.\"\n ],\n \"goals\": [\n \"Alice wants to make a difference in the world.\",\n \"She hopes to one day open her own counseling practice.\",\n \"She also wants to travel the world and experience different cultures.\"\n ],\n \"relationships\": [\n \"Alice is very close to her family and friends.\",\n \"She is also in a loving relationship with her partner, Ben.\",\n \"She has a good relationship with her colleagues and is well-respected by her clients.\"\n ]\n }\n ]\n}\n```"
328 | },
329 | "metadata": {}
330 | }
331 | ]
332 | },
333 | {
334 | "cell_type": "markdown",
335 | "source": [
336 | "### on Sample2"
337 | ],
338 | "metadata": {
339 | "id": "3oZQzFsazXPJ"
340 | }
341 | },
342 | {
343 | "cell_type": "code",
344 | "source": [
345 | "initi_vertexai(GCP_PROJECT_ID, GCP_PROJECT_LOCATION)\n",
346 | "try:\n",
347 | " response = ask_gemini(\n",
348 | " prompt=prompt,\n",
349 | " base64_encoded=sample2\n",
350 | " )\n",
351 | " display(Markdown(response.text))\n",
352 | "except Exception as e:\n",
353 | " print(f\"something went wrong {e}\")"
354 | ],
355 | "metadata": {
356 | "colab": {
357 | "base_uri": "https://localhost:8080/",
358 | "height": 506
359 | },
360 | "id": "ZsuGWPcmsq0r",
361 | "outputId": "2f6dc7a6-92b5-44a8-e3c1-2c7e4464639c"
362 | },
363 | "execution_count": 12,
364 | "outputs": [
365 | {
366 | "output_type": "display_data",
367 | "data": {
368 | "text/plain": [
369 | ""
370 | ],
371 | "text/markdown": " ```json\n{\n \"name\": \"Little Furry\",\n \"physicalDescription\": \"Little Furry is a small, furry creature with big, round eyes and a long, bushy tail. Its fur is a light brown color, and it has a white belly. It has two small horns on its head and a pair of wings on its back, but it is still too young to fly.\",\n \"personalityTraits\": [\"Curious\", \"Playful\", \"Mischievous\", \"Loyal\", \"Protective\"],\n \"likes\": [\"Playing with candles\", \"Exploring the forest\", \"Making new friends\", \"Helping others\"],\n \"dislikes\": [\"Being alone\", \"Darkness\", \"Loud noises\", \"Being told what to do\"],\n \"background\": \"Little Furry is a young creature who lives in the forest with its family. It is still learning about the world and loves to explore and play. It is very curious and loves to learn new things.\",\n \"goals\": [\"To make new friends\", \"To learn about the world\", \"To help others\", \"To have fun\"],\n \"relationships\": [\n {\n \"name\": \"Mother Furry\",\n \"relation\": \"Little Furry's mother\"\n },\n {\n \"name\": \"Father Furry\",\n \"relation\": \"Little Furry's father\"\n },\n {\n \"name\": \"Big Furry\",\n \"relation\": \"Little Furry's older sibling\"\n },\n {\n \"name\": \"Little Furry's Friends\",\n \"relation\": \"Little Furry's friends\"\n }\n ]\n}\n```"
372 | },
373 | "metadata": {}
374 | }
375 | ]
376 | },
377 | {
378 | "cell_type": "markdown",
379 | "source": [
380 | "### on Sample3"
381 | ],
382 | "metadata": {
383 | "id": "kXb4g4hMzYSb"
384 | }
385 | },
386 | {
387 | "cell_type": "code",
388 | "source": [
389 | "initi_vertexai(GCP_PROJECT_ID, GCP_PROJECT_LOCATION)\n",
390 | "try:\n",
391 | " response = ask_gemini(\n",
392 | " prompt=prompt,\n",
393 | " base64_encoded=sample3\n",
394 | " )\n",
395 | " display(Markdown(response.text))\n",
396 | "except Exception as e:\n",
397 | " print(f\"something went wrong {e}\")"
398 | ],
399 | "metadata": {
400 | "colab": {
401 | "base_uri": "https://localhost:8080/",
402 | "height": 975
403 | },
404 | "id": "8bZ-xfinvGpZ",
405 | "outputId": "9fc416ba-1188-461a-9c0d-d412386b5061"
406 | },
407 | "execution_count": 13,
408 | "outputs": [
409 | {
410 | "output_type": "display_data",
411 | "data": {
412 | "text/plain": [
413 | ""
414 | ],
415 | "text/markdown": " ```json\n{\n \"characters\": [\n {\n \"name\": \"Jane Doe\",\n \"physicalDescription\": \"Jane is a young woman in her early 20s, with long, dark hair and piercing blue eyes. She is of average height and has a slim build. Her most distinctive feature is her warm, friendly smile.\",\n \"personalityTraits\": [\n \"Confident\",\n \"Optimistic\",\n \"Independent\",\n \"Curious\",\n \"Determined\"\n ],\n \"likes\": [\n \"Exploring new places\",\n \"Learning new things\",\n \"Spending time with friends and family\",\n \"Helping others\",\n \"Making a difference in the world\"\n ],\n \"dislikes\": [\n \"Injustice\",\n \"Cruelty\",\n \"Ignorance\",\n \"Laziness\",\n \"Negativity\"\n ],\n \"background\": \"Jane grew up in a small town in the Midwest. She was always a good student and excelled in her studies. After graduating from high school, she moved to the big city to attend university. She is currently working as a journalist for a local newspaper.\",\n \"goals\": [\n \"To become a successful journalist\",\n \"To make a difference in the world\",\n \"To help others\",\n \"To learn new things\",\n \"To grow as a person\"\n ],\n \"relationships\": [\n {\n \"name\": \"John Smith\",\n \"nature\": \"Jane's boyfriend\",\n \"description\": \"John is a kind and supportive boyfriend. He is always there for Jane and helps her through tough times.\"\n },\n {\n \"name\": \"Mary Johnson\",\n \"nature\": \"Jane's best friend\",\n \"description\": \"Mary is Jane's best friend. They have been friends since childhood and share everything with each other.\"\n },\n {\n \"name\": \"Jane's parents\",\n \"nature\": \"Jane's parents\",\n \"description\": \"Jane's parents are loving and supportive. They have always been there for Jane and encouraged her to follow her dreams.\"\n }\n ]\n }\n ]\n}\n```"
416 | },
417 | "metadata": {}
418 | }
419 | ]
420 | },
421 | {
422 | "cell_type": "markdown",
423 | "source": [
424 | "### on Sample4"
425 | ],
426 | "metadata": {
427 | "id": "TI9xWmR2zZVd"
428 | }
429 | },
430 | {
431 | "cell_type": "code",
432 | "source": [
433 | "initi_vertexai(GCP_PROJECT_ID, GCP_PROJECT_LOCATION)\n",
434 | "try:\n",
435 | " response = ask_gemini(\n",
436 | " prompt=prompt,\n",
437 | " base64_encoded=sample4\n",
438 | " )\n",
439 | " display(Markdown(response.text))\n",
440 | "except Exception as e:\n",
441 | " print(f\"something went wrong {e}\")"
442 | ],
443 | "metadata": {
444 | "colab": {
445 | "base_uri": "https://localhost:8080/",
446 | "height": 766
447 | },
448 | "id": "z68Q0nxzvQGy",
449 | "outputId": "e6a7f566-9ae0-41ef-c66c-06c7cd7b250a"
450 | },
451 | "execution_count": 14,
452 | "outputs": [
453 | {
454 | "output_type": "display_data",
455 | "data": {
456 | "text/plain": [
457 | ""
458 | ],
459 | "text/markdown": " ```json\n{\n \"characters\": [\n {\n \"name\": \"Jean-Pierre\",\n \"physicalDescription\": \"Jean-Pierre is a tall, slender man with silver hair and a thick beard. He wears glasses and has a warm, friendly smile. He is usually seen wearing a brown beret and a tweed jacket.\",\n \"personalityTraits\": [\n \"Jean-Pierre is a kind and compassionate man.\",\n \"He is also very intelligent and well-read.\",\n \"He is a bit of a loner, but he enjoys spending time with his friends and family.\",\n \"He is always willing to help others, and he is always looking for ways to make the world a better place.\"\n ],\n \"likes\": [\n \"Jean-Pierre loves to read, especially history and philosophy.\",\n \"He also enjoys spending time in nature, and he is an avid birdwatcher.\",\n \"He is a big fan of classical music, and he often goes to concerts.\"\n ],\n \"dislikes\": [\n \"Jean-Pierre dislikes cruelty and injustice.\",\n \"He also dislikes loud noises and crowds.\",\n \"He is not a fan of modern technology, and he prefers to live a simple life.\"\n ],\n \"background\": [\n \"Jean-Pierre was born in Paris, France, in 1950.\",\n \"He grew up in a large family, and he was the youngest of five children.\",\n \"His father was a professor, and his mother was a stay-at-home mom.\",\n \"Jean-Pierre was a good student, and he went on to study at the Sorbonne.\",\n \"After graduating, he worked as a teacher for several years.\",\n \"He then decided to pursue his passion for writing, and he became a full-time writer.\"\n ],\n \"goals\": [\n \"Jean-Pierre's goal is to write books that make a difference in the world.\",\n \"He wants to write books that inspire people to think and to act.\",\n \"He also wants to write books that make people laugh and cry.\"\n ],\n \"relationships\": [\n \"Jean-Pierre is married to a woman named Marie.\",\n \"They have two children, a son named Paul and a daughter named Sophie.\",\n \"Jean-Pierre is very close to his family, and he loves spending time with them.\"\n ]\n }\n ]\n}\n```"
460 | },
461 | "metadata": {}
462 | }
463 | ]
464 | },
465 | {
466 | "cell_type": "code",
467 | "source": [],
468 | "metadata": {
469 | "id": "cWMUVGqcvlSE"
470 | },
471 | "execution_count": 14,
472 | "outputs": []
473 | }
474 | ],
475 | "metadata": {
476 | "colab": {
477 | "provenance": [],
478 | "include_colab_link": true
479 | },
480 | "kernelspec": {
481 | "display_name": "Python 3",
482 | "name": "python3"
483 | },
484 | "language_info": {
485 | "name": "python"
486 | }
487 | },
488 | "nbformat": 4,
489 | "nbformat_minor": 0
490 | }
--------------------------------------------------------------------------------
/notebooks/Vid2Desc_Gemini_1_0_Pro_Vision.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "view-in-github",
7 | "colab_type": "text"
8 | },
9 | "source": [
10 | "
"
11 | ]
12 | },
13 | {
14 | "cell_type": "markdown",
15 | "metadata": {
16 | "id": "ddSBgL68Zu8j"
17 | },
18 | "source": [
19 | "# Ask about Video clip with Gemini 1.0 Pro Vision on Vertex AI"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": null,
25 | "metadata": {
26 | "id": "Rx6vrvyfBDcd"
27 | },
28 | "outputs": [],
29 | "source": [
30 | "!pip install --upgrade google-cloud-aiplatform"
31 | ]
32 | },
33 | {
34 | "cell_type": "markdown",
35 | "metadata": {
36 | "id": "u1IQpbbTZ60q"
37 | },
38 | "source": [
39 | "## Authentication to Vertex AI with `gcloud`"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": 2,
45 | "metadata": {
46 | "id": "Fv1mYIWcD9dB",
47 | "colab": {
48 | "base_uri": "https://localhost:8080/"
49 | },
50 | "outputId": "928fcb61-cd0a-4a45-d8e1-22054b8c007b"
51 | },
52 | "outputs": [
53 | {
54 | "output_type": "stream",
55 | "name": "stdout",
56 | "text": [
57 | "Go to the following link in your browser:\n",
58 | "\n",
59 | " https://accounts.google.com/o/oauth2/auth?response_type=code&client_id=764086051850-6qr4p6gpi6hn506pt8ejuq83di341hur.apps.googleusercontent.com&redirect_uri=https%3A%2F%2Fsdk.cloud.google.com%2Fapplicationdefaultauthcode.html&scope=openid+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fcloud-platform+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fsqlservice.login&state=y5WPnVc0yhvIlt4n0fu87xzoPybWki&prompt=consent&token_usage=remote&access_type=offline&code_challenge=KB3CebOE1LzZOb0jHlY6xfMLvFYdyrb7Uv2nAnD9NEA&code_challenge_method=S256\n",
60 | "\n",
61 | "Enter authorization code: 4/0AeaYSHAZqCL-B3-yRlALC9TU5yq9EFFZyivdJ6laefam31EDueOcnfyNuUsBx8zrInUbgA\n",
62 | "\n",
63 | "Credentials saved to file: [/content/.config/application_default_credentials.json]\n",
64 | "\n",
65 | "These credentials will be used by any library that requests Application Default Credentials (ADC).\n",
66 | "\u001b[1;33mWARNING:\u001b[0m \n",
67 | "Cannot find a quota project to add to ADC. You might receive a \"quota exceeded\" or \"API not enabled\" error. Run $ gcloud auth application-default set-quota-project to add a quota project.\n"
68 | ]
69 | }
70 | ],
71 | "source": [
72 | "!gcloud auth application-default login\n",
73 | "\n",
74 | "# or do the same thing without interrupting prompt\n",
75 | "#\n",
76 | "# export GOOGLE_APPLICATION_CREDENTIALS=\"/path/to/your/service_account_key.json\"\n",
77 | "# gcloud auth application-default login --client-id-file=/path/to/your/service_account_key.json"
78 | ]
79 | },
80 | {
81 | "cell_type": "markdown",
82 | "metadata": {
83 | "id": "Ai1E5s7XaC04"
84 | },
85 | "source": [
86 | "## Setup GCP Project and Location"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": 3,
92 | "metadata": {
93 | "id": "N5fpTuQhCnN5"
94 | },
95 | "outputs": [],
96 | "source": [
97 | "GCP_PROJECT_ID=\"gde-prj\"\n",
98 | "GCP_PROJECT_LOCATION=\"us-central1\""
99 | ]
100 | },
101 | {
102 | "cell_type": "markdown",
103 | "metadata": {
104 | "id": "nakhkiwYaIXz"
105 | },
106 | "source": [
107 | "## Call Gemini 1.0 Pro Vision"
108 | ]
109 | },
110 | {
111 | "cell_type": "markdown",
112 | "metadata": {
113 | "id": "JfTetkgpcQ5n"
114 | },
115 | "source": [
116 | "### Define general function"
117 | ]
118 | },
119 | {
120 | "cell_type": "code",
121 | "execution_count": 15,
122 | "metadata": {
123 | "id": "-oCtUjGDBVYr"
124 | },
125 | "outputs": [],
126 | "source": [
127 | "import base64\n",
128 | "import vertexai\n",
129 | "from vertexai.generative_models import GenerativeModel, Part, GenerationResponse, GenerationConfig\n",
130 | "\n",
131 | "def initi_vertexai(project_id: str, location: str) -> None:\n",
132 | " vertexai.init(project=project_id, location=location)\n",
133 | "\n",
134 | "def ask_gemini(\n",
135 | " prompt: str=None, gcs: str=None, base64_encoded: bytes=None, stream: bool=False, generation_config: dict=None\n",
136 | ") -> GenerationResponse:\n",
137 | " if gcs is None and base64_encoded is None:\n",
138 | " raise ValueError(\"Either a GCS bucket path or base64_encoded string of the video must be provided\")\n",
139 | "\n",
140 | " if gcs is not None and base64_encoded is not None:\n",
141 | " raise ValueError(\"Only one of gcs or base64_encoded must be provided\")\n",
142 | "\n",
143 | " if gcs is not None:\n",
144 | " video = Part.from_uri(gcs, mime_type=\"video/mp4\")\n",
145 | " else:\n",
146 | " video = Part.from_data(data=base64.b64decode(base64_encoded), mime_type=\"video/mp4\")\n",
147 | "\n",
148 | " if prompt is None:\n",
149 | " prompt = \"What is in the video?\"\n",
150 | "\n",
151 | " if generation_config is None:\n",
152 | " generation_config = GenerationConfig(\n",
153 | " max_output_tokens=2048,\n",
154 | " temperature=0.4,\n",
155 | " top_p=1,\n",
156 | " top_k=32\n",
157 | " )\n",
158 | "\n",
159 | " vision_model = GenerativeModel(\"gemini-1.0-pro-vision\")\n",
160 | " return vision_model.generate_content(\n",
161 | " [video, prompt],\n",
162 | " generation_config=generation_config, stream=stream\n",
163 | " )"
164 | ]
165 | },
166 | {
167 | "cell_type": "markdown",
168 | "metadata": {
169 | "id": "RVl-dI-Qc5eG"
170 | },
171 | "source": [
172 | "### Ask about video on GCS with non-streamining mode"
173 | ]
174 | },
175 | {
176 | "cell_type": "code",
177 | "execution_count": 9,
178 | "metadata": {
179 | "id": "6T8cBE6-CH15",
180 | "colab": {
181 | "base_uri": "https://localhost:8080/"
182 | },
183 | "outputId": "3b4d58be-e2e7-4274-c38f-c9cdddc892f2"
184 | },
185 | "outputs": [
186 | {
187 | "output_type": "stream",
188 | "name": "stderr",
189 | "text": [
190 | "/usr/local/lib/python3.10/dist-packages/google/auth/_default.py:76: UserWarning: Your application has authenticated using end user credentials from Google Cloud SDK without a quota project. You might receive a \"quota exceeded\" or \"API not enabled\" error. See the following page for troubleshooting: https://cloud.google.com/docs/authentication/adc-troubleshooting/user-creds. \n",
191 | " warnings.warn(_CLOUD_SDK_CREDENTIALS_WARNING)\n"
192 | ]
193 | }
194 | ],
195 | "source": [
196 | "initi_vertexai(GCP_PROJECT_ID, GCP_PROJECT_LOCATION)\n",
197 | "try:\n",
198 | " response = ask_gemini(gcs=\"gs://cloud-samples-data/video/animals.mp4\")\n",
199 | "except Exception as e:\n",
200 | " print(f\"something went wrong {e}\")"
201 | ]
202 | },
203 | {
204 | "cell_type": "code",
205 | "execution_count": 10,
206 | "metadata": {
207 | "colab": {
208 | "base_uri": "https://localhost:8080/"
209 | },
210 | "id": "4hbpvEoocVMB",
211 | "outputId": "35f2edad-223b-48ee-ba40-14cc062b24fc"
212 | },
213 | "outputs": [
214 | {
215 | "output_type": "stream",
216 | "name": "stdout",
217 | "text": [
218 | " The video is an advertisement for the movie Zootopia. It features a sloth, a fox, and a rabbit taking selfies with a Google Pixel phone. The ad highlights the phone's camera quality and its ability to take great photos even in low-light conditions. The video ends with the三個動物 taking a selfie together.\n"
219 | ]
220 | }
221 | ],
222 | "source": [
223 | "print(response.text)"
224 | ]
225 | },
226 | {
227 | "cell_type": "markdown",
228 | "metadata": {
229 | "id": "eRskB88ydIOm"
230 | },
231 | "source": [
232 | "### Ask about video on GCS with streamining mode"
233 | ]
234 | },
235 | {
236 | "cell_type": "code",
237 | "execution_count": null,
238 | "metadata": {
239 | "id": "rb7mjBMYdIr8"
240 | },
241 | "outputs": [],
242 | "source": [
243 | "initi_vertexai(GCP_PROJECT_ID, GCP_PROJECT_LOCATION)\n",
244 | "try:\n",
245 | " response = ask_gemini(gcs=\"gs://cloud-samples-data/video/animals.mp4\", stream=True)\n",
246 | "except:\n",
247 | " print(\"something went wrong\")"
248 | ]
249 | },
250 | {
251 | "cell_type": "code",
252 | "execution_count": null,
253 | "metadata": {
254 | "colab": {
255 | "base_uri": "https://localhost:8080/"
256 | },
257 | "id": "7J2xWr_7FZrP",
258 | "outputId": "91623edd-8420-43aa-dd09-cc4e9fd56295"
259 | },
260 | "outputs": [
261 | {
262 | "name": "stdout",
263 | "output_type": "stream",
264 | "text": [
265 | " It is a commercial for the movie Zootopia. It shows a sloth, a fox, and a rabbit in a city. It also shows a tiger,\n",
266 | "\n",
267 | " an elephant, and a seal. The animals are taking pictures of each other. The commercial is funny because it shows the animals doing human things.\n",
268 | "\n"
269 | ]
270 | }
271 | ],
272 | "source": [
273 | "for response_piece in response:\n",
274 | " print(response_piece.text)\n",
275 | " print()"
276 | ]
277 | },
278 | {
279 | "cell_type": "markdown",
280 | "metadata": {
281 | "id": "cwpowCPQdavV"
282 | },
283 | "source": [
284 | "### Ask about based64 encoded video with non-streamining mode"
285 | ]
286 | },
287 | {
288 | "cell_type": "code",
289 | "execution_count": 11,
290 | "metadata": {
291 | "colab": {
292 | "base_uri": "https://localhost:8080/"
293 | },
294 | "id": "B-GUyCKlGDki",
295 | "outputId": "2ba61152-bad5-46fa-b24c-7384d9158e39"
296 | },
297 | "outputs": [
298 | {
299 | "output_type": "stream",
300 | "name": "stdout",
301 | "text": [
302 | "Copying gs://cloud-samples-data/video/animals.mp4...\n",
303 | "\\ [1 files][ 16.1 MiB/ 16.1 MiB] \n",
304 | "Operation completed over 1 objects/16.1 MiB. \n"
305 | ]
306 | }
307 | ],
308 | "source": [
309 | "!gsutil cp gs://cloud-samples-data/video/animals.mp4 ./"
310 | ]
311 | },
312 | {
313 | "cell_type": "code",
314 | "execution_count": 12,
315 | "metadata": {
316 | "id": "qTxsGOUjGDiq"
317 | },
318 | "outputs": [],
319 | "source": [
320 | "import base64\n",
321 | "\n",
322 | "with open(\"animals.mp4\", \"rb\") as video_file:\n",
323 | " video_data = video_file.read()\n",
324 | "\n",
325 | "encoded_string = base64.b64encode(video_data)"
326 | ]
327 | },
328 | {
329 | "cell_type": "code",
330 | "execution_count": 16,
331 | "metadata": {
332 | "colab": {
333 | "base_uri": "https://localhost:8080/"
334 | },
335 | "id": "PO7WgZ0bG0Vh",
336 | "outputId": "d7730920-6437-4416-a1b0-4a6a9d7c79a6"
337 | },
338 | "outputs": [
339 | {
340 | "output_type": "stream",
341 | "name": "stdout",
342 | "text": [
343 | " It is a commercial for the movie Zootopia. The commercial features a sloth, a fox, and a rabbit. The commercial is about how Google Photos can help you take better pictures of your pets.\n"
344 | ]
345 | }
346 | ],
347 | "source": [
348 | "initi_vertexai(GCP_PROJECT_ID, GCP_PROJECT_LOCATION)\n",
349 | "try:\n",
350 | " response = ask_gemini(base64_encoded=encoded_string)\n",
351 | " print(response.text)\n",
352 | "except Exception as e:\n",
353 | " print(f\"something went wrong {e}\")"
354 | ]
355 | },
356 | {
357 | "cell_type": "markdown",
358 | "metadata": {
359 | "id": "9Wetio-1d4yP"
360 | },
361 | "source": [
362 | "### Ask about based64 encoded video with streamining mode"
363 | ]
364 | },
365 | {
366 | "cell_type": "code",
367 | "execution_count": null,
368 | "metadata": {
369 | "colab": {
370 | "base_uri": "https://localhost:8080/"
371 | },
372 | "id": "9xJ1xPVAdvh4",
373 | "outputId": "8f555e0d-57e0-41b3-f69d-206d0084e6d3"
374 | },
375 | "outputs": [
376 | {
377 | "name": "stdout",
378 | "output_type": "stream",
379 | "text": [
380 | " This is a commercial for the movie Zootopia. It features a sloth, a fox, and a rabbit taking selfies at the Los Angeles Zoo. The commercial\n",
381 | "\n",
382 | " was released in 2016.\n",
383 | "\n"
384 | ]
385 | }
386 | ],
387 | "source": [
388 | "initi_vertexai(GCP_PROJECT_ID, GCP_PROJECT_LOCATION)\n",
389 | "try:\n",
390 | " response = ask_gemini(base64_encoded=encoded_string, stream=True)\n",
391 | "except:\n",
392 | " print(\"something went wrong\")\n",
393 | "\n",
394 | "for response_piece in response:\n",
395 | " print(response_piece.text)\n",
396 | " print()"
397 | ]
398 | },
399 | {
400 | "cell_type": "code",
401 | "execution_count": null,
402 | "metadata": {
403 | "id": "liEmsS4teEkE"
404 | },
405 | "outputs": [],
406 | "source": []
407 | }
408 | ],
409 | "metadata": {
410 | "colab": {
411 | "provenance": [],
412 | "include_colab_link": true
413 | },
414 | "kernelspec": {
415 | "display_name": "Python 3",
416 | "name": "python3"
417 | },
418 | "language_info": {
419 | "name": "python"
420 | }
421 | },
422 | "nbformat": 4,
423 | "nbformat_minor": 0
424 | }
--------------------------------------------------------------------------------
/notebooks/llm_personality.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": [],
7 | "machine_shape": "hm",
8 | "gpuType": "A100"
9 | },
10 | "kernelspec": {
11 | "name": "python3",
12 | "display_name": "Python 3"
13 | },
14 | "language_info": {
15 | "name": "python"
16 | },
17 | "accelerator": "GPU"
18 | },
19 | "cells": [
20 | {
21 | "cell_type": "code",
22 | "execution_count": null,
23 | "metadata": {
24 | "id": "End1Gt7Js76-"
25 | },
26 | "outputs": [],
27 | "source": [
28 | "!pip install transformers accelerate bitsandbytes gradio -U -q"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "source": [
34 | "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
35 | "import torch\n",
36 | "\n",
37 | "model_id = \"HuggingFaceH4/zephyr-7b-beta\"\n",
38 | "model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map=\"auto\")\n",
39 | "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
40 | "tokenizer.use_default_system_prompt = False"
41 | ],
42 | "metadata": {
43 | "id": "e72HaxzatJmz"
44 | },
45 | "execution_count": null,
46 | "outputs": []
47 | },
48 | {
49 | "cell_type": "code",
50 | "source": [
51 | "SAMPLE_PERSONALITY = {\n",
52 | " \"characters\": [\n",
53 | " {\n",
54 | " \"name\": \"Alice\",\n",
55 | " \"physicalDescription\": \"Alice is a young woman with long, wavy brown hair and hazel eyes. She is of average height and has a slim build. Her most distinctive feature is her warm, friendly smile.\",\n",
56 | " \"personalityTraits\": [\n",
57 | " \"Alice is a kind, compassionate, and intelligent woman. She is always willing to help others and is a great listener. She is also very creative and has a great sense of humor.\",\n",
58 | " ],\n",
59 | " \"likes\": [\n",
60 | " \"Alice loves spending time with her friends and family.\",\n",
61 | " \"She enjoys reading, writing, and listening to music.\",\n",
62 | " \"She is also a big fan of traveling and exploring new places.\",\n",
63 | " ],\n",
64 | " \"dislikes\": [\n",
65 | " \"Alice dislikes rudeness and cruelty.\",\n",
66 | " \"She also dislikes being lied to or taken advantage of.\",\n",
67 | " \"She is not a fan of heights or roller coasters.\",\n",
68 | " ],\n",
69 | " \"background\": [\n",
70 | " \"Alice grew up in a small town in the Midwest.\",\n",
71 | " \"She was always a good student and excelled in her studies.\",\n",
72 | " \"After graduating from high school, she moved to the city to attend college.\",\n",
73 | " \"She is currently working as a social worker.\",\n",
74 | " ],\n",
75 | " \"goals\": [\n",
76 | " \"Alice wants to make a difference in the world.\",\n",
77 | " \"She hopes to one day open her own counseling practice.\",\n",
78 | " \"She also wants to travel the world and experience different cultures.\",\n",
79 | " ],\n",
80 | " \"relationships\": [\n",
81 | " \"Alice is very close to her family and friends.\",\n",
82 | " \"She is also in a loving relationship with her partner, Ben.\",\n",
83 | " \"She has a good relationship with her colleagues and is well-respected by her clients.\",\n",
84 | " ],\n",
85 | " }\n",
86 | " ]\n",
87 | "}"
88 | ],
89 | "metadata": {
90 | "id": "tvhZLGmYt3B0"
91 | },
92 | "execution_count": null,
93 | "outputs": []
94 | },
95 | {
96 | "cell_type": "code",
97 | "source": [
98 | "SAMPLE_PERSONALITY[\"characters\"][0].keys()"
99 | ],
100 | "metadata": {
101 | "id": "SyOk6YkHvHlU"
102 | },
103 | "execution_count": null,
104 | "outputs": []
105 | },
106 | {
107 | "cell_type": "code",
108 | "source": [
109 | "def get_system_prompt(personality_json_dict: dict) -> str:\n",
110 | " \"\"\"Assumes a single character is passed.\"\"\"\n",
111 | " name = personality_json_dict[\"name\"]\n",
112 | " physcial_description = personality_json_dict[\"physicalDescription\"]\n",
113 | " personality_traits = [trait for trait in personality_json_dict[\"personalityTraits\"]]\n",
114 | " likes = [like for like in personality_json_dict[\"likes\"]]\n",
115 | " dislikes = [dislike for dislike in personality_json_dict[\"dislikes\"]]\n",
116 | " background = [info for info in personality_json_dict[\"background\"]]\n",
117 | " goals = [goal for goal in personality_json_dict[\"goals\"]]\n",
118 | " relationships = [relationship for relationship in personality_json_dict[\"relationships\"]]\n",
119 | "\n",
120 | " system_prompt = f\"\"\"\n",
121 | "You are acting as the character detailed below. The details of the character contain different traits, starting from its inherent personality traits to its background.\n",
122 | "\n",
123 | "* Name: {name}\n",
124 | "* Physical description: {physcial_description}\n",
125 | "* Personality traits: {', '.join(personality_traits)}\n",
126 | "* Likes: {', '.join(likes)}\n",
127 | "* Background: {', '.join(background)}\n",
128 | "* Goals: {', '.join(goals)}\n",
129 | "* Relationships: {', '.join(relationships)}\n",
130 | "\n",
131 | "While generating your responses, you must consider the information above.\n",
132 | "\"\"\"\n",
133 | " return system_prompt"
134 | ],
135 | "metadata": {
136 | "id": "Ri_pnhp9thoH"
137 | },
138 | "execution_count": null,
139 | "outputs": []
140 | },
141 | {
142 | "cell_type": "code",
143 | "source": [
144 | "from pprint import pprint\n",
145 | "\n",
146 | "pprint(get_system_prompt(SAMPLE_PERSONALITY[\"characters\"][0]))"
147 | ],
148 | "metadata": {
149 | "id": "jIchZaQRwo-V"
150 | },
151 | "execution_count": null,
152 | "outputs": []
153 | },
154 | {
155 | "cell_type": "code",
156 | "source": [
157 | "import os\n",
158 | "\n",
159 | "MAX_MAX_NEW_TOKENS = 2048\n",
160 | "DEFAULT_MAX_NEW_TOKENS = 1024\n",
161 | "MAX_INPUT_TOKEN_LENGTH = int(os.getenv(\"MAX_INPUT_TOKEN_LENGTH\", \"4096\"))"
162 | ],
163 | "metadata": {
164 | "id": "2HUwseXzxN_H"
165 | },
166 | "execution_count": null,
167 | "outputs": []
168 | },
169 | {
170 | "cell_type": "code",
171 | "source": [
172 | "!wget https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/raw/main/style.css --content-disposition"
173 | ],
174 | "metadata": {
175 | "id": "5tbbhUbmx9oV"
176 | },
177 | "execution_count": null,
178 | "outputs": []
179 | },
180 | {
181 | "cell_type": "code",
182 | "source": [
183 | "import gradio as gr\n",
184 | "from threading import Thread\n",
185 | "from transformers import TextIteratorStreamer\n",
186 | "\n",
187 | "def generate(\n",
188 | " message: str,\n",
189 | " chat_history: list[tuple[str, str]],\n",
190 | " max_new_tokens: int = 1024,\n",
191 | " temperature: float = 0.6,\n",
192 | " top_p: float = 0.9,\n",
193 | " top_k: int = 50,\n",
194 | " repetition_penalty: float = 1.2,\n",
195 | "):\n",
196 | " conversation = []\n",
197 | " system_prompt = get_system_prompt(SAMPLE_PERSONALITY[\"characters\"][0])\n",
198 | " conversation.append({\"role\": \"system\", \"content\": system_prompt})\n",
199 | " for user, assistant in chat_history:\n",
200 | " conversation.extend([{\"role\": \"user\", \"content\": user}, {\"role\": \"assistant\", \"content\": assistant}])\n",
201 | " conversation.append({\"role\": \"user\", \"content\": message})\n",
202 | "\n",
203 | " input_ids = tokenizer.apply_chat_template(conversation, return_tensors=\"pt\")\n",
204 | " if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:\n",
205 | " input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]\n",
206 | " gr.Warning(f\"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.\")\n",
207 | " input_ids = input_ids.to(model.device)\n",
208 | "\n",
209 | " streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)\n",
210 | " generate_kwargs = dict(\n",
211 | " {\"input_ids\": input_ids},\n",
212 | " streamer=streamer,\n",
213 | " max_new_tokens=max_new_tokens,\n",
214 | " do_sample=True,\n",
215 | " top_p=top_p,\n",
216 | " top_k=top_k,\n",
217 | " temperature=temperature,\n",
218 | " num_beams=1,\n",
219 | " repetition_penalty=repetition_penalty,\n",
220 | " )\n",
221 | " t = Thread(target=model.generate, kwargs=generate_kwargs)\n",
222 | " t.start()\n",
223 | "\n",
224 | " outputs = []\n",
225 | " for text in streamer:\n",
226 | " outputs.append(text.replace(\"<|assistant|>\", \"\"))\n",
227 | " yield \"\".join(outputs)\n",
228 | "\n",
229 | "\n",
230 | "chat_interface = gr.ChatInterface(\n",
231 | " fn=generate,\n",
232 | " additional_inputs=[\n",
233 | " gr.Slider(\n",
234 | " label=\"Max new tokens\",\n",
235 | " minimum=1,\n",
236 | " maximum=MAX_MAX_NEW_TOKENS,\n",
237 | " step=1,\n",
238 | " value=DEFAULT_MAX_NEW_TOKENS,\n",
239 | " ),\n",
240 | " gr.Slider(\n",
241 | " label=\"Temperature\",\n",
242 | " minimum=0.1,\n",
243 | " maximum=4.0,\n",
244 | " step=0.1,\n",
245 | " value=0.6,\n",
246 | " ),\n",
247 | " gr.Slider(\n",
248 | " label=\"Top-p (nucleus sampling)\",\n",
249 | " minimum=0.05,\n",
250 | " maximum=1.0,\n",
251 | " step=0.05,\n",
252 | " value=0.9,\n",
253 | " ),\n",
254 | " gr.Slider(\n",
255 | " label=\"Top-k\",\n",
256 | " minimum=1,\n",
257 | " maximum=1000,\n",
258 | " step=1,\n",
259 | " value=50,\n",
260 | " ),\n",
261 | " gr.Slider(\n",
262 | " label=\"Repetition penalty\",\n",
263 | " minimum=1.0,\n",
264 | " maximum=2.0,\n",
265 | " step=0.05,\n",
266 | " value=1.2,\n",
267 | " ),\n",
268 | " ],\n",
269 | " stop_btn=None,\n",
270 | " examples=[\n",
271 | " [\"Hello there! How are you doing?\"],\n",
272 | " [\"Recite me a short poem.\"],\n",
273 | " [\"Explain the plot of Cinderella in a sentence.\"],\n",
274 | " [\"Write a 100-word article on 'Benefits of Open-Source in AI research'\"],\n",
275 | " ],\n",
276 | ")\n",
277 | "\n",
278 | "with gr.Blocks(css=\"style.css\") as demo:\n",
279 | " gr.Markdown(\"## Demo of Vid2Persona chat component\")\n",
280 | " gr.DuplicateButton(value=\"Duplicate Space for private use\", elem_id=\"duplicate-button\")\n",
281 | " chat_interface.render()\n",
282 | "\n",
283 | "demo.launch()"
284 | ],
285 | "metadata": {
286 | "id": "Cvot5sxWw5sB"
287 | },
288 | "execution_count": null,
289 | "outputs": []
290 | }
291 | ]
292 | }
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | toml
2 | google-cloud-aiplatform
3 | transformers
4 | accelerate
5 | bitsandbytes
6 | openai
7 | gradio
--------------------------------------------------------------------------------
/styles.css:
--------------------------------------------------------------------------------
1 | .textbox-no-label > label > span {
2 | display: none;
3 | }
4 |
5 | .textbox-no-top-bottom-borders > label > textarea {
6 | border: none !important;
7 | }
8 |
9 | .chatbot-no-label > div > label {
10 | display: none;
11 | }
12 |
13 | .md-center {
14 | text-align: center;
15 | display: block;
16 | }
17 |
18 | .h1-font > span {
19 | font-size: xx-large;
20 | font-weight: bold;
21 | }
22 |
23 | .json-holder {
24 | overflow: scroll;
25 | height: 500px;
26 | }
27 |
28 | .group {
29 | padding-top: 10px;
30 | padding-left: 10px;
31 | padding-right: 10px;
32 | padding-bottom: 10px;
33 | border: 2px dashed gray;
34 | border-radius: 20px;
35 | box-shadow: 5px 3px 10px 1px rgba(0, 0, 0, 0.4) !important;
36 | }
37 |
38 | #bottom-md a {
39 | float: left;
40 | margin-right: 10px;
41 | }
42 |
43 | #chatbot {
44 | height: 600px !important;
45 | }
--------------------------------------------------------------------------------
/vid2persona/gen/gemini.py:
--------------------------------------------------------------------------------
1 | from typing import Union, Iterable
2 |
3 | import vertexai
4 | from vertexai.generative_models import (
5 | GenerativeModel, Part,
6 | GenerationResponse, GenerationConfig
7 | )
8 |
9 | from .utils import parse_first_json_snippet
10 |
11 | def _default_gen_config():
12 | return GenerationConfig(
13 | max_output_tokens=2048,
14 | temperature=0.4,
15 | top_p=1,
16 | top_k=32
17 | )
18 |
19 | def init_vertexai(project_id: str, location: str) -> None:
20 | vertexai.init(project=project_id, location=location)
21 |
22 | async def _ask_about_video(
23 | prompt: str="What is in the video?",
24 | gen_config: dict=_default_gen_config(),
25 | model_name: str="gemini-1.0-pro-vision",
26 | gcs: str=None,
27 | base64_content: bytes=None
28 | ) -> Union[GenerationResponse, Iterable[GenerationResponse]]:
29 | if gcs is None and base64_content is None:
30 | raise ValueError("Either a GCS bucket path or base64_encoded string of the video must be provided")
31 |
32 | if gcs is not None and base64_content is not None:
33 | raise ValueError("Only one of gcs or base64_encoded must be provided")
34 |
35 | if gcs is not None:
36 | video = Part.from_uri(gcs, mime_type="video/mp4")
37 | else:
38 | video = Part.from_data(data=base64_content, mime_type="video/mp4")
39 |
40 | model = GenerativeModel(model_name)
41 | return await model.generate_content_async(
42 | [video, prompt],
43 | generation_config=gen_config
44 | )
45 |
46 | async def ask_about_video(prompt: str, video_clip: bytes, retry_num: int=10):
47 | json_content = None
48 | cur_retry = 0
49 |
50 | while json_content is None and cur_retry < retry_num:
51 | try:
52 | resps = await _ask_about_video(
53 | prompt=prompt, base64_content=video_clip
54 | )
55 |
56 | json_content = parse_first_json_snippet(resps.text)
57 | except Exception as e:
58 | cur_retry = cur_retry + 1
59 | print(f"......retry {e}")
60 |
61 | return json_content
--------------------------------------------------------------------------------
/vid2persona/gen/local_openllm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from threading import Thread
3 | from transformers import AutoModelForCausalLM, AutoTokenizer
4 | from transformers import TextIteratorStreamer
5 |
6 | model = None
7 | tokenizer = None
8 |
9 | def send_message(
10 | messages: list,
11 | model_id: str,
12 | max_input_token_length: int,
13 | parameters: dict
14 | ):
15 | global tokenizer
16 | global model
17 |
18 | if tokenizer is None:
19 | tokenizer = AutoTokenizer.from_pretrained(model_id)
20 | tokenizer.use_default_system_prompt = False
21 | if model is None:
22 | model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
23 |
24 | input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt")
25 | if input_ids.shape[1] > max_input_token_length:
26 | input_ids = input_ids[:, -max_input_token_length:]
27 | print(f"Trimmed input from conversation as it was longer than {max_input_token_length} tokens.")
28 | input_ids = input_ids.to(model.device)
29 |
30 | streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
31 | generate_kwargs = dict(
32 | {"input_ids": input_ids},
33 | streamer=streamer,
34 | do_sample=True,
35 | num_beams=1,
36 | **parameters
37 | )
38 | t = Thread(target=model.generate, kwargs=generate_kwargs)
39 | t.start()
40 |
41 | for text in streamer:
42 | yield text.replace("<|assistant|>", "")
--------------------------------------------------------------------------------
/vid2persona/gen/tgi_openllm.py:
--------------------------------------------------------------------------------
1 | from openai import AsyncOpenAI
2 |
3 | async def send_messages(
4 | messages: list,
5 | model_id: str,
6 | hf_token: str,
7 | parameters: dict
8 | ):
9 | parameters.pop('repetition_penalty')
10 | parameters['max_tokens'] = parameters.pop('max_new_tokens')
11 | parameters['logprobs'] = True
12 | parameters['top_logprobs'] = parameters.pop('top_k')
13 | # parameters['presence_penalty'] = parameters.pop('repetition_penalty')
14 |
15 | client = AsyncOpenAI(
16 | base_url=f"https://api-inference.huggingface.co/models/{model_id}/v1",
17 | api_key=hf_token,
18 | )
19 |
20 | responses = await client.chat.completions.create(
21 | model="tgi", messages=messages, stream=True, **parameters
22 | )
23 |
24 | async for response in responses:
25 | yield response.choices[0].delta.content
--------------------------------------------------------------------------------
/vid2persona/gen/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | def find_json_snippet(raw_snippet):
4 | json_parsed_string = None
5 |
6 | json_start_index = raw_snippet.find('{')
7 | json_end_index = raw_snippet.rfind('}')
8 |
9 | if json_start_index >= 0 and json_end_index >= 0:
10 | json_snippet = raw_snippet[json_start_index:json_end_index+1]
11 | try:
12 | json_parsed_string = json.loads(json_snippet, strict=False)
13 | except:
14 | raise ValueError('......failed to parse string into JSON format')
15 | else:
16 | raise ValueError('......No JSON code snippet found in string.')
17 |
18 | return json_parsed_string
19 |
20 | def parse_first_json_snippet(snippet):
21 | json_parsed_string = None
22 |
23 | if isinstance(snippet, list):
24 | for snippet_piece in snippet:
25 | try:
26 | json_parsed_string = find_json_snippet(snippet_piece)
27 | return json_parsed_string
28 | except:
29 | pass
30 | else:
31 | try:
32 | json_parsed_string = find_json_snippet(snippet)
33 | except Exception as e:
34 | print(e)
35 | raise ValueError()
36 |
37 | return json_parsed_string
--------------------------------------------------------------------------------
/vid2persona/init.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | # https://huggingface.co/blog/inference-pro
4 | ALLOWED_LLM_FOR_HF_PRO_ACCOUNTS = [
5 | "mistralai/Mixtral-8x7B-Instruct-v0.1",
6 | "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
7 | "mistralai/Mistral-7B-Instruct-v0.2",
8 | "mistralai/Mistral-7B-Instruct-v0.1",
9 | "HuggingFaceH4/zephyr-7b-beta",
10 | "meta-llama/Llama-2-7b-chat-hf",
11 | "meta-llama/Llama-2-13b-chat-hf",
12 | "meta-llama/Llama-2-70b-chat-hf",
13 | "openchat/openchat-3.5-0106"
14 | ]
15 |
16 | def get_env_vars():
17 | gcp_project_id = os.getenv("GCP_PROJECT_ID", None)
18 | gcp_project_loc = os.getenv("GCP_PROJECT_LOCATION", None)
19 |
20 | return gcp_project_id, gcp_project_loc
--------------------------------------------------------------------------------
/vid2persona/pipeline/llm.py:
--------------------------------------------------------------------------------
1 | import toml
2 | from string import Template
3 | from transformers import AutoTokenizer
4 |
5 | from vid2persona.gen import tgi_openllm
6 | from vid2persona.gen import local_openllm
7 |
8 | tokenizer = None
9 |
10 | def _get_system_prompt(
11 | personality_json_dict: dict,
12 | prompt_tpl_path: str
13 | ) -> str:
14 | """Assumes a single character is passed."""
15 | prompt_tpl_path = f"{prompt_tpl_path}/llm.toml"
16 | system_prompt = Template(toml.load(prompt_tpl_path)['conversation']['system'])
17 |
18 | name = personality_json_dict["name"]
19 | physcial_description = personality_json_dict["physicalDescription"]
20 | personality_traits = [str(trait) for trait in personality_json_dict["personalityTraits"]]
21 | likes = [str(like) for like in personality_json_dict["likes"]]
22 | dislikes = [str(dislike) for dislike in personality_json_dict["dislikes"]]
23 | background = [str(info) for info in personality_json_dict["background"]]
24 | goals = [str(goal) for goal in personality_json_dict["goals"]]
25 | relationships = [str(relationship) for relationship in personality_json_dict["relationships"]]
26 |
27 | system_prompt = system_prompt.substitute(
28 | name=name,
29 | physcial_description=physcial_description,
30 | personality_traits=', '.join(personality_traits),
31 | likes=', '.join(likes),
32 | background=', '.join(background),
33 | goals=', '.join(goals),
34 | relationships=', '.join(relationships)
35 | )
36 |
37 | return system_prompt
38 |
39 | async def chat(
40 | message: str,
41 | chat_history: list[tuple[str, str]],
42 | personality_json_dict: dict,
43 | prompt_tpl_path: str,
44 |
45 | model_id: str,
46 | max_input_token_length: int,
47 | max_new_tokens: int,
48 | temperature: float,
49 | top_p: float,
50 | top_k: int,
51 | repetition_penalty: float,
52 |
53 | hf_token: str,
54 | ):
55 | messages = []
56 | system_prompt = _get_system_prompt(personality_json_dict, prompt_tpl_path)
57 | messages.append({"role": "system", "content": system_prompt})
58 | for user, assistant in chat_history:
59 | messages.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
60 | messages.append({"role": "user", "content": message})
61 |
62 | parameters = {
63 | "max_new_tokens": max_new_tokens,
64 | "temperature": temperature,
65 | "top_p": top_p,
66 | "top_k": top_k,
67 | "repetition_penalty": repetition_penalty
68 | }
69 |
70 | if hf_token is None:
71 | for response in local_openllm.send_message(messages, model_id, max_input_token_length, parameters):
72 | yield response
73 | else:
74 | async for response in tgi_openllm.send_messages(messages, model_id, hf_token, parameters):
75 | yield response
--------------------------------------------------------------------------------
/vid2persona/pipeline/vlm.py:
--------------------------------------------------------------------------------
1 | import toml
2 | from vid2persona.gen.gemini import init_vertexai, ask_about_video
3 | from vid2persona.utils import get_base64_content
4 |
5 | async def get_traits(
6 | gcp_project_id: str, gcp_project_location: str,
7 | video_clip_path: str, prompt_tpl_path: str,
8 | ):
9 | prompt_tpl_path = f"{prompt_tpl_path}/vlm.toml"
10 | prompt = toml.load(prompt_tpl_path)['extraction']['traits']
11 | init_vertexai(gcp_project_id, gcp_project_location)
12 | video_clip = get_base64_content(video_clip_path)
13 |
14 | response = await ask_about_video(prompt=prompt, video_clip=video_clip)
15 | return response
--------------------------------------------------------------------------------
/vid2persona/prompts/llm.toml:
--------------------------------------------------------------------------------
1 | [conversation]
2 | system = """
3 | You are acting as the character detailed below. The details of the character contain different traits, starting from its inherent personality traits to its background.
4 |
5 | * Name: $name
6 | * Physical description: $physcial_description
7 | * Personality traits: $personality_traits
8 | * Likes: $likes
9 | * Background: $background
10 | * Goals: $goals
11 | * Relationships: $relationships
12 |
13 | While generating your responses, you must consider the information above.
14 | """
15 |
16 | examples = [
17 | ["Hello there! How are you doing?"],
18 | ["Recite me a short poem."],
19 | ["Explain the plot of Cinderella in a sentence."],
20 | ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
21 | ]
--------------------------------------------------------------------------------
/vid2persona/prompts/vlm.toml:
--------------------------------------------------------------------------------
1 | [extraction]
2 | traits = """
3 | Carefully analyze the provided video clip to identify and extract detailed information about the main character(s) featured. Pay attention to visual elements, spoken dialogue, character interactions, and any narrative cues that reveal aspects of the character's personality, physical appearance, behaviors, and background.
4 |
5 | Your task is to construct a rich, imaginative character profile based on your observations, and where explicit information is not available, you are encouraged to use your creativity to fill in the gaps. The goal is to create a vivid, believable character profile that can be used to simulate conversation with a language model as if it were the character itself.
6 |
7 | Format the extracted data as a structured JSON object containing the following fields for each main character:
8 |
9 | name(text): The character's name as mentioned or inferred in the video. If not provided, create a suitable name that matches the character's traits and context.
10 | physicalDescription(text): Describe the character's appearance, including hair color, eye color, height, and distinctive features. Use imaginative details if necessary to provide a complete picture.
11 | personalityTraits(list): List descriptive adjectives or phrases that capture the character's personality, based on their actions and dialogue. Invent traits as needed to ensure a well-rounded personality.
12 | likes(list): Specify things, activities, or concepts the character enjoys or values, deduced or imagined from their behavior and interactions.
13 | dislikes(list): Note what the character appears to dislike or avoid, filling in creatively where direct evidence is not available.
14 | background(list): Provide background information such as occupation, family ties, or significant life events, inferring where possible or inventing details to add depth to the character's story.
15 | goals(list): Describe the character's apparent motivations and objectives, whether explicitly stated or implied. Where not directly observable, construct plausible goals that align with the character's portrayed or inferred traits.
16 | relationships(list): Detail the character's relationships with other characters, including the nature of each relationship and the names of other characters involved. Use creative license to elaborate on these relationships if the video provides limited information.
17 |
18 | Ensure the JSON object is well-structured and comprehensive, ready for integration with a language model to facilitate engaging conversations as if with the character itself. For multiple main characters, provide a distinct profile for each within the same JSON object.
19 | """
20 |
--------------------------------------------------------------------------------
/vid2persona/utils.py:
--------------------------------------------------------------------------------
1 | import base64
2 |
3 | def get_base64_content(file_path, decode=True):
4 | with open(file_path, 'rb') as f:
5 | data = f.read()
6 |
7 | return base64.b64decode(base64.b64encode(data)) if decode else base64.b64encode(data)
8 |
--------------------------------------------------------------------------------