├── .gitignore ├── examples └── videos │ ├── cat_video.mp4 │ ├── cutting_pepper.mp4 │ └── dog_digging.mp4 ├── go.mod ├── go.sum ├── inference ├── __pycache__ │ ├── _registry.cpython-38.pyc │ ├── compare_embedded_text.cpython-38.pyc │ ├── frame_text_processor.cpython-38.pyc │ ├── recommender.cpython-38.pyc │ ├── tokenizer.cpython-38.pyc │ ├── topics.cpython-38.pyc │ ├── video_features.cpython-38.pyc │ ├── web.cpython-38.pyc │ └── zmq_ops.cpython-38.pyc ├── clip │ ├── __pycache__ │ │ ├── frame_text_processor.cpython-38.pyc │ │ └── scene_features.cpython-38.pyc │ ├── frame_text_processor.py │ └── scene_features.py ├── utils │ ├── __pycache__ │ │ ├── audio_features.cpython-38.pyc │ │ ├── scene_features.cpython-38.pyc │ │ ├── tensor.cpython-38.pyc │ │ └── video_downloader.cpython-38.pyc │ ├── tensor.py │ └── video_downloader.py ├── video_features.py ├── zmq_ops.py └── zmq_server.py ├── internal ├── engine │ └── engine.go ├── index │ ├── scene_embedding │ │ └── scene_embedding.go │ ├── video_metadata │ │ └── video_metadata.go │ └── video_scene │ │ └── video_scene.go └── web │ └── engine_controller.go ├── main.go ├── pkg ├── inference │ └── inference.go ├── queue │ ├── queue.go │ ├── queue_iterator.go │ └── queue_test.go ├── sorted_list │ ├── sorted_list.go │ └── sorted_list_test.go ├── util │ └── tensor.go └── zmqpool │ └── zmqpool.go └── readme.md /.gitignore: -------------------------------------------------------------------------------- 1 | /dbs* 2 | /doc -------------------------------------------------------------------------------- /examples/videos/cat_video.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuyARoss/CLIP-video-search/0d9a418f1c96f24cd984636c72a77c756348563e/examples/videos/cat_video.mp4 -------------------------------------------------------------------------------- /examples/videos/cutting_pepper.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuyARoss/CLIP-video-search/0d9a418f1c96f24cd984636c72a77c756348563e/examples/videos/cutting_pepper.mp4 -------------------------------------------------------------------------------- /examples/videos/dog_digging.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuyARoss/CLIP-video-search/0d9a418f1c96f24cd984636c72a77c756348563e/examples/videos/dog_digging.mp4 -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/GuyARoss/clip-video-search 2 | 3 | go 1.18 4 | 5 | require ( 6 | github.com/felixge/fgprof v0.9.3 // indirect 7 | github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect 8 | github.com/google/pprof v0.0.0-20221203041831-ce31453925ec // indirect 9 | github.com/google/uuid v1.3.0 // indirect 10 | github.com/pebbe/zmq4 v1.2.9 // indirect 11 | github.com/pkg/profile v1.7.0 // indirect 12 | github.com/syndtr/goleveldb v1.0.0 // indirect 13 | golang.org/x/exp v0.0.0-20221126150942-6ab00d035af9 // indirect 14 | ) 15 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= 2 | github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= 3 | github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= 4 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 5 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 6 | github.com/felixge/fgprof v0.9.3 h1:VvyZxILNuCiUCSXtPtYmmtGvb65nqXh2QFWc0Wpf2/g= 7 | github.com/felixge/fgprof v0.9.3/go.mod h1:RdbpDgzqYVh/T9fPELJyV7EYJuHB55UTEULNun8eiPw= 8 | github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= 9 | github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= 10 | github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pOkfl+p/TAqKOfFu+7KPlMVpok/w= 11 | github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= 12 | github.com/google/pprof v0.0.0-20211214055906-6f57359322fd/go.mod h1:KgnwoLYCZ8IQu3XUZ8Nc/bM9CCZFOyjUNOSygVozoDg= 13 | github.com/google/pprof v0.0.0-20221203041831-ce31453925ec h1:fR20TYVVwhK4O7r7y+McjRYyaTH6/vjwJOajE+XhlzM= 14 | github.com/google/pprof v0.0.0-20221203041831-ce31453925ec/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo= 15 | github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= 16 | github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 17 | github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= 18 | github.com/ianlancetaylor/demangle v0.0.0-20210905161508-09a460cdf81d/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w= 19 | github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= 20 | github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= 21 | github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= 22 | github.com/pebbe/zmq4 v1.2.9 h1:JlHcdgq6zpppNR1tH0wXJq0XK03pRUc4lBlHTD7aj/4= 23 | github.com/pebbe/zmq4 v1.2.9/go.mod h1:nqnPueOapVhE2wItZ0uOErngczsJdLOGkebMxaO8r48= 24 | github.com/pkg/profile v1.7.0 h1:hnbDkaNWPCLMO9wGLdBFTIZvzDrDfBM2072E1S9gJkA= 25 | github.com/pkg/profile v1.7.0/go.mod h1:8Uer0jas47ZQMJ7VD+OHknK4YDY07LPUC6dEvqDjvNo= 26 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 27 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 28 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 29 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 30 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 31 | github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE= 32 | github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ= 33 | golang.org/x/exp v0.0.0-20221126150942-6ab00d035af9 h1:yZNXmy+j/JpX19vZkVktWqAo7Gny4PBWYYK3zskGpx4= 34 | golang.org/x/exp v0.0.0-20221126150942-6ab00d035af9/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= 35 | golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= 36 | golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 37 | golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 38 | golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 39 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 40 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 41 | gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= 42 | gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= 43 | gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 44 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 45 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 46 | -------------------------------------------------------------------------------- /inference/__pycache__/_registry.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuyARoss/CLIP-video-search/0d9a418f1c96f24cd984636c72a77c756348563e/inference/__pycache__/_registry.cpython-38.pyc -------------------------------------------------------------------------------- /inference/__pycache__/compare_embedded_text.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuyARoss/CLIP-video-search/0d9a418f1c96f24cd984636c72a77c756348563e/inference/__pycache__/compare_embedded_text.cpython-38.pyc -------------------------------------------------------------------------------- /inference/__pycache__/frame_text_processor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuyARoss/CLIP-video-search/0d9a418f1c96f24cd984636c72a77c756348563e/inference/__pycache__/frame_text_processor.cpython-38.pyc -------------------------------------------------------------------------------- /inference/__pycache__/recommender.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuyARoss/CLIP-video-search/0d9a418f1c96f24cd984636c72a77c756348563e/inference/__pycache__/recommender.cpython-38.pyc -------------------------------------------------------------------------------- /inference/__pycache__/tokenizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuyARoss/CLIP-video-search/0d9a418f1c96f24cd984636c72a77c756348563e/inference/__pycache__/tokenizer.cpython-38.pyc -------------------------------------------------------------------------------- /inference/__pycache__/topics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuyARoss/CLIP-video-search/0d9a418f1c96f24cd984636c72a77c756348563e/inference/__pycache__/topics.cpython-38.pyc -------------------------------------------------------------------------------- /inference/__pycache__/video_features.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuyARoss/CLIP-video-search/0d9a418f1c96f24cd984636c72a77c756348563e/inference/__pycache__/video_features.cpython-38.pyc -------------------------------------------------------------------------------- /inference/__pycache__/web.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuyARoss/CLIP-video-search/0d9a418f1c96f24cd984636c72a77c756348563e/inference/__pycache__/web.cpython-38.pyc -------------------------------------------------------------------------------- /inference/__pycache__/zmq_ops.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuyARoss/CLIP-video-search/0d9a418f1c96f24cd984636c72a77c756348563e/inference/__pycache__/zmq_ops.cpython-38.pyc -------------------------------------------------------------------------------- /inference/clip/__pycache__/frame_text_processor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuyARoss/CLIP-video-search/0d9a418f1c96f24cd984636c72a77c756348563e/inference/clip/__pycache__/frame_text_processor.cpython-38.pyc -------------------------------------------------------------------------------- /inference/clip/__pycache__/scene_features.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuyARoss/CLIP-video-search/0d9a418f1c96f24cd984636c72a77c756348563e/inference/clip/__pycache__/scene_features.cpython-38.pyc -------------------------------------------------------------------------------- /inference/clip/frame_text_processor.py: -------------------------------------------------------------------------------- 1 | from transformers import CLIPProcessor, CLIPModel 2 | import torch 3 | 4 | class FrameProcessor: 5 | def __init__(self) -> None: 6 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | 8 | self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device) 9 | self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") 10 | 11 | def text_probability_from_tensor_paths(self, serial_image_tensor_paths: str, text: str) -> str: 12 | image_tensor_paths = serial_image_tensor_paths.split(' ') 13 | 14 | avg_sum = 0.0 15 | for image_tensor_path in image_tensor_paths: 16 | image_tensor = torch.load(image_tensor_path) 17 | image_tensor = torch.load(image_tensor_path).to(self.device) 18 | 19 | inputs = self.processor(text=text, return_tensors="pt", padding=True).to(self.device) 20 | 21 | inputs['pixel_values'] = image_tensor 22 | outputs = self.model(**inputs) 23 | 24 | logits_per_image = outputs.logits_per_image 25 | probs = logits_per_image.squeeze() 26 | 27 | avg_sum += probs.item() 28 | 29 | return str(avg_sum / len(image_tensor_paths)) -------------------------------------------------------------------------------- /inference/clip/scene_features.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import typing 3 | import scenedetect as sd 4 | from PIL import Image 5 | import torch 6 | from utils.tensor import save_tensor 7 | 8 | from transformers import CLIPProcessor, CLIPModel 9 | 10 | class FrameNumTimecode(): 11 | def __init__(self, frame_num: int) -> None: 12 | self.frame_num = frame_num 13 | 14 | class SceneFeatures: 15 | def __init__(self) -> None: 16 | self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") 17 | self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") 18 | 19 | def collect_scenes_in_video(self, video_path: str) -> typing.List[typing.Tuple[sd.FrameTimecode, sd.FrameTimecode]]: 20 | video = sd.open_video(video_path) 21 | sm = sd.SceneManager() 22 | 23 | sm.add_detector(sd.ContentDetector(threshold=27.0)) 24 | sm.detect_scenes(video) 25 | 26 | return sm.get_scene_list() 27 | 28 | def clip_embeddings(self, image: Image): 29 | inputs = self.clip_processor(images=image, return_tensors="pt", padding=True) 30 | input_tokens = { 31 | k: v for k, v in inputs.items() 32 | } 33 | return input_tokens['pixel_values'] 34 | 35 | def clip_features_to_dic(self, num_of_scenes: int, clip_pixel_scenes: typing.List, scenes: typing.List[typing.Tuple[sd.FrameTimecode, sd.FrameTimecode]]) -> typing.Dict[str, any]: 36 | d = {} 37 | d['num_of_scenes'] = num_of_scenes 38 | d['clip_pixel_scenes'] = [{ 39 | 'local_path': save_tensor(s['pixel_embeddings']), 40 | 'scene': { 41 | 'start_frame_num': scenes[s['scene_no']][0].frame_num, 42 | 'end_frame_num': scenes[s['scene_no']][1].frame_num, 43 | } 44 | } for s in clip_pixel_scenes] 45 | return d 46 | 47 | def scene_features(self, video_path: str, no_of_samples: int = 3) -> typing.Dict: 48 | scenes = self.collect_scenes_in_video(video_path) 49 | 50 | cap = cv2.VideoCapture(video_path) 51 | scenes_frame_samples = [] 52 | for scene_idx in range(len(scenes)): 53 | scene_length = abs(scenes[scene_idx][0].frame_num - scenes[scene_idx][1].frame_num) 54 | every_n = round(scene_length/no_of_samples) 55 | local_samples = [(every_n * n) + scenes[scene_idx][0].frame_num for n in range(3)] 56 | 57 | scenes_frame_samples.append(local_samples) 58 | 59 | if len(scenes) == 0: 60 | # this could denote a single contiguous scene. 61 | frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 62 | if frame_count > 0: 63 | every_n = round(frame_count/no_of_samples) 64 | local_samples = [(every_n * n) for n in range(3)] 65 | scenes_frame_samples.append(local_samples) 66 | scenes = [(FrameNumTimecode(0),FrameNumTimecode(frame_count))] 67 | 68 | scene_clip_embeddings = [] 69 | for scene_idx in range(len(scenes_frame_samples)): 70 | scene_samples = scenes_frame_samples[scene_idx] 71 | 72 | pixel_tensors = [] 73 | for frame_sample in scene_samples: 74 | cap.set(1, frame_sample) 75 | ret, frame = cap.read() 76 | if not ret: 77 | print('breaks oops', ret, frame_sample, scene_idx, frame) 78 | break 79 | 80 | pil_image = Image.fromarray(frame) 81 | 82 | clip_pixel_values = self.clip_embeddings(pil_image) 83 | pixel_tensors.append(clip_pixel_values) 84 | 85 | scene_clip_embeddings.append({ 'pixel_embeddings': torch.mean(torch.stack(pixel_tensors), dim=0), 'scene_no': scene_idx }) 86 | 87 | return self.clip_features_to_dic(len(scenes), scene_clip_embeddings, scenes) 88 | -------------------------------------------------------------------------------- /inference/utils/__pycache__/audio_features.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuyARoss/CLIP-video-search/0d9a418f1c96f24cd984636c72a77c756348563e/inference/utils/__pycache__/audio_features.cpython-38.pyc -------------------------------------------------------------------------------- /inference/utils/__pycache__/scene_features.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuyARoss/CLIP-video-search/0d9a418f1c96f24cd984636c72a77c756348563e/inference/utils/__pycache__/scene_features.cpython-38.pyc -------------------------------------------------------------------------------- /inference/utils/__pycache__/tensor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuyARoss/CLIP-video-search/0d9a418f1c96f24cd984636c72a77c756348563e/inference/utils/__pycache__/tensor.cpython-38.pyc -------------------------------------------------------------------------------- /inference/utils/__pycache__/video_downloader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuyARoss/CLIP-video-search/0d9a418f1c96f24cd984636c72a77c756348563e/inference/utils/__pycache__/video_downloader.cpython-38.pyc -------------------------------------------------------------------------------- /inference/utils/tensor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import uuid 3 | 4 | def save_tensor(t: torch.Tensor) -> str: 5 | path = f'/tmp/{uuid.uuid4()}' 6 | torch.save(t, path) 7 | 8 | return path 9 | -------------------------------------------------------------------------------- /inference/utils/video_downloader.py: -------------------------------------------------------------------------------- 1 | from urllib import request 2 | import subprocess 3 | 4 | def get_video_duration(filename: str) -> float: 5 | try: 6 | result = subprocess.run(["ffprobe", "-v", "error", "-show_entries", 7 | "format=duration", "-of", 8 | "default=noprint_wrappers=1:nokey=1", filename], 9 | stdout=subprocess.PIPE, 10 | stderr=subprocess.STDOUT) 11 | return float(result.stdout) 12 | 13 | except: 14 | return 0 15 | 16 | def generic_web_downloader(url: str, path: str, filename: str) -> str: 17 | try: 18 | if not url and ('https://' not in url or 'http://' not in url): 19 | raise Exception('video URL not found') 20 | 21 | with open(f'{path}/{filename}.mp4', 'wb') as video_file: 22 | video_file.write(request.urlopen(url).read()) 23 | 24 | return f'{path}/{filename}.mp4' 25 | 26 | except Exception as e: 27 | raise Exception('error: web downloader: ', e.__str__()) 28 | 29 | -------------------------------------------------------------------------------- /inference/video_features.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import json 3 | 4 | from utils.video_downloader import generic_web_downloader, get_video_duration 5 | from clip.scene_features import SceneFeatures 6 | 7 | sf = SceneFeatures() 8 | 9 | def local_path(path_uri: str) -> str: 10 | if "http://" not in path_uri and "https://" not in path_uri: 11 | # we will assume that this is a local path 12 | return path_uri 13 | 14 | md5 = hashlib.md5(path_uri.encode()) 15 | filename = md5.hexdigest() 16 | 17 | return generic_web_downloader(path_uri, '/tmp', filename) 18 | 19 | def video_features(path_uri: str) -> str: 20 | system_path = local_path(path_uri) 21 | 22 | features = { 23 | 'scene_features': sf.scene_features(system_path), 24 | 'video_duration': get_video_duration(system_path), 25 | } 26 | 27 | return json.dumps(features) 28 | -------------------------------------------------------------------------------- /inference/zmq_ops.py: -------------------------------------------------------------------------------- 1 | from video_features import video_features 2 | from clip.frame_text_processor import FrameProcessor 3 | fp = FrameProcessor() 4 | 5 | registry = {} 6 | registry['video_features'] = video_features 7 | registry['frame_text_processor'] = fp.text_probability_from_tensor_paths -------------------------------------------------------------------------------- /inference/zmq_server.py: -------------------------------------------------------------------------------- 1 | import zmq 2 | import threading 3 | 4 | from zmq_ops import registry 5 | 6 | def build_worker(port: int): 7 | def worker(): 8 | print('started worker', port) 9 | context = zmq.Context() 10 | socket = context.socket(zmq.REP) 11 | socket.bind(f"tcp://*:{str(port)}") 12 | 13 | while True: 14 | message = socket.recv() 15 | message:str = message.decode('utf-8') 16 | 17 | split_message = message.split(',') 18 | type = split_message[0] 19 | 20 | response = registry[type](*split_message[1:]) 21 | socket.send(bytes(response,'utf-8')) 22 | 23 | return worker 24 | 25 | 26 | def main() -> int: 27 | free_ports = [5550, 5551, 5552] 28 | 29 | threads = [threading.Thread(target=build_worker(free_ports[i])) for i in range(len(free_ports))] 30 | for thread in threads: 31 | thread.start() 32 | 33 | if __name__ == '__main__': 34 | SystemExit(main()) -------------------------------------------------------------------------------- /internal/engine/engine.go: -------------------------------------------------------------------------------- 1 | package engine 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | 7 | sceneembedding "github.com/GuyARoss/clip-video-search/internal/index/scene_embedding" 8 | videometadata "github.com/GuyARoss/clip-video-search/internal/index/video_metadata" 9 | sceneIndex "github.com/GuyARoss/clip-video-search/internal/index/video_scene" 10 | "github.com/GuyARoss/clip-video-search/pkg/inference" 11 | sortedlist "github.com/GuyARoss/clip-video-search/pkg/sorted_list" 12 | "github.com/GuyARoss/clip-video-search/pkg/util" 13 | "github.com/google/uuid" 14 | ) 15 | 16 | type VideoItem struct { 17 | VideoURI string `json:"videoURI"` 18 | } 19 | 20 | type Engine struct { 21 | VideoSceneIndex sceneIndex.VideoSceneIndex 22 | SceneEmbeddingIndex sceneembedding.SceneEmbeddingIndex 23 | VideoMetadataIndex videometadata.VideoMetadataIndex 24 | inference inference.InferenceImplementation 25 | } 26 | 27 | func (s *Engine) Next(next interface{}) { 28 | nextItem := next.(*VideoItem) 29 | 30 | if nextItem.VideoURI == "" { 31 | fmt.Println("empty video uri") 32 | return 33 | } 34 | 35 | response, err := s.inference.VideoFeatures(nextItem.VideoURI) 36 | if err != nil { 37 | fmt.Println(err) 38 | return 39 | } 40 | 41 | videoId := uuid.NewString() 42 | 43 | s.VideoMetadataIndex.InsertByVideoID(videoId, videometadata.VideoMetaData{ 44 | VideoDuration: response.VideoDuration, 45 | VideoURI: nextItem.VideoURI, 46 | }) 47 | 48 | sceneIds := []string{} 49 | for _, pixelData := range response.SceneFeatures.ClipPixelScenes { 50 | sceneID := uuid.NewString() 51 | f, err := ioutil.ReadFile(pixelData.LocalPath) 52 | if err != nil { 53 | fmt.Println(err) 54 | } 55 | 56 | sceneIds = append(sceneIds, sceneID) 57 | s.SceneEmbeddingIndex.InsertSceneEmbedding(sceneID, f) 58 | } 59 | 60 | s.VideoSceneIndex.InsertVideoScenes(videoId, sceneIds) 61 | fmt.Println("scenes inserted") 62 | } 63 | 64 | type TransientResult struct { 65 | // TODO(feature): would also be beneficial to understand which scene is the best fit 66 | VideoID string 67 | Score float64 68 | } 69 | 70 | type EngineResult struct { 71 | *TransientResult 72 | VideoURI string 73 | } 74 | 75 | type topNVisitor struct { 76 | *Engine 77 | input string 78 | Results sortedlist.SortableList 79 | } 80 | 81 | func (s *topNVisitor) InferSceneScore(sceneIDs []string) (float64, error) { 82 | tensorPaths := make([]string, len(sceneIDs)) 83 | 84 | for idx, sceneID := range sceneIDs { 85 | cachePath := util.TensorCachedPath(sceneID) 86 | if cachePath != "" { 87 | tensorPaths[idx] = cachePath 88 | continue 89 | } 90 | 91 | tensor := s.SceneEmbeddingIndex.GetTensorBySceneID(sceneID) 92 | tensorPath, err := util.TensorBytesToFile(sceneID, tensor) 93 | 94 | if err != nil { 95 | fmt.Println("cannot read tensor for", sceneID) 96 | continue 97 | } 98 | 99 | tensorPaths[idx] = tensorPath 100 | } 101 | 102 | return s.inference.FrameTextProcessor(tensorPaths, s.input) 103 | } 104 | 105 | func (t *topNVisitor) Visit(vis sceneIndex.VideoSceneRecord) error { 106 | v, err := t.InferSceneScore(vis.SceneIDs) 107 | if err != nil { 108 | return err 109 | } 110 | 111 | t.Results.MaybeAdd(v, &TransientResult{ 112 | Score: v, 113 | VideoID: vis.VideoID, 114 | }) 115 | 116 | return nil 117 | } 118 | 119 | func (e *Engine) TopNFromText(text string, limit int) []EngineResult { 120 | visitorInstance := &topNVisitor{ 121 | input: text, 122 | Engine: e, 123 | Results: sortedlist.New(limit), 124 | } 125 | 126 | err := e.VideoSceneIndex.ForEach(visitorInstance) 127 | if err != nil { 128 | fmt.Println("error propagated in record", err) 129 | } 130 | 131 | recommendations := make([]EngineResult, limit) 132 | 133 | for i, r := range visitorInstance.Results.Results() { 134 | if r != nil { 135 | unpacked := r.(*TransientResult) 136 | 137 | metadata := e.VideoMetadataIndex.GetByVideoID(unpacked.VideoID) 138 | recommendations[i] = EngineResult{ 139 | TransientResult: unpacked, 140 | VideoURI: metadata.VideoURI, 141 | } 142 | } 143 | } 144 | 145 | return recommendations 146 | } 147 | 148 | func NewDefaultEngine(inference inference.InferenceImplementation) (*Engine, error) { 149 | sceneIndex, err := sceneIndex.New("./dbs/scene_index") 150 | if err != nil { 151 | return nil, err 152 | } 153 | 154 | sceneEmbeddingIndex, err := sceneembedding.New("./dbs/scene_embedding_index") 155 | if err != nil { 156 | return nil, err 157 | } 158 | 159 | videoMetadataIndex, err := videometadata.New("./dbs/videometadata_index") 160 | if err != nil { 161 | return nil, err 162 | } 163 | 164 | return &Engine{ 165 | inference: inference, 166 | VideoSceneIndex: sceneIndex, 167 | SceneEmbeddingIndex: sceneEmbeddingIndex, 168 | VideoMetadataIndex: videoMetadataIndex, 169 | }, nil 170 | } 171 | -------------------------------------------------------------------------------- /internal/index/scene_embedding/scene_embedding.go: -------------------------------------------------------------------------------- 1 | package sceneembedding 2 | 3 | import ( 4 | "github.com/syndtr/goleveldb/leveldb" 5 | ) 6 | 7 | type SceneEmbeddingIndex interface { 8 | InsertSceneEmbedding(sceneID string, tensor []byte) 9 | GetTensorBySceneID(string) []byte 10 | } 11 | 12 | type LevelSceneEmbeddingIndex struct { 13 | db *leveldb.DB 14 | } 15 | 16 | func (i *LevelSceneEmbeddingIndex) InsertSceneEmbedding(sceneID string, tensor []byte) { 17 | i.db.Put([]byte(sceneID), tensor, nil) 18 | } 19 | 20 | func (i *LevelSceneEmbeddingIndex) GetTensorBySceneID(sceneID string) []byte { 21 | b, _ := i.db.Get([]byte(sceneID), nil) 22 | 23 | return b 24 | } 25 | 26 | func New(path string) (SceneEmbeddingIndex, error) { 27 | db, err := leveldb.OpenFile(path, nil) 28 | if err != nil { 29 | return nil, err 30 | } 31 | 32 | return &LevelSceneEmbeddingIndex{ 33 | db: db, 34 | }, nil 35 | } 36 | -------------------------------------------------------------------------------- /internal/index/video_metadata/video_metadata.go: -------------------------------------------------------------------------------- 1 | package videometadata 2 | 3 | import ( 4 | "encoding/json" 5 | 6 | "github.com/syndtr/goleveldb/leveldb" 7 | ) 8 | 9 | type VideoMetadataIndex interface { 10 | InsertByVideoID(videoID string, data VideoMetaData) 11 | GetByVideoID(string) *VideoMetaData 12 | } 13 | 14 | type LevelVideoMetadataIndex struct { 15 | db *leveldb.DB 16 | } 17 | 18 | type VideoMetaData struct { 19 | VideoDuration float32 20 | VideoURI string 21 | } 22 | 23 | func (i *LevelVideoMetadataIndex) InsertByVideoID(videoID string, data VideoMetaData) { 24 | b, _ := json.Marshal(data) 25 | 26 | i.db.Put([]byte(videoID), b, nil) 27 | } 28 | 29 | func (i *LevelVideoMetadataIndex) GetByVideoID(videoID string) *VideoMetaData { 30 | b, _ := i.db.Get([]byte(videoID), nil) 31 | 32 | md := &VideoMetaData{} 33 | json.Unmarshal(b, md) 34 | 35 | return md 36 | } 37 | 38 | func New(path string) (VideoMetadataIndex, error) { 39 | db, err := leveldb.OpenFile(path, nil) 40 | if err != nil { 41 | return nil, err 42 | } 43 | 44 | return &LevelVideoMetadataIndex{ 45 | db: db, 46 | }, nil 47 | } 48 | -------------------------------------------------------------------------------- /internal/index/video_scene/video_scene.go: -------------------------------------------------------------------------------- 1 | package videoscene 2 | 3 | import ( 4 | "strings" 5 | "sync" 6 | 7 | "github.com/syndtr/goleveldb/leveldb" 8 | ) 9 | 10 | type VideoSceneRecord struct { 11 | VideoID string 12 | SceneIDs []string 13 | } 14 | 15 | type RecordVisitor interface { 16 | Visit(VideoSceneRecord) error 17 | } 18 | 19 | type VideoSceneIndex interface { 20 | InsertVideoScenes(videoID string, sceneIds []string) 21 | ForEach(RecordVisitor) error 22 | } 23 | 24 | type LevelVideoSceneIndex struct { 25 | db *leveldb.DB 26 | } 27 | 28 | func (i *LevelVideoSceneIndex) InsertVideoScenes(videoID string, sceneIds []string) { 29 | i.db.Put([]byte(videoID), []byte(strings.Join(sceneIds, ",")), nil) 30 | } 31 | 32 | func (i *LevelVideoSceneIndex) ForEach(vis RecordVisitor) error { 33 | iter := i.db.NewIterator(nil, nil) 34 | wg := &sync.WaitGroup{} 35 | 36 | for iter.Next() { 37 | key := iter.Key() 38 | value := iter.Value() 39 | 40 | r := VideoSceneRecord{ 41 | VideoID: string(key), 42 | SceneIDs: strings.Split(string(value), ","), 43 | } 44 | 45 | wg.Add(1) 46 | go func(g *sync.WaitGroup) { 47 | vis.Visit(r) 48 | wg.Done() 49 | }(wg) 50 | // err := vis.Visit(r) 51 | // if err != nil { 52 | // return err 53 | // } 54 | } 55 | wg.Wait() 56 | return nil 57 | } 58 | 59 | func New(path string) (VideoSceneIndex, error) { 60 | db, err := leveldb.OpenFile(path, nil) 61 | if err != nil { 62 | return nil, err 63 | } 64 | 65 | return &LevelVideoSceneIndex{ 66 | db: db, 67 | }, nil 68 | } 69 | -------------------------------------------------------------------------------- /internal/web/engine_controller.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "net/http" 7 | 8 | "github.com/GuyARoss/clip-video-search/internal/engine" 9 | "github.com/GuyARoss/clip-video-search/pkg/queue" 10 | ) 11 | 12 | type EngineWebController struct { 13 | engine *engine.Engine 14 | queue queue.Queue 15 | } 16 | 17 | type SearchRequest struct { 18 | Input string `json:"input"` 19 | MaxResults int `json:"maxResults"` 20 | } 21 | 22 | func (c *EngineWebController) TopNFromText(w http.ResponseWriter, r *http.Request) { 23 | dec := json.NewDecoder(r.Body) 24 | req := &SearchRequest{} 25 | dec.Decode(req) 26 | 27 | if req.MaxResults == 0 { 28 | req.MaxResults = 3 29 | } 30 | 31 | response := c.engine.TopNFromText(req.Input, req.MaxResults) 32 | 33 | rout, _ := json.Marshal(response) 34 | w.Write(rout) 35 | } 36 | 37 | func (c *EngineWebController) Insert(w http.ResponseWriter, r *http.Request) { 38 | dec := json.NewDecoder(r.Body) 39 | req := &engine.VideoItem{} 40 | 41 | dec.Decode(req) 42 | c.queue.Add(req) 43 | 44 | w.Write([]byte(fmt.Sprintf(`{ "queueSize": %d }`, c.queue.Length()))) 45 | } 46 | 47 | func NewEngineController(engine *engine.Engine, queue queue.Queue) *EngineWebController { 48 | return &EngineWebController{ 49 | engine: engine, 50 | queue: queue, 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/GuyARoss/clip-video-search/internal/engine" 7 | "github.com/GuyARoss/clip-video-search/internal/web" 8 | "github.com/GuyARoss/clip-video-search/pkg/inference" 9 | "github.com/GuyARoss/clip-video-search/pkg/queue" 10 | "github.com/GuyARoss/clip-video-search/pkg/zmqpool" 11 | ) 12 | 13 | func main() { 14 | q := queue.New() 15 | 16 | pool := zmqpool.New([]string{ 17 | "tcp://localhost:5550", 18 | "tcp://localhost:5551", 19 | "tcp://localhost:5552", 20 | }) 21 | server := inference.New(pool) 22 | 23 | engine, err := engine.NewDefaultEngine(server) 24 | if err != nil { 25 | panic(err) 26 | } 27 | 28 | go queue.LongLivedIterator(q, engine) 29 | 30 | engineController := web.NewEngineController(engine, q) 31 | 32 | http.HandleFunc("/search", engineController.TopNFromText) 33 | http.HandleFunc("/insert", engineController.Insert) 34 | 35 | http.ListenAndServe(":3000", nil) 36 | } 37 | -------------------------------------------------------------------------------- /pkg/inference/inference.go: -------------------------------------------------------------------------------- 1 | package inference 2 | 3 | import ( 4 | "encoding/json" 5 | "strconv" 6 | "strings" 7 | "sync" 8 | 9 | "github.com/GuyARoss/clip-video-search/pkg/zmqpool" 10 | ) 11 | 12 | type PixelScene struct { 13 | LocalPath string `json:"local_path"` 14 | Scene struct { 15 | StartFrameNum int `json:"start_frame_num"` 16 | EndFrameNum int `json:"end_frame_num"` 17 | } `json:"scene"` 18 | } 19 | 20 | type SceneFeatures struct { 21 | NumOfScenes int `json:"num_of_scenes"` 22 | ClipPixelScenes []PixelScene `json:"clip_pixel_scenes"` 23 | } 24 | 25 | type FullVideoFeatures struct { 26 | SceneFeatures SceneFeatures `json:"scene_features"` 27 | VideoDuration float32 `json:"video_duration"` 28 | } 29 | 30 | type InferenceImplementation interface { 31 | VideoFeatures(string) (*FullVideoFeatures, error) 32 | FrameTextProcessor([]string, string) (float64, error) 33 | } 34 | 35 | type InferenceServer struct { 36 | ipc zmqpool.ZMQPoolImplementation 37 | m *sync.Mutex 38 | } 39 | 40 | type InferenceOperations string 41 | 42 | const ( 43 | VideoFeatures InferenceOperations = "video_features" 44 | FrameTextProcessor InferenceOperations = "frame_text_processor" 45 | ) 46 | 47 | func (s *InferenceServer) FrameTextProcessor(tensorPath []string, text string) (float64, error) { 48 | ipcResponse, err := s.ipc.Send(string(FrameTextProcessor), strings.Join(tensorPath, " "), text) 49 | if err != nil { 50 | return 0.0, err 51 | } 52 | 53 | score, _ := strconv.ParseFloat(ipcResponse, 32) 54 | 55 | return score, nil 56 | } 57 | 58 | func (s *InferenceServer) VideoFeatures(uri string) (*FullVideoFeatures, error) { 59 | ipcResponse, err := s.ipc.Send(string(VideoFeatures), uri) 60 | if err != nil { 61 | return nil, err 62 | } 63 | 64 | features := &FullVideoFeatures{} 65 | json.Unmarshal([]byte(ipcResponse), features) 66 | 67 | return features, nil 68 | } 69 | 70 | func New(ipc zmqpool.ZMQPoolImplementation) InferenceImplementation { 71 | return &InferenceServer{ 72 | ipc: ipc, 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /pkg/queue/queue.go: -------------------------------------------------------------------------------- 1 | package queue 2 | 3 | type Queue interface { 4 | Next() interface{} 5 | IsEmpty() bool 6 | Add(interface{}) 7 | Length() int 8 | } 9 | 10 | type SimpleQueue struct { 11 | data []interface{} 12 | } 13 | 14 | func (s *SimpleQueue) Next() interface{} { 15 | item := s.data[0] 16 | s.data = s.data[1:] 17 | 18 | return item 19 | } 20 | 21 | func (s *SimpleQueue) Length() int { 22 | return len(s.data) 23 | } 24 | 25 | func (s *SimpleQueue) IsEmpty() bool { 26 | return len(s.data) == 0 27 | } 28 | 29 | func (s *SimpleQueue) Add(item interface{}) { 30 | s.data = append(s.data, item) 31 | } 32 | 33 | func New() Queue { 34 | return &SimpleQueue{ 35 | data: make([]interface{}, 0), 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /pkg/queue/queue_iterator.go: -------------------------------------------------------------------------------- 1 | package queue 2 | 3 | import "time" 4 | 5 | type QueueIterator interface { 6 | Next(interface{}) 7 | } 8 | 9 | func LongLivedIterator(q Queue, iter QueueIterator) { 10 | for { 11 | if !q.IsEmpty() { 12 | iter.Next(q.Next()) 13 | } else { 14 | time.Sleep(time.Millisecond * 200) 15 | } 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /pkg/queue/queue_test.go: -------------------------------------------------------------------------------- 1 | package queue 2 | 3 | import "testing" 4 | 5 | func TestQueue(t *testing.T) { 6 | q := New() 7 | messages := []string{"is this working?", "is this working x2?"} 8 | 9 | for _, m := range messages { 10 | q.Add(m) 11 | } 12 | 13 | item := q.Next() 14 | if item != messages[0] { 15 | t.Errorf("expected '%s' got '%s'", messages[0], item) 16 | } 17 | 18 | item = q.Next() 19 | if item != messages[1] { 20 | t.Errorf("expected '%s' got '%s'", messages[1], item) 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /pkg/sorted_list/sorted_list.go: -------------------------------------------------------------------------------- 1 | package sortedlist 2 | 3 | import ( 4 | "math" 5 | "sync" 6 | ) 7 | 8 | type SortableList interface { 9 | Results() []interface{} 10 | MaybeAdd(sortKey float64, item interface{}) 11 | TotalEvaluated() int 12 | } 13 | 14 | type SortedList struct { 15 | lowestRated float64 16 | lowestIndex int 17 | evaluated int 18 | 19 | items []interface{} 20 | sortKeys []float64 21 | m *sync.Mutex 22 | } 23 | 24 | func (s *SortedList) Results() []interface{} { 25 | return s.items 26 | } 27 | 28 | func (s *SortedList) addAtIndex(index int, sortKey float64, item interface{}) { 29 | s.sortKeys[index] = sortKey 30 | s.items[index] = item 31 | 32 | if index == 0 { 33 | s.lowestRated = sortKey 34 | } 35 | } 36 | 37 | func (s *SortedList) MaybeAdd(sortKey float64, item interface{}) { 38 | s.m.Lock() 39 | s.evaluated += 1 40 | 41 | if sortKey > s.lowestRated { 42 | for i := 0; i < len(s.sortKeys); i++ { 43 | k := s.sortKeys[i] 44 | 45 | if sortKey > k { 46 | if i > 0 { 47 | s.addAtIndex(i-1, k, s.items[i]) 48 | } 49 | if i == len(s.sortKeys)-1 { 50 | s.addAtIndex(i, sortKey, item) 51 | } 52 | } 53 | 54 | if sortKey < k { 55 | s.addAtIndex(i-1, sortKey, item) 56 | 57 | break 58 | } 59 | } 60 | } 61 | s.m.Unlock() 62 | } 63 | 64 | func (s *SortedList) TotalEvaluated() int { 65 | return s.evaluated 66 | } 67 | 68 | func New(topN int) SortableList { 69 | return &SortedList{ 70 | items: make([]interface{}, topN), 71 | sortKeys: make([]float64, topN), 72 | lowestRated: math.Inf(-1), 73 | m: &sync.Mutex{}, 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /pkg/sorted_list/sorted_list_test.go: -------------------------------------------------------------------------------- 1 | package sortedlist 2 | 3 | import ( 4 | "fmt" 5 | "sort" 6 | "testing" 7 | 8 | "golang.org/x/exp/slices" 9 | ) 10 | 11 | func TestSortedList(t *testing.T) { 12 | s := New(3) 13 | 14 | items := []float64{1.2, 2.3, 67.2, 12.6, 7.2, 6.1} 15 | for _, v := range items { 16 | s.MaybeAdd(v, v) 17 | } 18 | 19 | res := s.Results() 20 | 21 | castedResults := make([]float64, 3) 22 | for idx, r := range res { 23 | if r != nil { 24 | castedResults[idx] = r.(float64) 25 | } else { 26 | castedResults[idx] = 0.0 27 | } 28 | } 29 | 30 | sort.Sort(sort.Float64Slice(items)) 31 | 32 | fin := make([]float64, 3) 33 | subItems := items[len(items)-3:] 34 | for i := 2; i >= 0; i-- { 35 | fmt.Println(i) 36 | fin[i] = subItems[i] 37 | } 38 | 39 | if slices.Compare(castedResults, fin) != 0 { 40 | fmt.Println(castedResults, fin) 41 | t.Errorf("slices do not equate") 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /pkg/util/tensor.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "os" 7 | ) 8 | 9 | func TensorBytesToFile(sceneID string, tensor []byte) (string, error) { 10 | tensorPath := fmt.Sprintf("/tmp/%s", sceneID) 11 | 12 | if err := os.WriteFile(tensorPath, tensor, 0666); err != nil { 13 | return "", err 14 | } 15 | 16 | return tensorPath, nil 17 | } 18 | 19 | func TensorCachedPath(sceneID string) string { 20 | tensorPath := fmt.Sprintf("/tmp/%s", sceneID) 21 | 22 | if _, err := os.Stat(tensorPath); errors.Is(err, os.ErrNotExist) { 23 | return "" 24 | } 25 | 26 | return tensorPath 27 | } 28 | -------------------------------------------------------------------------------- /pkg/zmqpool/zmqpool.go: -------------------------------------------------------------------------------- 1 | package zmqpool 2 | 3 | import ( 4 | "container/list" 5 | "fmt" 6 | "strings" 7 | "sync" 8 | 9 | zmq "github.com/pebbe/zmq4" 10 | ) 11 | 12 | type ZMQPoolImplementation interface { 13 | Send(string, ...string) (string, error) 14 | } 15 | 16 | type zeroSocketInstance struct { 17 | socket *zmq.Socket 18 | } 19 | 20 | type ZeroMQPool struct { 21 | sockets []*zeroSocketInstance 22 | readySocket chan *zeroSocketInstance 23 | m *sync.Mutex 24 | queueMutex *sync.Mutex 25 | socketQueue *list.List 26 | } 27 | 28 | func serializeOutbound(data []string) string { 29 | // TODO: ensure that data does not contain any ',' 30 | return strings.Join(data, ",") 31 | } 32 | 33 | func (pool *ZeroMQPool) waitOpenSocket(sock chan *zeroSocketInstance) { 34 | for { 35 | if pool.socketQueue.Len() > 0 { 36 | e := pool.socketQueue.Front() 37 | if e.Value != nil { 38 | sock <- e.Value.(*zeroSocketInstance) 39 | } 40 | pool.queueMutex.Lock() 41 | pool.socketQueue.Remove(e) 42 | pool.queueMutex.Unlock() 43 | } 44 | } 45 | } 46 | 47 | func (pool *ZeroMQPool) socketDone(s *zeroSocketInstance) { 48 | pool.queueMutex.Lock() 49 | pool.socketQueue.PushBack(s) 50 | pool.queueMutex.Unlock() 51 | } 52 | 53 | func (pool *ZeroMQPool) Send(operation string, messages ...string) (string, error) { 54 | pool.m.Lock() 55 | inst := <-pool.readySocket 56 | pool.m.Unlock() 57 | 58 | _, err := inst.socket.Send(fmt.Sprintf("%s,%s", operation, serializeOutbound(messages)), 0) 59 | if err != nil { 60 | return "", err 61 | } 62 | 63 | response, err := inst.socket.Recv(0) 64 | if err != nil { 65 | return "", err 66 | } 67 | pool.socketDone(inst) 68 | 69 | if len(response) >= 4 && response[:4] == "error" { 70 | return "", fmt.Errorf(response[5:]) 71 | } 72 | 73 | return response, nil 74 | } 75 | 76 | func New(addresses []string) ZMQPoolImplementation { 77 | zctx, _ := zmq.NewContext() 78 | 79 | sockets := make([]*zeroSocketInstance, len(addresses)) 80 | sq := list.New() 81 | for idx, addr := range addresses { 82 | s, _ := zctx.NewSocket(zmq.REQ) 83 | s.Connect(addr) 84 | 85 | t := &zeroSocketInstance{ 86 | socket: s, 87 | } 88 | sockets[idx] = t 89 | sq.PushBack(t) 90 | } 91 | 92 | socketChan := make(chan *zeroSocketInstance) 93 | 94 | pool := &ZeroMQPool{ 95 | sockets: sockets, 96 | readySocket: socketChan, 97 | m: &sync.Mutex{}, 98 | socketQueue: sq, 99 | queueMutex: &sync.Mutex{}, 100 | } 101 | go pool.waitOpenSocket(socketChan) 102 | 103 | return pool 104 | } 105 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # CLIP Video Search 2 | [CLIP (Contrastive Language–Image Pre-training)](https://openai.com/blog/clip/) is a technique _which efficiently learns visual concepts from natural language supervision_. CLIP has found applications in [stable diffusion](https://github.com/huggingface/diffusers/tree/main/examples/community#clip-guided-stable-diffusion). 3 | 4 | This repository aims act as a POC in exploring the ability to use CLIP for video search using natural language outlined in the article found [here](https://medium.com/@guyallenross/using-clip-to-build-a-natural-language-video-search-engine-6498c03c40d2). 5 | 6 | Adapted for a [natual language video search engine found here](https://github.com/GuyARoss/movie-search-engine) 7 | 8 | ## Usage 9 | ### Dependencies 10 | - [libzmq](https://github.com/zeromq/libzmq) 11 | - python >= 3.8 12 | - go >= 1.18 13 | 14 | ### Running 15 | 1. start up the inference zmq server found in the `./inference` directory `python3 zmq_server.py`. 16 | 2. start up the go server with `go run main.go`. 17 | 18 | ### Example 19 | Before running this example, please ensure that your environment is correctly configured and the application is running without errors. 20 | 21 | 1. index the video clips provided by the `examples/videos`. 22 | ```bash 23 | curl -X POST -d '{"videoURI": "/examples/videos/.mp4" }' http://localhost:3000/insert 24 | ``` 25 | 26 | __note__: it can take a moment for the video to become searchable. 27 | 28 | 2. then search for a video 29 | ```bash 30 | curl -X POST -d '{"input": "a man cutting pepper", "maxResults": 1 }' http://localhost:3000/search 31 | ``` 32 | 33 | ## TODO 34 | - [ ] CLI to remove manual setup process 35 | - [ ] ability to add dedicated inference machines (currently limited to same host) 36 | --------------------------------------------------------------------------------