├── .gitignore ├── .milvus_twelvelabs_demo3.db.lock ├── README.md ├── milvus_twelvelabs_demo3.db ├── milvus_twelvelabs_demo5.db ├── requirements.txt └── app.py /.gitignore: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.milvus_twelvelabs_demo3.db.lock: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Twelve-Labs-Fashion-Assistant -------------------------------------------------------------------------------- /milvus_twelvelabs_demo3.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hrishikesh332/Twelve-Labs-Fashion-Assistant/HEAD/milvus_twelvelabs_demo3.db -------------------------------------------------------------------------------- /milvus_twelvelabs_demo5.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hrishikesh332/Twelve-Labs-Fashion-Assistant/HEAD/milvus_twelvelabs_demo5.db -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | streamlit 2 | gunicorn 3 | pandas 4 | pymilvus 5 | milvus 6 | twelvelabs 7 | python-dotenv 8 | torch 9 | torchvision 10 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import time 3 | from twelvelabs import TwelveLabs 4 | # import torch 5 | # from torchvision import models, transforms 6 | from PIL import Image 7 | import pandas as pd 8 | from urllib.parse import urlparse 9 | import uuid 10 | from dotenv import load_dotenv 11 | import os 12 | from pymilvus import MilvusClient 13 | from pymilvus import connections 14 | from pymilvus import ( 15 | FieldSchema, DataType, 16 | CollectionSchema, Collection, 17 | utility 18 | ) 19 | 20 | load_dotenv() 21 | 22 | TWELVELABS_API_KEY = os.getenv('TWELVELABS_API_KEY') 23 | MILVUS_DB_NAME = os.getenv('MILVUS_DB_NAME') 24 | COLLECTION_NAME = os.getenv('COLLECTION_NAME') 25 | MILVUS_HOST = os.getenv('MILVUS_HOST') 26 | MILVUS_PORT = os.getenv('MILVUS_PORT') 27 | URL = os.getenv('URL') 28 | TOKEN = os.getenv('TOKEN') 29 | 30 | # Connect to Milvus 31 | connections.connect( 32 | uri=URL, 33 | token=TOKEN 34 | ) 35 | 36 | # Define fields for schema 37 | fields = [ 38 | FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=False), 39 | FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=1024), 40 | ] 41 | 42 | # Create schema with dynamic fields for metadata 43 | schema = CollectionSchema( 44 | fields=fields, 45 | enable_dynamic_field=True 46 | ) 47 | 48 | # Check if collection exists 49 | if utility.has_collection(COLLECTION_NAME): 50 | # If exists, just load the existing collection 51 | collection = Collection(COLLECTION_NAME) 52 | print(f"Using existing collection: {COLLECTION_NAME}") 53 | else: 54 | # If doesn't exist, create new collection 55 | collection = Collection(COLLECTION_NAME, schema) 56 | print(f"Created new collection: {COLLECTION_NAME}") 57 | 58 | # Create index for new collection 59 | if not collection.has_index(): 60 | collection.create_index( 61 | field_name="vector", 62 | index_params={ 63 | "metric_type": "COSINE", 64 | "index_type": "IVF_FLAT", 65 | "params": {"nlist": 128} 66 | } 67 | ) 68 | print("Created index for the new collection") 69 | 70 | # Load collection for searching 71 | collection.load() 72 | 73 | # Set the milvus_client to the collection 74 | milvus_client = collection 75 | 76 | # st.write(f"Connected to collection: {COLLECTION_NAME}") 77 | 78 | 79 | # # Initialize Milvus client 80 | # milvus_client = MilvusClient( 81 | # uri=URL, 82 | # token=TOKEN 83 | # ) 84 | 85 | # collection_name = COLLECTION_NAME 86 | 87 | # # Check if collection exists and drop if necessary 88 | # if milvus_client.has_collection(collection_name): 89 | # milvus_client.drop_collection(collection_name) 90 | 91 | # # Create collection with proper schema 92 | # milvus_client.create_collection( 93 | # collection_name=collection_name, 94 | # dimension=1024, 95 | # vector_field_name="vector", 96 | # enable_dynamic_field=True 97 | # ) 98 | 99 | # # Create index 100 | # milvus_client.create_index( 101 | # collection_name=collection_name, 102 | # field_name="vector", 103 | # index_params={ 104 | # "metric_type": "COSINE", 105 | # "index_type": "IVF_FLAT", 106 | # "params": {"nlist": 128} 107 | # } 108 | # ) 109 | 110 | # # Load collection 111 | # milvus_client.load_collection(collection_name) 112 | 113 | # st.write(f"Collection '{COLLECTION_NAME}' created successfully") 114 | # st.write("Hello") 115 | 116 | # Clean, professional CSS styling 117 | st.markdown(""" 118 | 255 | """, unsafe_allow_html=True) 256 | 257 | 258 | 259 | def generate_embedding(video_url): 260 | try: 261 | twelvelabs_client = TwelveLabs(api_key="tlk_32YBVAW1GVJHV42ASQ5KB3WEJYW1") 262 | print(f"Processing video URL: {video_url}") 263 | 264 | task = twelvelabs_client.embed.task.create( 265 | engine_name="Marengo-retrieval-2.6", 266 | video_url=video_url 267 | ) 268 | print(f"Created task: id={task.id} engine_name={task.engine_name} status={task.status}") 269 | 270 | status = task.wait_for_done( 271 | sleep_interval=2, 272 | callback=lambda t: print(f"Status={t.status}") 273 | ) 274 | print(f"Embedding done: {status}") 275 | 276 | # Get the task result explicitly using the client 277 | task_result = twelvelabs_client.embed.task.retrieve(task.id) 278 | print(task_result) 279 | 280 | if hasattr(task_result, 'video_embedding') and task_result.video_embedding is not None and task_result.video_embedding.segments is not None: 281 | embeddings = [] 282 | for segment in task_result.video_embedding.segments: 283 | embeddings.append({ 284 | 'embedding': segment.embeddings_float, 285 | 'start_offset_sec': segment.start_offset_sec, 286 | 'end_offset_sec': segment.end_offset_sec, 287 | 'embedding_scope': segment.embedding_scope, 288 | 'video_url': video_url 289 | }) 290 | return embeddings, task_result, None 291 | else: 292 | return None, None, "No embeddings found in task result" 293 | 294 | except Exception as e: 295 | print(f"Error in generate_embedding: {str(e)}") 296 | return None, None, str(e) 297 | 298 | 299 | 300 | def image_embedding(twelvelabs_client, image_file): 301 | 302 | embedding_result = twelvelabs_client.embed.create( 303 | engine_name="Marengo-retrieval-2.6", 304 | image_file=image_file 305 | ) 306 | 307 | # if verbose: 308 | # print("Created an image embedding") 309 | # print(f" Engine: {embedding_result.engine_name}") 310 | # if embedding_result.image_embedding and embedding_result.image_embedding.segments: 311 | # first_segment = embedding_result.image_embedding.segments[0] 312 | # print(f" Embedding: {first_segment.embeddings_float[:5]}... (truncated)") 313 | 314 | return embedding_result.image_embedding.segments[0].embeddings_float 315 | 316 | 317 | # class ImageEncoder: 318 | # def __init__(self): 319 | # self.model = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1) 320 | # self.model = torch.nn.Sequential(*list(self.model.children())[:-1]) 321 | # self.projection = torch.nn.Linear(512, 1024) 322 | # self.model.eval() 323 | 324 | # def encode(self, image): 325 | # if isinstance(image, str): 326 | # img = Image.open(image) 327 | # else: 328 | # img = Image.open(image) 329 | # img = img.convert('RGB') 330 | 331 | # transform = transforms.Compose([ 332 | # transforms.Resize(256), 333 | # transforms.CenterCrop(224), 334 | # transforms.ToTensor(), 335 | # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 336 | # ]) 337 | 338 | # img = transform(img).unsqueeze(0) 339 | 340 | # with torch.no_grad(): 341 | # features = self.model(img).squeeze() 342 | # features = self.projection(features) 343 | 344 | # return features.numpy() 345 | 346 | def insert_embeddings(embeddings, video_url): 347 | data = [] 348 | timestamp = int(time.time()) 349 | 350 | for i, emb in enumerate(embeddings): 351 | data.append({ 352 | "id": int(f"{timestamp}{i:03d}"), 353 | "vector": emb['embedding'], 354 | "metadata": { 355 | "scope": emb['embedding_scope'], 356 | "start_time": emb['start_offset_sec'], 357 | "end_time": emb['end_offset_sec'], 358 | "video_url": video_url 359 | } 360 | }) 361 | 362 | try: 363 | # Modified insert call for Collection object 364 | insert_result = milvus_client.insert(data) 365 | # Force flush to ensure data is persisted 366 | milvus_client.flush() 367 | return True, len(data) 368 | except Exception as e: 369 | return False, str(e) 370 | 371 | 372 | def search_similar_videos(image, top_k=5): 373 | # encoder = ImageEncoder() 374 | # features = encoder.encode(image) 375 | twelvelabs_client = TwelveLabs(api_key=TWELVELABS_API_KEY) 376 | 377 | features = image_embedding( 378 | twelvelabs_client=twelvelabs_client, 379 | image_file=image 380 | ) 381 | 382 | 383 | 384 | results = milvus_client.search( 385 | data=[features], 386 | anns_field="vector", 387 | param={"metric_type": "COSINE", "params": {"nprobe": 10}}, 388 | limit=top_k, 389 | output_fields=["metadata"] 390 | ) 391 | 392 | search_results = [] 393 | for hits in results: 394 | for hit in hits: 395 | metadata = hit.entity.get('metadata') 396 | if metadata: 397 | search_results.append({ 398 | 'Title': metadata['title'], 399 | 'Description': metadata['description'], 400 | 'Link': metadata['link'], 401 | 'Start Time': f"{metadata['start_time']:.1f}s", 402 | 'End Time': f"{metadata['end_time']:.1f}s", 403 | 'Video URL': metadata['video_url'], 404 | 'Similarity': f"{(1 - float(hit.distance)) * 100:.2f}%" 405 | }) 406 | 407 | return search_results 408 | 409 | 410 | def format_time_for_url(seconds): 411 | return f"{int(float(seconds))}" 412 | 413 | def get_video_id_from_url(url): 414 | 415 | parsed_url = urlparse(url) 416 | 417 | # Vimeo 418 | if 'vimeo.com' in url: 419 | return parsed_url.path[1:], 'vimeo' 420 | 421 | # Direct video URL 422 | elif url.endswith(('.mp4', '.webm', '.ogg')): 423 | return url, 'direct' 424 | 425 | return None, None 426 | 427 | def create_video_embed(video_url, start_time, end_time): 428 | video_id, platform = get_video_id_from_url(video_url) 429 | start_seconds = format_time_for_url(start_time) 430 | 431 | if platform == 'vimeo': 432 | return f""" 433 | 441 | """ 442 | elif platform == 'direct': 443 | return f""" 444 | 453 | 458 | """ 459 | else: 460 | return f"

