├── .gitignore ├── README.md ├── assets ├── 1-dl.PNG ├── 10-chat.PNG ├── 11-gen.png ├── 2-start.PNG ├── 3-play.PNG ├── 4-adjust.PNG ├── 5-chat.PNG ├── 6-chat.PNG ├── 7-chat.PNG ├── 8-chat.PNG └── 9-chat.PNG ├── boot.sh ├── demovl.py ├── download.sh ├── environment.yml ├── generation └── seine-v2 │ ├── configs │ └── demo.yaml │ ├── diffusion │ ├── __init__.py │ ├── diffusion_utils.py │ ├── gaussian_diffusion.py │ ├── respace.py │ └── timestep_sampler.py │ ├── functions │ └── video_transforms.py │ ├── models_new │ ├── __init__.py │ ├── attention.py │ ├── clip.py │ ├── resnet.py │ ├── unet.py │ └── unet_blocks.py │ ├── requirements.txt │ ├── seine.py │ └── slurm_scripts │ └── run_inference.sh ├── vinci-inference ├── .env ├── README.md ├── app │ ├── data.py │ ├── exception │ │ ├── __init__.py │ │ └── handler.py │ ├── global_var │ │ ├── __init__.py │ │ └── cache.py │ ├── main.py │ ├── models │ │ ├── __init__.py │ │ ├── internvl.py │ │ └── seine.py │ ├── service │ │ ├── __init__.py │ │ ├── internvl.py │ │ └── seine.py │ └── util │ │ ├── __init__.py │ │ ├── image.py │ │ └── oss.py ├── boot.sh ├── client │ ├── internvl.py │ ├── internvl_sse.py │ └── seine.py ├── demo │ └── demo.mp4 └── requirements │ ├── app.txt │ └── client.txt ├── vinci-local ├── .gitignore ├── README.md └── docker │ ├── README.md │ ├── boot.sh │ ├── clone.sh │ ├── docker-compose-build.yaml │ ├── docker-compose.yaml │ ├── minio │ └── entry.sh │ ├── mysql │ └── init.sql │ ├── nginx │ └── conf.d │ │ └── default.conf │ └── srs │ └── conf │ └── vinci.conf └── vl_open.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | *.out 10 | *.log 11 | *.tar 12 | *.jpg 13 | *.png 14 | Vinci-8B-ckpt/ 15 | Vinci-8B-base/ 16 | seine_weights/ 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | cover/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | .pybuilder/ 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | # For a library or package, you might want to ignore these files since the code is 96 | # intended to run in multiple environments; otherwise, check them in: 97 | # .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # poetry 107 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 108 | # This is especially recommended for binary packages to ensure reproducibility, and is more 109 | # commonly ignored for libraries. 110 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 111 | #poetry.lock 112 | 113 | # pdm 114 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 115 | #pdm.lock 116 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 117 | # in version control. 118 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 119 | .pdm.toml 120 | .pdm-python 121 | .pdm-build/ 122 | 123 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 124 | __pypackages__/ 125 | 126 | # Celery stuff 127 | celerybeat-schedule 128 | celerybeat.pid 129 | 130 | # SageMath parsed files 131 | *.sage.py 132 | 133 | # Environments 134 | .env 135 | .venv 136 | env/ 137 | venv/ 138 | ENV/ 139 | env.bak/ 140 | venv.bak/ 141 | 142 | # Spyder project settings 143 | .spyderproject 144 | .spyproject 145 | 146 | # Rope project settings 147 | .ropeproject 148 | 149 | # mkdocs documentation 150 | /site 151 | 152 | # mypy 153 | .mypy_cache/ 154 | .dmypy.json 155 | dmypy.json 156 | 157 | # Pyre type checker 158 | .pyre/ 159 | 160 | # pytype static type analyzer 161 | .pytype/ 162 | 163 | # Cython debug symbols 164 | cython_debug/ 165 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Vinci - An Online Egocentric Video-Language Assistant 2 | 3 | Open in Spaces 4 | 5 | > **Vinci: A Real-time Embodied Smart Assistant based on Egocentric Vision-Language Model**
6 | Arxiv, 2024
7 | 8 | ## 💬 TL,DR 9 | 10 | - **Overview**: A real-time, embodied smart assistant based on an egocentric vision-language model. 11 | - **Portable Device Compatibility**: Designed for smartphones and wearable cameras, operating in an "always on" mode. 12 | - **Hands-Free Interaction**: Users engage in natural conversations to ask questions and get responses delivered via audio. 13 | - **Real-Time Video Processing**: Processes long video streams to answer queries about current and historical observations. 14 | - **Task Planning and Guidance**: Provides task planning based on past interactions and generates visual task demonstrations. 15 | 16 | ## 📣 Demo video 17 | [https://github.com/user-attachments/assets/ab019895-a7fe-4a1c-aa91-5a1e06dd4f2b](https://github.com/user-attachments/assets/ab019895-a7fe-4a1c-aa91-5a1e06dd4f2b) 18 | 19 | [https://github.com/user-attachments/assets/6be2aa5c-81bb-4a85-b1cf-f08e30d97903](https://github.com/user-attachments/assets/6be2aa5c-81bb-4a85-b1cf-f08e30d97903) 20 | 21 | 22 | 23 | ## 🔨 Installation 24 | ``` 25 | git clone https://github.com/OpenGVLab/vinci.git 26 | conda env create -f environment.yml 27 | ``` 28 | Requirements: 29 | - python 3.8 and above 30 | - pytorch 2.0 and above are recommended 31 | - CUDA 11.4 and above are recommended 32 | - Docker is required when deploying streaming demo 33 | - Gradio is required when using local web-based demo 34 |
35 | 36 | ### Downloading Checkpoints 37 | ``` 38 | bash download.sh 39 | ``` 40 | Running download.sh will take up >100GB disk space. 41 | 42 | ## 🎓 Getting Started 43 | We offer two ways to run our Vinci model 44 | 45 | ### 🎬 Online Streaming Demo 46 | 1. start the frontend, backend and model services: 47 | ```bash 48 | sudo ./boot.sh {start|stop|restart} [--cuda ] [--language chn/eng] [--version v0/v1] 49 | ``` 50 | 51 | - --cuda : Specify the GPU devices to run the model 52 | - --language : Choose the language for the demo (default: chn). 53 | - chn: Chinese 54 | - eng: English 55 | 56 | - --version : Select the model version (default: v1). 57 | - v0: Optimized for first-person perspective videos. 58 | - v1: Generalized model for both first-person and third-person perspective videos. 59 | 60 | Then use the browser to access the frontend page:http://YOUR_IP_ADDRESS:19333 (E.g., http://102.2.52.16:19333) 61 | 62 | 2. Push live stream 63 | With an smartphone app or GoPro/DJI cameras, push the stream to: `rtmp://YOUR_IP_ADDRESS/vinci/livestream` 64 | 65 | With a webcam, use the following command: `ffmpeg -f video4linux2 -framerate 30 -video_size 1280x720 -i /dev/video1 -f alsa -i default -vcodec libx264 -preset ultrafast -pix_fmt yuv420p -video_size 1280x720 -c:a aac -threads 0 -f flv rtmp://YOUR_IP_ADDRESS:1935/vinci/livestream` 66 | 67 | #### Interact with Online Video Streaming Demo 68 | 1. Activate Model Service: To wake up the model and begin using it, simply say the wake-up phrase: "你好望舒 (Ni hao wang shu)" (Currently, only Chinese wakeup command is supported) 69 | 2. Chat with Vinci: Once activated, you can start chatting with Vinci with speech. The model will respond in text and speech. 70 | Tip: For the best experience, speak clearly and at a moderate pace. 71 | 3. Generate Predictive Visualizations: If you want to generate a predictive visualization of actions, include the keyword "可视化 (Ke shi hua)" in your command. 72 | 73 | ### 🎬 Gradio Demo for uploaded videos 74 | ```bash 75 | python demovl.py [--language chn/eng] [--version v0/v1] 76 | ``` 77 | - --cuda : Specify the GPU devices to run the model 78 | - --language : Choose the language for the demo (default: chn). 79 | - chn: Chinese 80 | - eng: English 81 | 82 | - --version : Select the model version (default: v1). 83 | - v0: Optimized for first-person perspective videos. 84 | - v1: Generalized model for both first-person and third-person perspective videos. 85 | 86 | #### Interact with Gradio Demo 87 | 1. Upload local video file 88 |
89 | 90 |
91 | 92 | 2. Click Upload & Start Chat button to initiate the chat session 93 |
94 | 95 |
96 | 97 | 3. Click the play button to start playing the video 98 |
99 | 100 |
101 | 102 | 4. Adjusting the Stride of Memory. This allows you to control the granularity of the model's memory. 103 |
104 | 105 |
106 | 107 | 5. Real-Time Interaction:Type your questions in the chat box. The model will respond based on the current frame and historical context. 108 | 109 |
110 |
111 | 112 |
113 |
Describe current action
114 |
115 |
116 |
117 |
118 | 119 |
120 |
121 | 122 |
123 |
Retrieve object from the history
124 |
125 |
126 |
127 |
128 | 129 |
130 |
131 | 132 |
133 |
Summarize previous actions
134 |
135 |
136 |
137 |
138 | 139 |
140 |
141 | 142 |
143 |
Scene understanding
144 |
145 |
146 |
147 |
148 | 149 |
150 |
151 | 152 |
153 |
Temporal grounding
154 |
155 |
156 |
157 |
158 | 159 |
160 |
161 | 162 |
163 |
Predict future actions
164 |
165 |
166 |
167 |
168 | 169 | 170 | 6. Generate future videos: based on the current frame and the historical context, the model can generate a short future video. 171 |
172 |
173 | 174 |
Generate future actions
175 |
176 |
177 |
178 | 179 | 180 | ## ♥️ Origin 181 | The name "Vinci" embodies several layers of meaning: 182 | 183 | - It is inspired by the famous Renaissance master Leonardo da Vinci, symbolizing a wealth of knowledge and insight, and suggesting that this assistant can provide equally exceptional service to users. 184 | - The word "Vinci" is derived from the Latin "vincere," meaning "to conquer" or "to overcome," implying that this assistant helps users overcome various difficulties and challenges. 185 | - Phonetically, it resembles "Vision," highlighting the assistant's core function of analyzing and responding based on visual information. 186 | - It represents a fusion of elegance, wisdom, and innovation, complementing the high-tech nature of first-person camera devices. 187 | 188 | 望舒 - 出自《楚辞·离骚》:“前望舒使先驱兮,后飞廉使奔属。” 望舒是神话传说中替月亮驾车的天神,描述象征着引导和指引的意义 189 | 190 | ## ✒️ Citation 191 | If this work is helpful for your research, please consider citing us. 192 | ``` 193 | @article{vinci, 194 | title={Vinci: A Real-time Embodied Smart Assistant based on Egocentric Vision-Language Model}, 195 | author={Huang, Yifei and Xu, Jilan and Pei, Baoqi and He, Yuping and Chen, Guo and Yang, Lijin and Chen, Xinyuan and Wang, Yaohui and Nie, Zheng and Liu, Jinyao and Fan, Guoshun and Lin, Dechen and Fang, Fang and Li, Kunpeng and Yuan, Chang and Wang, Yali and Qiao, Yu and Wang, Limin}, 196 | journal={arXiv preprint arXiv:2412.21080}, 197 | year={2024} 198 | } 199 | ``` 200 | 201 | ``` 202 | @article{pei2024egovideo, 203 | title={EgoVideo: Exploring Egocentric Foundation Model and Downstream Adaptation}, 204 | author={Pei, Baoqi and Chen, Guo and Xu, Jilan and He, Yuping and Liu, Yicheng and Pan, Kanghua and Huang, Yifei and Wang, Yali and Lu, Tong and Wang, Limin and Qiao, Yu}, 205 | journal={arXiv preprint arXiv:2406.18070 }, 206 | year={2024} 207 | } 208 | ``` 209 | ``` 210 | @inproceedings{xu2024retrieval, 211 | title={Retrieval-augmented egocentric video captioning}, 212 | author={Xu, Jilan and Huang, Yifei and Hou, Junlin and Chen, Guo and Zhang, Yuejie and Feng, Rui and Xie, Weidi}, 213 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 214 | pages={13525--13536}, 215 | year={2024} 216 | } 217 | ``` 218 | ``` 219 | @InProceedings{huang2024egoexolearn, 220 | title={EgoExoLearn: A Dataset for Bridging Asynchronous Ego- and Exo-centric View of Procedural Activities in Real World}, 221 | author={Huang, Yifei and Chen, Guo and Xu, Jilan and Zhang, Mingfang and Yang, Lijin and Pei, Baoqi and Zhang, Hongjie and Lu, Dong and Wang, Yali and Wang, Limin and Qiao, Yu}, 222 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 223 | year={2024} 224 | } 225 | ``` 226 | ``` 227 | @inproceedings{chen2023seine, 228 | title={Seine: Short-to-long video diffusion model for generative transition and prediction}, 229 | author={Chen, Xinyuan and Wang, Yaohui and Zhang, Lingjun and Zhuang, Shaobin and Ma, Xin and Yu, Jiashuo and Wang, Yali and Lin, Dahua and Qiao, Yu and Liu, Ziwei}, 230 | booktitle={The Twelfth International Conference on Learning Representations}, 231 | year={2023} 232 | } 233 | ``` 234 | ``` 235 | @article{wang2024internvideo2, 236 | title={Internvideo2: Scaling video foundation models for multimodal video understanding}, 237 | author={Wang, Yi and Li, Kunchang and Li, Xinhao and Yu, Jiashuo and He, Yinan and Wang, Chenting and Chen, Guo and Pei, Baoqi and Zheng, Rongkun and Xu, Jilan and Wang, Zun and others}, 238 | journal={arXiv preprint arXiv:2403.15377}, 239 | year={2024} 240 | } 241 | ``` 242 | -------------------------------------------------------------------------------- /assets/1-dl.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/vinci/797d2d383efa264a16502063cbedc5ae39517aac/assets/1-dl.PNG -------------------------------------------------------------------------------- /assets/10-chat.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/vinci/797d2d383efa264a16502063cbedc5ae39517aac/assets/10-chat.PNG -------------------------------------------------------------------------------- /assets/11-gen.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/vinci/797d2d383efa264a16502063cbedc5ae39517aac/assets/11-gen.png -------------------------------------------------------------------------------- /assets/2-start.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/vinci/797d2d383efa264a16502063cbedc5ae39517aac/assets/2-start.PNG -------------------------------------------------------------------------------- /assets/3-play.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/vinci/797d2d383efa264a16502063cbedc5ae39517aac/assets/3-play.PNG -------------------------------------------------------------------------------- /assets/4-adjust.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/vinci/797d2d383efa264a16502063cbedc5ae39517aac/assets/4-adjust.PNG -------------------------------------------------------------------------------- /assets/5-chat.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/vinci/797d2d383efa264a16502063cbedc5ae39517aac/assets/5-chat.PNG -------------------------------------------------------------------------------- /assets/6-chat.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/vinci/797d2d383efa264a16502063cbedc5ae39517aac/assets/6-chat.PNG -------------------------------------------------------------------------------- /assets/7-chat.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/vinci/797d2d383efa264a16502063cbedc5ae39517aac/assets/7-chat.PNG -------------------------------------------------------------------------------- /assets/8-chat.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/vinci/797d2d383efa264a16502063cbedc5ae39517aac/assets/8-chat.PNG -------------------------------------------------------------------------------- /assets/9-chat.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/vinci/797d2d383efa264a16502063cbedc5ae39517aac/assets/9-chat.PNG -------------------------------------------------------------------------------- /boot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA="0" 4 | RUNNING_LANGUAGE='chn' 5 | VERSION='v1' 6 | 7 | while [[ "$#" -gt 0 ]]; do 8 | case $1 in 9 | --cuda) CUDA="$2"; shift ;; 10 | --version) VERSION="$2"; shift ;; 11 | --language) RUNNING_LANGUAGE="$2"; shift ;; 12 | start) COMMAND_ACTION="start" ;; 13 | stop) COMMAND_ACTION="stop" ;; 14 | restart) COMMAND_ACTION="restart" ;; 15 | *) echo "Unknown parameter passed: $1"; exit 1 ;; 16 | esac 17 | shift 18 | done 19 | 20 | cd vinci-local/docker 21 | ./boot.sh "$COMMAND_ACTION" 22 | cd ../.. 23 | ./vinci-inference/boot.sh --cuda "$CUDA" --language "$RUNNING_LANGUAGE" --version "$VERSION" "$COMMAND_ACTION" 24 | -------------------------------------------------------------------------------- /demovl.py: -------------------------------------------------------------------------------- 1 | import torch, torchvision 2 | torch.backends.cudnn.enabled = False 3 | import gradio as gr 4 | from gradio.themes.utils import colors, fonts, sizes 5 | import os, subprocess 6 | import threading 7 | 8 | #generation module 9 | import sys 10 | sys.path.append('generation/seine-v2/') 11 | # torchvision.set_video_backend('video_reader') 12 | from seine import gen, model_seine 13 | from omegaconf import OmegaConf 14 | omega_conf = OmegaConf.load('generation/seine-v2/configs/demo.yaml') 15 | omega_conf.run_time = 13 16 | omega_conf.input_path = '' 17 | omega_conf.text_prompt = [] 18 | omega_conf.save_img_path = '.' 19 | 20 | import argparse 21 | 22 | # Create ArgumentParser object 23 | parser = argparse.ArgumentParser(description='Argument Parser Example') 24 | parser.add_argument('--version', type=str, help='v0 or v1', default='v1') 25 | parser.add_argument('--language', type=str, help='chn or eng', default='chn') 26 | args = parser.parse_args() 27 | version = args.version 28 | running_language = args.language 29 | 30 | 31 | get_gr_video_current_time = """async (video, grtime, one, two, three) => { 32 | const videoEl = document.querySelector("#up_video video"); 33 | return [video, videoEl.currentTime, one, two, three]; 34 | }""" 35 | 36 | get_time = """async (video, grtime, one, two, three, four, five) => { 37 | const videoEl = document.querySelector("#up_video video"); 38 | return [video, videoEl.currentTime, one, two, three, four, five]; 39 | }""" 40 | 41 | import numpy as np 42 | import torch 43 | import torchvision.transforms as T 44 | from decord import VideoReader, cpu 45 | from PIL import Image 46 | from torchvision.transforms.functional import InterpolationMode 47 | from transformers import AutoModel, AutoTokenizer 48 | from random import randint 49 | from transformers import TextIteratorStreamer 50 | from threading import Thread 51 | import os 52 | 53 | IMAGENET_MEAN = (0.485, 0.456, 0.406) 54 | IMAGENET_STD = (0.229, 0.224, 0.225) 55 | 56 | 57 | def build_transform(input_size): 58 | MEAN, STD = IMAGENET_MEAN, IMAGENET_STD 59 | transform = T.Compose([ 60 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), 61 | T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), 62 | T.ToTensor(), 63 | T.Normalize(mean=MEAN, std=STD) 64 | ]) 65 | return transform 66 | 67 | 68 | def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): 69 | best_ratio_diff = float('inf') 70 | best_ratio = (1, 1) 71 | area = width * height 72 | for ratio in target_ratios: 73 | target_aspect_ratio = ratio[0] / ratio[1] 74 | ratio_diff = abs(aspect_ratio - target_aspect_ratio) 75 | if ratio_diff < best_ratio_diff: 76 | best_ratio_diff = ratio_diff 77 | best_ratio = ratio 78 | elif ratio_diff == best_ratio_diff: 79 | if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: 80 | best_ratio = ratio 81 | return best_ratio 82 | 83 | 84 | def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False): 85 | orig_width, orig_height = image.size 86 | aspect_ratio = orig_width / orig_height 87 | 88 | # calculate the existing image aspect ratio 89 | target_ratios = set( 90 | (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if 91 | i * j <= max_num and i * j >= min_num) 92 | target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) 93 | 94 | # find the closest aspect ratio to the target 95 | target_aspect_ratio = find_closest_aspect_ratio( 96 | aspect_ratio, target_ratios, orig_width, orig_height, image_size) 97 | 98 | # calculate the target width and height 99 | target_width = image_size * target_aspect_ratio[0] 100 | target_height = image_size * target_aspect_ratio[1] 101 | blocks = target_aspect_ratio[0] * target_aspect_ratio[1] 102 | 103 | # resize the image 104 | resized_img = image.resize((target_width, target_height)) 105 | processed_images = [] 106 | for i in range(blocks): 107 | box = ( 108 | (i % (target_width // image_size)) * image_size, 109 | (i // (target_width // image_size)) * image_size, 110 | ((i % (target_width // image_size)) + 1) * image_size, 111 | ((i // (target_width // image_size)) + 1) * image_size 112 | ) 113 | # split the image 114 | split_img = resized_img.crop(box) 115 | processed_images.append(split_img) 116 | assert len(processed_images) == blocks 117 | if use_thumbnail and len(processed_images) != 1: 118 | thumbnail_img = image.resize((image_size, image_size)) 119 | processed_images.append(thumbnail_img) 120 | return processed_images 121 | 122 | 123 | class Chat: 124 | def __init__(self, path='Vinci-8B-base', stream=True, device='cuda:0', use_chat_history=False, language='chn', version='v1', max_history=10): 125 | self.device = device 126 | self.vr = None 127 | self.video_fps = None 128 | self.prev_timestamp = 0 129 | self.history = [] 130 | self.chat_history = [] 131 | self.stream = stream 132 | self.use_chat_history = use_chat_history 133 | self.transform = build_transform(input_size=448) 134 | self.language = language 135 | self.version = version 136 | self.max_history = max_history 137 | self.model = AutoModel.from_pretrained( 138 | path, 139 | torch_dtype=torch.bfloat16, 140 | low_cpu_mem_usage=True, 141 | trust_remote_code=True) 142 | self.model_lock = threading.Lock() 143 | 144 | if version == 'v0': 145 | from safetensors.torch import load_file 146 | def merge_dicts(dict1, dict2, dict3, dict4): 147 | result = {**dict1, **dict2, **dict3, **dict4} 148 | return result 149 | path2 = 'Vinci-8B-ckpt' 150 | model_weights1 = load_file(os.path.join(path2,"model-00001-of-00004.safetensors")) 151 | model_weights2 = load_file(os.path.join(path2,"model-00002-of-00004.safetensors")) 152 | model_weights3 = load_file(os.path.join(path2,"model-00003-of-00004.safetensors")) 153 | model_weights4 = load_file(os.path.join(path2,"model-00004-of-00004.safetensors")) 154 | merged_weight = merge_dicts(model_weights1,model_weights2,model_weights3,model_weights4) 155 | self.model.wrap_llm_lora(r=16, lora_alpha=2 * 16) 156 | msg = self.model.load_state_dict(merged_weight,strict=False) 157 | print(msg) 158 | self.model = self.model.eval().cuda() 159 | state1 = self.model.state_dict() 160 | self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) 161 | if self.stream: 162 | self.streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=10) 163 | self.generation_config = dict( 164 | num_beams=1, 165 | max_new_tokens=1024, 166 | do_sample=False, 167 | streamer=self.streamer 168 | ) 169 | else: 170 | self.generation_config = dict( 171 | num_beams=1, 172 | max_new_tokens=1024, 173 | do_sample=False, 174 | ) 175 | 176 | def load_video_timestamp(self, timestamp, num_segments=4): 177 | pixel_values_list, num_patches_list = [], [] 178 | offset = np.linspace(-2, 0, num_segments) 179 | rand_offset = randint(-4, 4) 180 | offset = offset + rand_offset 181 | frame_indices = (timestamp+offset) * self.video_fps 182 | frame_indices = frame_indices.astype(np.int64) 183 | if frame_indices[0] < 0: 184 | frame_indices -= frame_indices[0] 185 | print('***** using video timestamps at:', frame_indices) 186 | for i, frame_index in enumerate(frame_indices): 187 | img = Image.fromarray(self.vr[frame_index].asnumpy()).convert('RGB') 188 | if i == len(frame_indices) - 1: 189 | img.save('./lastim.jpg') 190 | img = dynamic_preprocess(img, image_size=448, use_thumbnail=True, max_num=1) 191 | pixel_values = [self.transform(tile) for tile in img] 192 | pixel_values = torch.stack(pixel_values) 193 | num_patches_list.append(pixel_values.shape[0]) 194 | pixel_values_list.append(pixel_values) 195 | pixel_values = torch.cat(pixel_values_list) 196 | return pixel_values, num_patches_list 197 | 198 | 199 | def ask(self,text,conv): 200 | conv['questions'].append(text + '\n') 201 | return conv 202 | 203 | def answer(self, conv, timestamp=0, add_to_history=False): 204 | with self.model_lock: 205 | pixel_values, num_patches_list = self.load_video_timestamp(timestamp) 206 | pixel_values = pixel_values.to(torch.bfloat16).cuda() 207 | video_prefix = ''.join([f'Frame{i+1}: \n' for i in range(len(num_patches_list))]) 208 | if add_to_history: # silent ask 209 | if self.language == 'chn': 210 | question = video_prefix + '现在视频到了 %.1f 秒处. 简单的用一句话描述视频中我的动作.' % timestamp 211 | else: 212 | question = video_prefix + 'Now the video is at %.1f second. Briefly describe my actions in the video in one sentence.' % timestamp #conv['questions'][-1] 213 | 214 | response, history = self.model.chat(self.tokenizer, pixel_values, question, self.generation_config, 215 | num_patches_list=num_patches_list, 216 | history=None, return_history=True) 217 | self.history.append((timestamp, response)) 218 | print('VL_HISTORY:', self.history) 219 | else: 220 | self.chat_history.append([conv['questions'][-1]]) 221 | question = self.add_history(conv['questions'][-1]) 222 | question = video_prefix + question 223 | if self.stream: 224 | thread = Thread(target=self.model.chat, kwargs=dict(tokenizer=self.tokenizer, pixel_values=pixel_values, question=question, generation_config=self.generation_config, 225 | num_patches_list=num_patches_list, 226 | history=None, return_history=False)) 227 | thread.start() 228 | response = '' 229 | else: 230 | response = self.model.chat(self.tokenizer, pixel_values, question, self.generation_config, 231 | num_patches_list=num_patches_list, 232 | history=None, return_history=False) 233 | self.chat_history[-1].append(timestamp) 234 | self.chat_history[-1].append(response) 235 | # self.history.append((timestamp, response)) 236 | conv['answers'].append(response + '\n') 237 | # print('Real question at %.1f is |||' % timestamp, question) 238 | # print('Answer at %.1f is ||| '%timestamp, response) 239 | # print('the history is:', self.history) 240 | return response, conv, './lastim.jpg' 241 | 242 | def add_history(self, question): 243 | if not self.history: 244 | print('history not added because self.history is empty') 245 | return question 246 | if len(self.history) > 0: 247 | if self.language == 'chn': 248 | system = "你是一个视频智能助手。仔细观察我拍摄的视频并重点关注物体的运动和人的动作。由于你看不到发生在当前帧之前的部分,现在以文字形式提供给你这个视频的之前的历史供参考:" 249 | else: 250 | system = 'You are an intelligent assistant. You receive video frames from my egocentric viewpoint. Carefully watch the video and pay attention to the movement of objects, and the action of human. Since you cannot see the previous part of the video, I provide you the history of this video for reference. The history is: ' 251 | res = system 252 | for hist in self.history: 253 | ts = hist[0] 254 | a = hist[1] 255 | if self.language == 'chn': 256 | res += '当视频在%.1f秒时, 视频的内容是 "%s". ' % (ts, a.strip()) 257 | else: 258 | res += 'When the video is at %.1f seconds, the video contect is "%s". ' % (ts, a.strip()) 259 | if self.language == 'chn': 260 | res += '以上是所有的视频历史, 表明了之前发生了什么.\n' 261 | else: 262 | res += 'This is the end of the video history that indicates what happened before.\n' 263 | if self.use_chat_history and len(self.chat_history)>1: 264 | if self.language == 'chn': 265 | res += '另外我提供根据之前的视频,我们的对话历史如下: ' 266 | else: 267 | res += 'Also I provide you with our chat history based on the previous video content: ' 268 | for hist in self.chat_history[:-1]: 269 | q = hist[0] 270 | ts = hist[1] 271 | a = hist[2] 272 | if self.language == 'chn': 273 | res += '当视频在%.1f秒时, 问题是: "%s", 回答是"%s". ' % (ts, q.strip(), a.strip()) 274 | else: 275 | res += 'When the video is at %.1f seconds, the question was: "%s", and its answer was: "%s". ' % (ts, q.strip(), a.strip()) 276 | if self.language == 'chn': 277 | res += '以上是所有的对话历史, 表明了之前我们交流了什么,但是不表明现在的任何信息.\n' 278 | else: 279 | res += 'This is the end of the chat history. The chat history indicate what our previous chat was, but does not necessarily contain the current information.\n' 280 | 281 | # res += 'Now the video is at %0.1fs, the action is "%s". ' % (ts, a) 282 | # res += 'Given the video information, please answer my question: ' 283 | if self.language == 'chn': 284 | res += '请根据当前视频, 同时参照视频历史, 用中文回答我的问题. 注意如果问题与之前发生的事情有关, 请参考视频历史, 否则请只关注当前图像信息. 我的问题是: "%s". 用三句话以内回答.' % question 285 | else: 286 | res += 'Given the current video and using the previous video as reference, answer my question in English: "%s". Note that if the question is about what has been previously done, please only focus on the history. Otherwise, please only focus on the question and the current video input. Do not repeat.' % question 287 | # question = res + '\n' + question 288 | question = res 289 | if len(self.history) > self.max_history: 290 | self.history = self.history[-self.max_history:] 291 | return question 292 | 293 | def upload_video(self, video_path): 294 | self.vr = VideoReader(video_path, ctx=cpu(0)) 295 | self.num_frames = len(self.vr) 296 | self.video_fps = self.vr.get_avg_fps() 297 | return 'succeed' 298 | 299 | 300 | # ======================================== 301 | # Model Initialization 302 | # ======================================== 303 | def init_model(): 304 | print('Initializing VLChat') 305 | chat = Chat(stream=False, version=args.version, language=args.language) 306 | print('Initialization Finished') 307 | return chat 308 | chat = init_model() 309 | 310 | # ======================================== 311 | # Gradio Setting 312 | # ======================================== 313 | def gradio_reset(chat_state): 314 | if chat_state is not None: 315 | print(chat_state) 316 | print(chat_state.keys()) 317 | chat_state['questions'] = [] 318 | chat_state['answers'] = [] 319 | chat.history = [] 320 | chat.chat_history = [] 321 | return gr.update(value=None), gr.update(value=None), gr.update(placeholder='Please upload your video first'), gr.update(value="Upload & Start Chat"), chat_state 322 | 323 | 324 | def upload_img(gr_video, chat_state): 325 | print('gr_video:', gr_video) 326 | num_segments=4 327 | chat_state = { 328 | "questions": [], 329 | "answers": [], 330 | } 331 | # img_list = [] 332 | if gr_video is None: 333 | return None, gr.update(interactive=True),gr.update(interactive=True, placeholder='Please upload video/image first!'), chat_state, 0.0 334 | else: 335 | llm_message = chat.upload_video(gr_video) 336 | return gr.update(interactive=True), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, 0.0 337 | 338 | 339 | def gradio_ask(up_video, gr_video_time, user_message, chatbot, chat_state): 340 | print('gr_video_time:', gr_video_time) 341 | if len(user_message) == 0: 342 | return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state 343 | # time_prompt = 'Now the video is at %.1f second. ' % gr_video_time 344 | time_prompt = '现在视频到了%.1f秒处. ' % gr_video_time 345 | chat_state = chat.ask(time_prompt+user_message, chat_state) 346 | chatbot = chatbot + [[f'User@{gr_video_time}s: '+user_message, None]] 347 | return '', chatbot, chat_state, gr_video_time 348 | 349 | 350 | def gradio_answer(chatbot, chat_state, gr_video_time): 351 | llm_message, chat_state, last_img_list = chat.answer(chat_state, timestamp=gr_video_time, add_to_history=False) 352 | # llm_message = llm_message.replace("", "") # handle 353 | chatbot[-1][1] = llm_message 354 | # print(chat_state) 355 | print(f"Answer: {llm_message}") 356 | return chatbot, chat_state, last_img_list 357 | 358 | def silent_ask(user_message, chat_state, gr_video_time, memory_size): 359 | chat.max_history = memory_size 360 | # user_message = 'Now the video is at %.1f second. What am I doing?' % gr_video_time 361 | user_message = '现在视频到了%.1f秒处. 描述当前视频中你在环境中所处的位置. 描述出物体的方位, 而不要仅仅描述有什么物体.' % gr_video_time 362 | chat_state = chat.ask(user_message, chat_state) 363 | # chatbot = chatbot + [[f'User@{gr_video_time}s: '+user_message, None]] 364 | return '', chat_state 365 | 366 | def silent_answer(chat_state, gr_video_time): 367 | llm_message, chat_state, last_img_list = chat.answer(chat_state, timestamp=gr_video_time, add_to_history=True) 368 | llm_message = llm_message.replace("", "") # handle 369 | return chat_state 370 | 371 | 372 | class OpenGVLab(gr.themes.base.Base): 373 | def __init__( 374 | self, 375 | *, 376 | primary_hue=colors.blue, 377 | secondary_hue=colors.sky, 378 | neutral_hue=colors.gray, 379 | spacing_size=sizes.spacing_md, 380 | radius_size=sizes.radius_sm, 381 | text_size=sizes.text_md, 382 | font=( 383 | fonts.GoogleFont("Noto Sans"), 384 | "ui-sans-serif", 385 | "sans-serif", 386 | ), 387 | font_mono=( 388 | fonts.GoogleFont("IBM Plex Mono"), 389 | "ui-monospace", 390 | "monospace", 391 | ), 392 | ): 393 | super().__init__( 394 | primary_hue=primary_hue, 395 | secondary_hue=secondary_hue, 396 | neutral_hue=neutral_hue, 397 | spacing_size=spacing_size, 398 | radius_size=radius_size, 399 | text_size=text_size, 400 | font=font, 401 | font_mono=font_mono, 402 | ) 403 | super().set( 404 | body_background_fill="*neutral_50", 405 | ) 406 | 407 | 408 | gvlabtheme = OpenGVLab(primary_hue=colors.blue, 409 | secondary_hue=colors.sky, 410 | neutral_hue=colors.gray, 411 | spacing_size=sizes.spacing_md, 412 | radius_size=sizes.radius_sm, 413 | text_size=sizes.text_md, 414 | ) 415 | 416 | title = """

Vinci

""" 417 | description =""" 418 | An Egocentric Video Foundation Model based Online Intelligent Assistant 419 | """ 420 | 421 | with gr.Blocks(title="Vinci Demo",theme=gvlabtheme,css="#chatbot {overflow:auto; height:500px;} #InputVideo {overflow:visible; height:320px;} footer {visibility: none}") as demo: 422 | gr.Markdown(title) 423 | gr.Markdown(description) 424 | gr_timer = gr.Timer(5, active=False) 425 | silent_time = gr.Number(0.0, visible=False) 426 | with gr.Row(): 427 | with gr.Column(scale=0.5, visible=True) as video_upload: 428 | with gr.Column(elem_id="image", scale=0.5) as img_part: 429 | up_video = gr.Video(interactive=True, elem_id="up_video", height=360,) 430 | upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary") 431 | # clear = gr.Button("Restart") 432 | 433 | memory_size = gr.Slider( 434 | minimum=5, 435 | maximum=25, 436 | value=10, 437 | step=1, 438 | interactive=True, 439 | label="size of memory", 440 | ) 441 | 442 | memory_stride = gr.Slider( 443 | minimum=5, 444 | maximum=100, 445 | value=10, 446 | step=0.1, 447 | interactive=True, 448 | label="stride of memory", 449 | ) 450 | 451 | 452 | with gr.Column(visible=True) as input_raws: 453 | chat_state = gr.State() 454 | img_list = gr.State() 455 | last_img_list = gr.State() 456 | chatbot = gr.Chatbot(elem_id="chatbot",label='ChatBot') 457 | with gr.Row(): 458 | with gr.Column(scale=0.7): 459 | text_input = gr.Textbox(show_label=False, placeholder='Please upload your video first', interactive=False, container=False) 460 | with gr.Column(scale=0.15, min_width=0): 461 | run = gr.Button("💭Send") 462 | with gr.Column(scale=0.15, min_width=0): 463 | clear = gr.Button("🔄Clear️") 464 | with gr.Row(): 465 | with gr.Column(scale=0.3): 466 | inimage_interface = gr.Image(label="input image", elem_id="gr_inimage", visible=True, height=360) 467 | with gr.Column(scale=0.7): 468 | outvideo_interface = gr.Video(label="output video", elem_id="gr_outvideo", visible=True, height=360) 469 | with gr.Row(): 470 | with gr.Column(scale=0.5): 471 | generate_button = gr.Button(value="Video how-to demo", interactive=True, variant="primary") 472 | with gr.Column(scale=0.5): 473 | generate_clear_button = gr.Button(value="Clear", interactive=True, variant="primary") 474 | gr_video_time = gr.Number(value=-1, visible=False) 475 | def gr_video_time_change(_, video_time): 476 | return video_time 477 | def video_change_init_time(): 478 | return 0, gr.update(active=True) 479 | 480 | def timertick(up_video, gr_video_time, silent_time, text_input, chat_state, memory_stride, memory_size): 481 | if gr_video_time - silent_time < memory_stride: 482 | return silent_time, chat_state, gr_video_time 483 | silent_time = gr_video_time 484 | _, chat_state = silent_ask(text_input, chat_state, gr_video_time, memory_size) 485 | chat_state = silent_answer(chat_state, gr_video_time) 486 | return silent_time, chat_state, gr_video_time 487 | 488 | gr_timer.tick(timertick, [up_video, gr_video_time, silent_time, text_input, chat_state, memory_stride, memory_size], [silent_time, chat_state, gr_video_time], js=get_time) 489 | up_video.play(video_change_init_time, [], [gr_video_time, gr_timer]) 490 | 491 | def generate_video(img, conv, gr_video_time): 492 | text = conv["answers"][-1] 493 | omega_conf.input_path = './lastim.jpg' 494 | omega_conf.text_prompt = [text] 495 | gen(omega_conf, model_seine) 496 | return img, './result.mp4' 497 | generate_button.click(generate_video, [last_img_list, chat_state], [inimage_interface, outvideo_interface]) 498 | 499 | 500 | def generate_clear(): 501 | return gr.update(value=None), gr.update(value=None) 502 | generate_clear_button.click(generate_clear, [], [inimage_interface, outvideo_interface]) 503 | 504 | upload_button.click(upload_img, [up_video, chat_state], [up_video, text_input, upload_button, chat_state, gr_video_time]) 505 | 506 | text_input.submit(gradio_ask, [up_video, gr_video_time, text_input, chatbot, chat_state], [text_input, chatbot, chat_state, gr_video_time], js=get_gr_video_current_time).then( 507 | gradio_answer, [chatbot, chat_state, gr_video_time], [chatbot, chat_state, last_img_list] 508 | ) 509 | run.click(gradio_ask, [up_video, gr_video_time, text_input, chatbot, chat_state], [text_input, chatbot, chat_state, gr_video_time], js=get_gr_video_current_time).then( 510 | gradio_answer, [chatbot, chat_state, gr_video_time], [chatbot, chat_state, last_img_list] 511 | ) 512 | run.click(lambda: "", None, text_input) 513 | clear.click(gradio_reset, [chat_state], [chatbot, up_video, text_input, upload_button, chat_state], queue=False) 514 | 515 | # demo.launch(share=True, enable_queue=True) 516 | demo.queue(default_concurrency_limit=10) 517 | demo.launch(server_name="0.0.0.0", server_port=10050, debug=True) 518 | -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Function to install Git LFS 4 | install_git_lfs() { 5 | echo "Installing Git LFS..." 6 | # Check the operating system and install accordingly 7 | if [[ "$OSTYPE" == "linux-gnu"* ]]; then 8 | # For Debian/Ubuntu 9 | sudo apt-get update 10 | sudo apt-get install git-lfs -y 11 | elif [[ "$OSTYPE" == "darwin"* ]]; then 12 | # For macOS 13 | brew install git-lfs 14 | elif [[ "$OSTYPE" == "msys" || "$OSTYPE" == "win32" ]]; then 15 | # For Windows (using Chocolatey) 16 | choco install git-lfs 17 | else 18 | echo "Unsupported OS. Please install Git LFS manually." 19 | exit 1 20 | fi 21 | git lfs install 22 | } 23 | 24 | # Check if Git LFS is installed 25 | if ! command -v git-lfs &> /dev/null; then 26 | echo "Git LFS is not installed." 27 | install_git_lfs 28 | else 29 | echo "Git LFS is already installed." 30 | fi 31 | REPO_URL="https://huggingface.co/hyf015/Vinci-8B-base" 32 | echo "Cloning the repository: $REPO_URL" 33 | git clone "$REPO_URL" 34 | REPO_URL2="https://huggingface.co/hyf015/Vinci-8B-ckpt" 35 | echo "Cloning the repository: $REPO_URL2" 36 | git clone "$REPO_URL2" 37 | REPO_URL3="https://huggingface.co/hyf015/seine_weights" 38 | echo "Cloning the repository: $REPO_URL3" 39 | git clone "$REPO_URL3" 40 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: vinci 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | - https://repo.anaconda.com/pkgs/main 7 | - https://repo.anaconda.com/pkgs/r 8 | dependencies: 9 | - _libgcc_mutex=0.1=main 10 | - _openmp_mutex=5.1=1_gnu 11 | - blas=1.0=mkl 12 | - brotli-python=1.0.9=py39h6a678d5_8 13 | - bzip2=1.0.8=h5eee18b_6 14 | - ca-certificates=2024.7.2=h06a4308_0 15 | - certifi=2024.7.4=py39h06a4308_0 16 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 17 | - cuda-cudart=11.8.89=0 18 | - cuda-cupti=11.8.87=0 19 | - cuda-libraries=11.8.0=0 20 | - cuda-nvrtc=11.8.89=0 21 | - cuda-nvtx=11.8.86=0 22 | - cuda-runtime=11.8.0=0 23 | - cuda-version=12.5=3 24 | - ffmpeg=4.3=hf484d3e_0 25 | - filelock=3.13.1=py39h06a4308_0 26 | - freetype=2.12.1=h4a9f257_0 27 | - gmp=6.2.1=h295c915_3 28 | - gmpy2=2.1.2=py39heeb90bb_0 29 | - gnutls=3.6.15=he1e5248_0 30 | - idna=3.7=py39h06a4308_0 31 | - intel-openmp=2023.1.0=hdb19cb5_46306 32 | - jinja2=3.1.4=py39h06a4308_0 33 | - jpeg=9e=h5eee18b_1 34 | - lame=3.100=h7b6447c_0 35 | - lcms2=2.12=h3be6417_0 36 | - ld_impl_linux-64=2.38=h1181459_1 37 | - lerc=3.0=h295c915_0 38 | - libcublas=11.11.3.6=0 39 | - libcufft=10.9.0.58=0 40 | - libcufile=1.10.1.7=0 41 | - libcurand=10.3.6.82=0 42 | - libcusolver=11.4.1.48=0 43 | - libcusparse=11.7.5.86=0 44 | - libdeflate=1.17=h5eee18b_1 45 | - libffi=3.4.4=h6a678d5_1 46 | - libgcc-ng=11.2.0=h1234567_1 47 | - libgomp=11.2.0=h1234567_1 48 | - libiconv=1.16=h5eee18b_3 49 | - libidn2=2.3.4=h5eee18b_0 50 | - libnpp=11.8.0.86=0 51 | - libnvjpeg=11.9.0.86=0 52 | - libpng=1.6.39=h5eee18b_0 53 | - libstdcxx-ng=11.2.0=h1234567_1 54 | - libtasn1=4.19.0=h5eee18b_0 55 | - libtiff=4.5.1=h6a678d5_0 56 | - libunistring=0.9.10=h27cfd23_0 57 | - libwebp-base=1.3.2=h5eee18b_0 58 | - lz4-c=1.9.4=h6a678d5_1 59 | - markupsafe=2.1.3=py39h5eee18b_0 60 | - mkl=2023.1.0=h213fc3f_46344 61 | - mkl-service=2.4.0=py39h5eee18b_1 62 | - mkl_fft=1.3.8=py39h5eee18b_0 63 | - mkl_random=1.2.4=py39hdb19cb5_0 64 | - mpc=1.1.0=h10f8cd9_1 65 | - mpfr=4.0.2=hb69a4c5_1 66 | - mpmath=1.3.0=py39h06a4308_0 67 | - ncurses=6.4=h6a678d5_0 68 | - nettle=3.7.3=hbbd107a_1 69 | - networkx=3.2.1=py39h06a4308_0 70 | - numpy=1.26.4=py39h5f9d8c6_0 71 | - numpy-base=1.26.4=py39hb5e798b_0 72 | - openh264=2.1.1=h4ff587b_0 73 | - openjpeg=2.4.0=h9ca470c_2 74 | - openssl=3.0.14=h5eee18b_0 75 | - packaging=24.1=py39h06a4308_0 76 | - pillow=10.4.0=py39h5eee18b_0 77 | - pip=24.0=py39h06a4308_0 78 | - pysocks=1.7.1=py39h06a4308_0 79 | - python=3.9.19=h955ad1f_1 80 | - pytorch=2.0.1=py3.9_cuda11.8_cudnn8.7.0_0 81 | - pytorch-cuda=11.8=h7e8668a_5 82 | - pytorch-mutex=1.0=cuda 83 | - readline=8.2=h5eee18b_0 84 | - requests=2.32.2=py39h06a4308_0 85 | - setuptools=69.5.1=py39h06a4308_0 86 | - sqlite=3.45.3=h5eee18b_0 87 | - sympy=1.12=py39h06a4308_0 88 | - tbb=2021.8.0=hdb19cb5_0 89 | - tk=8.6.14=h39e8969_0 90 | - torchaudio=2.0.2=py39_cu118 91 | - torchtriton=2.0.0=py39 92 | - torchvision=0.15.2=py39_cu118 93 | - typing_extensions=4.11.0=py39h06a4308_0 94 | - urllib3=2.2.2=py39h06a4308_0 95 | - wheel=0.43.0=py39h06a4308_0 96 | - xz=5.4.6=h5eee18b_1 97 | - zlib=1.2.13=h5eee18b_1 98 | - zstd=1.5.5=hc292b87_2 99 | - pip: 100 | - absl-py==2.1.0 101 | - accelerate==0.32.1 102 | - aiofiles==23.2.1 103 | - aiohappyeyeballs==2.4.4 104 | - aiohttp==3.11.9 105 | - aiosignal==1.3.1 106 | - aliyun-python-sdk-core==2.15.1 107 | - aliyun-python-sdk-kms==2.16.3 108 | - altair==5.3.0 109 | - annotated-types==0.7.0 110 | - antlr4-python3-runtime==4.9.3 111 | - anyio==4.4.0 112 | - anykeystore==0.2 113 | - apex==0.9.10.dev0 114 | - argon2-cffi==23.1.0 115 | - argon2-cffi-bindings==21.2.0 116 | - arrow==1.3.0 117 | - asttokens==2.4.1 118 | - async-lru==2.0.4 119 | - async-timeout==5.0.1 120 | - attrs==23.2.0 121 | - av==12.2.0 122 | - babel==2.16.0 123 | - beautifulsoup4==4.12.3 124 | - bleach==6.1.0 125 | - cffi==1.16.0 126 | - click==8.1.7 127 | - comm==0.2.2 128 | - contourpy==1.2.1 129 | - crcmod==1.7 130 | - cryptacular==1.6.2 131 | - cryptography==43.0.0 132 | - cycler==0.12.1 133 | - datasets==3.1.0 134 | - debugpy==1.8.6 135 | - decorator==5.1.1 136 | - decord==0.6.0 137 | - deepspeed==0.13.5 138 | - defusedxml==0.7.1 139 | - diffusers==0.29.2 140 | - dill==0.3.8 141 | - dnspython==2.6.1 142 | - docker-pycreds==0.4.0 143 | - einops==0.8.0 144 | - email-validator==2.2.0 145 | - exceptiongroup==1.2.2 146 | - executing==2.0.1 147 | - fastapi==0.111.1 148 | - fastapi-cli==0.0.4 149 | - fastjsonschema==2.20.0 150 | - ffmpy==0.3.2 151 | - flash-attn==2.3.6 152 | - fonttools==4.53.1 153 | - fqdn==1.5.1 154 | - frozenlist==1.5.0 155 | - fsspec==2024.6.1 156 | - gitdb==4.0.11 157 | - gitpython==3.1.43 158 | - gradio==4.38.1 159 | - gradio-client==1.1.0 160 | - greenlet==3.1.1 161 | - grpcio==1.65.0 162 | - h11==0.14.0 163 | - hjson==3.1.0 164 | - httpcore==1.0.5 165 | - httptools==0.6.1 166 | - httpx==0.27.0 167 | - huggingface-hub==0.23.4 168 | - hupper==1.12.1 169 | - imageio==2.34.2 170 | - imgui==2.0.0 171 | - importlib-metadata==8.0.0 172 | - importlib-resources==6.4.0 173 | - ipdb==0.13.13 174 | - ipykernel==6.29.5 175 | - ipython==8.18.1 176 | - isoduration==20.11.0 177 | - jedi==0.19.1 178 | - jmespath==0.10.0 179 | - json5==0.9.25 180 | - jsonpointer==3.0.0 181 | - jsonschema==4.23.0 182 | - jsonschema-specifications==2023.12.1 183 | - jupyter-client==8.6.3 184 | - jupyter-core==5.7.2 185 | - jupyter-events==0.10.0 186 | - jupyter-lsp==2.2.5 187 | - jupyter-server==2.14.2 188 | - jupyter-server-terminals==0.5.3 189 | - jupyterlab==4.2.5 190 | - jupyterlab-pygments==0.3.0 191 | - jupyterlab-server==2.27.3 192 | - kiwisolver==1.4.5 193 | - markdown==3.6 194 | - markdown-it-py==3.0.0 195 | - matplotlib==3.9.1 196 | - matplotlib-inline==0.1.7 197 | - mdurl==0.1.2 198 | - minio==7.2.12 199 | - mistune==3.0.2 200 | - multidict==6.1.0 201 | - multiprocess==0.70.16 202 | - natsort==8.4.0 203 | - nbclient==0.10.0 204 | - nbconvert==7.16.4 205 | - nbformat==5.10.4 206 | - nest-asyncio==1.6.0 207 | - ninja==1.11.1.1 208 | - notebook==7.2.2 209 | - notebook-shim==0.2.4 210 | - oauthlib==3.2.2 211 | - omegaconf==2.3.0 212 | - opencv-python==4.10.0.84 213 | - orjson==3.10.6 214 | - oss2==2.18.6 215 | - overrides==7.7.0 216 | - pandas==2.2.2 217 | - pandocfilters==1.5.1 218 | - parso==0.8.4 219 | - pastedeploy==3.1.0 220 | - pbkdf2==1.3 221 | - peft==0.11.1 222 | - pexpect==4.9.0 223 | - plaster==1.1.2 224 | - plaster-pastedeploy==1.0.1 225 | - platformdirs==4.3.6 226 | - prometheus-client==0.21.0 227 | - prompt-toolkit==3.0.47 228 | - propcache==0.2.1 229 | - protobuf==4.25.3 230 | - psutil==6.0.0 231 | - ptyprocess==0.7.0 232 | - pure-eval==0.2.2 233 | - py-cpuinfo==9.0.0 234 | - pyarrow==18.1.0 235 | - pycocoevalcap==1.2 236 | - pycocotools==2.0.8 237 | - pycparser==2.22 238 | - pycryptodome==3.20.0 239 | - pydantic==2.8.2 240 | - pydantic-core==2.20.1 241 | - pydub==0.25.1 242 | - pygments==2.18.0 243 | - pynvml==11.5.2 244 | - pyparsing==3.1.2 245 | - pyramid==2.0.2 246 | - pyramid-mailer==0.15.1 247 | - python-dateutil==2.9.0.post0 248 | - python-dotenv==1.0.1 249 | - python-json-logger==2.0.7 250 | - python-multipart==0.0.9 251 | - python3-openid==3.2.0 252 | - pytz==2024.1 253 | - pyyaml==6.0.1 254 | - pyzmq==26.2.0 255 | - referencing==0.35.1 256 | - regex==2024.5.15 257 | - repoze-sendmail==4.4.1 258 | - requests-oauthlib==2.0.0 259 | - rfc3339-validator==0.1.4 260 | - rfc3986-validator==0.1.1 261 | - rich==13.7.1 262 | - rotary-embedding-torch==0.6.4 263 | - rpds-py==0.19.0 264 | - ruff==0.5.2 265 | - safetensors==0.4.3 266 | - scipy==1.13.1 267 | - semantic-version==2.10.0 268 | - send2trash==1.8.3 269 | - sentencepiece==0.2.0 270 | - sentry-sdk==2.14.0 271 | - setproctitle==1.3.3 272 | - shellingham==1.5.4 273 | - six==1.16.0 274 | - smmap==5.0.1 275 | - sniffio==1.3.1 276 | - soupsieve==2.6 277 | - sqlalchemy==2.0.36 278 | - sse-starlette==2.1.2 279 | - stack-data==0.6.3 280 | - starlette==0.37.2 281 | - tensorboard==2.17.0 282 | - tensorboard-data-server==0.7.2 283 | - termcolor==2.4.0 284 | - terminado==0.18.1 285 | - timm==0.9.12 286 | - tinycss2==1.3.0 287 | - tokenizers==0.15.2 288 | - tomli==2.0.1 289 | - tomlkit==0.12.0 290 | - toolz==0.12.1 291 | - tornado==6.4.1 292 | - tqdm==4.66.4 293 | - traitlets==5.14.3 294 | - transaction==5.0 295 | - transformers==4.37.2 296 | - translationstring==1.4 297 | - typer==0.12.3 298 | - types-python-dateutil==2.9.0.20241003 299 | - tzdata==2024.1 300 | - uri-template==1.3.0 301 | - uvicorn==0.30.1 302 | - uvloop==0.19.0 303 | - velruse==1.1.1 304 | - venusian==3.1.0 305 | - wandb==0.18.1 306 | - watchfiles==0.22.0 307 | - wcwidth==0.2.13 308 | - webcolors==24.8.0 309 | - webencodings==0.5.1 310 | - webob==1.8.9 311 | - websocket-client==1.8.0 312 | - websockets==11.0.3 313 | - werkzeug==3.0.3 314 | - wtforms==3.2.1 315 | - wtforms-recaptcha==0.3.2 316 | - xxhash==3.5.0 317 | - yacs==0.1.8 318 | - yarl==1.18.3 319 | - zipp==3.19.2 320 | - zope-deprecation==5.0 321 | - zope-interface==7.1.1 322 | - zope-sqlalchemy==3.1 323 | prefix: /home/pjlab/.conda/envs/vinci 324 | -------------------------------------------------------------------------------- /generation/seine-v2/configs/demo.yaml: -------------------------------------------------------------------------------- 1 | ckpt: '../0200322.pt' 2 | save_img_path: "." 3 | pretrained_model_path: 'stable-diffusion-v1-5/' 4 | 5 | finetuned_image_sd_path: null 6 | 7 | # model config: 8 | model: TAVU 9 | num_frames: 16 10 | frame_interval: 1 11 | image_size: [256, 256] 12 | 13 | # model speedup 14 | use_compile: False 15 | use_fp16: True 16 | enable_xformers_memory_efficient_attention: True 17 | 18 | # sample config: 19 | seed: 20 | run_time: 13 21 | cfg_scale: 8.0 22 | sample_method: 'ddim' 23 | num_sampling_steps: 100 24 | text_prompt: [ 25 | 'the man is kicking the football' 26 | ] 27 | 28 | additional_prompt: "" 29 | negative_prompt: "" 30 | do_classifier_free_guidance: True 31 | 32 | use_autoregressive: True 33 | 34 | input_path: /mnt/hwfile/internvideo/share_data/huangyifei/model_weights/lastim.jpg 35 | 36 | researve_frame: 3 37 | mask_type: "first1" 38 | use_mask: True 39 | 40 | demoimage: 41 | demotext: 42 | demosavepath: -------------------------------------------------------------------------------- /generation/seine-v2/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from . import gaussian_diffusion as gd 7 | from .respace import SpacedDiffusion, space_timesteps 8 | 9 | 10 | def create_diffusion( 11 | timestep_respacing, 12 | noise_schedule="linear", 13 | use_kl=False, 14 | sigma_small=False, 15 | predict_xstart=False, 16 | # learn_sigma=True, 17 | learn_sigma=False, # for unet 18 | rescale_learned_sigmas=False, 19 | diffusion_steps=1000 20 | ): 21 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) 22 | if use_kl: 23 | loss_type = gd.LossType.RESCALED_KL 24 | elif rescale_learned_sigmas: 25 | loss_type = gd.LossType.RESCALED_MSE 26 | else: 27 | loss_type = gd.LossType.MSE 28 | if timestep_respacing is None or timestep_respacing == "": 29 | timestep_respacing = [diffusion_steps] 30 | return SpacedDiffusion( 31 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), 32 | betas=betas, 33 | model_mean_type=( 34 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 35 | ), 36 | model_var_type=( 37 | ( 38 | gd.ModelVarType.FIXED_LARGE 39 | if not sigma_small 40 | else gd.ModelVarType.FIXED_SMALL 41 | ) 42 | if not learn_sigma 43 | else gd.ModelVarType.LEARNED_RANGE 44 | ), 45 | loss_type=loss_type 46 | # rescale_timesteps=rescale_timesteps, 47 | ) 48 | -------------------------------------------------------------------------------- /generation/seine-v2/diffusion/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import torch as th 7 | import numpy as np 8 | 9 | 10 | def normal_kl(mean1, logvar1, mean2, logvar2): 11 | """ 12 | Compute the KL divergence between two gaussians. 13 | Shapes are automatically broadcasted, so batches can be compared to 14 | scalars, among other use cases. 15 | """ 16 | tensor = None 17 | for obj in (mean1, logvar1, mean2, logvar2): 18 | if isinstance(obj, th.Tensor): 19 | tensor = obj 20 | break 21 | assert tensor is not None, "at least one argument must be a Tensor" 22 | 23 | # Force variances to be Tensors. Broadcasting helps convert scalars to 24 | # Tensors, but it does not work for th.exp(). 25 | logvar1, logvar2 = [ 26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 27 | for x in (logvar1, logvar2) 28 | ] 29 | 30 | return 0.5 * ( 31 | -1.0 32 | + logvar2 33 | - logvar1 34 | + th.exp(logvar1 - logvar2) 35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 36 | ) 37 | 38 | 39 | def approx_standard_normal_cdf(x): 40 | """ 41 | A fast approximation of the cumulative distribution function of the 42 | standard normal. 43 | """ 44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 45 | 46 | 47 | def continuous_gaussian_log_likelihood(x, *, means, log_scales): 48 | """ 49 | Compute the log-likelihood of a continuous Gaussian distribution. 50 | :param x: the targets 51 | :param means: the Gaussian mean Tensor. 52 | :param log_scales: the Gaussian log stddev Tensor. 53 | :return: a tensor like x of log probabilities (in nats). 54 | """ 55 | centered_x = x - means 56 | inv_stdv = th.exp(-log_scales) 57 | normalized_x = centered_x * inv_stdv 58 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) 59 | return log_probs 60 | 61 | 62 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 63 | """ 64 | Compute the log-likelihood of a Gaussian distribution discretizing to a 65 | given image. 66 | :param x: the target images. It is assumed that this was uint8 values, 67 | rescaled to the range [-1, 1]. 68 | :param means: the Gaussian mean Tensor. 69 | :param log_scales: the Gaussian log stddev Tensor. 70 | :return: a tensor like x of log probabilities (in nats). 71 | """ 72 | assert x.shape == means.shape == log_scales.shape 73 | centered_x = x - means 74 | inv_stdv = th.exp(-log_scales) 75 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 76 | cdf_plus = approx_standard_normal_cdf(plus_in) 77 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 78 | cdf_min = approx_standard_normal_cdf(min_in) 79 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 80 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 81 | cdf_delta = cdf_plus - cdf_min 82 | log_probs = th.where( 83 | x < -0.999, 84 | log_cdf_plus, 85 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 86 | ) 87 | assert log_probs.shape == x.shape 88 | return log_probs 89 | -------------------------------------------------------------------------------- /generation/seine-v2/diffusion/respace.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | import torch 6 | import numpy as np 7 | import torch as th 8 | 9 | from .gaussian_diffusion import GaussianDiffusion 10 | 11 | 12 | def space_timesteps(num_timesteps, section_counts): 13 | """ 14 | Create a list of timesteps to use from an original diffusion process, 15 | given the number of timesteps we want to take from equally-sized portions 16 | of the original process. 17 | For example, if there's 300 timesteps and the section counts are [10,15,20] 18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 19 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 20 | If the stride is a string starting with "ddim", then the fixed striding 21 | from the DDIM paper is used, and only one section is allowed. 22 | :param num_timesteps: the number of diffusion steps in the original 23 | process to divide up. 24 | :param section_counts: either a list of numbers, or a string containing 25 | comma-separated numbers, indicating the step count 26 | per section. As a special case, use "ddimN" where N 27 | is a number of steps to use the striding from the 28 | DDIM paper. 29 | :return: a set of diffusion steps from the original process to use. 30 | """ 31 | if isinstance(section_counts, str): 32 | if section_counts.startswith("ddim"): 33 | desired_count = int(section_counts[len("ddim") :]) 34 | for i in range(1, num_timesteps): 35 | if len(range(0, num_timesteps, i)) == desired_count: 36 | return set(range(0, num_timesteps, i)) 37 | raise ValueError( 38 | f"cannot create exactly {num_timesteps} steps with an integer stride" 39 | ) 40 | section_counts = [int(x) for x in section_counts.split(",")] 41 | size_per = num_timesteps // len(section_counts) 42 | extra = num_timesteps % len(section_counts) 43 | start_idx = 0 44 | all_steps = [] 45 | for i, section_count in enumerate(section_counts): 46 | size = size_per + (1 if i < extra else 0) 47 | if size < section_count: 48 | raise ValueError( 49 | f"cannot divide section of {size} steps into {section_count}" 50 | ) 51 | if section_count <= 1: 52 | frac_stride = 1 53 | else: 54 | frac_stride = (size - 1) / (section_count - 1) 55 | cur_idx = 0.0 56 | taken_steps = [] 57 | for _ in range(section_count): 58 | taken_steps.append(start_idx + round(cur_idx)) 59 | cur_idx += frac_stride 60 | all_steps += taken_steps 61 | start_idx += size 62 | return set(all_steps) 63 | 64 | 65 | class SpacedDiffusion(GaussianDiffusion): 66 | """ 67 | A diffusion process which can skip steps in a base diffusion process. 68 | :param use_timesteps: a collection (sequence or set) of timesteps from the 69 | original diffusion process to retain. 70 | :param kwargs: the kwargs to create the base diffusion process. 71 | """ 72 | 73 | def __init__(self, use_timesteps, **kwargs): 74 | self.use_timesteps = set(use_timesteps) 75 | self.timestep_map = [] 76 | self.original_num_steps = len(kwargs["betas"]) 77 | 78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 79 | last_alpha_cumprod = 1.0 80 | new_betas = [] 81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 82 | if i in self.use_timesteps: 83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 84 | last_alpha_cumprod = alpha_cumprod 85 | self.timestep_map.append(i) 86 | kwargs["betas"] = np.array(new_betas) 87 | super().__init__(**kwargs) 88 | 89 | def p_mean_variance( 90 | self, model, *args, **kwargs 91 | ): # pylint: disable=signature-differs 92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 93 | 94 | # @torch.compile 95 | def training_losses( 96 | self, model, *args, **kwargs 97 | ): # pylint: disable=signature-differs 98 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 99 | 100 | def condition_mean(self, cond_fn, *args, **kwargs): 101 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 102 | 103 | def condition_score(self, cond_fn, *args, **kwargs): 104 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 105 | 106 | def _wrap_model(self, model): 107 | if isinstance(model, _WrappedModel): 108 | return model 109 | return _WrappedModel( 110 | model, self.timestep_map, self.original_num_steps 111 | ) 112 | 113 | def _scale_timesteps(self, t): 114 | # Scaling is done by the wrapped model. 115 | return t 116 | 117 | 118 | class _WrappedModel: 119 | def __init__(self, model, timestep_map, original_num_steps): 120 | self.model = model 121 | self.timestep_map = timestep_map 122 | # self.rescale_timesteps = rescale_timesteps 123 | self.original_num_steps = original_num_steps 124 | 125 | def __call__(self, x, ts, **kwargs): 126 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 127 | new_ts = map_tensor[ts] 128 | # if self.rescale_timesteps: 129 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 130 | return self.model(x, new_ts, **kwargs) 131 | -------------------------------------------------------------------------------- /generation/seine-v2/diffusion/timestep_sampler.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from abc import ABC, abstractmethod 7 | 8 | import numpy as np 9 | import torch as th 10 | import torch.distributed as dist 11 | 12 | 13 | def create_named_schedule_sampler(name, diffusion): 14 | """ 15 | Create a ScheduleSampler from a library of pre-defined samplers. 16 | :param name: the name of the sampler. 17 | :param diffusion: the diffusion object to sample for. 18 | """ 19 | if name == "uniform": 20 | return UniformSampler(diffusion) 21 | elif name == "loss-second-moment": 22 | return LossSecondMomentResampler(diffusion) 23 | else: 24 | raise NotImplementedError(f"unknown schedule sampler: {name}") 25 | 26 | 27 | class ScheduleSampler(ABC): 28 | """ 29 | A distribution over timesteps in the diffusion process, intended to reduce 30 | variance of the objective. 31 | By default, samplers perform unbiased importance sampling, in which the 32 | objective's mean is unchanged. 33 | However, subclasses may override sample() to change how the resampled 34 | terms are reweighted, allowing for actual changes in the objective. 35 | """ 36 | 37 | @abstractmethod 38 | def weights(self): 39 | """ 40 | Get a numpy array of weights, one per diffusion step. 41 | The weights needn't be normalized, but must be positive. 42 | """ 43 | 44 | def sample(self, batch_size, device): 45 | """ 46 | Importance-sample timesteps for a batch. 47 | :param batch_size: the number of timesteps. 48 | :param device: the torch device to save to. 49 | :return: a tuple (timesteps, weights): 50 | - timesteps: a tensor of timestep indices. 51 | - weights: a tensor of weights to scale the resulting losses. 52 | """ 53 | w = self.weights() 54 | p = w / np.sum(w) 55 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 56 | indices = th.from_numpy(indices_np).long().to(device) 57 | weights_np = 1 / (len(p) * p[indices_np]) 58 | weights = th.from_numpy(weights_np).float().to(device) 59 | return indices, weights 60 | 61 | 62 | class UniformSampler(ScheduleSampler): 63 | def __init__(self, diffusion): 64 | self.diffusion = diffusion 65 | self._weights = np.ones([diffusion.num_timesteps]) 66 | 67 | def weights(self): 68 | return self._weights 69 | 70 | 71 | class LossAwareSampler(ScheduleSampler): 72 | def update_with_local_losses(self, local_ts, local_losses): 73 | """ 74 | Update the reweighting using losses from a model. 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | :param local_ts: an integer Tensor of timesteps. 80 | :param local_losses: a 1D Tensor of losses. 81 | """ 82 | batch_sizes = [ 83 | th.tensor([0], dtype=th.int32, device=local_ts.device) 84 | for _ in range(dist.get_world_size()) 85 | ] 86 | dist.all_gather( 87 | batch_sizes, 88 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 89 | ) 90 | 91 | # Pad all_gather batches to be the maximum batch size. 92 | batch_sizes = [x.item() for x in batch_sizes] 93 | max_bs = max(batch_sizes) 94 | 95 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 96 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 97 | dist.all_gather(timestep_batches, local_ts) 98 | dist.all_gather(loss_batches, local_losses) 99 | timesteps = [ 100 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 101 | ] 102 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 103 | self.update_with_all_losses(timesteps, losses) 104 | 105 | @abstractmethod 106 | def update_with_all_losses(self, ts, losses): 107 | """ 108 | Update the reweighting using losses from a model. 109 | Sub-classes should override this method to update the reweighting 110 | using losses from the model. 111 | This method directly updates the reweighting without synchronizing 112 | between workers. It is called by update_with_local_losses from all 113 | ranks with identical arguments. Thus, it should have deterministic 114 | behavior to maintain state across workers. 115 | :param ts: a list of int timesteps. 116 | :param losses: a list of float losses, one per timestep. 117 | """ 118 | 119 | 120 | class LossSecondMomentResampler(LossAwareSampler): 121 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 122 | self.diffusion = diffusion 123 | self.history_per_term = history_per_term 124 | self.uniform_prob = uniform_prob 125 | self._loss_history = np.zeros( 126 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 127 | ) 128 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 129 | 130 | def weights(self): 131 | if not self._warmed_up(): 132 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 133 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 134 | weights /= np.sum(weights) 135 | weights *= 1 - self.uniform_prob 136 | weights += self.uniform_prob / len(weights) 137 | return weights 138 | 139 | def update_with_all_losses(self, ts, losses): 140 | for t, loss in zip(ts, losses): 141 | if self._loss_counts[t] == self.history_per_term: 142 | # Shift out the oldest loss term. 143 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 144 | self._loss_history[t, -1] = loss 145 | else: 146 | self._loss_history[t, self._loss_counts[t]] = loss 147 | self._loss_counts[t] += 1 148 | 149 | def _warmed_up(self): 150 | return (self._loss_counts == self.history_per_term).all() 151 | -------------------------------------------------------------------------------- /generation/seine-v2/functions/video_transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numbers 4 | from torchvision.transforms import RandomCrop, RandomResizedCrop 5 | from PIL import Image 6 | 7 | def _is_tensor_video_clip(clip): 8 | if not torch.is_tensor(clip): 9 | raise TypeError("clip should be Tensor. Got %s" % type(clip)) 10 | 11 | if not clip.ndimension() == 4: 12 | raise ValueError("clip should be 4D. Got %dD" % clip.dim()) 13 | 14 | return True 15 | 16 | 17 | def center_crop_arr(pil_image, image_size): 18 | """ 19 | Center cropping implementation from ADM. 20 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 21 | """ 22 | while min(*pil_image.size) >= 2 * image_size: 23 | pil_image = pil_image.resize( 24 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 25 | ) 26 | 27 | scale = image_size / min(*pil_image.size) 28 | pil_image = pil_image.resize( 29 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 30 | ) 31 | 32 | arr = np.array(pil_image) 33 | crop_y = (arr.shape[0] - image_size) // 2 34 | crop_x = (arr.shape[1] - image_size) // 2 35 | return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) 36 | 37 | 38 | def crop(clip, i, j, h, w): 39 | """ 40 | Args: 41 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) 42 | """ 43 | if len(clip.size()) != 4: 44 | raise ValueError("clip should be a 4D tensor") 45 | return clip[..., i : i + h, j : j + w] 46 | 47 | 48 | def resize(clip, target_size, interpolation_mode): 49 | if len(target_size) != 2: 50 | raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") 51 | return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False) 52 | 53 | def resize_scale(clip, target_size, interpolation_mode): 54 | if len(target_size) != 2: 55 | raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") 56 | H, W = clip.size(-2), clip.size(-1) 57 | scale_ = target_size[0] / min(H, W) 58 | return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False) 59 | 60 | def resize_with_scale_factor(clip, scale_factor, interpolation_mode): 61 | return torch.nn.functional.interpolate(clip, scale_factor=scale_factor, mode=interpolation_mode, align_corners=False) 62 | 63 | def resize_scale_with_height(clip, target_size, interpolation_mode): 64 | H, W = clip.size(-2), clip.size(-1) 65 | scale_ = target_size / H 66 | return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False) 67 | 68 | def resize_scale_with_weight(clip, target_size, interpolation_mode): 69 | H, W = clip.size(-2), clip.size(-1) 70 | scale_ = target_size / W 71 | return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False) 72 | 73 | 74 | def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"): 75 | """ 76 | Do spatial cropping and resizing to the video clip 77 | Args: 78 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) 79 | i (int): i in (i,j) i.e coordinates of the upper left corner. 80 | j (int): j in (i,j) i.e coordinates of the upper left corner. 81 | h (int): Height of the cropped region. 82 | w (int): Width of the cropped region. 83 | size (tuple(int, int)): height and width of resized clip 84 | Returns: 85 | clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W) 86 | """ 87 | if not _is_tensor_video_clip(clip): 88 | raise ValueError("clip should be a 4D torch.tensor") 89 | clip = crop(clip, i, j, h, w) 90 | clip = resize(clip, size, interpolation_mode) 91 | return clip 92 | 93 | 94 | def center_crop(clip, crop_size): 95 | if not _is_tensor_video_clip(clip): 96 | raise ValueError("clip should be a 4D torch.tensor") 97 | h, w = clip.size(-2), clip.size(-1) 98 | # print(clip.shape) 99 | th, tw = crop_size 100 | if h < th or w < tw: 101 | # print(h, w) 102 | raise ValueError("height {} and width {} must be no smaller than crop_size".format(h, w)) 103 | 104 | i = int(round((h - th) / 2.0)) 105 | j = int(round((w - tw) / 2.0)) 106 | return crop(clip, i, j, th, tw) 107 | 108 | 109 | def center_crop_using_short_edge(clip): 110 | if not _is_tensor_video_clip(clip): 111 | raise ValueError("clip should be a 4D torch.tensor") 112 | h, w = clip.size(-2), clip.size(-1) 113 | if h < w: 114 | th, tw = h, h 115 | i = 0 116 | j = int(round((w - tw) / 2.0)) 117 | else: 118 | th, tw = w, w 119 | i = int(round((h - th) / 2.0)) 120 | j = 0 121 | return crop(clip, i, j, th, tw) 122 | 123 | 124 | def random_shift_crop(clip): 125 | ''' 126 | Slide along the long edge, with the short edge as crop size 127 | ''' 128 | if not _is_tensor_video_clip(clip): 129 | raise ValueError("clip should be a 4D torch.tensor") 130 | h, w = clip.size(-2), clip.size(-1) 131 | 132 | if h <= w: 133 | long_edge = w 134 | short_edge = h 135 | else: 136 | long_edge = h 137 | short_edge =w 138 | 139 | th, tw = short_edge, short_edge 140 | 141 | i = torch.randint(0, h - th + 1, size=(1,)).item() 142 | j = torch.randint(0, w - tw + 1, size=(1,)).item() 143 | return crop(clip, i, j, th, tw) 144 | 145 | 146 | def to_tensor(clip): 147 | """ 148 | Convert tensor data type from uint8 to float, divide value by 255.0 and 149 | permute the dimensions of clip tensor 150 | Args: 151 | clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) 152 | Return: 153 | clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) 154 | """ 155 | _is_tensor_video_clip(clip) 156 | if not clip.dtype == torch.uint8: 157 | raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype)) 158 | # return clip.float().permute(3, 0, 1, 2) / 255.0 159 | return clip.float() / 255.0 160 | 161 | 162 | def normalize(clip, mean, std, inplace=False): 163 | """ 164 | Args: 165 | clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) 166 | mean (tuple): pixel RGB mean. Size is (3) 167 | std (tuple): pixel standard deviation. Size is (3) 168 | Returns: 169 | normalized clip (torch.tensor): Size is (T, C, H, W) 170 | """ 171 | if not _is_tensor_video_clip(clip): 172 | raise ValueError("clip should be a 4D torch.tensor") 173 | if not inplace: 174 | clip = clip.clone() 175 | mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device) 176 | # print(mean) 177 | std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device) 178 | clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) 179 | return clip 180 | 181 | 182 | def hflip(clip): 183 | """ 184 | Args: 185 | clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) 186 | Returns: 187 | flipped clip (torch.tensor): Size is (T, C, H, W) 188 | """ 189 | if not _is_tensor_video_clip(clip): 190 | raise ValueError("clip should be a 4D torch.tensor") 191 | return clip.flip(-1) 192 | 193 | 194 | class RandomCropVideo: 195 | def __init__(self, size): 196 | if isinstance(size, numbers.Number): 197 | self.size = (int(size), int(size)) 198 | else: 199 | self.size = size 200 | 201 | def __call__(self, clip): 202 | """ 203 | Args: 204 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) 205 | Returns: 206 | torch.tensor: randomly cropped video clip. 207 | size is (T, C, OH, OW) 208 | """ 209 | i, j, h, w = self.get_params(clip) 210 | return crop(clip, i, j, h, w) 211 | 212 | def get_params(self, clip): 213 | h, w = clip.shape[-2:] 214 | th, tw = self.size 215 | 216 | if h < th or w < tw: 217 | raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}") 218 | 219 | if w == tw and h == th: 220 | return 0, 0, h, w 221 | 222 | i = torch.randint(0, h - th + 1, size=(1,)).item() 223 | j = torch.randint(0, w - tw + 1, size=(1,)).item() 224 | 225 | return i, j, th, tw 226 | 227 | def __repr__(self) -> str: 228 | return f"{self.__class__.__name__}(size={self.size})" 229 | 230 | class CenterCropResizeVideo: 231 | ''' 232 | First use the short side for cropping length, 233 | center crop video, then resize to the specified size 234 | ''' 235 | def __init__( 236 | self, 237 | size, 238 | interpolation_mode="bilinear", 239 | ): 240 | if isinstance(size, tuple): 241 | if len(size) != 2: 242 | raise ValueError(f"size should be tuple (height, width), instead got {size}") 243 | self.size = size 244 | else: 245 | self.size = (size, size) 246 | 247 | self.interpolation_mode = interpolation_mode 248 | 249 | 250 | def __call__(self, clip): 251 | """ 252 | Args: 253 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) 254 | Returns: 255 | torch.tensor: scale resized / center cropped video clip. 256 | size is (T, C, crop_size, crop_size) 257 | """ 258 | # print(clip.shape) 259 | clip_center_crop = center_crop_using_short_edge(clip) 260 | # print(clip_center_crop.shape) 320 512 261 | clip_center_crop_resize = resize(clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode) 262 | return clip_center_crop_resize 263 | 264 | def __repr__(self) -> str: 265 | return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" 266 | 267 | class WebVideo320512: 268 | def __init__( 269 | self, 270 | size, 271 | interpolation_mode="bilinear", 272 | ): 273 | if isinstance(size, tuple): 274 | if len(size) != 2: 275 | raise ValueError(f"size should be tuple (height, width), instead got {size}") 276 | self.size = size 277 | else: 278 | self.size = (size, size) 279 | 280 | self.interpolation_mode = interpolation_mode 281 | 282 | 283 | def __call__(self, clip): 284 | """ 285 | Args: 286 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) 287 | Returns: 288 | torch.tensor: scale resized / center cropped video clip. 289 | size is (T, C, crop_size, crop_size) 290 | """ 291 | # add aditional one pixel for avoiding error in center crop 292 | h, w = clip.size(-2), clip.size(-1) 293 | # print('before resize', clip.shape) 294 | if h < 320: 295 | clip = resize_scale_with_height(clip=clip, target_size=321, interpolation_mode=self.interpolation_mode) 296 | # print('after h resize', clip.shape) 297 | if w < 512: 298 | clip = resize_scale_with_weight(clip=clip, target_size=513, interpolation_mode=self.interpolation_mode) 299 | # print('after w resize', clip.shape) 300 | clip_center_crop = center_crop(clip, self.size) 301 | # print(clip_center_crop.shape) 302 | return clip_center_crop 303 | 304 | def __repr__(self) -> str: 305 | return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" 306 | 307 | class WebVideo256256: 308 | def __init__( 309 | self, 310 | size, 311 | interpolation_mode="bilinear", 312 | ): 313 | if isinstance(size, tuple): 314 | if len(size) != 2: 315 | raise ValueError(f"size should be tuple (height, width), instead got {size}") 316 | self.size = size 317 | else: 318 | self.size = (size, size) 319 | 320 | self.interpolation_mode = interpolation_mode 321 | 322 | 323 | def __call__(self, clip): 324 | """ 325 | Args: 326 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) 327 | Returns: 328 | torch.tensor: scale resized / center cropped video clip. 329 | size is (T, C, crop_size, crop_size) 330 | """ 331 | # add aditional one pixel for avoiding error in center crop 332 | h, w = clip.size(-2), clip.size(-1) 333 | # print('before resize', clip.shape) 334 | if h < 256: 335 | clip = resize_scale_with_height(clip=clip, target_size=258, interpolation_mode=self.interpolation_mode) 336 | # print('after h resize', clip.shape) 337 | if w < 256: 338 | clip = resize_scale_with_weight(clip=clip, target_size=258, interpolation_mode=self.interpolation_mode) 339 | # print('after w resize', clip.shape) 340 | clip_center_crop = center_crop(clip, self.size) 341 | # print(clip_center_crop.shape) 342 | return clip_center_crop 343 | 344 | def __repr__(self) -> str: 345 | return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" 346 | 347 | class InternVideo320512: 348 | def __init__( 349 | self, 350 | size, 351 | interpolation_mode="bilinear", 352 | ): 353 | if isinstance(size, tuple): 354 | if len(size) != 2: 355 | raise ValueError(f"size should be tuple (height, width), instead got {size}") 356 | self.size = size 357 | else: 358 | self.size = (size, size) 359 | 360 | self.interpolation_mode = interpolation_mode 361 | 362 | 363 | def __call__(self, clip): 364 | """ 365 | Args: 366 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) 367 | Returns: 368 | torch.tensor: scale resized / center cropped video clip. 369 | size is (T, C, crop_size, crop_size) 1.6 370 | """ 371 | # add aditional one pixel for avoiding error in center crop 372 | h, w = clip.size(-2), clip.size(-1) 373 | if h < w and h * 1.6 < w: 374 | clip_center_crop = center_crop(clip, (h, h * 1.6)) 375 | clip = resize(clip_center_crop, (320, 512)) 376 | return clip 377 | else: 378 | # print('before resize', clip.shape) 379 | if h < 320: 380 | clip = resize_scale_with_height(clip=clip, target_size=321, interpolation_mode=self.interpolation_mode) 381 | # print('after h resize', clip.shape) 382 | if w < 512: 383 | clip = resize_scale_with_weight(clip=clip, target_size=513, interpolation_mode=self.interpolation_mode) 384 | # print('after w resize', clip.shape) 385 | clip_center_crop = center_crop(clip, self.size) 386 | # print(clip_center_crop.shape) 387 | return clip_center_crop 388 | 389 | def __repr__(self) -> str: 390 | return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" 391 | 392 | class UCFCenterCropVideo: 393 | ''' 394 | First scale to the specified size in equal proportion to the short edge, 395 | then center cropping 396 | ''' 397 | def __init__( 398 | self, 399 | size, 400 | interpolation_mode="bilinear", 401 | ): 402 | if isinstance(size, tuple): 403 | if len(size) != 2: 404 | raise ValueError(f"size should be tuple (height, width), instead got {size}") 405 | self.size = size 406 | else: 407 | self.size = (size, size) 408 | 409 | self.interpolation_mode = interpolation_mode 410 | 411 | 412 | def __call__(self, clip): 413 | """ 414 | Args: 415 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) 416 | Returns: 417 | torch.tensor: scale resized / center cropped video clip. 418 | size is (T, C, crop_size, crop_size) 419 | """ 420 | clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode) 421 | clip_center_crop = center_crop(clip_resize, self.size) 422 | return clip_center_crop 423 | 424 | def __repr__(self) -> str: 425 | return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" 426 | 427 | class KineticsRandomCropResizeVideo: 428 | ''' 429 | Slide along the long edge, with the short edge as crop size. And resie to the desired size. 430 | ''' 431 | def __init__( 432 | self, 433 | size, 434 | interpolation_mode="bilinear", 435 | ): 436 | if isinstance(size, tuple): 437 | if len(size) != 2: 438 | raise ValueError(f"size should be tuple (height, width), instead got {size}") 439 | self.size = size 440 | else: 441 | self.size = (size, size) 442 | 443 | self.interpolation_mode = interpolation_mode 444 | 445 | def __call__(self, clip): 446 | clip_random_crop = random_shift_crop(clip) 447 | clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode) 448 | return clip_resize 449 | 450 | 451 | class CenterCropVideo: 452 | def __init__( 453 | self, 454 | size, 455 | interpolation_mode="bilinear", 456 | ): 457 | if isinstance(size, tuple): 458 | if len(size) != 2: 459 | raise ValueError(f"size should be tuple (height, width), instead got {size}") 460 | self.size = size 461 | else: 462 | self.size = (size, size) 463 | 464 | self.interpolation_mode = interpolation_mode 465 | 466 | 467 | def __call__(self, clip): 468 | """ 469 | Args: 470 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) 471 | Returns: 472 | torch.tensor: center cropped video clip. 473 | size is (T, C, crop_size, crop_size) 474 | """ 475 | clip_center_crop = center_crop(clip, self.size) 476 | return clip_center_crop 477 | 478 | def __repr__(self) -> str: 479 | return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" 480 | 481 | 482 | class NormalizeVideo: 483 | """ 484 | Normalize the video clip by mean subtraction and division by standard deviation 485 | Args: 486 | mean (3-tuple): pixel RGB mean 487 | std (3-tuple): pixel RGB standard deviation 488 | inplace (boolean): whether do in-place normalization 489 | """ 490 | 491 | def __init__(self, mean, std, inplace=False): 492 | self.mean = mean 493 | self.std = std 494 | self.inplace = inplace 495 | 496 | def __call__(self, clip): 497 | """ 498 | Args: 499 | clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W) 500 | """ 501 | return normalize(clip, self.mean, self.std, self.inplace) 502 | 503 | def __repr__(self) -> str: 504 | return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})" 505 | 506 | 507 | class ToTensorVideo: 508 | """ 509 | Convert tensor data type from uint8 to float, divide value by 255.0 and 510 | permute the dimensions of clip tensor 511 | """ 512 | 513 | def __init__(self): 514 | pass 515 | 516 | def __call__(self, clip): 517 | """ 518 | Args: 519 | clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) 520 | Return: 521 | clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) 522 | """ 523 | return to_tensor(clip) 524 | 525 | def __repr__(self) -> str: 526 | return self.__class__.__name__ 527 | 528 | 529 | class RandomHorizontalFlipVideo: 530 | """ 531 | Flip the video clip along the horizontal direction with a given probability 532 | Args: 533 | p (float): probability of the clip being flipped. Default value is 0.5 534 | """ 535 | 536 | def __init__(self, p=0.5): 537 | self.p = p 538 | 539 | def __call__(self, clip): 540 | """ 541 | Args: 542 | clip (torch.tensor): Size is (T, C, H, W) 543 | Return: 544 | clip (torch.tensor): Size is (T, C, H, W) 545 | """ 546 | if random.random() < self.p: 547 | clip = hflip(clip) 548 | return clip 549 | 550 | def __repr__(self) -> str: 551 | return f"{self.__class__.__name__}(p={self.p})" 552 | 553 | 554 | class ResizeVideo(): 555 | ''' 556 | First use the short side for cropping length, 557 | center crop video, then resize to the specified size 558 | ''' 559 | def __init__( 560 | self, 561 | size, 562 | interpolation_mode="bilinear", 563 | ): 564 | if isinstance(size, tuple): 565 | if len(size) != 2: 566 | raise ValueError(f"size should be tuple (height, width), instead got {size}") 567 | self.size = size 568 | else: 569 | self.size = (size, size) 570 | 571 | self.interpolation_mode = interpolation_mode 572 | 573 | 574 | def __call__(self, clip): 575 | """ 576 | Args: 577 | clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) 578 | Returns: 579 | torch.tensor: scale resized / center cropped video clip. 580 | size is (T, C, crop_size, crop_size) 581 | """ 582 | clip_resize = resize(clip, target_size=self.size, interpolation_mode=self.interpolation_mode) 583 | return clip_resize 584 | 585 | def __repr__(self) -> str: 586 | return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" 587 | 588 | 589 | # ------------------------------------------------------------ 590 | # --------------------- Sampling --------------------------- 591 | # ------------------------------------------------------------ 592 | class TemporalRandomCrop(object): 593 | """Temporally crop the given frame indices at a random location. 594 | 595 | Args: 596 | size (int): Desired length of frames will be seen in the model. 597 | """ 598 | 599 | def __init__(self, size): 600 | self.size = size 601 | 602 | def __call__(self, total_frames): 603 | rand_end = max(0, total_frames - self.size - 1) 604 | begin_index = random.randint(0, rand_end) 605 | end_index = min(begin_index + self.size, total_frames) 606 | return begin_index, end_index 607 | 608 | 609 | if __name__ == '__main__': 610 | from torchvision import transforms 611 | import torchvision.io as io 612 | import numpy as np 613 | from torchvision.utils import save_image 614 | import os 615 | 616 | vframes, aframes, info = io.read_video( 617 | filename='./v_Archery_g01_c03.avi', 618 | pts_unit='sec', 619 | output_format='TCHW' 620 | ) 621 | 622 | trans = transforms.Compose([ 623 | ToTensorVideo(), 624 | RandomHorizontalFlipVideo(), 625 | UCFCenterCropVideo(512), 626 | # NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 627 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) 628 | ]) 629 | 630 | target_video_len = 32 631 | frame_interval = 1 632 | total_frames = len(vframes) 633 | print(total_frames) 634 | 635 | temporal_sample = TemporalRandomCrop(target_video_len * frame_interval) 636 | 637 | 638 | # Sampling video frames 639 | start_frame_ind, end_frame_ind = temporal_sample(total_frames) 640 | # print(start_frame_ind) 641 | # print(end_frame_ind) 642 | assert end_frame_ind - start_frame_ind >= target_video_len 643 | frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int) 644 | print(frame_indice) 645 | 646 | select_vframes = vframes[frame_indice] 647 | print(select_vframes.shape) 648 | print(select_vframes.dtype) 649 | 650 | select_vframes_trans = trans(select_vframes) 651 | print(select_vframes_trans.shape) 652 | print(select_vframes_trans.dtype) 653 | 654 | select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8) 655 | print(select_vframes_trans_int.dtype) 656 | print(select_vframes_trans_int.permute(0, 2, 3, 1).shape) 657 | 658 | io.write_video('./test.avi', select_vframes_trans_int.permute(0, 2, 3, 1), fps=8) 659 | 660 | for i in range(target_video_len): 661 | save_image(select_vframes_trans[i], os.path.join('./test000', '%04d.png' % i), normalize=True, value_range=(-1, 1)) -------------------------------------------------------------------------------- /generation/seine-v2/models_new/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from .unet import UNet3DConditionModel 4 | 5 | def get_models(args): 6 | pretrained_model_path = args.pretrained_model_path 7 | return UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", use_concat=args.use_mask, finetuned_image_sd_path=args.finetuned_image_sd_path) -------------------------------------------------------------------------------- /generation/seine-v2/models_new/clip.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import torch.nn as nn 3 | from transformers import CLIPTokenizer, CLIPTextModel 4 | 5 | import transformers 6 | transformers.logging.set_verbosity_error() 7 | 8 | """ 9 | Will encounter following warning: 10 | - This IS expected if you are initializing CLIPTextModel from the checkpoint of a model trained on another task 11 | or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). 12 | - This IS NOT expected if you are initializing CLIPTextModel from the checkpoint of a model 13 | that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). 14 | 15 | https://github.com/CompVis/stable-diffusion/issues/97 16 | according to this issue, this warning is safe. 17 | 18 | This is expected since the vision backbone of the CLIP model is not needed to run Stable Diffusion. 19 | You can safely ignore the warning, it is not an error. 20 | 21 | This clip usage is from U-ViT and same with Stable Diffusion. 22 | """ 23 | 24 | class AbstractEncoder(nn.Module): 25 | def __init__(self): 26 | super().__init__() 27 | 28 | def encode(self, *args, **kwargs): 29 | raise NotImplementedError 30 | 31 | 32 | class FrozenCLIPEmbedder(AbstractEncoder): 33 | """Uses the CLIP transformer encoder for text (from Hugging Face)""" 34 | # def __init__(self, version="openai/clip-vit-huge-patch14", device="cuda", max_length=77): 35 | def __init__(self, path, device="cuda", max_length=77): 36 | super().__init__() 37 | self.tokenizer = CLIPTokenizer.from_pretrained(path, subfolder="tokenizer") 38 | self.transformer = CLIPTextModel.from_pretrained(path, subfolder='text_encoder') 39 | self.device = device 40 | self.max_length = max_length 41 | self.freeze() 42 | 43 | def freeze(self): 44 | self.transformer = self.transformer.eval() 45 | for param in self.parameters(): 46 | param.requires_grad = False 47 | 48 | def forward(self, text): 49 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 50 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 51 | # tokens = batch_encoding["input_ids"].to(self.device) 52 | tokens = batch_encoding["input_ids"].to(self.transformer.device) 53 | outputs = self.transformer(input_ids=tokens) 54 | 55 | z = outputs.last_hidden_state 56 | return z 57 | 58 | def encode(self, text): 59 | return self(text) 60 | 61 | 62 | class TextEmbedder(nn.Module): 63 | """ 64 | Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance. 65 | """ 66 | def __init__(self, path, dropout_prob=0.1): 67 | super().__init__() 68 | self.text_encodder = FrozenCLIPEmbedder(path=path) 69 | self.dropout_prob = dropout_prob 70 | 71 | def token_drop(self, text_prompts, force_drop_ids=None): 72 | """ 73 | Drops text to enable classifier-free guidance. 74 | """ 75 | if force_drop_ids is None: 76 | drop_ids = numpy.random.uniform(0, 1, len(text_prompts)) < self.dropout_prob 77 | else: 78 | # TODO 79 | drop_ids = force_drop_ids == 1 80 | labels = list(numpy.where(drop_ids, "", text_prompts)) 81 | # print(labels) 82 | return labels 83 | 84 | def forward(self, text_prompts, train, force_drop_ids=None): 85 | use_dropout = self.dropout_prob > 0 86 | if (train and use_dropout) or (force_drop_ids is not None): 87 | text_prompts = self.token_drop(text_prompts, force_drop_ids) 88 | embeddings = self.text_encodder(text_prompts) 89 | return embeddings 90 | 91 | 92 | if __name__ == '__main__': 93 | 94 | r""" 95 | Returns: 96 | 97 | Examples from CLIPTextModel: 98 | 99 | ```python 100 | >>> from transformers import AutoTokenizer, CLIPTextModel 101 | 102 | >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") 103 | >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") 104 | 105 | >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") 106 | 107 | >>> outputs = model(**inputs) 108 | >>> last_hidden_state = outputs.last_hidden_state 109 | >>> pooled_output = outputs.pooler_output # pooled (EOS token) states 110 | ```""" 111 | 112 | import torch 113 | 114 | device = "cuda" if torch.cuda.is_available() else "cpu" 115 | 116 | text_encoder = TextEmbedder(path='/mnt/petrelfs/maxin/work/pretrained/stable-diffusion-2-1-base', 117 | dropout_prob=0.00001).to(device) 118 | 119 | text_prompt = [["a photo of a cat", "a photo of a cat"], ["a photo of a dog", "a photo of a cat"], ['a photo of a dog human', "a photo of a cat"]] 120 | # text_prompt = ('None', 'None', 'None') 121 | output = text_encoder(text_prompts=text_prompt, train=False) 122 | # print(output) 123 | print(output.shape) 124 | # print(output.shape) -------------------------------------------------------------------------------- /generation/seine-v2/models_new/resnet.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py 2 | import os 3 | import sys 4 | sys.path.append(os.path.split(sys.path[0])[0]) 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from einops import rearrange 11 | 12 | 13 | class InflatedConv3d(nn.Conv2d): 14 | def forward(self, x): 15 | video_length = x.shape[2] 16 | 17 | x = rearrange(x, "b c f h w -> (b f) c h w") 18 | x = super().forward(x) 19 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 20 | 21 | return x 22 | 23 | 24 | class Upsample3D(nn.Module): 25 | def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): 26 | super().__init__() 27 | self.channels = channels 28 | self.out_channels = out_channels or channels 29 | self.use_conv = use_conv 30 | self.use_conv_transpose = use_conv_transpose 31 | self.name = name 32 | 33 | conv = None 34 | if use_conv_transpose: 35 | raise NotImplementedError 36 | elif use_conv: 37 | conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) 38 | 39 | if name == "conv": 40 | self.conv = conv 41 | else: 42 | self.Conv2d_0 = conv 43 | 44 | def forward(self, hidden_states, output_size=None): 45 | assert hidden_states.shape[1] == self.channels 46 | 47 | if self.use_conv_transpose: 48 | raise NotImplementedError 49 | 50 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 51 | dtype = hidden_states.dtype 52 | if dtype == torch.bfloat16: 53 | hidden_states = hidden_states.to(torch.float32) 54 | 55 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 56 | if hidden_states.shape[0] >= 64: 57 | hidden_states = hidden_states.contiguous() 58 | 59 | # if `output_size` is passed we force the interpolation output 60 | # size and do not make use of `scale_factor=2` 61 | if output_size is None: 62 | hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") 63 | else: 64 | hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") 65 | 66 | # If the input is bfloat16, we cast back to bfloat16 67 | if dtype == torch.bfloat16: 68 | hidden_states = hidden_states.to(dtype) 69 | 70 | if self.use_conv: 71 | if self.name == "conv": 72 | hidden_states = self.conv(hidden_states) 73 | else: 74 | hidden_states = self.Conv2d_0(hidden_states) 75 | 76 | return hidden_states 77 | 78 | 79 | class Downsample3D(nn.Module): 80 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): 81 | super().__init__() 82 | self.channels = channels 83 | self.out_channels = out_channels or channels 84 | self.use_conv = use_conv 85 | self.padding = padding 86 | stride = 2 87 | self.name = name 88 | 89 | if use_conv: 90 | conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) 91 | else: 92 | raise NotImplementedError 93 | 94 | if name == "conv": 95 | self.Conv2d_0 = conv 96 | self.conv = conv 97 | elif name == "Conv2d_0": 98 | self.conv = conv 99 | else: 100 | self.conv = conv 101 | 102 | def forward(self, hidden_states): 103 | assert hidden_states.shape[1] == self.channels 104 | if self.use_conv and self.padding == 0: 105 | raise NotImplementedError 106 | 107 | assert hidden_states.shape[1] == self.channels 108 | hidden_states = self.conv(hidden_states) 109 | 110 | return hidden_states 111 | 112 | 113 | class ResnetBlock3D(nn.Module): 114 | def __init__( 115 | self, 116 | *, 117 | in_channels, 118 | out_channels=None, 119 | conv_shortcut=False, 120 | dropout=0.0, 121 | temb_channels=512, 122 | groups=32, 123 | groups_out=None, 124 | pre_norm=True, 125 | eps=1e-6, 126 | non_linearity="swish", 127 | time_embedding_norm="default", 128 | output_scale_factor=1.0, 129 | use_in_shortcut=None, 130 | ): 131 | super().__init__() 132 | self.pre_norm = pre_norm 133 | self.pre_norm = True 134 | self.in_channels = in_channels 135 | out_channels = in_channels if out_channels is None else out_channels 136 | self.out_channels = out_channels 137 | self.use_conv_shortcut = conv_shortcut 138 | self.time_embedding_norm = time_embedding_norm 139 | self.output_scale_factor = output_scale_factor 140 | 141 | if groups_out is None: 142 | groups_out = groups 143 | 144 | self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 145 | 146 | self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 147 | 148 | if temb_channels is not None: 149 | if self.time_embedding_norm == "default": 150 | time_emb_proj_out_channels = out_channels 151 | elif self.time_embedding_norm == "scale_shift": 152 | time_emb_proj_out_channels = out_channels * 2 153 | else: 154 | raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") 155 | 156 | self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) 157 | else: 158 | self.time_emb_proj = None 159 | 160 | self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 161 | self.dropout = torch.nn.Dropout(dropout) 162 | self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 163 | 164 | if non_linearity == "swish": 165 | self.nonlinearity = lambda x: F.silu(x) 166 | elif non_linearity == "mish": 167 | self.nonlinearity = Mish() 168 | elif non_linearity == "silu": 169 | self.nonlinearity = nn.SiLU() 170 | 171 | self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut 172 | 173 | self.conv_shortcut = None 174 | if self.use_in_shortcut: 175 | self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 176 | 177 | def forward(self, input_tensor, temb): 178 | hidden_states = input_tensor 179 | 180 | hidden_states = self.norm1(hidden_states) 181 | hidden_states = self.nonlinearity(hidden_states) 182 | 183 | hidden_states = self.conv1(hidden_states) 184 | 185 | if temb is not None: 186 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] 187 | 188 | if temb is not None and self.time_embedding_norm == "default": 189 | hidden_states = hidden_states + temb 190 | 191 | hidden_states = self.norm2(hidden_states) 192 | 193 | if temb is not None and self.time_embedding_norm == "scale_shift": 194 | scale, shift = torch.chunk(temb, 2, dim=1) 195 | hidden_states = hidden_states * (1 + scale) + shift 196 | 197 | hidden_states = self.nonlinearity(hidden_states) 198 | 199 | hidden_states = self.dropout(hidden_states) 200 | hidden_states = self.conv2(hidden_states) 201 | 202 | if self.conv_shortcut is not None: 203 | input_tensor = self.conv_shortcut(input_tensor) 204 | 205 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 206 | 207 | return output_tensor 208 | 209 | 210 | class Mish(torch.nn.Module): 211 | def forward(self, hidden_states): 212 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) -------------------------------------------------------------------------------- /generation/seine-v2/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.1 2 | torchaudio==2.0.2 3 | torchvision==0.15.2 4 | decord==0.6.0 5 | diffusers==0.15.0 6 | imageio==2.29.0 7 | transformers==4.29.2 8 | xformers==0.0.20 9 | einops 10 | omegaconf 11 | tensorboard==2.15.1 12 | timm==0.9.10 13 | rotary-embedding-torch==0.3.5 14 | natsort==8.4.0 -------------------------------------------------------------------------------- /generation/seine-v2/seine.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Sample new images from a pre-trained DiT. 9 | """ 10 | import io 11 | import os 12 | import sys 13 | # sys.path.append('seine-v2/') 14 | import math 15 | # import seine_utils as utils 16 | from diffusion import create_diffusion 17 | 18 | import torch 19 | torch.backends.cuda.matmul.allow_tf32 = True 20 | torch.backends.cudnn.allow_tf32 = True 21 | import argparse 22 | import torchvision 23 | 24 | from einops import rearrange 25 | # from models import get_models 26 | from torchvision.utils import save_image 27 | from diffusers.models import AutoencoderKL 28 | from models_new.clip import TextEmbedder 29 | from omegaconf import OmegaConf 30 | from PIL import Image 31 | import numpy as np 32 | from torchvision import transforms 33 | from functions import video_transforms 34 | from decord import VideoReader 35 | # from seine_utils import mask_generation_before 36 | from diffusers.utils.import_utils import is_xformers_available 37 | 38 | from models_new.unet import UNet3DConditionModel 39 | current_directory = os.getcwd() 40 | print("Current Working Directory:", current_directory) 41 | # try: 42 | # model = UNet3DConditionModel.from_pretrained_2d("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet", use_concat=True, finetuned_image_sd_path=None) 43 | # except: 44 | ### modify the absolute path if sd1.5 is available locally 45 | model = UNet3DConditionModel.from_pretrained_2d('seine_weights', subfolder="unet", use_concat=True, finetuned_image_sd_path=None) 46 | vae = AutoencoderKL.from_pretrained("seine_weights", subfolder="vae") 47 | text_encoder = TextEmbedder("seine_weights") 48 | # Auto-download a pre-trained model or load a custom DiT checkpoint from train.py: 49 | # https://huggingface.co/hyf015/model_weights/blob/main/finetune_seine_256p_15ep_60K.pt 50 | ckpt_path = 'seine_weights/finetune_seine_256p_15ep_60K.pt' 51 | 52 | state_dict = torch.load(ckpt_path, map_location='cpu')['ema'] 53 | res = model.load_state_dict(state_dict) 54 | print('loading succeed') 55 | print(res) 56 | model_seine = model 57 | 58 | def mask_generation_before(mask_type, shape, dtype, device): 59 | b, f, c, h, w = shape 60 | if mask_type.startswith('first'): 61 | num = int(mask_type.split('first')[-1]) 62 | mask_f = torch.cat([torch.zeros(1, num, 1, 1, 1, dtype=dtype, device=device), 63 | torch.ones(1, f-num, 1, 1, 1, dtype=dtype, device=device)], dim=1) 64 | mask = mask_f.expand(b, -1, c, h, w) 65 | else: 66 | raise ValueError(f"Invalid mask type: {mask_type}") 67 | return mask 68 | 69 | def get_input(args): 70 | input_path = args.input_path 71 | 72 | transform_video = transforms.Compose([ 73 | video_transforms.ToTensorVideo(), # TCHW 74 | video_transforms.ResizeVideo((args.image_h, args.image_w)), 75 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) 76 | ]) 77 | temporal_sample_func = video_transforms.TemporalRandomCrop(args.num_frames * args.frame_interval) 78 | 79 | assert os.path.isfile(input_path), 'Please make sure the given input_path is a JPG/PNG file' 80 | 81 | print(f'loading video from {input_path}') 82 | _, full_file_name = os.path.split(input_path) 83 | file_name, extention = os.path.splitext(full_file_name) 84 | assert extention == '.jpg' or extention == '.png', "Not a PNG or JPG file" 85 | 86 | video_frames = [] 87 | num = int(args.mask_type.split('first')[-1]) 88 | first_frame = torch.as_tensor(np.array(Image.open(input_path), dtype=np.uint8, copy=True)).unsqueeze(0) 89 | for i in range(num): 90 | video_frames.append(first_frame) 91 | num_zeros = args.num_frames-num 92 | for i in range(num_zeros): 93 | zeros = torch.zeros_like(first_frame) 94 | video_frames.append(zeros) 95 | n = 0 96 | video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w 97 | video_frames = transform_video(video_frames) 98 | return video_frames, n 99 | 100 | 101 | def auto_inpainting(args, video_input, masked_video, mask, prompt, vae, text_encoder, diffusion, model, device,): 102 | b, f, c, h, w = video_input.shape 103 | latent_h = args.image_size[0] // 8 104 | latent_w = args.image_size[1] // 8 105 | 106 | if args.use_fp16: 107 | z = torch.randn(1, 4, args.num_frames, args.latent_h, args.latent_w, dtype=torch.float16, device=device) # b,c,f,h,w 108 | masked_video = masked_video.to(dtype=torch.float16) 109 | mask = mask.to(dtype=torch.float16) 110 | else: 111 | z = torch.randn(1, 4, args.num_frames, args.latent_h, args.latent_w, device=device) # b,c,f,h,w 112 | 113 | masked_video = rearrange(masked_video, 'b f c h w -> (b f) c h w').contiguous() 114 | masked_video = vae.encode(masked_video).latent_dist.sample().mul_(0.18215) 115 | masked_video = rearrange(masked_video, '(b f) c h w -> b c f h w', b=b).contiguous() 116 | mask = torch.nn.functional.interpolate(mask[:,:,0,:], size=(latent_h, latent_w)).unsqueeze(1) 117 | 118 | # classifier_free_guidance 119 | if args.do_classifier_free_guidance: 120 | masked_video = torch.cat([masked_video] * 2) 121 | mask = torch.cat([mask] * 2) 122 | z = torch.cat([z] * 2) 123 | prompt_all = [prompt] + [args.negative_prompt] 124 | else: 125 | masked_video = masked_video 126 | mask = mask 127 | z = z 128 | prompt_all = [prompt] 129 | 130 | text_prompt = text_encoder(text_prompts=prompt_all, train=False) 131 | model_kwargs = dict(encoder_hidden_states=text_prompt, 132 | class_labels=None, 133 | cfg_scale=args.cfg_scale, 134 | use_fp16=args.use_fp16, 135 | ) # tav unet 136 | 137 | # Sample images: 138 | if args.sample_method == 'ddim': 139 | samples = diffusion.ddim_sample_loop( 140 | model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device, \ 141 | mask=mask, x_start=masked_video, use_concat=args.use_mask 142 | ) 143 | elif args.sample_method == 'ddpm': 144 | samples = diffusion.p_sample_loop( 145 | model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device, \ 146 | mask=mask, x_start=masked_video, use_concat=args.use_mask 147 | ) 148 | samples, _ = samples.chunk(2, dim=0) # [1, 4, 16, 32, 32] 149 | if args.use_fp16: 150 | samples = samples.to(dtype=torch.float16) 151 | 152 | video_clip = samples[0].permute(1, 0, 2, 3).contiguous() # [16, 4, 32, 32] 153 | video_clip = vae.decode(video_clip / 0.18215).sample # [16, 3, 256, 256] 154 | return video_clip 155 | 156 | 157 | def gen(args, model, save_path='result.mp4'): 158 | model.cuda() 159 | print('entered gen') 160 | # Setup PyTorch: 161 | if args.seed: 162 | torch.manual_seed(args.seed) 163 | torch.set_grad_enabled(False) 164 | device = "cuda" if torch.cuda.is_available() else "cpu" 165 | 166 | args.latent_h = latent_h = args.image_size[0] // 8 167 | args.latent_w = latent_w = args.image_size[1] // 8 168 | args.image_h = args.image_size[0] 169 | args.image_w = args.image_size[1] 170 | 171 | print('loading model') 172 | if args.use_compile: 173 | model = torch.compile(model) 174 | 175 | if args.enable_xformers_memory_efficient_attention: 176 | print('Using xformers memory efficient attention') 177 | if is_xformers_available(): 178 | model.enable_xformers_memory_efficient_attention() 179 | else: 180 | print('Xformers not available') 181 | pass 182 | else: 183 | print('Not using xformers memory efficient attention') 184 | 185 | model.eval() # important! 186 | pretrained_model_path = "seine_weights" 187 | 188 | diffusion = create_diffusion(str(args.num_sampling_steps)) 189 | vae.to(device) 190 | text_encoder.to(device) 191 | 192 | if args.use_fp16: 193 | print('Warnning: using half percision for inferencing!') 194 | vae.to(device, dtype=torch.float16) 195 | model.to(device, dtype=torch.float16) 196 | text_encoder.to(device, dtype=torch.float16) 197 | 198 | assert args.use_autoregressive is True 199 | 200 | video_input, researve_frames = get_input(args) # f,c,h,w 201 | video_input = video_input.to(device).unsqueeze(0) # b,f,c,h,w 202 | mask = mask_generation_before(args.mask_type, video_input.shape, video_input.dtype, device) # b,f,c,h,w 203 | # TODO: change the first3 to last3 204 | if args.mask_type.startswith('first') and researve_frames != 0: 205 | masked_video = torch.cat([video_input[:,-researve_frames:], video_input[:,:-researve_frames]], dim=1) * (mask == 0) 206 | else: 207 | masked_video = video_input * (mask == 0) 208 | 209 | print(args.text_prompt, 'current text prompt') 210 | video_clip = auto_inpainting(args, video_input, masked_video, mask, args.text_prompt[0], vae, text_encoder, diffusion, model, device,) 211 | video_clip_ = video_clip.unsqueeze(0) 212 | video_ = ((video_clip * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1) 213 | print(f'Saving result video to {save_path}') 214 | torchvision.io.write_video(save_path, video_, fps=8) 215 | 216 | if __name__ == "__main__": 217 | parser = argparse.ArgumentParser() 218 | parser.add_argument("--config", type=str, default="./configs/sample_mask.yaml") 219 | parser.add_argument("--run-time", type=int, default=0) 220 | parser.add_argument("--demoimage", type=str, default=None, required=True) 221 | parser.add_argument('--demotext', type=str, default=None, required=True) 222 | parser.add_argument('--demosavepath', type=str, default=None, required=True) 223 | parser.add_argument('--checkpoint', type=str, default=None, required=True) 224 | args = parser.parse_args() 225 | omega_conf = OmegaConf.load(args.config) 226 | omega_conf.run_time = args.run_time 227 | 228 | omega_conf.input_path = args.demoimage 229 | omega_conf.text_prompt = args.demotext 230 | omega_conf.save_img_path = args.demosavepath 231 | omega_conf.ckpt = args.checkpoint 232 | gen(omega_conf, model, args.demosavepath) 233 | -------------------------------------------------------------------------------- /generation/seine-v2/slurm_scripts/run_inference.sh: -------------------------------------------------------------------------------- 1 | srun -p Gvlab-S1 -N1 -n1 --cpus-per-task=8 --quotatype=auto \ 2 | python seine.py \ 3 | --config='./configs/demo.yaml' \ 4 | --demotext="#C C mixes the food in the pot." \ 5 | --demoimage="/mnt/petrelfs/xujilan/newtools/seine-v2/input2/test18_mixthefoodinthepot.png" \ 6 | --demosavepath="result.mp4" \ 7 | --checkpoint="/mnt/hwfile/internvideo/share_data/huangyifei/model_weights/seine/finetune_seine_256p_15ep_60K.pt" \ -------------------------------------------------------------------------------- /vinci-inference/.env: -------------------------------------------------------------------------------- 1 | # .env 文件内容 2 | export access_key=minio_admin # 根据实际情况替换 3 | export access_key_secret=minio_admin # 根据实际情况替换 4 | export endpoint=127.0.0.1:19000 5 | export bucket=vinci 6 | export external_endpoint=http://127.0.0.1:19000 7 | export cdn=http://127.0.0.1:19000 8 | 9 | -------------------------------------------------------------------------------- /vinci-inference/README.md: -------------------------------------------------------------------------------- 1 | # Vinci Inference 2 | 3 | 代码放置在`egodemo`根目录下,启停操作均在根目录下,日志在`inference.log`。 4 | 5 | ## 配置 6 | 7 | 修改.env文件,设置环境变量,包括OSS秘钥等配置。 8 | 9 | ## 依赖 10 | 11 | 不包括模型依赖。 12 | 13 | ``` 14 | pip install -r requirements/app.txt 15 | ``` 16 | 17 | ## 操作 18 | 19 | ### 启动 20 | 21 | 通过`--cuda`参数指定使用的GPU。 22 | 23 | ``` 24 | ./vinci-inference/boot.sh --cuda 4,5 start 25 | ``` 26 | 27 | ### 停止 28 | 29 | ``` 30 | ./vinci-inference/boot.sh stop 31 | ``` 32 | 33 | ### 重启 34 | 35 | ``` 36 | ./vinci-inference/boot.sh --cuda 4,5 restart 37 | ``` 38 | 39 | -------------------------------------------------------------------------------- /vinci-inference/app/data.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | from pydantic import BaseModel, Field 3 | 4 | class SeineInferenceRequest(BaseModel): 5 | prompt: str=Field(title="提示词", description="描述视频如何生成的提示词", example="跳舞") 6 | image: Optional[Any]=Field(default="",title="图片URL链接或者多维数组", description="", example="https://oss.openmmlab.com/vinci/minimalism-p1.PNG") 7 | base64_image: Optional[Any]=Field(default="",title="图片base64字符串", description="", example="https://oss.openmmlab.com/vinci/minimalism-p1.PNG") 8 | 9 | class SeineInferenceResponse(BaseModel): 10 | video: Optional[Any]=Field(default="", title="视频URL链接") 11 | video_base64: Optional[str]=Field(default="") 12 | 13 | class InternvlInferenceRequest(BaseModel): 14 | session_id: Optional[str]=Field("default", title="会话ID", description="同一个会话ID,时间戳必须大于上次的时间戳", example="default") 15 | timestamp: Optional[int]=Field(0, title="时间戳", description="如果是同一个对话,则时间戳必须大于上次的时间戳", example=0) 16 | silent: Optional[bool]=Field(False, title="静默模式", description="非应答模式,用于定时截帧时用静默模式,否则使用应答模式", example=False) 17 | question: str=Field(title="问题", example="视频内容是什么?") 18 | history: Optional[list]=Field(default=[], title="问答记录", description="每次对话后返回的对话历史记录", example=[]) 19 | frames: Optional[list]=Field(default=[], title="视频帧", description="视频帧的URL链接或者多维数组", example=["https://oss.openmmlab.com/vinci/minimalism-p1.PNG"]) 20 | base64_frames: Optional[list]=Field(default=[], title="视频帧", description="视频帧的base64", example=["https://oss.openmmlab.com/vinci/minimalism-p1.PNG"]) 21 | 22 | class IntervlInferenceResponse(BaseModel): 23 | answer: Optional[str]=Field(None, title="回答") 24 | history: Optional[list]=Field(None, title="问答记录", description="每次对话后返回的对话历史记录") 25 | session_id: Optional[str]=Field("default", title="会话ID", description="同一个会话ID,时间戳必须大于上次的时间戳", example="default") 26 | -------------------------------------------------------------------------------- /vinci-inference/app/exception/__init__.py: -------------------------------------------------------------------------------- 1 | from fastapi.exceptions import RequestValidationError 2 | from starlette.exceptions import HTTPException 3 | 4 | from .handler import * 5 | 6 | exception_handlers = { 7 | HTTPException: http_exception_handler, 8 | RequestValidationError: request_validation_exception_handler, 9 | Exception: last_exception_handler 10 | } -------------------------------------------------------------------------------- /vinci-inference/app/exception/handler.py: -------------------------------------------------------------------------------- 1 | from fastapi.encoders import jsonable_encoder 2 | from fastapi.exceptions import RequestValidationError 3 | from starlette.exceptions import HTTPException 4 | from starlette.requests import Request 5 | from starlette.responses import JSONResponse 6 | from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, HTTP_500_INTERNAL_SERVER_ERROR 7 | 8 | 9 | async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse: 10 | headers = getattr(exc, "headers", None) 11 | if headers: 12 | return JSONResponse( 13 | {"detail": exc.detail}, status_code=exc.status_code, headers=headers 14 | ) 15 | else: 16 | return JSONResponse({"detail": exc.detail}, status_code=exc.status_code) 17 | 18 | 19 | async def request_validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse: 20 | return JSONResponse( 21 | status_code=HTTP_422_UNPROCESSABLE_ENTITY, 22 | content={"detail": jsonable_encoder(exc.errors())}, 23 | ) 24 | 25 | async def last_exception_handler(request: Request, exc: Exception) -> JSONResponse: 26 | return JSONResponse( 27 | status_code=HTTP_500_INTERNAL_SERVER_ERROR, 28 | content=dict(detail=str(exc)), 29 | ) -------------------------------------------------------------------------------- /vinci-inference/app/global_var/__init__.py: -------------------------------------------------------------------------------- 1 | from .cache import FIFOSafeCache -------------------------------------------------------------------------------- /vinci-inference/app/global_var/cache.py: -------------------------------------------------------------------------------- 1 | import threading 2 | from typing import OrderedDict 3 | 4 | 5 | class FIFOSafeCache: 6 | def __init__(self, capacity: int): 7 | self.cache = OrderedDict() 8 | self.capacity = capacity 9 | self.lock = threading.Lock() 10 | 11 | def get(self, key): 12 | with self.lock: 13 | if key not in self.cache: 14 | return None 15 | # 移动到末尾以保持顺序 16 | self.cache.move_to_end(key) 17 | return self.cache[key] 18 | 19 | def put(self, key, value): 20 | with self.lock: 21 | if key in self.cache: 22 | # 如果键已存在,更新值并移动到末尾 23 | self.cache.move_to_end(key) 24 | self.cache[key] = value 25 | if len(self.cache) > self.capacity: 26 | # 移除第一个(最旧的)条目 27 | self.cache.popitem(last=False) -------------------------------------------------------------------------------- /vinci-inference/app/main.py: -------------------------------------------------------------------------------- 1 | import uvicorn 2 | import threading 3 | 4 | from fastapi import FastAPI 5 | from fastapi import APIRouter 6 | from sse_starlette.sse import EventSourceResponse 7 | 8 | import data 9 | import service 10 | 11 | from exception import exception_handlers 12 | 13 | 14 | router = APIRouter() 15 | 16 | @router.post("/inference/internvl", response_model=data.IntervlInferenceResponse) 17 | async def internval_inference(req: data.InternvlInferenceRequest): 18 | return service.internvl_inference(req.question, req.history, req.base64_frames, req.frames, 19 | req.session_id, req.timestamp, req.silent) 20 | 21 | chat_lock = threading.Lock() 22 | @router.post("/inference/internvl/stream", response_model=data.IntervlInferenceResponse) 23 | async def internval_inference(req: data.InternvlInferenceRequest): 24 | try: 25 | chat_lock.acquire() 26 | stream_generator = service.internvl_stream_inference(req.question, req.history, req.base64_frames, req.frames, 27 | req.session_id, req.timestamp, req.silent, model_index=0) 28 | return EventSourceResponse(stream_generator) 29 | finally: 30 | chat_lock.release() 31 | 32 | 33 | screen_shot_chat_lock = threading.Lock() 34 | @router.post("/inference/internvl/screenshot", response_model=data.IntervlInferenceResponse) 35 | async def internval_inference(req: data.InternvlInferenceRequest): 36 | try: 37 | chat_lock.acquire() 38 | return service.internvl_inference(req.question, req.history, req.base64_frames, req.frames, 39 | req.session_id, req.timestamp, req.silent, model_index=-1) 40 | finally: 41 | chat_lock.release() 42 | 43 | @router.post("/inference/seine", response_model=data.SeineInferenceResponse) 44 | async def seine_inference(req: data.SeineInferenceRequest): 45 | return service.seine_inference(req.prompt, req.base64_image, req.image) 46 | 47 | def build_app(): 48 | fast_kwargs = { 49 | "include_in_schema": True, 50 | "docs_url": '/swagger-ui'} 51 | app = FastAPI(exception_handlers=exception_handlers, **fast_kwargs) 52 | app.include_router(router, prefix="/api/v1", tags=["Model"]) 53 | 54 | return app 55 | 56 | def start_fast_server(*args, **kwarg): 57 | uvicorn.run(build_app(), host="0.0.0.0", port=int(18081)) 58 | 59 | if __name__ == "__main__": 60 | start_fast_server() 61 | -------------------------------------------------------------------------------- /vinci-inference/app/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .internvl import chat as internvl_chat 2 | from .internvl import stream_chat as internvl_stream_chat -------------------------------------------------------------------------------- /vinci-inference/app/models/internvl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | import os 4 | import time 5 | 6 | from PIL import Image 7 | from threading import Thread 8 | 9 | # 设置模型加载相对路径 10 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', ".."))) 11 | from vl_open import Chat, dynamic_preprocess 12 | 13 | from global_var import FIFOSafeCache 14 | 15 | history_cache = FIFOSafeCache(capacity=1000) 16 | first_chat_timestamp_cache = FIFOSafeCache(capacity=1000) 17 | 18 | def init_model(sep_chat: False, stream: bool=False, device: str='cuda:0', language: str='chn', version: str='v1'): 19 | print(f'Initializing VLChat, sep_chat: {sep_chat}, stream: {stream}, device: {device}') 20 | with torch.device(device): 21 | chat = Chat(sep_chat=sep_chat, stream=stream, language=language, version=version) 22 | print('Initialization Finished') 23 | return chat 24 | 25 | 26 | def add_history(question, history, sep_chat: bool=False, language: str='chn'): 27 | if not history: 28 | print('history not added because self.history is empty') 29 | return question 30 | if len(history) > 0: 31 | if language == 'chn': 32 | system = "你是一个在增强现实(AR)眼镜上的智能助手。看到的图像是来自我第一人称视角的视频帧。仔细观察视频并重点关注物体的运动和我的动作。由于你看不到发生在当前帧之前的部分,现在以文字形式提供给你这个视频的之前的历史供参考。视频历史是:" 33 | else: 34 | system = 'You are an intelligent assistant on AR glasses. The AR glasses receive video frames from my egocentric viewpoint. Carefully watch the video and pay attention to the movement of objects, and the action of human. Since you cannot see the previous part of the video, I provide you the history of this video for reference. The history is: ' 35 | res = system 36 | if sep_chat: 37 | for hist in history[:-1]: 38 | ts = hist[0] 39 | a = hist[1] 40 | if language == 'chn': 41 | res += '当视频在%.1f秒时, 视频的内容是 "%s"' % (ts, a) 42 | else: 43 | res += 'When the video is at %.1f seconds, the video content is "%s". ' % (ts, a) 44 | ts = history[-1][0] 45 | a = history[-1][1] 46 | if language == 'chn': 47 | res += '以上是所有的视频历史, 表明了之前发生了什么.\n现在视频到了 %.1f秒, 视频的内容是 "%s". ' % (ts, a) 48 | else: 49 | res += 'This is the end of the history which indicate what have previously happened.\n Now the video is at %.1f seconds, the video content is: "%s". ' % (ts, a) 50 | else: 51 | for hist in history: 52 | ts = hist[0] 53 | a = hist[1] 54 | if language == 'chn': 55 | res += '当视频在%.1f秒时, 视频的内容是 "%s". ' % (ts, a.strip()) 56 | else: 57 | res += 'When the video is at %.1f seconds, the video contect is "%s". ' % (ts, a.strip()) 58 | if language == 'chn': 59 | res += '以上是所有的视频历史, 表明了之前发生了什么. 如果后面的问题问到了之前发生的事情, 可以作为参考.\n' 60 | else: 61 | res += 'This is the end of the video history that indicates what happened before.\n' 62 | 63 | if language == 'chn': 64 | res += '请根据当前视频, 用中文回答我的问题: "%s".' % question 65 | else: 66 | res += 'Given the current video and using the previous video as reference, answer my question in English: "%s". Note that if the question is about what has been previously done, please only focus on the history. Otherwise, please only focus on the question and the current video input. If the question is about future planning, provide at most 3 steps.' % question 67 | 68 | question = res 69 | return question 70 | 71 | 72 | class IntervlChat(): 73 | 74 | def __init__(self, sep_chat: bool=False, stream: bool=False, 75 | device: str='cuda:0', language: str='chn', version: str='v1'): 76 | self.sep_chat = sep_chat 77 | self.language = language 78 | self.device = device 79 | self.stream = stream 80 | self.version = version 81 | self.intervl_chat = init_model(sep_chat, stream=self.stream, device=device, language=language, version=version) 82 | self.origin_intervl_model = self.intervl_chat.model 83 | 84 | # frame是否可以resize 85 | def load_video_frames(self, frames: list): 86 | pixel_values_list, num_patches_list = [], [] 87 | for i, frame in enumerate(frames): 88 | # frame = np.array(frame, dtype=np.uint8) 89 | img = Image.fromarray(frame).convert('RGB') 90 | if i == len(frames) - 1: 91 | img.save('./lastim.jpg') 92 | 93 | img = dynamic_preprocess(img, image_size=448, use_thumbnail=True, max_num=1) 94 | pixel_values = [self.intervl_chat.transform(tile) for tile in img] 95 | pixel_values = torch.stack(pixel_values) 96 | num_patches_list.append(pixel_values.shape[0]) 97 | pixel_values_list.append(pixel_values) 98 | pixel_values = torch.cat(pixel_values_list) 99 | 100 | return pixel_values, num_patches_list 101 | 102 | def _chat(self, model, tokenizer, pixel_values, question: str, num_patches_list: list): 103 | return model.chat(tokenizer, pixel_values, question, self.intervl_chat.generation_config, 104 | num_patches_list=num_patches_list, history=None, return_history=True) 105 | 106 | def _chat_stream(self, model, tokenizer, pixel_values, question: str, generation_config, num_patches_list: list): 107 | thread = Thread(target=model.chat, kwargs=dict(tokenizer=tokenizer, pixel_values=pixel_values, question=question, generation_config=generation_config, 108 | num_patches_list=num_patches_list, history=None, return_history=False)) 109 | thread.start() 110 | 111 | return self.intervl_chat.streamer 112 | 113 | def chat_unsafe_silent(self, question: str, frames: list, history: list=[], timestamp: int=0,): 114 | """ 115 | 在静默模式下聊天交互。 116 | 117 | 118 | 参数: 119 | - question: str 120 | 输入的问题字符串,目前此参数在函数内部被重新构造而不是使用。 121 | - frames: list 122 | 视频帧的列表,每个帧代表视频的一个时间点。 123 | - history: list 124 | 之前的聊天历史记录,默认为空列表。 125 | - session_id: str 126 | 会话的唯一标识符,默认为"default"。 127 | - timestamp: int 128 | 当前视频的时间戳,用于构造问题和记录历史。 129 | 130 | 返回: 131 | - response: str 132 | 生成的回答。 133 | - history: list 134 | 更新后的聊天历史记录,包括新增的回答。 135 | """ 136 | print(f"silent chat history: {history}") 137 | print(f"silent chat timestamp: {timestamp}") 138 | 139 | with torch.device(self.device): 140 | pixel_values, num_patches_list = self.load_video_frames(frames) 141 | pixel_values = pixel_values.to(torch.bfloat16).cuda() 142 | 143 | video_prefix = ''.join([f'Frame{i+1}: \n' for i in range(len(num_patches_list))]) 144 | question = video_prefix + '现在视频到了 %.1f 秒处. 简单的描述视频中我的动作.' % timestamp 145 | response, _history = self._chat(self.origin_intervl_model, self.intervl_chat.tokenizer, pixel_values, question, num_patches_list) 146 | history.append((timestamp, response)) 147 | # print('VL_HISTORY:', _history) 148 | 149 | # print('Real question at %2.1f is |||' % (timestamp), question) 150 | # print('Answer at %2.1f is ||| '% (timestamp), response) 151 | 152 | return response, history 153 | 154 | def chat_unsafe(self, question: str, frames: list, history: list=[], timestamp: int=0, silent: bool=False): 155 | """ 156 | 实现与用户视频内容相关的聊天功能,支持实时流式输出。 157 | 158 | 参数: 159 | - question: str, 用户提出的问题。 160 | - frames: list, 视频的帧数据列表。 161 | - history: list, 聊天历史记录,默认为空列表。 162 | - session_id: str, 会话ID,默认为"default"。 163 | - timestamp: int, 视频的当前时间戳,默认为0。 164 | - silent: bool, 是否静默模式,默认为False。 165 | 166 | 该函数首先从缓存中获取特定会话的历史记录(如果存在),然后将视频帧数据加载到设备上进行处理。 167 | 根据配置决定是使用流式输出还是非流式输出来生成回答。聊天历史记录会被更新并缓存。 168 | """ 169 | print(f"chat history: {history}") 170 | print(f"chat timestamp: {timestamp}") 171 | 172 | with torch.device(self.device): 173 | pixel_values, num_patches_list = self.load_video_frames(frames) 174 | pixel_values = pixel_values.to(torch.bfloat16).cuda() 175 | 176 | video_prefix = ''.join([f'Frame{i+1}: \n' for i in range(len(num_patches_list))]) 177 | 178 | if self.sep_chat: 179 | quest = video_prefix + '现在视频到了 %.1f 秒处. 描述视频中我的动作.' % timestamp 180 | response, _ = self._chat(self.origin_intervl_model, self.intervl_chat.tokenizer, pixel_values, quest, num_patches_list) 181 | 182 | history.append((timestamp, response)) 183 | question = add_history(question, history, self.sep_chat, self.language) 184 | # question = add_history(question, history) 185 | 186 | print(f"chat question: {question}") 187 | if self.stream: 188 | streamer = self._chat_stream(self.intervl_chat.lmmodel, self.intervl_chat.tokenizer, None, question, 189 | self.intervl_chat.lmgeneration_config, num_patches_list) 190 | 191 | for new_text in streamer: 192 | if new_text == self.intervl_chat.lmmodel.conv_template.sep: 193 | return new_text, history 194 | 195 | yield new_text, history 196 | else: 197 | response, _ = self._chat(self.intervl_chat.model, self.intervl_chat.tokenizer, None, question, 198 | self.intervl_chat.generation_config, num_patches_list) 199 | 200 | print('Real question at %2.1f is |||' % timestamp, question) 201 | print('Answer at %2.1f is ||| '%timestamp, response) 202 | 203 | yield response, history 204 | else: 205 | question = '现在视频到了%.1f秒处. ' % timestamp + question 206 | question = add_history(question, history, self.sep_chat) 207 | question = video_prefix + question 208 | 209 | print(f"chat question: {question}") 210 | 211 | if self.stream: 212 | streamer = self._chat_stream(self.intervl_chat.model, self.intervl_chat.tokenizer, pixel_values, question, 213 | self.intervl_chat.generation_config, num_patches_list) 214 | 215 | for new_text in self.intervl_chat.streamer: 216 | if new_text == self.intervl_chat.model.conv_template.sep: 217 | print(f"set history: {history}") 218 | return new_text, history 219 | 220 | yield new_text, history 221 | # response, _ = self._chat(self.intervl_chat.model, self.intervl_chat.tokenizer, pixel_values, question, num_patches_list) 222 | 223 | # print('Real question at %2.1f is |||' % timestamp, question) 224 | # print('Answer at %2.1f is ||| '%timestamp, response) 225 | 226 | # yield response, history 227 | 228 | else: 229 | response, _ = self._chat(self.intervl_chat.model, self.intervl_chat.tokenizer, pixel_values, question, num_patches_list) 230 | 231 | print('Real question at %2.1f is |||' % timestamp, question) 232 | print('Answer at %2.1f is ||| '%timestamp, response) 233 | 234 | yield response, history 235 | 236 | # 获取所有可用的CUDA设备 237 | available_devices = [f'cuda:{i}' for i in range(torch.cuda.device_count())] 238 | print(f"availabel devices: {available_devices}") 239 | 240 | sep_chat = False 241 | stream = True 242 | running_language = os.environ.get('RUNNING_LANGUAGE') 243 | version = os.environ.get('VERSION') 244 | if running_language is None: 245 | running_language = 'chn' 246 | chats = [IntervlChat(sep_chat, stream, device, running_language, version) for device in available_devices] 247 | 248 | def get_timestamp(session_id: str): 249 | current_timestamp = time.time() 250 | 251 | first_timestamp = first_chat_timestamp_cache.get(session_id) 252 | if first_timestamp: 253 | return int(current_timestamp) - int(first_timestamp) 254 | 255 | first_chat_timestamp_cache.put(session_id, current_timestamp) 256 | 257 | return 0 258 | 259 | def get_history(session_id: str): 260 | history = history_cache.get(session_id) 261 | if not history: 262 | return [] 263 | 264 | return history[-10:] 265 | 266 | def chat(question: str, frames: list, history: list=[], session_id: str="default", 267 | timestamp: int=0, silent: bool=False, model_index: int=0): 268 | print(f"session id: {session_id}") 269 | 270 | chat = chats[model_index] 271 | timestamp = get_timestamp(session_id) 272 | history = get_history(session_id) 273 | 274 | if silent: 275 | response, history = chat.chat_unsafe_silent(question, frames, history, timestamp=timestamp) 276 | 277 | # 记录历史 278 | print(f"set silent history: {history}") 279 | history_cache.put(session_id, history) 280 | 281 | return response, history 282 | 283 | answers = chat.chat_unsafe(question, frames, history, timestamp=timestamp, silent=silent) 284 | 285 | generate_txt = '' 286 | history = [] 287 | for answer, history in answers: 288 | generate_txt += answer 289 | history = history 290 | 291 | # 记录历史 292 | print(f"set history: {history}") 293 | history_cache.put(session_id, history) 294 | 295 | return generate_txt, history 296 | 297 | def stream_chat(question: str, frames: list, history: list, session_id: str="default", 298 | timestamp: int=0, silent: bool=False, model_index: int=0): 299 | print(f"session id: {session_id}") 300 | 301 | chat = chats[model_index] 302 | timestamp = get_timestamp(session_id) 303 | history = get_history(session_id) 304 | 305 | if silent: 306 | response, history = chat.chat_unsafe_silent(question, frames, history, timestamp=timestamp) 307 | 308 | # 记录历史 309 | print(f"set silent history: {history}") 310 | history_cache.put(session_id, history) 311 | 312 | yield response, history 313 | return 314 | 315 | answers = chat.chat_unsafe(question, frames, history, timestamp=timestamp, silent=silent) 316 | 317 | generate_txt = '' 318 | for answer, history in answers: 319 | generate_txt += answer 320 | yield generate_txt, history 321 | 322 | # 记录历史 323 | print(f"set history: {history}") 324 | history_cache.put(session_id, history) 325 | 326 | print(f"chat answer: {generate_txt}") -------------------------------------------------------------------------------- /vinci-inference/app/models/seine.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/vinci/797d2d383efa264a16502063cbedc5ae39517aac/vinci-inference/app/models/seine.py -------------------------------------------------------------------------------- /vinci-inference/app/service/__init__.py: -------------------------------------------------------------------------------- 1 | from .internvl import inference as internvl_inference 2 | from .internvl import stream_inference as internvl_stream_inference 3 | from .seine import inference as seine_inference -------------------------------------------------------------------------------- /vinci-inference/app/service/internvl.py: -------------------------------------------------------------------------------- 1 | from sse_starlette import ServerSentEvent 2 | from typing import List 3 | 4 | from data import IntervlInferenceResponse 5 | from models import internvl_chat, internvl_stream_chat 6 | from util import url_to_ndarray, base64_to_ndarray 7 | 8 | def parse_image_from_urls(image_urls: List[str]): 9 | """ 10 | 从URL列表中解析图像。 11 | 12 | 该函数接收一个图像URL列表,将每个URL解析为图像,并返回一个图像列表。 13 | 如果URL无效,将抛出异常。 14 | """ 15 | images = [] 16 | for image in image_urls: 17 | if isinstance(image, str): 18 | image = url_to_ndarray(image) 19 | if image is None: 20 | raise Exception("Invalid image url") 21 | 22 | images.append(image) 23 | 24 | return images 25 | 26 | def parse_image_from_base64(image_base64: List[str]): 27 | """ 28 | 从Base64列表中解析图像。 29 | 30 | 该函数接收一个Base64图像列表,将每个Base64图像解析为图像,并返回一个图像列表。 31 | 如果Base64图像无效,将抛出异常。 32 | """ 33 | images = [] 34 | for image in image_base64: 35 | if isinstance(image, str): 36 | image = base64_to_ndarray(image) 37 | if image is None: 38 | raise Exception("Invalid image url") 39 | 40 | images.append(image) 41 | 42 | return images 43 | 44 | 45 | def inference(question: str, history: list, base64_frames: list, frames: list, session_id: str="default", 46 | timestamp: int=0, silent: bool=False, model_index: int=0): 47 | if base64_frames: 48 | frames = parse_image_from_base64(base64_frames) 49 | else: 50 | frames = parse_image_from_urls(frames) 51 | 52 | answer, history = internvl_chat(question, frames, history, session_id, timestamp, silent, 53 | model_index=model_index) 54 | return IntervlInferenceResponse(answer=answer, history=[[1, "test"]], session_id=session_id) 55 | 56 | def stream_inference(question: str, history: list, base64_frames: list, frames: list, session_id: str="default", 57 | timestamp: int=0, silent: bool=False, model_index: int=0): 58 | if base64_frames: 59 | frames = parse_image_from_base64(base64_frames) 60 | else: 61 | frames = parse_image_from_urls(frames) 62 | 63 | for answer, history in internvl_stream_chat(question, frames, history, session_id, timestamp, silent, 64 | model_index=model_index): 65 | yield ServerSentEvent(event="message", id=0, 66 | data=IntervlInferenceResponse(answer=answer, history=[[1, "test"]], session_id=session_id).model_dump_json()) 67 | 68 | yield ServerSentEvent(event="end", data="Connection closed") 69 | -------------------------------------------------------------------------------- /vinci-inference/app/service/seine.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import sys 3 | import os 4 | import uuid 5 | import base64 6 | import os.path as osp 7 | import numpy as np 8 | 9 | from PIL import Image 10 | 11 | from util import url_to_ndarray, base64_to_ndarray, save_video, OssClient 12 | from data import SeineInferenceResponse 13 | 14 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..', 'generation', 'seine-v2'))) 15 | 16 | print(sys.path) 17 | from seine import gen, model_seine 18 | from omegaconf import OmegaConf 19 | 20 | omega_conf = OmegaConf.load('generation/seine-v2/configs/demo.yaml') 21 | omega_conf.run_time = 13 22 | omega_conf.input_path = '' 23 | omega_conf.text_prompt = [] 24 | omega_conf.save_img_path = '.' 25 | 26 | from util import OssClient 27 | 28 | oss_client = None 29 | 30 | access_key = os.getenv('access_key') 31 | access_key_secret = os.getenv('access_key_secret') 32 | endpoint = os.getenv('endpoint') 33 | bucket = os.getenv('bucket') 34 | cdn = os.getenv('cdn') 35 | if access_key: 36 | oss_client = OssClient(access_key=access_key, 37 | access_key_secret=access_key_secret, 38 | endpoint=endpoint, 39 | bucket=bucket, 40 | cdn=cdn) 41 | 42 | # 创建视频存储目录 43 | video_save_dir = osp.join('/tmp', 'vinci') 44 | if not osp.exists(video_save_dir): 45 | os.makedirs(video_save_dir) 46 | 47 | def upload_video(video_path: str='vid.mp4', object_name = 'vinci/vid.mp4'): 48 | # save_video(frames, video_path, fps=fps, h_256=h_256) 49 | # oss_client.put_object_from_file(object_name, video_path) 50 | oss_client.upload_local_file_then_remove('', video_path, object_name) 51 | 52 | return oss_client.sign_url(object_name, cdn=True) 53 | 54 | def inference(prompt: str, base64_image, image): 55 | if base64_image: 56 | image = base64_to_ndarray(base64_image) 57 | else: 58 | image = url_to_ndarray(image) 59 | 60 | image = Image.fromarray(image) 61 | 62 | input_img_path = f'{str(uuid.uuid1())}.png' 63 | image.save(input_img_path) 64 | 65 | omega_conf.input_path = input_img_path 66 | omega_conf.text_prompt = [prompt] 67 | 68 | video_name = str(uuid.uuid1()) + '.mp4' 69 | video_path = osp.join(video_save_dir, video_name) 70 | gen(omega_conf, model_seine, save_path=video_path) 71 | 72 | # frames_array = None 73 | # try: 74 | # # 视频文件路径 75 | # video_path = 'vid.mp4' 76 | # cap = cv2.VideoCapture(video_path) 77 | 78 | # frames = [] 79 | # while True: 80 | # ret, frame = cap.read() 81 | 82 | # if not ret: 83 | # break 84 | 85 | # frames.append(frame) 86 | 87 | # frames_array = np.array(frames) 88 | # finally: 89 | # cap.release() 90 | 91 | url = "" 92 | try: 93 | object_name = f'vinci/{video_name}' 94 | url = upload_video(video_path=video_path, object_name=object_name) 95 | print(f"generate video url: {url}") 96 | except Exception as e: 97 | print(e) 98 | 99 | # 清理临时文件 100 | if osp.exists(input_img_path): 101 | os.remove(input_img_path) 102 | if osp.exists(video_path): 103 | os.remove(video_path) 104 | 105 | return SeineInferenceResponse(video=url, video_base64="") -------------------------------------------------------------------------------- /vinci-inference/app/util/__init__.py: -------------------------------------------------------------------------------- 1 | from .image import * 2 | from .oss import OssClient 3 | -------------------------------------------------------------------------------- /vinci-inference/app/util/image.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import requests 3 | import os 4 | import cv2 5 | import uuid 6 | import subprocess 7 | 8 | import numpy as np 9 | 10 | from PIL import Image 11 | from io import BytesIO 12 | 13 | def url_to_ndarray(url): 14 | """Convert an image URL to an ndarray. 15 | """ 16 | response = requests.get(url) 17 | if response.status_code != 200: 18 | print(f"Failed to download image from {url}, status: {response.status_code}") 19 | return None 20 | 21 | image = Image.open(BytesIO(response.content)) 22 | return np.array(image) 23 | 24 | def base64_to_ndarray(base64_str): 25 | """Convert a base64-encoded image string to an ndarray.""" 26 | try: 27 | # 解码 base64 字符串 28 | image_data = base64.b64decode(base64_str) 29 | 30 | # 使用 BytesIO 处理解码后的二进制数据 31 | image = Image.open(BytesIO(image_data)) 32 | 33 | # 将 PIL Image 转换为 NumPy 数组 34 | return np.array(image) 35 | 36 | except Exception as e: 37 | print(f"Failed to convert base64 string to image ndarray: {e}") 38 | return None 39 | 40 | def convert_h256(src_path, dst_path): 41 | cmd = ["ffmpeg", "-y", "-i", src_path, "-vcodec", "h264", dst_path] 42 | try: 43 | subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True) 44 | except subprocess.CalledProcessError as ex: 45 | print(f'returncode:{ex.returncode}, \ncmd: {ex.cmd}, \noutput: {ex.output}, \nstderr: {ex.stderr}, \nstdout: {ex.stdout}'.format()) 46 | raise ex 47 | finally: 48 | os.remove(src_path) 49 | 50 | def save_video(video, 51 | local_path, 52 | size=None, 53 | fps: int=30, 54 | h_256: bool=True, 55 | rgb2bgr: bool=False): 56 | dir_path = os.path.dirname(local_path) 57 | if dir_path and not os.path.exists(dir_path): 58 | os.makedirs(dir_path) 59 | 60 | save_path = local_path 61 | 62 | if h_256: 63 | save_path = "tmp-" + str(uuid.uuid4()) + ".mp4" 64 | 65 | try: 66 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 67 | frame1 = video[0] 68 | size = (frame1.shape[1], frame1.shape[0]) 69 | video_writer = cv2.VideoWriter(save_path, fourcc, fps, size) 70 | 71 | for frame in video: 72 | if rgb2bgr: 73 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) 74 | video_writer.write(frame) 75 | finally: 76 | if video_writer: 77 | video_writer.release() 78 | 79 | if h_256: 80 | convert_h256(save_path, local_path) -------------------------------------------------------------------------------- /vinci-inference/app/util/oss.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import os 3 | import oss2 4 | import sys 5 | 6 | from urllib.parse import urljoin, urlparse, urlunparse 7 | from minio import Minio 8 | from minio.error import S3Error 9 | 10 | 11 | class OssClient: 12 | def __init__(self, access_key, access_key_secret, 13 | endpoint, 14 | bucket, 15 | cdn): 16 | 17 | self.access_key_id = access_key 18 | self.access_key_secret= access_key_secret 19 | self.endpoint = endpoint 20 | self.external_endpoint = endpoint 21 | self.bucket = bucket 22 | self.cdn = cdn 23 | 24 | self.storage_type = 'minio' 25 | if self.storage_type == 'minio': 26 | self._init_minio() 27 | else: 28 | raise('Only support minio at the moment') 29 | 30 | def _init_minio(self): 31 | """初始化 MinIO 客户端""" 32 | self.client = Minio(self.endpoint, access_key=self.access_key_id, 33 | secret_key=self.access_key_secret, secure=False) 34 | 35 | if not self.client.bucket_exists(self.bucket): 36 | self.client.make_bucket(self.bucket) 37 | 38 | self.result_url_prefix = urljoin(self.external_endpoint, self.bucket) 39 | 40 | def _init_oss(self): 41 | """初始化阿里云 OSS 客户端""" 42 | auth = oss2.Auth(self.access_key_id, self.access_key_secret) 43 | self.client = oss2.Bucket(auth, self.endpoint, self.bucket) 44 | 45 | parsed_url = urlparse(self.external_endpoint) 46 | new_netloc = self.bucket + "." + parsed_url.netloc 47 | 48 | self.result_url_prefix = urlunparse((parsed_url.scheme, new_netloc, parsed_url.path, parsed_url.params, 49 | parsed_url.query, parsed_url.fragment)) 50 | 51 | async def upload_local_file(self, key_prefix: str, file_path: str, name: str) -> str: 52 | if self.storage_type == 'minio': 53 | return await self._upload_local_file_minio(key_prefix, file_path, name) 54 | elif self.storage_type == 'oss': 55 | return await self._upload_local_file_oss(key_prefix, file_path, name) 56 | 57 | async def upload_image_base64(self, key_prefix: str, name: str, base64_image: str) -> str: 58 | if self.storage_type == 'minio': 59 | return await self._upload_image_base64_minio(key_prefix, name, base64_image) 60 | elif self.storage_type == 'oss': 61 | return await self._upload_image_base64_oss(key_prefix, name, base64_image) 62 | 63 | def upload_local_file_then_remove(self, key_prefix: str, file_path: str, name: str) -> str: 64 | if self.storage_type == 'minio': 65 | return self._upload_local_file_then_remove_minio(key_prefix, file_path, name) 66 | elif self.storage_type == 'oss': 67 | return self._upload_local_file_then_remove_oss(key_prefix, file_path, name) 68 | 69 | async def _upload_local_file_minio(self, key_prefix: str, file_path: str, name: str) -> str: 70 | object_name = key_prefix + name 71 | try: 72 | self.client.fput_object(self.bucket, object_name, file_path) 73 | except S3Error as e: 74 | logger.error(f"upload_local_file_minio err {e}") 75 | return "" 76 | 77 | return self.result_url_prefix + "/" + object_name 78 | 79 | async def _upload_image_base64_minio(self, key_prefix: str, name: str, base64_image: str) -> str: 80 | object_name = key_prefix + name 81 | try: 82 | decoded_data = base64.b64decode(base64_image) 83 | self.client.put_object(self.bucket, key_prefix + name, data=decoded_data, length=len(decoded_data)) 84 | except S3Error as e: 85 | logger.error(f"upload_image_base64_minio err {e}") 86 | return "" 87 | 88 | return self.result_url_prefix + "/" + object_name 89 | 90 | def _upload_local_file_then_remove_minio(self, key_prefix: str, file_path: str, name: str) -> str: 91 | object_name = key_prefix + name 92 | try: 93 | self.client.fput_object(self.bucket, object_name, file_path) 94 | os.remove(file_path) 95 | except S3Error as e: 96 | logger.error(f"upload_local_file_then_remove_minio err {e}") 97 | return "" 98 | 99 | return self.result_url_prefix + "/" + object_name 100 | 101 | async def _upload_local_file_oss(self, key_prefix: str, file_path: str, name: str) -> str: 102 | object_name = key_prefix + name 103 | try: 104 | with open(file_path, 'rb') as file: 105 | self.client.put_object(object_name, file) 106 | except Exception as e: 107 | logger.error(f"upload_local_file_oss err {e}") 108 | return "" 109 | 110 | return self.result_url_prefix + "/" + object_name 111 | 112 | async def _upload_image_base64_oss(self, key_prefix: str, name: str, base64_image: str) -> str: 113 | object_name = key_prefix + name 114 | try: 115 | self.client.put_object(object_name, base64.b64decode(base64_image)) 116 | except Exception as e: 117 | logger.error(f"upload_image_base64_oss err {e}") 118 | return "" 119 | 120 | return urljoin(self.result_url_prefix, self.bucket + "/" + object_name) 121 | 122 | def _upload_local_file_then_remove_oss(self, key_prefix: str, file_path: str, name: str) -> str: 123 | object_name = key_prefix + name 124 | try: 125 | with open(file_path, 'rb') as file: 126 | self.client.put_object(object_name, file) 127 | os.remove(file_path) 128 | except Exception as e: 129 | logger.error(f"upload_local_file_then_remove_oss err {e}") 130 | return "" 131 | 132 | return urljoin(self.result_url_prefix, self.bucket + "/" + object_name) 133 | 134 | def sign_url(self, 135 | object_name, 136 | method = "GET", 137 | cdn=False, 138 | internal=3600, 139 | slash_safe=True): 140 | # if cdn is True: 141 | # return urljoin(self.cdn, object_name) 142 | 143 | return self.client.presigned_get_object(self.bucket, object_name) 144 | 145 | # object_storage = ObjectStorageManager(config=c) 146 | 147 | # async def upload_local_file(key_prefix: str, file_path: str, name: str) -> str: 148 | # return await object_storage.upload_local_file(key_prefix, file_path, name) 149 | 150 | # async def upload_image_base64(key_prefix: str, name: str, base64_image: str) -> str: 151 | # return await object_storage.upload_image_base64(key_prefix, name, base64_image) 152 | 153 | # def upload_local_file_then_remove(key_prefix: str, file_path: str, name: str) -> str: 154 | # return object_storage.upload_local_file_then_remove(key_prefix, file_path, name) 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | ''' 163 | import os 164 | 165 | from urllib.parse import urljoin 166 | 167 | import oss2 168 | 169 | 170 | class OssClient(): 171 | 172 | def __init__(self, 173 | access_key: str, 174 | access_key_secret:str, 175 | endpoint: str, 176 | **kwargs): 177 | self.id = access_key 178 | self.secret = access_key_secret 179 | self.endpoint = endpoint 180 | self._bucket_name = kwargs.pop("bucket", None) 181 | self.cdn = kwargs.pop("cdn", None) 182 | 183 | self.auth = oss2.Auth(access_key, access_key_secret) 184 | 185 | if self._bucket_name: 186 | self.bucket_client = oss2.Bucket(self.auth, self.endpoint, self._bucket_name) 187 | 188 | def bucket(self, bucket: str): 189 | if bucket: 190 | return oss2.Bucket(self.auth, self.endpoint, bucket) 191 | 192 | return self.bucket_client 193 | 194 | def put_object(self, object_name: str, content): 195 | self.bucket_client.put_object(object_name, content) 196 | 197 | def put_object_from_file(self, 198 | object_name: str, 199 | file, 200 | delete_local: bool=True): 201 | self.bucket_client.put_object_from_file(object_name, file) 202 | if delete_local: 203 | os.remove(file) 204 | 205 | def sign_url(self, 206 | object_name, 207 | method = "GET", 208 | cdn=False, 209 | internal=3600, 210 | slash_safe=True): 211 | if cdn is True: 212 | return urljoin(self.cdn, object_name) 213 | 214 | return self.bucket_client.sign_url(method, object_name, internal, slash_safe=slash_safe) 215 | ''' 216 | -------------------------------------------------------------------------------- /vinci-inference/boot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 获取脚本所在目录 4 | SCRIPT_DIR=$(dirname "$(realpath "$0")") 5 | 6 | # 加载环境变量文件 7 | if [ -f "$SCRIPT_DIR/.env" ]; then 8 | source "$SCRIPT_DIR/.env" 9 | else 10 | echo ".env file not found in $SCRIPT_DIR. Exiting." 11 | exit 1 12 | fi 13 | 14 | # Conda 环境名称 15 | CONDA_ENV=vinci 16 | 17 | # 默认的 CUDA 设备 18 | DEFAULT_CUDA_VISIBLE_DEVICES="0" 19 | DEFAULT_RUNNING_LANGUAGE='chn' 20 | DEFAULT_VERSION="v1" 21 | 22 | # Python 启动命令 23 | COMMAND="python vinci-inference/app/main.py" 24 | LOG_FILE="vinci_inference.log" 25 | PID_FILE="/tmp/.vinci/vinci_inference.pid" 26 | 27 | # 确保目录存在 28 | mkdir -p "$(dirname "$PID_FILE")" 29 | 30 | # 函数:启动服务 31 | start_service() { 32 | export CUDA_VISIBLE_DEVICES=$1 33 | export RUNNING_LANGUAGE=$2 34 | export VERSION=$3 35 | 36 | echo "Activating conda environment: $CONDA_ENV" 37 | source $(conda info --base)/etc/profile.d/conda.sh 38 | conda activate $CONDA_ENV 39 | 40 | echo "Starting service with CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES..." 41 | $COMMAND & 42 | echo $! > $PID_FILE 43 | echo "Service started with PID $(cat $PID_FILE)" 44 | } 45 | 46 | # 函数:停止服务 47 | stop_service() { 48 | if [ -f $PID_FILE ]; then 49 | PID=$(cat $PID_FILE) 50 | echo "Stopping service with PID $PID..." 51 | kill $PID 52 | rm -f $PID_FILE 53 | echo "Service stopped." 54 | else 55 | echo "Service is not running." 56 | fi 57 | } 58 | 59 | # 函数:重启服务 60 | restart_service() { 61 | echo "Restarting service..." 62 | stop_service 63 | start_service $1 $2 64 | echo "Service restarted." 65 | } 66 | 67 | # 主逻辑 68 | CUDA_VISIBLE_DEVICES=$DEFAULT_CUDA_VISIBLE_DEVICES 69 | RUNNING_LANGUAGE=$DEFAULT_RUNNING_LANGUAGE 70 | VERSION=$DEFAULT_VERSION 71 | 72 | while [[ "$#" -gt 0 ]]; do 73 | case $1 in 74 | --cuda) CUDA_VISIBLE_DEVICES="$2"; shift ;; 75 | --version) VERSION="$2"; shift ;; 76 | --language) RUNNING_LANGUAGE="$2"; shift ;; 77 | start) COMMAND_ACTION="start" ;; 78 | stop) COMMAND_ACTION="stop" ;; 79 | restart) COMMAND_ACTION="restart" ;; 80 | *) echo "Unknown parameter passed: $1"; exit 1 ;; 81 | esac 82 | shift 83 | done 84 | 85 | case "$COMMAND_ACTION" in 86 | start) 87 | start_service $CUDA_VISIBLE_DEVICES $RUNNING_LANGUAGE $VERSION 88 | ;; 89 | stop) 90 | stop_service 91 | ;; 92 | restart) 93 | restart_service $CUDA_VISIBLE_DEVICES $RUNNING_LANGUAGE $VERSION 94 | ;; 95 | *) 96 | echo "Usage: $0 {start|stop|restart} [--cuda ] [--language chn/eng] [--version v0/v1]" 97 | exit 1 98 | ;; 99 | esac 100 | 101 | exit 0 102 | -------------------------------------------------------------------------------- /vinci-inference/client/internvl.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import uuid 3 | import requests 4 | import os 5 | import json 6 | import subprocess 7 | import sseclient 8 | 9 | import numpy as np 10 | 11 | from PIL import Image 12 | from decord import VideoReader, cpu 13 | 14 | video_name = 'demo1.mp4' 15 | video_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'demo', video_name)) 16 | 17 | def resize_frame(frame, max_size=320): 18 | """ 19 | Resize a single video frame, ensuring that its dimensions do not exceed max_size, 20 | while maintaining the aspect ratio. 21 | """ 22 | image = Image.fromarray(frame) 23 | orig_width, orig_height = image.size 24 | aspect_ratio = orig_width / orig_height 25 | 26 | # Calculate new dimensions 27 | if orig_width > orig_height: 28 | new_width = min(orig_width, max_size) 29 | new_height = int(new_width / aspect_ratio) 30 | else: 31 | new_height = min(orig_height, max_size) 32 | new_width = int(new_height * aspect_ratio) 33 | 34 | # Ensure dimensions do not exceed max_size 35 | if new_width > max_size: 36 | new_width = max_size 37 | new_height = int(new_width / aspect_ratio) 38 | if new_height > max_size: 39 | new_height = max_size 40 | new_width = int(new_height * aspect_ratio) 41 | 42 | # Resize the frame 43 | resized_image = image.resize((new_width, new_height), Image.LANCZOS) 44 | resized_frame = np.array(resized_image) 45 | 46 | return resized_frame 47 | 48 | def load_video_slice(video_path, step=10, num=10): 49 | vr = VideoReader(video_path, ctx=cpu(0)) 50 | num_frames = len(vr) 51 | 52 | frames = [] 53 | for i in range(1, num_frames, step): 54 | if len(frames) >= num: 55 | break 56 | 57 | frames.append(vr[i].asnumpy()) 58 | 59 | return vr.get_avg_fps(), len(vr), frames 60 | 61 | def send_inference(stream: bool=False): 62 | fps, num_frames, frames = load_video_slice(video_path) 63 | print(f"video fps: {fps}, num frames: {num_frames}") 64 | 65 | frames = [resize_frame(frame).tolist() for frame in frames] 66 | 67 | input = { 68 | "question": "视频内容是什么?", 69 | "history": [], 70 | "frames": frames, 71 | } 72 | 73 | print(f'request body size: {len(json.dumps(input).encode("utf-8"))}') 74 | url = 'http://10.140.0.243:18080/api/v1/inference/internvl' 75 | if stream: 76 | url += "/stream" 77 | 78 | return requests.post(url, json=input, stream=stream) 79 | 80 | def internvl_inference(): 81 | response = send_inference() 82 | 83 | # 检查响应 84 | if response.status_code == 200: 85 | print(f'Inference successfully.') 86 | result = response.json() 87 | print(result) 88 | else: 89 | print('Failed to upload image:', response.status_code) 90 | 91 | def internvl_stream_inference(): 92 | response = send_inference(stream=True) 93 | 94 | # 检查响应 95 | if response.status_code == 200: 96 | print(f'Inference successfully.') 97 | sse_client = sseclient.SSEClient(response) 98 | for event in sse_client.events(): 99 | print(event.data) 100 | else: 101 | print('Failed to upload image:', response.status_code) 102 | 103 | if __name__ == '__main__': 104 | internvl_stream_inference() -------------------------------------------------------------------------------- /vinci-inference/client/internvl_sse.py: -------------------------------------------------------------------------------- 1 | import json 2 | import requests 3 | import sseclient 4 | 5 | from typing import List 6 | 7 | def chat(question: str, frames: List[str], session_id: str, silent=False): 8 | post_body = { 9 | "session_id": session_id, 10 | "timestamp": 12, 11 | "silent": silent, 12 | "question": question, 13 | "history": [], 14 | "frames": [] 15 | } 16 | 17 | url = "http://10.140.0.243:18081/api/v1/inference/internvl/stream" 18 | body_json = json.dumps(post_body, ensure_ascii=False) 19 | print(f"chat sse url: {url}, body: {body_json}") 20 | 21 | with requests.post(url, data=body_json.encode(), stream=True) as response: 22 | if response.status_code != 200: 23 | raise Exception("intern chat sse failed", response) 24 | client = sseclient.SSEClient(response) 25 | for event in client.events(): 26 | if len(event.data) != 0 and event.data != "Connection closed": 27 | yield json.loads(event.data) 28 | 29 | if __name__ == "__main__": 30 | frames = ["https://oss.openmmlab.com/vinci/minimalism-p1.PNG"] * 3 31 | for data in chat(question="视频内容是什么?", frames=frames, session_id="1", silent=False): 32 | print(data) -------------------------------------------------------------------------------- /vinci-inference/client/seine.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import uuid 3 | import requests 4 | import os 5 | import subprocess 6 | 7 | import numpy as np 8 | 9 | from PIL import Image 10 | 11 | image_name = '4.png' 12 | image_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'demo', image_name)) 13 | 14 | # 读取图片并转换为 ndarray 15 | def image_to_array(image_path): 16 | img = Image.open(image_path) 17 | img_array = np.array(img) 18 | return img_array 19 | 20 | def convert_h256(src_path, dst_path): 21 | cmd = ["ffmpeg", "-y", "-i", src_path, "-vcodec", "h264", dst_path] 22 | try: 23 | subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True) 24 | except subprocess.CalledProcessError as ex: 25 | print(f'returncode:{ex.returncode}, \ncmd: {ex.cmd}, \noutput: {ex.output}, \nstderr: {ex.stderr}, \nstdout: {ex.stdout}'.format()) 26 | raise ex 27 | finally: 28 | os.remove(src_path) 29 | 30 | def save_video(video, 31 | local_path, 32 | size=None, 33 | fps: int=30, 34 | h_256: bool=True, 35 | rgb2bgr: bool=False): 36 | dir_path = os.path.dirname(local_path) 37 | if dir_path and not os.path.exists(dir_path): 38 | os.makedirs(dir_path) 39 | 40 | save_path = local_path 41 | 42 | if h_256: 43 | save_path = "tmp-" + str(uuid.uuid4()) + ".mp4" 44 | 45 | try: 46 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 47 | frame1 = video[0] 48 | size = (frame1.shape[1], frame1.shape[0]) 49 | video_writer = cv2.VideoWriter(save_path, fourcc, fps, size) 50 | 51 | for frame in video: 52 | if rgb2bgr: 53 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) 54 | video_writer.write(frame) 55 | finally: 56 | if video_writer: 57 | video_writer.release() 58 | 59 | if h_256: 60 | convert_h256(save_path, local_path) 61 | 62 | 63 | def seine_inference(): 64 | image_array = image_to_array(image_path) 65 | input = { 66 | "prompt": "跳舞", 67 | "image": image_array.tolist() 68 | } 69 | 70 | fps = 8 71 | url = 'http://10.140.0.243:18080/api/v1/inference/seine' 72 | response = requests.post(url, json=input) 73 | 74 | # 检查响应 75 | if response.status_code == 200: 76 | print(f'Inference successfully.') 77 | result = response.json() 78 | output = result['output'] 79 | output = np.array(output, dtype=np.uint8) 80 | save_video(output, 'vid.mp4', fps=fps) 81 | else: 82 | print('Failed to upload image:', response.status_code) 83 | 84 | if __name__ == '__main__': 85 | seine_inference() -------------------------------------------------------------------------------- /vinci-inference/demo/demo.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/vinci/797d2d383efa264a16502063cbedc5ae39517aac/vinci-inference/demo/demo.mp4 -------------------------------------------------------------------------------- /vinci-inference/requirements/app.txt: -------------------------------------------------------------------------------- 1 | fastapi==0.73.0 2 | uvicorn[standard]==0.17.4 3 | pydantic 4 | oss2==2.17.0 5 | sse-starlette==2.1.2 6 | transformers -------------------------------------------------------------------------------- /vinci-inference/requirements/client.txt: -------------------------------------------------------------------------------- 1 | opencv-python-headless==4.5.5.64 2 | decord==0.6.0 3 | numpy==1.26.4 4 | requests 5 | pillow 6 | # sseclient 7 | sseclient-py -------------------------------------------------------------------------------- /vinci-local/.gitignore: -------------------------------------------------------------------------------- 1 | **.cache 2 | .idea -------------------------------------------------------------------------------- /vinci-local/README.md: -------------------------------------------------------------------------------- 1 | # vinci_local 2 | 3 | 4 | 5 | ## 内置服务 6 | 7 | * srs 8 | * mysql 9 | * oss 10 | * llm 11 | * retrieval 12 | * tts 13 | 14 | ## 前后端服务 && 模型服务 15 | 16 | 1. vinci-be 17 | 18 | `git checkout develop` 19 | 20 | -------------------------------------------------------------------------------- /vinci-local/docker/README.md: -------------------------------------------------------------------------------- 1 | Docker Compose 2.0 支持 2 | 3 | ## 快速开始 4 | 5 | 在执行命令前,请确保设置环境变量 `CANDIDATE` 以指定服务的访问 IP(局域网地址)。 6 | 7 | - 默认情况下,使用 `hostname` 命令自动获取本地 IP。如果获取的 IP 不正确,可以通过 `--hostname` 参数手动设置。 8 | 9 | ### 启动服务 10 | 11 | 使用以下命令启动所有服务: 12 | 13 | ```bash 14 | boot.sh start 15 | ``` 16 | 17 | ### 使用流程 18 | 1. 使用手机下载并安装 RTMP 推流软件,并设置推流地址(确保手机与电脑在同一局域网下): 19 | `rtmp://{hostname}:1935/vinci/livestream`,然后开始推流。 20 | 21 | 2. 打开浏览器,访问: 22 | http://{hostname}:1933,点击页面上的“开启会话”按钮。 23 | 24 | 3. 使用手机进行语音唤醒并下达指令。 25 | 26 | ## 服务验证 27 | 28 | 启动各个服务后,使用以下方式进行验证: 29 | 30 | #### Minio 31 | 32 | 在浏览器中访问: 33 | `http://127.0.0.1:19001/` 34 | 35 | #### SRS 36 | 37 | * 在浏览器中访问播放器页面: 38 | `http://127.0.0.1:18080/players/whep.html` 39 | 40 | * 查看流信息: 41 | http://127.0.0.1:18080/api/v1/streams/ 42 | 43 | #### Vinci 后端 (vinci-be) 44 | 45 | 使用以下命令检查后端服务状态: 46 | 47 | ```curl http://127.0.0.1:18000/session/info``` 48 | 49 | #### Vinci 前端 (vinci-fe) 50 | 51 | 在浏览器中访问: 52 | `http://127.0.0.1:19333` 53 | 54 | ## 环境要求 55 | 56 | 请确保具有足够的存储空间、内存和 GPU 支持: 57 | * 镜像存储需求:约 50GB -------------------------------------------------------------------------------- /vinci-local/docker/boot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" 4 | 5 | HOSTNAME_ARG="" 6 | PULL_ARG="" 7 | ACTION="start" 8 | 9 | while [[ "$#" -gt 0 ]]; do 10 | case $1 in 11 | --hostname) 12 | HOSTNAME_ARG="$2"; shift ;; 13 | --pull) 14 | PULL_ARG="true" ;; 15 | start|stop|restart|clean|ps) 16 | ACTION="$1" ;; 17 | *) 18 | echo "Unknown parameter passed: $1"; exit 1 ;; 19 | esac 20 | shift 21 | done 22 | 23 | if [ -n "$HOSTNAME_ARG" ]; then 24 | LAN_IP="$HOSTNAME_ARG" 25 | else 26 | LAN_IP=$(hostname -I | awk '{print $1}') 27 | fi 28 | 29 | echo "LAN IP: $LAN_IP" 30 | 31 | export CANDIDATE="$LAN_IP" 32 | 33 | # 起停 docker-compose 服务 34 | cd "$SCRIPT_DIR" 35 | 36 | # 是否拉取最新镜像 37 | if [ "$PULL_ARG" == "true" ]; then 38 | echo "Pulling latest images..." 39 | docker compose pull 40 | fi 41 | 42 | case "$ACTION" in 43 | start) 44 | echo "Starting docker-compose services..." 45 | docker compose up -d 46 | ;; 47 | stop) 48 | echo "Stopping docker-compose services..." 49 | docker compose down 50 | ;; 51 | restart) 52 | echo "Restarting docker-compose services..." 53 | docker compose down 54 | docker compose up -d 55 | ;; 56 | ps) 57 | echo "Listing docker-compose services status..." 58 | docker compose ps 59 | ;; 60 | clean) 61 | echo "Stopping docker-compose services and cleaning up..." 62 | docker compose down 63 | if [ -d "$SCRIPT_DIR/.cache" ]; then 64 | rm -rf "$SCRIPT_DIR/.cache" 65 | echo "Removed .cache directory." 66 | else 67 | echo "No .cache directory found." 68 | fi 69 | ;; 70 | *) 71 | echo "No valid action specified. Use 'start' or 'stop'." 72 | ;; 73 | esac 74 | -------------------------------------------------------------------------------- /vinci-local/docker/clone.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 固定的多个仓库 URL 和分支名称数组 4 | REPO_URLS=("https://gitlab.pjlab.org.cn/cloud/vinci-be.git" "https://gitlab.pjlab.org.cn/openxlabs/vinci-fe.git" "https://gitlab.pjlab.org.cn/cloud/vinci-inference.git" "https://gitlab.pjlab.org.cn/liujinyao/vinci-retrieval.git") 5 | BRANCH_NAMES=("privatization" "privatization" "privatization" "privatization") 6 | 7 | # 检查数组长度是否匹配 8 | if [ ${#REPO_URLS[@]} -ne ${#BRANCH_NAMES[@]} ]; then 9 | echo "Error: The number of repositories and branches must match." 10 | exit 1 11 | fi 12 | 13 | # 循环处理每个仓库 14 | for i in "${!REPO_URLS[@]}"; do 15 | REPO_URL="${REPO_URLS[$i]}" 16 | BRANCH_NAME="${BRANCH_NAMES[$i]}" 17 | 18 | # 从仓库 URL 获取目录名 19 | REPO_DIR=$(basename "$REPO_URL" .git) 20 | 21 | echo "Processing repository $REPO_URL on branch $BRANCH_NAME..." 22 | 23 | # 检查仓库目录是否存在 24 | if [ ! -d "$REPO_DIR" ]; then 25 | # 如果目录不存在,克隆仓库 26 | echo "Cloning repository from $REPO_URL..." 27 | git clone "$REPO_URL" 28 | if [ $? -ne 0 ]; then 29 | echo "Failed to clone repository $REPO_URL." 30 | continue 31 | fi 32 | fi 33 | 34 | # 进入仓库目录 35 | cd "$REPO_DIR" || exit 36 | 37 | # 拉取最新的远程分支 38 | echo "Fetching latest changes from remote..." 39 | git fetch origin 40 | 41 | # 检查指定的分支是否存在 42 | if git show-ref --verify --quiet refs/heads/"$BRANCH_NAME"; then 43 | # 如果本地已经有该分支,切换到分支并拉取最新的代码 44 | echo "Switching to branch $BRANCH_NAME..." 45 | git checkout "$BRANCH_NAME" 46 | else 47 | # 如果本地没有该分支,尝试从远程创建并切换 48 | echo "Branch $BRANCH_NAME does not exist locally. Checking out from origin..." 49 | git checkout -b "$BRANCH_NAME" origin/"$BRANCH_NAME" 50 | fi 51 | 52 | # 拉取最新代码 53 | echo "Pulling latest changes for branch $BRANCH_NAME..." 54 | git pull origin "$BRANCH_NAME" 55 | 56 | echo "Repository $REPO_URL is now on branch $BRANCH_NAME." 57 | 58 | # 回到上一级目录以继续处理下一个仓库 59 | cd .. 60 | 61 | done 62 | 63 | echo "All repositories processed." 64 | -------------------------------------------------------------------------------- /vinci-local/docker/docker-compose-build.yaml: -------------------------------------------------------------------------------- 1 | version: '2' 2 | 3 | services: 4 | chat: 5 | image: eng-center-registry.cn-shanghai.cr.aliyuncs.com/public/vinci-chat:latest 6 | container_name: chat 7 | build: 8 | context: ./vinci-inference 9 | dockerfile: docker/Dockerfile 10 | ports: 11 | - '18081:18081' 12 | environment: 13 | - CUDA_VISIBLE_DEVICES=1,2 14 | - access_key=minio_admin 15 | - access_key_secret=minio_admin 16 | - endpoint=minio:9000 17 | - bucket=vinci 18 | - cdn=http://${CANDIDATE}:19000/vinci 19 | restart: always 20 | networks: 21 | - vinci_local 22 | 23 | retrieval: 24 | image: eng-center-registry.cn-shanghai.cr.aliyuncs.com/public/vinci-retrieval:latest 25 | container_name: retrieval 26 | build: 27 | context: ./vinci-retrieval 28 | dockerfile: docker/Dockerfile 29 | restart: always 30 | ports: 31 | - '18082:18082' 32 | environment: 33 | config_path: /opt/worker/app/config.yaml 34 | oss__access_key: minio_admin 35 | oss__access_key_secret: minio_admin 36 | oss__endpoint: minio:9000 37 | oss__bucket: vinci 38 | oss__cdn: http://${CANDIDATE}:19000/vinci 39 | networks: 40 | - vinci_local 41 | 42 | srs: 43 | image: registry.cn-hangzhou.aliyuncs.com/ossrs/srs:5 44 | container_name: srs 45 | restart: always 46 | ports: 47 | - '1935:1935' 48 | - '18080:8080' 49 | - '18081:8081/udp' 50 | - "1985:1985" 51 | volumes: 52 | - ./srs/conf/:/usr/local/srs/vinci/ 53 | environment: 54 | - CANDIDATE=${CANDIDATE} 55 | command: ./objs/srs -c vinci/vinci.conf 56 | privileged: true 57 | networks: 58 | - vinci_local 59 | 60 | mysql: 61 | image: mysql:8.0 62 | container_name: mysql 63 | restart: always 64 | ports: 65 | - '13306:3306' 66 | volumes: 67 | - ./mysql/init.sql:/docker-entrypoint-initdb.d/init.sql 68 | - ./.cache/mysql/data:/bitnami/mysql/data 69 | environment: 70 | - BITNAMI_DEBUG=true 71 | - TZ=Asia/Shanghai 72 | - MYSQL_CHARACTER_SET=utf8mb4 73 | - MYSQL_COLLATE=utf8mb4_general_ci 74 | - MYSQL_ROOT_PASSWORD=123456 75 | privileged: true 76 | networks: 77 | - vinci_local 78 | 79 | minio: 80 | image: minio/minio 81 | container_name: minio 82 | environment: 83 | - MINIO_ROOT_USER=minio_admin 84 | - MINIO_ROOT_PASSWORD=minio_admin 85 | ports: 86 | - "19000:9000" 87 | - "19001:9001" 88 | volumes: 89 | - .cache/minio:/data 90 | - ./minio/entry.sh:/usr/local/bin/entry.sh 91 | networks: 92 | - vinci_local 93 | entrypoint: /bin/sh -c "/usr/local/bin/entry.sh" 94 | 95 | nginx: 96 | image: nginx:1.21-alpine 97 | container_name: nginx 98 | ports: 99 | - "19333:80" 100 | volumes: 101 | - ./nginx/conf.d:/etc/nginx/conf.d 102 | networks: 103 | - vinci_local 104 | depends_on: 105 | - vinci-be 106 | - vinci-fe 107 | 108 | vinci-be: 109 | image: eng-center-registry.cn-shanghai.cr.aliyuncs.com/public/vinci-be:latest 110 | container_name: vinci-be 111 | build: 112 | context: ./vinci-be 113 | dockerfile: Dockerfile 114 | restart: always 115 | environment: 116 | - env=private 117 | - srs_lb_host=${CANDIDATE} 118 | - srs_host=srs 119 | - rtmp_port=1935 120 | - webrtc_port=18080 121 | - api_port=8080 122 | - api_auth_token=cGpsYWI6UFpNdmM4VWhKN3hNQ2JlTQ== 123 | - internal_app_name=internal 124 | - mysql_host=mysql:3306 125 | - mysql_user=root 126 | - mysql_password=123456 127 | - oss_access_key_id=minio_admin 128 | - oss_access_key_secret=minio_admin 129 | - oss_endpoint=minio:9000 130 | - oss_bucket=vinci 131 | - oss_external_endpoint=http://${CANDIDATE}:19000 132 | - intern_endpoint=http://chat:18081 133 | - retrieval_endpoint=http://retrieval:18082 134 | - speech_language=en-US 135 | - tts=0 136 | - puyu_base_url=https://puyu.openxlab.org.cn/puyu/api/v1/ 137 | - puyu_api_key=eyJ0eXBlIjoiSldUIiwiYWxnIjoiSFM1MTIifQ.eyJqdGkiOiI1MDE4MjM4OSIsInJvbCI6IlJPTEVfUkVHSVNURVIiLCJpc3MiOiJPcGVuWExhYiIsImlhdCI6MTcyNDEyMjk3MCwiY2xpZW50SWQiOiJtcXprcGxtbnc5N29wa28zNmpxaiIsInBob25lIjoiIiwidXVpZCI6IjYzNTgxOTg0LTUzYjctNDAwMS1hYzJlLTBmOTFlMTc3ZDFkYyIsImVtYWlsIjoidmluY2lAcGpsYWIub3JnLmNuIiwiZXhwIjoxODgxODAyOTcwfQ.E01dyrnLZZgaPTD9dtrtXtnl_wkXYb3BHbaHQYnjN89RkdP88r7FY7VxeszW9ujk-913XurIc2e9ngbZCFbh2w 138 | ports: 139 | - '18000:8000' 140 | privileged: true 141 | networks: 142 | - vinci_local 143 | 144 | vinci-fe: 145 | image: eng-center-registry.cn-shanghai.cr.aliyuncs.com/public/vinci-fe:latest 146 | container_name: vinci-fe 147 | build: 148 | context: ./vinci-fe 149 | dockerfile: deploy/Dockerfile 150 | args: 151 | deployDir: deploy 152 | restart: always 153 | ports: 154 | - '19330:80' 155 | privileged: true 156 | networks: 157 | - vinci_local 158 | 159 | networks: 160 | vinci_local: 161 | driver: bridge 162 | -------------------------------------------------------------------------------- /vinci-local/docker/docker-compose.yaml: -------------------------------------------------------------------------------- 1 | version: '2' 2 | 3 | services: 4 | srs: 5 | image: registry.cn-hangzhou.aliyuncs.com/ossrs/srs:5 6 | container_name: srs 7 | restart: always 8 | ports: 9 | - '1935:1935' 10 | - '8080:8080/udp' 11 | - '8080:8080/tcp' 12 | - '8000:8000/udp' 13 | - "1985:1985" 14 | volumes: 15 | - ./srs/conf/:/usr/local/srs/vinci/ 16 | environment: 17 | - CANDIDATE=${CANDIDATE} 18 | command: ./objs/srs -c vinci/vinci.conf 19 | privileged: true 20 | networks: 21 | - vinci_local 22 | 23 | mysql: 24 | image: mysql:8.0 25 | container_name: mysql 26 | restart: always 27 | ports: 28 | - '13306:3306' 29 | volumes: 30 | - ./mysql/init.sql:/docker-entrypoint-initdb.d/init.sql 31 | - ./.cache/mysql/data:/bitnami/mysql/data 32 | environment: 33 | - BITNAMI_DEBUG=true 34 | - TZ=Asia/Shanghai 35 | - MYSQL_CHARACTER_SET=utf8mb4 36 | - MYSQL_COLLATE=utf8mb4_general_ci 37 | - MYSQL_ROOT_PASSWORD=123456 38 | privileged: true 39 | networks: 40 | - vinci_local 41 | 42 | minio: 43 | image: minio/minio 44 | container_name: minio 45 | environment: 46 | - MINIO_ROOT_USER=minio_admin 47 | - MINIO_ROOT_PASSWORD=minio_admin 48 | ports: 49 | - "19000:9000" 50 | - "19001:9001" 51 | volumes: 52 | - .cache/minio:/data 53 | - ./minio/entry.sh:/usr/local/bin/entry.sh 54 | privileged: true 55 | networks: 56 | - vinci_local 57 | entrypoint: /bin/sh -c "/usr/local/bin/entry.sh" 58 | 59 | nginx: 60 | image: nginx:1.21-alpine 61 | container_name: nginx 62 | ports: 63 | - "19333:80" 64 | volumes: 65 | - ./nginx/conf.d:/etc/nginx/conf.d 66 | privileged: true 67 | networks: 68 | - vinci_local 69 | depends_on: 70 | - vinci-be 71 | - vinci-fe 72 | 73 | vinci-be: 74 | image: crpi-dn1nyq7vw8amhh61.cn-shanghai.personal.cr.aliyuncs.com/vinci-dl/vinci-be:latest 75 | container_name: vinci-be 76 | restart: always 77 | environment: 78 | - env=private 79 | - srs_lb_host=${CANDIDATE} 80 | - srs_host=srs 81 | - rtmp_port=1935 82 | - webrtc_port=8080 83 | - api_port=8080 84 | - api_auth_token=cGpsYWI6UFpNdmM4VWhKN3hNQ2JlTQ== 85 | - internal_app_name=internal 86 | - mysql_host=mysql:3306 87 | - mysql_user=root 88 | - mysql_password=123456 89 | - oss_access_key_id=minio_admin 90 | - oss_access_key_secret=minio_admin 91 | - oss_endpoint=minio:9000 92 | - oss_bucket=vinci 93 | - oss_external_endpoint=http://${CANDIDATE}:19000 94 | - intern_endpoint=http://10.6.20.128:18081 95 | - retrieval_endpoint=http://10.6.20.128:18082 96 | - speech_language=en-US 97 | - tts=0 98 | - puyu_base_url=https://puyu.openxlab.org.cn/puyu/api/v1/ 99 | - puyu_api_key=eyJ0eXBlIjoiSldUIiwiYWxnIjoiSFM1MTIifQ.eyJqdGkiOiI1MDE4MjM4OSIsInJvbCI6IlJPTEVfUkVHSVNURVIiLCJpc3MiOiJPcGVuWExhYiIsImlhdCI6MTcyNDEyMjk3MCwiY2xpZW50SWQiOiJtcXprcGxtbnc5N29wa28zNmpxaiIsInBob25lIjoiIiwidXVpZCI6IjYzNTgxOTg0LTUzYjctNDAwMS1hYzJlLTBmOTFlMTc3ZDFkYyIsImVtYWlsIjoidmluY2lAcGpsYWIub3JnLmNuIiwiZXhwIjoxODgxODAyOTcwfQ.E01dyrnLZZgaPTD9dtrtXtnl_wkXYb3BHbaHQYnjN89RkdP88r7FY7VxeszW9ujk-913XurIc2e9ngbZCFbh2w 100 | ports: 101 | - '18000:8000' 102 | privileged: true 103 | networks: 104 | - vinci_local 105 | 106 | vinci-fe: 107 | image: crpi-dn1nyq7vw8amhh61.cn-shanghai.personal.cr.aliyuncs.com/vinci-dl/vinci-fe:latest 108 | container_name: vinci-fe 109 | restart: always 110 | ports: 111 | - '19330:80' 112 | privileged: true 113 | networks: 114 | - vinci_local 115 | 116 | networks: 117 | vinci_local: 118 | ipam: 119 | driver: default 120 | config: 121 | - subnet: 172.28.0.0/16 122 | gateway: 172.28.0.1 123 | -------------------------------------------------------------------------------- /vinci-local/docker/minio/entry.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | minio server /data --console-address :9001 & 4 | 5 | # 等待 MinIO 服务(9000 端口)启动 6 | while ! curl -s http://localhost:9000/minio/health/live; do 7 | echo "等待 MinIO 服务启动..." 8 | sleep 3 9 | done 10 | 11 | mc alias set local_minio http://localhost:9000 $MINIO_ROOT_USER $MINIO_ROOT_PASSWORD 12 | 13 | # 创建bucket 14 | if mc ls local_minio/vinci > /dev/null 2>&1; then 15 | echo "Bucket 'vinci' 已经存在,跳过创建步骤" 16 | else 17 | mc mb local_minio/vinci 18 | fi 19 | 20 | echo "设置 bucket 'vinci' 为公共访问权限" 21 | mc anonymous set public local_minio/vinci 22 | 23 | tail -f /dev/null 24 | -------------------------------------------------------------------------------- /vinci-local/docker/mysql/init.sql: -------------------------------------------------------------------------------- 1 | CREATE DATABASE IF NOT EXISTS `vinci` /*!40100 DEFAULT CHARACTER SET utf8mb4 */; 2 | USE `vinci`; 3 | -- MySQL dump 10.13 Distrib 5.7.33, for Linux (x86_64) 4 | -- 5 | -- Host: rm-uf60p2114565x449u4o.mysql.rds.aliyuncs.com Database: vinci 6 | -- ------------------------------------------------------ 7 | -- Server version 5.7.42-log 8 | 9 | /*!40101 SET @OLD_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT */; 10 | /*!40101 SET @OLD_CHARACTER_SET_RESULTS=@@CHARACTER_SET_RESULTS */; 11 | /*!40101 SET @OLD_COLLATION_CONNECTION=@@COLLATION_CONNECTION */; 12 | /*!40101 SET NAMES utf8 */; 13 | /*!40103 SET @OLD_TIME_ZONE=@@TIME_ZONE */; 14 | /*!40103 SET TIME_ZONE='+00:00' */; 15 | /*!40014 SET @OLD_UNIQUE_CHECKS=@@UNIQUE_CHECKS, UNIQUE_CHECKS=0 */; 16 | /*!40014 SET @OLD_FOREIGN_KEY_CHECKS=@@FOREIGN_KEY_CHECKS, FOREIGN_KEY_CHECKS=0 */; 17 | /*!40101 SET @OLD_SQL_MODE=@@SQL_MODE, SQL_MODE='NO_AUTO_VALUE_ON_ZERO' */; 18 | /*!40111 SET @OLD_SQL_NOTES=@@SQL_NOTES, SQL_NOTES=0 */; 19 | SET @MYSQLDUMP_TEMP_LOG_BIN = @@SESSION.SQL_LOG_BIN; 20 | SET @@SESSION.SQL_LOG_BIN= 0; 21 | 22 | -- 23 | -- GTID state at the beginning of the backup 24 | -- 25 | 26 | SET @@GLOBAL.GTID_PURGED='11d9f446-5278-11ef-9ae1-00163e3a02d1:1-430758, 27 | 514a3f81-2c5c-11ee-bd18-00163e285d0f:1-214379, 28 | 5391b628-46be-11ef-83f8-00163e36f15d:1-87510, 29 | 607a026e-481d-11ee-9243-00163e0a5235:1-1508997, 30 | 69a206a0-fdaf-11ee-a3dc-00163e3a0e88:1-473717, 31 | c28b70f8-3aef-11ef-8124-00163e108a1c:1-87559'; 32 | 33 | -- 34 | -- Table structure for table `session_info` 35 | -- 36 | 37 | DROP TABLE IF EXISTS `session_info`; 38 | /*!40101 SET @saved_cs_client = @@character_set_client */; 39 | /*!40101 SET character_set_client = utf8 */; 40 | CREATE TABLE `session_info` ( 41 | `session_id` varchar(50) NOT NULL COMMENT '会话id', 42 | `stream_group_name` varchar(50) NOT NULL DEFAULT '' COMMENT '推送来的视频流的组名', 43 | `stream_name` varchar(50) NOT NULL DEFAULT '' COMMENT '推送来的视频流名字', 44 | `is_delete` tinyint(1) NOT NULL DEFAULT '0' COMMENT '是否删除', 45 | `create_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP, 46 | `update_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, 47 | PRIMARY KEY (`session_id`) 48 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; 49 | /*!40101 SET character_set_client = @saved_cs_client */; 50 | 51 | -- 52 | -- Table structure for table `user_chat_record` 53 | -- 54 | 55 | DROP TABLE IF EXISTS `user_chat_record`; 56 | /*!40101 SET @saved_cs_client = @@character_set_client */; 57 | /*!40101 SET character_set_client = utf8 */; 58 | CREATE TABLE `user_chat_record` ( 59 | `id` bigint(20) NOT NULL, 60 | `session_id` varchar(50) NOT NULL DEFAULT '' COMMENT '会话id', 61 | `request_text` longtext COMMENT '请求文字', 62 | `response_text` longtext COMMENT '返回文字', 63 | `duration` bigint(20) NOT NULL DEFAULT '0' COMMENT '本次问答的持续时间,单位毫秒', 64 | `request_frame_urls` longtext COMMENT '请求帧图片', 65 | `response_audio_url` varchar(500) NOT NULL DEFAULT '' COMMENT '返回音频url', 66 | `response_video_url` varchar(500) NOT NULL DEFAULT '' COMMENT '返回视频url', 67 | `request_history` longtext COMMENT '本次问答产生的请求历史', 68 | `response_history` longtext COMMENT '本次问答产生的返回历史', 69 | `silent` tinyint(1) NOT NULL DEFAULT '0' COMMENT '本次问答是否是silent', 70 | `status` tinyint(1) NOT NULL DEFAULT '0' COMMENT '本次对话的状态,0正常,1用户主动取消', 71 | `is_delete` tinyint(1) NOT NULL DEFAULT '0' COMMENT '是否删除', 72 | `create_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP, 73 | `update_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, 74 | `retrieval_video_urls` text, 75 | PRIMARY KEY (`id`) 76 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; 77 | /*!40101 SET character_set_client = @saved_cs_client */; 78 | SET @@SESSION.SQL_LOG_BIN = @MYSQLDUMP_TEMP_LOG_BIN; 79 | /*!40103 SET TIME_ZONE=@OLD_TIME_ZONE */; 80 | 81 | /*!40101 SET SQL_MODE=@OLD_SQL_MODE */; 82 | /*!40014 SET FOREIGN_KEY_CHECKS=@OLD_FOREIGN_KEY_CHECKS */; 83 | /*!40014 SET UNIQUE_CHECKS=@OLD_UNIQUE_CHECKS */; 84 | /*!40101 SET CHARACTER_SET_CLIENT=@OLD_CHARACTER_SET_CLIENT */; 85 | /*!40101 SET CHARACTER_SET_RESULTS=@OLD_CHARACTER_SET_RESULTS */; 86 | /*!40101 SET COLLATION_CONNECTION=@OLD_COLLATION_CONNECTION */; 87 | /*!40111 SET SQL_NOTES=@OLD_SQL_NOTES */; 88 | 89 | -- Dump completed on 2024-10-16 15:47:03 -------------------------------------------------------------------------------- /vinci-local/docker/nginx/conf.d/default.conf: -------------------------------------------------------------------------------- 1 | server { 2 | listen 80; 3 | 4 | # 后端服务 5 | location /gw/vinci/ { 6 | rewrite ^/gw/vinci(/.*)$ $1 break; 7 | proxy_pass http://vinci-be:8000; 8 | proxy_set_header Host $host; 9 | proxy_set_header X-Real-IP $remote_addr; 10 | proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; 11 | proxy_set_header X-Forwarded-Proto $scheme; 12 | } 13 | 14 | # 前端服务 15 | location / { 16 | proxy_pass http://vinci-fe:80; 17 | #proxy_pass http://10.6.20.128:19330; 18 | proxy_set_header Host $host; 19 | proxy_set_header X-Real-IP $remote_addr; 20 | proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; 21 | proxy_set_header X-Forwarded-Proto $scheme; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /vinci-local/docker/srs/conf/vinci.conf: -------------------------------------------------------------------------------- 1 | listen 1935; 2 | max_connections 1000; 3 | daemon off; 4 | srs_log_tank console; 5 | 6 | http_server { 7 | enabled on; 8 | listen 8080; 9 | dir ./objs/nginx/html; 10 | } 11 | 12 | http_api { 13 | enabled on; 14 | listen 8080; # rtc拉流的端口也是它 15 | auth { 16 | # whether enable the HTTP AUTH. 17 | # Overwrite by env SRS_HTTP_API_AUTH_ENABLED 18 | # default: off 19 | enabled off; 20 | # The username of Basic authentication: 21 | # Overwrite by env SRS_HTTP_API_AUTH_USERNAME 22 | username pjlab; 23 | # The password of Basic authentication: 24 | # Overwrite by env SRS_HTTP_API_AUTH_PASSWORD 25 | password PZMvc8UhJ7xMCbeM; 26 | } 27 | } 28 | stats { 29 | network 0; 30 | } 31 | rtc_server { 32 | enabled on; 33 | listen 8080; # UDP port 34 | # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#config-candidate 35 | candidate 127.0.0.1; 36 | } 37 | 38 | vhost srs-xcomposer-dev.intern-ai.org.cn { 39 | tcp_nodelay on; 40 | min_latency on; 41 | play { 42 | gop_cache off; 43 | queue_length 10; 44 | mw_latency 100; 45 | } 46 | publish { 47 | mr off; 48 | } 49 | rtc { 50 | enabled on; 51 | # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtmp-to-rtc 52 | rtmp_to_rtc on; 53 | # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtc-to-rtmp 54 | rtc_to_rtmp on; 55 | } 56 | http_remux { 57 | enabled on; 58 | mount [vhost]/[app]/[stream].flv; 59 | } 60 | } 61 | 62 | vhost srs-vinci-dev.intern-ai.org.cn { 63 | tcp_nodelay on; 64 | min_latency on; 65 | play { 66 | gop_cache off; 67 | queue_length 10; 68 | mw_latency 100; 69 | } 70 | publish { 71 | mr off; 72 | } 73 | rtc { 74 | enabled on; 75 | # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtmp-to-rtc 76 | rtmp_to_rtc on; 77 | # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtc-to-rtmp 78 | rtc_to_rtmp off; 79 | } 80 | http_remux { 81 | enabled on; 82 | mount [vhost]/[app]/[stream].flv; 83 | } 84 | http_hooks { 85 | enabled on; 86 | on_unpublish http://vinci-be-service:8000/callback/stream/on_unpublish; 87 | # on_publish http://vinci-be-service:8000/callback/stream/on_publish; 88 | # on_play http://vinci-be-service:8000/callback/stream/on_play; 89 | # on_stop http://vinci-be-service:8000/callback/stream/on_stop; 90 | } 91 | } 92 | 93 | vhost __defaultVhost__ { 94 | tcp_nodelay on; 95 | min_latency on; 96 | play { 97 | gop_cache off; 98 | queue_length 10; 99 | mw_latency 100; 100 | } 101 | publish { 102 | mr off; 103 | } 104 | rtc { 105 | enabled on; 106 | # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtmp-to-rtc 107 | rtmp_to_rtc on; 108 | # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtc-to-rtmp 109 | rtc_to_rtmp off; 110 | } 111 | http_remux { 112 | enabled on; 113 | mount [vhost]/[app]/[stream].flv; 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /vl_open.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision.transforms as T 4 | from PIL import Image 5 | from torchvision.transforms.functional import InterpolationMode 6 | from transformers import AutoModel, AutoTokenizer 7 | from random import randint 8 | from transformers import TextIteratorStreamer 9 | from threading import Thread 10 | import os 11 | 12 | IMAGENET_MEAN = (0.485, 0.456, 0.406) 13 | IMAGENET_STD = (0.229, 0.224, 0.225) 14 | 15 | 16 | def build_transform(input_size): 17 | MEAN, STD = IMAGENET_MEAN, IMAGENET_STD 18 | transform = T.Compose([ 19 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), 20 | T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), 21 | T.ToTensor(), 22 | T.Normalize(mean=MEAN, std=STD) 23 | ]) 24 | return transform 25 | 26 | 27 | def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): 28 | best_ratio_diff = float('inf') 29 | best_ratio = (1, 1) 30 | area = width * height 31 | for ratio in target_ratios: 32 | target_aspect_ratio = ratio[0] / ratio[1] 33 | ratio_diff = abs(aspect_ratio - target_aspect_ratio) 34 | if ratio_diff < best_ratio_diff: 35 | best_ratio_diff = ratio_diff 36 | best_ratio = ratio 37 | elif ratio_diff == best_ratio_diff: 38 | if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: 39 | best_ratio = ratio 40 | return best_ratio 41 | 42 | 43 | def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False): 44 | orig_width, orig_height = image.size 45 | aspect_ratio = orig_width / orig_height 46 | 47 | # calculate the existing image aspect ratio 48 | target_ratios = set( 49 | (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if 50 | i * j <= max_num and i * j >= min_num) 51 | target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) 52 | 53 | # find the closest aspect ratio to the target 54 | target_aspect_ratio = find_closest_aspect_ratio( 55 | aspect_ratio, target_ratios, orig_width, orig_height, image_size) 56 | 57 | # calculate the target width and height 58 | target_width = image_size * target_aspect_ratio[0] 59 | target_height = image_size * target_aspect_ratio[1] 60 | blocks = target_aspect_ratio[0] * target_aspect_ratio[1] 61 | 62 | # resize the image 63 | resized_img = image.resize((target_width, target_height)) 64 | processed_images = [] 65 | for i in range(blocks): 66 | box = ( 67 | (i % (target_width // image_size)) * image_size, 68 | (i // (target_width // image_size)) * image_size, 69 | ((i % (target_width // image_size)) + 1) * image_size, 70 | ((i // (target_width // image_size)) + 1) * image_size 71 | ) 72 | # split the image 73 | split_img = resized_img.crop(box) 74 | processed_images.append(split_img) 75 | assert len(processed_images) == blocks 76 | if use_thumbnail and len(processed_images) != 1: 77 | thumbnail_img = image.resize((image_size, image_size)) 78 | processed_images.append(thumbnail_img) 79 | return processed_images 80 | 81 | 82 | def load_image(image_file, input_size=448, max_num=6): 83 | image = Image.open(image_file).convert('RGB') 84 | transform = build_transform(input_size=input_size) 85 | images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) 86 | pixel_values = [transform(image) for image in images] 87 | pixel_values = torch.stack(pixel_values) 88 | return pixel_values 89 | 90 | 91 | class Chat(): 92 | def __init__(self, path='Vinci-8B-base', path2='Vinci-8B-ckpt', sep_chat=False, stream=True, device='cuda:0', use_chat_history=False, language='chn', version='v0'): 93 | super().__init__() 94 | self.device = device 95 | self.vr = None 96 | self.video_fps = None 97 | self.prev_timestamp = 0 98 | self.history = [] 99 | self.chat_history = [] 100 | self.stream = stream 101 | self.use_chat_history = use_chat_history 102 | self.transform = build_transform(input_size=448) 103 | self.language = language 104 | 105 | from safetensors.torch import load_file 106 | 107 | def merge_dicts(dict1, dict2, dict3, dict4): 108 | result = {**dict1, **dict2, **dict3, **dict4} 109 | return result 110 | 111 | self.model = AutoModel.from_pretrained( 112 | path, 113 | torch_dtype=torch.bfloat16, 114 | low_cpu_mem_usage=True, 115 | trust_remote_code=True) 116 | if 'version' == 'v0': 117 | model_weights1 = load_file(os.path.join(path2,"model-00001-of-00004.safetensors")) 118 | model_weights2 = load_file(os.path.join(path2,"model-00002-of-00004.safetensors")) 119 | model_weights3 = load_file(os.path.join(path2,"model-00003-of-00004.safetensors")) 120 | model_weights4 = load_file(os.path.join(path2,"model-00004-of-00004.safetensors")) 121 | merged_weight = merge_dicts(model_weights1,model_weights2,model_weights3,model_weights4) 122 | self.model.wrap_llm_lora(r=16, lora_alpha=2 * 16) 123 | msg = self.model.load_state_dict(merged_weight,strict=False) 124 | 125 | self.model = self.model.eval().cuda() 126 | state1 = self.model.state_dict() 127 | self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) 128 | if self.stream: 129 | self.streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=10) 130 | self.generation_config = dict( 131 | num_beams=1, 132 | max_new_tokens=1024, 133 | do_sample=False, 134 | streamer=self.streamer 135 | ) 136 | else: 137 | self.generation_config = dict( 138 | num_beams=1, 139 | max_new_tokens=1024, 140 | do_sample=False, 141 | ) 142 | 143 | def load_video_timestamp(self, timestamp, num_segments=4): 144 | pixel_values_list, num_patches_list = [], [] 145 | offset = np.linspace(-2, 0, num_segments) 146 | rand_offset = randint(-4, 4) 147 | offset = offset + rand_offset 148 | frame_indices = (timestamp+offset) * self.video_fps 149 | frame_indices = frame_indices.astype(np.int64) 150 | if frame_indices[0] < 0: 151 | frame_indices -= frame_indices[0] 152 | print('***** using video timestamps at:', frame_indices) 153 | for i, frame_index in enumerate(frame_indices): 154 | img = Image.fromarray(self.vr[frame_index].asnumpy()).convert('RGB') 155 | if i == len(frame_indices) - 1: 156 | img.save('./lastim.jpg') 157 | img = dynamic_preprocess(img, image_size=448, use_thumbnail=True, max_num=1) 158 | pixel_values = [self.transform(tile) for tile in img] 159 | pixel_values = torch.stack(pixel_values) 160 | num_patches_list.append(pixel_values.shape[0]) 161 | pixel_values_list.append(pixel_values) 162 | pixel_values = torch.cat(pixel_values_list) 163 | return pixel_values, num_patches_list 164 | 165 | 166 | def answer(self, conv, timestamp=0, add_to_history=False): 167 | pixel_values, num_patches_list = self.load_video_timestamp(timestamp) 168 | pixel_values = pixel_values.to(torch.bfloat16).cuda() 169 | video_prefix = ''.join([f'Frame{i+1}: \n' for i in range(len(num_patches_list))]) 170 | if add_to_history: # silent ask 171 | if self.language == 'chn': 172 | question = video_prefix + '现在视频到了 %.1f 秒处. 简单的描述视频中我的动作.' % timestamp 173 | else: 174 | question = video_prefix + 'Now the video is at %.1f second. Briefly describe my actions in the video.' % timestamp #conv['questions'][-1] 175 | 176 | response, history = self.model.chat(self.tokenizer, pixel_values, question, self.generation_config, 177 | num_patches_list=num_patches_list, 178 | history=None, return_history=True) 179 | self.history.append((timestamp, response)) 180 | else: 181 | if True: 182 | self.chat_history.append([conv['questions'][-1]]) 183 | question = self.add_history(conv['questions'][-1]) 184 | question = video_prefix + question 185 | if self.stream: 186 | thread = Thread(target=self.model.chat, kwargs=dict(tokenizer=self.tokenizer, pixel_values=pixel_values, question=question, generation_config=self.generation_config, 187 | num_patches_list=num_patches_list, 188 | history=None, return_history=False)) 189 | thread.start() 190 | response = '' 191 | else: 192 | response = self.model.chat(self.tokenizer, pixel_values, question, self.generation_config, 193 | num_patches_list=num_patches_list, 194 | history=None, return_history=False) 195 | self.chat_history[-1].append(timestamp) 196 | self.chat_history[-1].append(response) 197 | # self.history.append((timestamp, response)) 198 | conv['answers'].append(response + '\n') 199 | 200 | return response, conv, './lastim.jpg' 201 | 202 | def add_history(self, question): 203 | if not self.history: 204 | print('history not added because self.history is empty') 205 | return question 206 | if len(self.history) > 0: 207 | if self.language == 'chn': 208 | system = "你是一个在增强现实(AR)眼镜上的智能助手。看到的图像是来自我第一人称视角的视频帧。仔细观察视频并重点关注物体的运动和我的动作。由于你看不到发生在当前帧之前的部分,现在以文字形式提供给你这个视频的之前的历史供参考。视频历史是:" 209 | else: 210 | system = 'You are an intelligent assistant on AR glasses. The AR glasses receive video frames from my egocentric viewpoint. Carefully watch the video and pay attention to the movement of objects, and the action of human. Since you cannot see the previous part of the video, I provide you the history of this video for reference. The history is: ' 211 | res = system 212 | if self.sep_chat: 213 | for hist in self.history[:-1]: 214 | ts = hist[0] 215 | a = hist[1] 216 | if self.language == 'chn': 217 | res += '当视频在%.1f秒时, 视频的内容是 "%s"' % (ts, a) 218 | else: 219 | res += 'When the video is at %.1f seconds, the video content is "%s". ' % (ts, a) 220 | ts = self.history[-1][0] 221 | a = self.history[-1][1] 222 | if self.language == 'chn': 223 | res += '以上是所有的视频历史, 表明了之前发生了什么.\n现在视频到了 %.1f秒, 视频的内容是 "%s". ' % (ts, a) 224 | else: 225 | res += 'This is the end of the history which indicate what have previously happened.\n Now the video is at %.1f seconds, the video content is: "%s". ' % (ts, a) 226 | else: 227 | for hist in self.history: 228 | ts = hist[0] 229 | a = hist[1] 230 | if self.language == 'chn': 231 | res += '当视频在%.1f秒时, 视频的内容是 "%s". ' % (ts, a.strip()) 232 | else: 233 | res += 'When the video is at %.1f seconds, the video contect is "%s". ' % (ts, a.strip()) 234 | if self.language == 'chn': 235 | res += '以上是所有的视频历史, 表明了之前发生了什么, 如果后面的问题问到了之前发生的事情, 请参照.\n' 236 | else: 237 | res += 'This is the end of the video history that indicates what happened before.\n' 238 | if self.use_chat_history and len(self.chat_history)>1: 239 | if self.language == 'chn': 240 | res += '另外我提供根据之前的视频,我们的对话历史如下: ' 241 | else: 242 | res += 'Also I provide you with our chat history based on the previous video content: ' 243 | for hist in self.chat_history[:-1]: 244 | q = hist[0] 245 | ts = hist[1] 246 | a = hist[2] 247 | if self.language == 'chn': 248 | res += '当视频在%.1f秒时, 问题是: "%s", 回答是"%s". ' % (ts, q.strip(), a.strip()) 249 | else: 250 | res += 'When the video is at %.1f seconds, the question was: "%s", and its answer was: "%s". ' % (ts, q.strip(), a.strip()) 251 | if self.language == 'chn': 252 | res += '以上是所有的对话历史, 表明了之前我们交流了什么,但是不表明现在的任何信息.\n' 253 | else: 254 | res += 'This is the end of the chat history. The chat history indicate what our previous chat was, but does not necessarily contain the current information.\n' 255 | 256 | if self.language == 'chn': 257 | res += '请根据当前视频, 用中文回答我的问题: "%s". 注意如果问题与之前发生的事情有关, 请参考视频历史, 否则请只关注图像信息. 如果问题是对未来的规划,给出最多3步规划. 用三句话以内回答.' % question 258 | else: 259 | res += 'Given the current video and using the previous video as reference, answer my question in English: "%s". Note that if the question is about what has been previously done, please only focus on the history. Otherwise, please only focus on the question and the current video input. If the question is about future planning, provide at most 3 steps.' % question 260 | 261 | question = res 262 | return question 263 | --------------------------------------------------------------------------------