Unable to embed video from URL: {video_url}

" 461 | 462 | 463 | def main(): 464 | 465 | st.markdown(""" 466 |

467 | 📹 Video Product Search Assitant 468 |

469 | """, unsafe_allow_html=True) 470 | 471 | # Sidebar with system status 472 | with st.sidebar: 473 | st.subheader("System Status") 474 | try: 475 | stats = milvus_client.num_entities 476 | st.success(f"Connected to: {COLLECTION_NAME}") 477 | st.info(f"Total Video Segments: {stats:,}") 478 | except Exception as e: 479 | st.error(f"Connection Error: {str(e)}") 480 | 481 | # Main content 482 | tabs = st.tabs(["Add Videos", "Search Videos"]) 483 | 484 | with tabs[0]: 485 | st.subheader("Add New Video to Knowledge Base") 486 | with st.container(): 487 | st.markdown('
', unsafe_allow_html=True) 488 | 489 | video_url = st.text_input( 490 | "Video URL", 491 | placeholder="Enter the URL of your video file", 492 | help="Provide the complete URL of the video you want to process" 493 | ) 494 | 495 | if st.button("Process Video", use_container_width=True): 496 | with st.spinner("Processing video..."): 497 | embeddings, task_result, error = generate_embedding(video_url) 498 | 499 | if error: 500 | st.error(f"Error: {error}") 501 | elif embeddings: 502 | with st.spinner("Storing embeddings..."): 503 | success, result = insert_embeddings(embeddings, video_url) 504 | if success: 505 | st.success(f"Successfully processed {result} segments") 506 | 507 | with st.expander("View Processing Details"): 508 | st.json({ 509 | "Segments processed": result, 510 | "Sample embedding": { 511 | "Time range": f"{embeddings[0]['start_offset_sec']} - {embeddings[0]['end_offset_sec']} seconds", 512 | "Vector preview": embeddings[0]['embedding'][:5] 513 | } 514 | }) 515 | else: 516 | st.error(f"Error inserting embeddings: {result}") 517 | else: 518 | st.error("No embeddings generated from the video") 519 | 520 | st.markdown('
', unsafe_allow_html=True) 521 | 522 | with tabs[1]: 523 | st.subheader("Search Similar Product Clips") 524 | with st.container(): 525 | st.markdown('
', unsafe_allow_html=True) 526 | 527 | col1, col2 = st.columns([1, 2]) 528 | 529 | with col1: 530 | uploaded_file = st.file_uploader( 531 | "Upload Image", 532 | type=['png', 'jpg', 'jpeg'], 533 | help="Select an image to find similar video segments" 534 | ) 535 | 536 | if uploaded_file: 537 | st.image(uploaded_file, caption="Query Image", use_column_width=True) 538 | 539 | with col2: 540 | if uploaded_file: 541 | st.subheader("Search Parameters") 542 | top_k = st.slider( 543 | "Number of results", 544 | min_value=1, 545 | max_value=20, 546 | value=2, 547 | help="Select the number of similar videos to retrieve" 548 | ) 549 | 550 | if st.button("Search", use_container_width=True): 551 | with st.spinner("Searching for similar videos..."): 552 | results = search_similar_videos(uploaded_file, top_k=top_k) 553 | 554 | if not results: 555 | st.warning("No similar videos found") 556 | else: 557 | st.subheader("Results") 558 | for idx, result in enumerate(results, 1): 559 | with st.expander(f"Match #{idx} - Similarity: {result['Similarity']}", expanded=(idx==1)): 560 | 561 | start_time = float(result['Start Time'].replace('s', '')) 562 | end_time = float(result['End Time'].replace('s', '')) 563 | 564 | 565 | video_col, details_col = st.columns([2, 1]) 566 | 567 | with video_col: 568 | st.markdown("#### Video Segment") 569 | 570 | video_embed = create_video_embed( 571 | result['Video URL'], 572 | start_time, 573 | end_time 574 | ) 575 | st.markdown(video_embed, unsafe_allow_html=True) 576 | 577 | with details_col: 578 | st.markdown("#### Details") 579 | 580 | st.markdown(f""" 581 | 📝 **Title** 582 | {result['Title']} 583 | 584 | 📖 **Description** 585 | {result['Description']} 586 | 587 | 🔗 **Link** 588 | [Open Product]({result['Link']}) 589 | 590 | 🕒 **Time Range** 591 | {result['Start Time']} - {result['End Time']} 592 | 593 | 🎥 **Video URL** 594 | [Watch Video]({result['Video URL']}) 595 | 596 | 📊 **Similarity Score** 597 | {result['Similarity']} 598 | """) 599 | if st.button("📋 Copy URL", key=f"copy_{idx}"): 600 | st.code(result['Video URL']) 601 | 602 | st.markdown('
', unsafe_allow_html=True) 603 | 604 | if __name__ == "__main__": 605 | main() 606 | --------------------------------------------------------------------------------