├── app ├── __init__.py ├── dependencies.py ├── settings.py ├── models │ ├── move-models.json │ ├── gen-models.json │ └── v2v-models.json ├── event_callback.py ├── models.py ├── cache.py ├── schema.py ├── main.py └── user_client.py ├── requirements-dev.txt ├── .env.template ├── .idea ├── .gitignore ├── vcs.xml ├── inspectionProfiles │ ├── profiles_settings.xml │ └── Project_Default.xml ├── modules.xml ├── domoai-api.iml └── misc.xml ├── requirements.txt ├── streamlit_demo ├── 🏠_Home.py ├── auth.py ├── pages │ ├── Animate.py │ ├── Move.py │ ├── Video.py │ ├── Real.py │ └── Gen.py └── utils.py ├── Dockerfile ├── README_CN.md ├── streamlit_demo.dockerfile ├── README.md ├── LICENSE ├── scripts └── update_models.py └── .gitignore /app/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | streamlit 2 | watchdog 3 | streamlit-authenticator -------------------------------------------------------------------------------- /.env.template: -------------------------------------------------------------------------------- 1 | DISCORD_TOKEN= 2 | DISCORD_GUILD_ID= 3 | DISCORD_CHANNEL_ID= 4 | 5 | # Optional 6 | # REDIS_URI= 7 | 8 | # EVENT_CALLBACK_URL= 9 | 10 | # AUTH_TOKEN= -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # 默认忽略的文件 2 | /shelf/ 3 | /workspace.xml 4 | # 基于编辑器的 HTTP 客户端请求 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi 2 | uvicorn[standard] 3 | python-multipart 4 | pydantic-settings 5 | discord.py-self @ git+https://github.com/dolfies/discord.py-self.git 6 | 7 | redis 8 | httpx 9 | tenacity -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /streamlit_demo/🏠_Home.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | 3 | from streamlit_demo.auth import check_password 4 | 5 | st.set_page_config(page_title="DomoAI API") 6 | 7 | 8 | if not check_password(): 9 | st.stop() 10 | 11 | st.title("Unofficial DomoAI API Demo") 12 | st.link_button(label='GitHub', url='https://github.com/T0nYie1dYaw/domoai-api') 13 | -------------------------------------------------------------------------------- /.idea/domoai-api.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | 8 | 10 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.10 2 | 3 | WORKDIR /code 4 | 5 | COPY ./requirements.txt /code/requirements.txt 6 | 7 | RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt 8 | 9 | COPY ./app /code/app 10 | 11 | ENV PYTHONUNBUFFERED 1 12 | 13 | ENV DISCORD_TOKEN '' 14 | ENV DISCORD_GUILD_ID '' 15 | ENV DISCORD_CHANNEL_ID '' 16 | 17 | ENV EVENT_CALLBACK_URL '' 18 | 19 | ENV AUTH_TOKEN '' 20 | 21 | ENV REDIS_URI '' 22 | 23 | EXPOSE 80 24 | 25 | CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "80"] 26 | -------------------------------------------------------------------------------- /README_CN.md: -------------------------------------------------------------------------------- 1 | DomoAI API 2 | === 3 | **非官方** DomoAI API. 4 | 5 | [English Documentation](README.md) 6 | 7 | 特性 8 | --- 9 | 10 | 1. 支持全部 [DomoAI](https://domoai.app/) 的AI指令. 11 | + `/video` 12 | + `/move` 13 | + `/animate` 14 | + `/gen` 15 | + `/real` 16 | 2. 支持 `upscale`(放大) 和 `vary`(变体) 17 | 3. 支持使用 `FAST`(高速)模式 和 `RELAX`(休闲)模式 18 | 4. 支持查询任务状态 19 | 5. 支持事件回调 20 | 6. 支持 Docker 部署 21 | 7. 提供基于 Streamlit 的示例 22 | 23 | TODO 24 | --- 25 | 26 | - [ ] 多账号/账号池 27 | - [ ] 操作队列 28 | - [ ] 标准化错误的响应 29 | - [ ] 完善文档 30 | 31 | 参考链接 32 | --- 33 | 34 | - [DomoAI Official Website](https://domoai.app/) 35 | - [Github - discord.py-self](https://github.com/dolfies/discord.py-self) 36 | - [Github - FastAPI](https://github.com/tiangolo/fastapi) -------------------------------------------------------------------------------- /app/dependencies.py: -------------------------------------------------------------------------------- 1 | from fastapi import Depends, HTTPException 2 | from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer 3 | from starlette import status 4 | 5 | from app.settings import Settings, get_settings 6 | 7 | security = HTTPBearer(scheme_name="DomoAI") 8 | 9 | 10 | async def api_auth( 11 | token: HTTPAuthorizationCredentials = Depends(security), 12 | settings: Settings = Depends(get_settings) 13 | ): 14 | if settings.api_auth_token and token.credentials != settings.api_auth_token: 15 | raise HTTPException( 16 | status_code=status.HTTP_401_UNAUTHORIZED, 17 | detail="Could not validate credentials", 18 | headers={"WWW-Authenticate": "Bearer"}, 19 | ) 20 | -------------------------------------------------------------------------------- /streamlit_demo.dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.10 2 | 3 | WORKDIR /code 4 | 5 | COPY ./requirements.txt /code/requirements.txt 6 | COPY ./requirements-dev.txt /code/requirements-dev.txt 7 | 8 | RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt 9 | RUN pip install --no-cache-dir --upgrade -r /code/requirements-dev.txt 10 | 11 | COPY ./app /code/app 12 | COPY ./streamlit_demo /code/streamlit_demo 13 | 14 | ENV PYTHONUNBUFFERED 1 15 | 16 | ENV STREAMLIT_BASE_URL '' 17 | ENV STREAMLIT_AUTH '' 18 | ENV API_AUTH_TOKEN '' 19 | 20 | EXPOSE 8501 21 | 22 | HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health 23 | 24 | ENTRYPOINT ["python", "-m", "streamlit", "run", "/code/streamlit_demo/🏠_Home.py", "--server.port=8501", "--server.address=0.0.0.0"] -------------------------------------------------------------------------------- /app/settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import lru_cache 3 | from typing import Optional 4 | 5 | from pydantic_settings import BaseSettings, SettingsConfigDict 6 | 7 | 8 | class Settings(BaseSettings): 9 | model_config = SettingsConfigDict( 10 | env_file=os.environ.get('ENV_FILE', '.env'), 11 | env_file_encoding='utf-8', 12 | extra='ignore' 13 | ) 14 | discord_token: str 15 | discord_guild_id: int 16 | discord_channel_id: int 17 | 18 | domoai_application_id: int = 1153984868804468756 19 | 20 | redis_uri: Optional[str] = None 21 | 22 | event_callback_url: Optional[str] = None 23 | 24 | cache_prefix: str = 'domoai:' 25 | 26 | api_auth_token: Optional[str] = None 27 | 28 | 29 | @lru_cache() 30 | def get_settings() -> Settings: 31 | return Settings() 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | DomoAI API 2 | === 3 | **Unofficial** DomoAI API. 4 | 5 | [中文说明/Chinese Documentation](README_CN.md) 6 | 7 | Feature 8 | --- 9 | 10 | 1. Support all AI commands of [DomoAI](https://domoai.app/). 11 | + `/video` 12 | + `/move` 13 | + `/animate` 14 | + `/gen` 15 | + `/real` 16 | 2. Support `upscale` and `vary` 17 | 3. Support `FAST` mode and `RELAX` mode 18 | 4. Support query task status 19 | 5. Support event callback 20 | 6. Support Docker 21 | 7. Streamlit Demo provided 22 | 23 | TODO 24 | --- 25 | 26 | - [ ] Multi-Account and Account Pool 27 | - [ ] Action Queue 28 | - [ ] Standardized error response format 29 | - [ ] Usage documentation 30 | 31 | Reference 32 | --- 33 | 34 | - [DomoAI Official Website](https://domoai.app/) 35 | - [Github - discord.py-self](https://github.com/dolfies/discord.py-self) 36 | - [Github - FastAPI](https://github.com/tiangolo/fastapi) -------------------------------------------------------------------------------- /app/models/move-models.json: -------------------------------------------------------------------------------- 1 | [{"id": 16004, "name": "Realistic v1", "description": "Realistic model", "cover": {"url": "https://imgo.domoai.app/ai-model/22a058ef-ca13-4e33-b7b6-61577674f84a.png", "width": 1024, "height": 1648}, "prompt_args": "--real v1"}, {"id": 16001, "name": "Anime v6", "description": "Detail anime style 2.0", "cover": {"url": "https://imgo.domoai.app/ai-model/9657fb98-93ef-4052-9df9-4669f34451d6.png", "width": 1024, "height": 1648}, "prompt_args": "--ani v6"}, {"id": 16002, "name": "Anime v2", "description": "Japanese anime style", "cover": {"url": "https://imgo.domoai.app/ai-model/2b2da758-b341-4738-8812-c1163338e14e.png", "width": 1024, "height": 1648}, "prompt_args": "--ani v2"}, {"id": 16003, "name": "Anime v1.1", "description": "Flat color anime style 2.0", "cover": {"url": "https://imgo.domoai.app/ai-model/dbe689f5-795d-424d-8c54-2503c9abc316.png", "width": 1024, "height": 1648}, "prompt_args": "--ani v1.1"}] -------------------------------------------------------------------------------- /app/event_callback.py: -------------------------------------------------------------------------------- 1 | import enum 2 | from typing import Optional 3 | 4 | import httpx 5 | from tenacity import retry, wait_fixed, stop_after_attempt 6 | 7 | from app.schema import TaskCacheData, TaskStateOut 8 | 9 | 10 | class EventType(enum.Enum): 11 | TASK_SUCCESS = "TASK_SUCCESS" 12 | 13 | 14 | class EventCallback: 15 | 16 | def __init__(self, callback_url: Optional[str]): 17 | self.callback_url = callback_url 18 | 19 | @retry(wait=wait_fixed(2), stop=stop_after_attempt(3), reraise=False) 20 | async def send_task_success(self, task_id: str, data: TaskCacheData): 21 | if not self.callback_url: 22 | return 23 | out = TaskStateOut.from_cache_data(data) 24 | async with httpx.AsyncClient() as client: 25 | await client.post(self.callback_url, json={ 26 | 'event': EventType.TASK_SUCCESS.value, 27 | 'task_id': task_id, 28 | 'data': out.model_dump_json() 29 | }) 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 T0nYie1dYaw 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/update_models.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import httpx 5 | from pydantic_settings import BaseSettings, SettingsConfigDict 6 | 7 | current_dir = os.path.dirname(os.path.realpath(__file__)) 8 | 9 | 10 | class ScriptSettings(BaseSettings): 11 | model_config = SettingsConfigDict( 12 | env_file=os.path.join(current_dir, '../.env'), 13 | env_file_encoding='utf-8', 14 | extra='ignore' 15 | ) 16 | domoai_web_token: str 17 | 18 | 19 | def update_models(req_end_path: str, filename: str, web_token: str): 20 | res = httpx.get(f'https://api.domoai.app/web-post/model/{req_end_path}?offset=0&limit=100&locale=en', headers={ 21 | 'Authorization': f'Bearer {web_token}' 22 | }) 23 | 24 | models = res.json()['data']['models'] 25 | with open(os.path.join(current_dir, f'../app/models/{filename}.json'), 'w') as f: 26 | json.dump(models, f) 27 | 28 | 29 | if __name__ == '__main__': 30 | settings = ScriptSettings() 31 | 32 | update_models(req_end_path='video-models', filename='v2v-models', web_token=settings.domoai_web_token) 33 | update_models(req_end_path='gen-models', filename='gen-models', web_token=settings.domoai_web_token) 34 | update_models(req_end_path='move-models', filename='move-models', web_token=settings.domoai_web_token) 35 | -------------------------------------------------------------------------------- /streamlit_demo/auth.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import streamlit as st 4 | from streamlit_authenticator import Authenticate 5 | 6 | 7 | def check_password(): 8 | STREAMLIT_AUTH = os.environ.get("STREAMLIT_AUTH") 9 | 10 | if not STREAMLIT_AUTH: 11 | return True 12 | 13 | username, password = STREAMLIT_AUTH.split(":") 14 | 15 | if not username or not password: 16 | return True 17 | 18 | """Returns `True` if the user had a correct password.""" 19 | 20 | authenticator = Authenticate( 21 | credentials={"usernames": {username: {"email": "demo@example.com", "name": "demo", "password": password}}}, 22 | cookie_name="domo_ai_demo", 23 | cookie_key="zgt!uay.aug4juv*FAF" 24 | ) 25 | 26 | name, authentication_status, username = authenticator.login('main') 27 | 28 | if authentication_status: 29 | st.markdown(""" 30 | 32 | """, unsafe_allow_html=True) 33 | with st.sidebar.container(): 34 | authenticator.logout('Logout', 'main') 35 | return True 36 | elif authentication_status == False: 37 | st.error('Username/password is incorrect') 38 | return False 39 | elif authentication_status == None: 40 | st.warning('Please enter your username and password') 41 | return False 42 | -------------------------------------------------------------------------------- /app/models.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import os 3 | from functools import lru_cache 4 | from typing import List, Optional 5 | 6 | from pydantic import TypeAdapter 7 | 8 | from app.schema import GenModelInfo, MoveModelInfo, VideoModelInfo 9 | 10 | current_dir = os.path.dirname(os.path.realpath(__file__)) 11 | 12 | 13 | @lru_cache() 14 | def get_v2v_models() -> List[VideoModelInfo]: 15 | ta = TypeAdapter(List[VideoModelInfo]) 16 | v2v_models_json_path = os.path.join(current_dir, "models/v2v-models.json") 17 | with open(v2v_models_json_path, 'r') as f: 18 | models = ta.validate_json(f.read()) 19 | return models 20 | 21 | 22 | @lru_cache() 23 | def get_move_models() -> List[MoveModelInfo]: 24 | ta = TypeAdapter(List[MoveModelInfo]) 25 | v2v_models_json_path = os.path.join(current_dir, "models/move-models.json") 26 | with open(v2v_models_json_path, 'r') as f: 27 | models = ta.validate_json(f.read()) 28 | return models 29 | 30 | 31 | @lru_cache() 32 | def get_gen_models() -> List[GenModelInfo]: 33 | ta = TypeAdapter(List[GenModelInfo]) 34 | v2v_models_json_path = os.path.join(current_dir, "models/gen-models.json") 35 | with open(v2v_models_json_path, 'r') as f: 36 | models = ta.validate_json(f.read()) 37 | return models 38 | 39 | 40 | @lru_cache() 41 | def get_v2v_model_info_by_instructions(instructions: str) -> Optional[VideoModelInfo]: 42 | all_model_info = get_v2v_models() 43 | for model in all_model_info: 44 | if instructions in model.prompt_args: 45 | return model 46 | return None 47 | 48 | 49 | GenModel = enum.Enum( 50 | 'GenModel', 51 | {x.name.replace(' ', '_').upper(): x.prompt_args.removeprefix('--') for x in get_gen_models()} 52 | ) 53 | 54 | MoveModel = enum.Enum( 55 | 'MoveModel', 56 | {x.name.replace(' ', '_').upper(): x.prompt_args.removeprefix('--') for x in get_move_models()} 57 | ) 58 | 59 | VideoModel = enum.Enum( 60 | 'VideoModel', 61 | {x.name.replace(' ', '_').upper(): x.prompt_args.removeprefix('--') for x in get_v2v_models()} 62 | ) 63 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 46 | -------------------------------------------------------------------------------- /streamlit_demo/pages/Animate.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import httpx 4 | import streamlit as st 5 | from streamlit.runtime.uploaded_file_manager import UploadedFile 6 | 7 | from app.schema import AnimateLength, AnimateIntensity, Mode 8 | from streamlit_demo.auth import check_password 9 | from streamlit_demo.utils import polling_check_state, BASE_URL, BASE_HEADERS 10 | 11 | if not check_password(): 12 | st.stop() 13 | 14 | st.title("Animate") 15 | 16 | 17 | async def run_animate(prompt, intensity, length, image: UploadedFile, mode): 18 | async with httpx.AsyncClient(base_url=BASE_URL, headers=BASE_HEADERS) as client: 19 | response = await client.post('/v1/animate', data={ 20 | "prompt": prompt, 21 | "intensity": intensity, 22 | "length": length, 23 | "mode": mode if mode != 'auto' else None 24 | }, files={'image': (image.name, image.read(), image.type)}, timeout=30) 25 | if not response.is_success: 26 | st.error(f"Generate Fail: {response}") 27 | return None 28 | 29 | task_id = response.json()['task_id'] 30 | 31 | result = await polling_check_state(task_id=task_id) 32 | return result['videos'][0]['proxy_url'] 33 | 34 | 35 | with st.form("real_form", border=True): 36 | mode = st.radio(label="Mode(*)", options=['auto'] + list(map(lambda x: x.value, Mode)), horizontal=True) 37 | 38 | length = st.radio(label="Length(*)", options=list(map(lambda x: x.value, AnimateLength)), horizontal=True) 39 | 40 | intensity = st.radio(label="Intensity(*)", options=list(map(lambda x: x.value, AnimateIntensity)), horizontal=True) 41 | 42 | image = st.file_uploader(label="Source Image(*)", type=['jpg', 'png']) 43 | 44 | prompt = st.text_area(label="Prompt") 45 | 46 | submitted = st.form_submit_button("Submit") 47 | 48 | if submitted: 49 | source_col, result_col = st.columns(2) 50 | with source_col: 51 | st.text("Source Image") 52 | if image: 53 | st.image(image) 54 | 55 | with result_col: 56 | st.text("Result Video") 57 | result_video = st.empty() 58 | 59 | with result_col: 60 | with st.spinner('Wait for completion...'): 61 | videos_url = asyncio.run( 62 | run_animate(prompt=prompt, intensity=intensity, length=length, image=image, mode=mode)) 63 | if videos_url: 64 | result_video.video(videos_url) 65 | st.success('Done!') 66 | -------------------------------------------------------------------------------- /streamlit_demo/pages/Move.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import httpx 4 | import streamlit as st 5 | from streamlit.runtime.uploaded_file_manager import UploadedFile 6 | 7 | from app.models import MoveModel 8 | from app.schema import VideoLength, Mode, VideoKey 9 | from streamlit_demo.auth import check_password 10 | from streamlit_demo.utils import polling_check_state, BASE_URL, BASE_HEADERS 11 | 12 | if not check_password(): 13 | st.stop() 14 | 15 | st.title("Move") 16 | 17 | 18 | async def run_move(prompt, model, length, video: UploadedFile, image: UploadedFile, mode, video_key): 19 | async with httpx.AsyncClient(base_url=BASE_URL, headers=BASE_HEADERS) as client: 20 | response = await client.post('/v1/move', data={ 21 | "prompt": prompt, 22 | "model": model, 23 | "length": length, 24 | "mode": mode if mode != 'auto' else None, 25 | "video_key": video_key if video_key != 'None' else None, 26 | }, files={'video': (video.name, video.read(), video.type), 'image': (image.name, image.read(), image.type)}, 27 | timeout=30) 28 | if not response.is_success: 29 | st.error(f"Generate Fail: {response}") 30 | return None 31 | 32 | task_id = response.json()['task_id'] 33 | 34 | result = await polling_check_state(task_id=task_id) 35 | return result['videos'][0]['proxy_url'] 36 | 37 | 38 | with st.form("move_form", border=True): 39 | mode = st.radio(label="Mode(*)", options=['auto'] + list(map(lambda x: x.value, Mode)), horizontal=True) 40 | 41 | length = st.radio(label="Length(*)", options=list(map(lambda x: x.value, VideoLength)), horizontal=True) 42 | 43 | model = st.selectbox(label="Model(*)", options=list(map(lambda x: x.value, MoveModel))) 44 | 45 | video_key = st.radio(label="Video Key", options=['None'] + list(map(lambda x: x.value, VideoKey)), 46 | horizontal=True) 47 | 48 | prompt = st.text_area(label="Prompt(*)") 49 | 50 | image = st.file_uploader(label="Source Image(*)", type=['jpg', 'png']) 51 | video = st.file_uploader(label="Source Video(*)", type=['mp4']) 52 | 53 | submitted = st.form_submit_button("Submit") 54 | 55 | if submitted: 56 | source_image_col, source_video_col, result_col = st.columns(3) 57 | with source_image_col: 58 | st.text("Source Image") 59 | if image: 60 | st.image(image) 61 | 62 | with source_video_col: 63 | st.text("Source Video") 64 | if video: 65 | st.video(video) 66 | 67 | with result_col: 68 | st.text("Result Video") 69 | result_video = st.empty() 70 | 71 | with result_col: 72 | with st.spinner('Wait for completion...'): 73 | # asyncio.run(asyncio.sleep(5)) 74 | # result_video.video(video) 75 | video_url = asyncio.run( 76 | run_move(prompt=prompt, model=model, length=length, video=video, image=image, mode=mode, video_key=video_key)) 77 | if video_url: 78 | result_video.video(video_url) 79 | st.success('Done!') 80 | -------------------------------------------------------------------------------- /app/cache.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any, Optional 4 | 5 | from redis.asyncio import Redis, from_url 6 | 7 | from app.schema import TaskCacheData 8 | 9 | 10 | class Cache: 11 | def __init__(self, prefix: str = ''): 12 | self.prefix = prefix 13 | 14 | async def set_value(self, key: str, value: Any, ex: Optional[int] = None): 15 | pass 16 | 17 | async def get_value(self, key: str): 18 | pass 19 | 20 | async def close(self): 21 | pass 22 | 23 | @staticmethod 24 | def __get_message_id2task_id_key(message_id: str) -> str: 25 | return f'message_id2task_id:{message_id}' 26 | 27 | async def set_message_id2task_id( 28 | self, 29 | message_id: str, 30 | task_id: str 31 | ): 32 | await self.set_value(key=self.__get_message_id2task_id_key(message_id), value=task_id) 33 | 34 | async def get_task_id_by_message_id(self, message_id: str) -> Optional[str]: 35 | value = await self.get_value(key=self.__get_message_id2task_id_key(message_id)) 36 | if value: 37 | try: 38 | return str(value) 39 | except Exception as e: 40 | # TODO: 41 | pass 42 | return None 43 | 44 | @staticmethod 45 | def __get_task_id2data_key(task_id: str) -> str: 46 | return f'task_id2data:{task_id}' 47 | 48 | async def set_task_id2data( 49 | self, 50 | task_id: str, 51 | data: TaskCacheData 52 | ): 53 | await self.set_value(key=self.__get_task_id2data_key(task_id), value=data.model_dump_json()) 54 | 55 | async def get_task_data_by_id(self, task_id: str) -> Optional[TaskCacheData]: 56 | value = await self.get_value(key=self.__get_task_id2data_key(task_id)) 57 | if value: 58 | try: 59 | return TaskCacheData.model_validate_json(value) 60 | except Exception as e: 61 | # TODO: 62 | pass 63 | return None 64 | 65 | 66 | class MemoryCache(Cache): 67 | def __init__(self, prefix: str = ''): 68 | super().__init__(prefix=prefix) 69 | self.data = {} 70 | 71 | async def set_value(self, key: str, value: Any, ex: Optional[int] = None): 72 | self.data[f"{self.prefix}{key}"] = value 73 | 74 | async def get_value(self, key: str): 75 | return self.data.get(f"{self.prefix}{key}") 76 | 77 | 78 | class RedisCache(Cache): 79 | def __init__(self, redis: Redis, prefix: str = ''): 80 | super().__init__(prefix=prefix) 81 | self.redis = redis 82 | 83 | async def set_value(self, key: str, value: Any, ex: Optional[int] = None): 84 | await self.redis.set(name=f"{self.prefix}{key}", value=value, ex=ex) 85 | 86 | async def get_value(self, key: str): 87 | return await self.redis.get(name=f"{self.prefix}{key}") 88 | 89 | async def close(self): 90 | await self.redis.close() 91 | 92 | @staticmethod 93 | async def init_redis_pool(redis_uri: str) -> Redis: 94 | redis = await from_url( 95 | redis_uri, 96 | encoding="utf-8", 97 | decode_responses=True, 98 | ) 99 | return redis 100 | -------------------------------------------------------------------------------- /streamlit_demo/utils.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | from typing import Optional, List, Callable 4 | 5 | import httpx 6 | import streamlit as st 7 | from pydantic import BaseModel 8 | 9 | BASE_URL = os.environ.get('STREAMLIT_BASE_URL') 10 | 11 | API_AUTH_TOKEN = os.environ.get('API_AUTH_TOKEN') 12 | 13 | if not BASE_URL: 14 | BASE_URL = 'http://127.0.0.1:8000' 15 | 16 | if API_AUTH_TOKEN: 17 | BASE_HEADERS = { 18 | 'Authorization': f'Bearer {API_AUTH_TOKEN}' 19 | } 20 | else: 21 | BASE_HEADERS = { 22 | 23 | } 24 | 25 | 26 | async def polling_check_state(task_id: str) -> Optional[dict]: 27 | async with httpx.AsyncClient(base_url=BASE_URL, headers=BASE_HEADERS) as client: 28 | while True: 29 | response = await client.get(f'/v1/task-data/{task_id}') 30 | if response.status_code == 404: 31 | return None 32 | response_json = response.json() 33 | if response_json['status'] == 'SUCCESS': 34 | return response_json 35 | await asyncio.sleep(1) 36 | 37 | 38 | def build_upscale_vary_buttons( 39 | task_id: str, 40 | upscale_indices: List[int], 41 | vary_indices: List[int], 42 | on_click_upscale: Callable[[str, int], None], 43 | on_click_vary: Callable[[str, int], None] 44 | ): 45 | upscale_container, vary_container = st.columns(2) 46 | upscale_cols_1, upscale_cols_2 = upscale_container.columns(2) 47 | upscale_cols_3, upscale_cols_4 = upscale_container.columns(2) 48 | upscale_cols_1.button( 49 | label=f":mag: U1", 50 | use_container_width=True, 51 | disabled=1 not in upscale_indices, 52 | on_click=on_click_upscale, 53 | args=(task_id, 1), 54 | key=f'U1-{task_id}' 55 | ) 56 | upscale_cols_2.button( 57 | label=f":mag: U2", 58 | use_container_width=True, 59 | disabled=2 not in upscale_indices, 60 | on_click=on_click_upscale, 61 | args=(task_id, 2), 62 | key=f'U2-{task_id}' 63 | ) 64 | upscale_cols_3.button( 65 | label=f":mag: U3", 66 | use_container_width=True, 67 | disabled=3 not in upscale_indices, 68 | on_click=on_click_upscale, 69 | args=(task_id, 3), 70 | key=f'U3-{task_id}' 71 | ) 72 | upscale_cols_4.button( 73 | label=f":mag: U4", 74 | use_container_width=True, 75 | disabled=4 not in upscale_indices, 76 | on_click=on_click_upscale, 77 | args=(task_id, 4), 78 | key=f'U4-{task_id}' 79 | ) 80 | 81 | vary_cols_1, vary_cols_2 = vary_container.columns(2) 82 | vary_cols_3, vary_cols_4 = vary_container.columns(2) 83 | vary_cols_1.button( 84 | label=f":magic_wand: V1", 85 | use_container_width=True, 86 | disabled=1 not in vary_indices, 87 | on_click=on_click_vary, 88 | args=(task_id, 1), 89 | key=f'V1-{task_id}' 90 | ) 91 | vary_cols_2.button( 92 | label=f":magic_wand: V2", 93 | use_container_width=True, 94 | disabled=2 not in vary_indices, 95 | on_click=on_click_vary, 96 | args=(task_id, 2), 97 | key=f'V2-{task_id}' 98 | ) 99 | vary_cols_3.button( 100 | label=f":magic_wand: V3", 101 | use_container_width=True, 102 | disabled=3 not in vary_indices, 103 | on_click=on_click_vary, 104 | args=(task_id, 3), 105 | key=f'V3-{task_id}' 106 | ) 107 | vary_cols_4.button( 108 | label=f":magic_wand: V4", 109 | use_container_width=True, 110 | disabled=4 not in vary_indices, 111 | on_click=on_click_vary, 112 | args=(task_id, 4), 113 | key=f'V4-{task_id}' 114 | ) 115 | 116 | 117 | class UVResult(BaseModel): 118 | task_id: str 119 | image_url: str 120 | upscale_indices: List[int] 121 | vary_indices: List[int] 122 | -------------------------------------------------------------------------------- /app/models/gen-models.json: -------------------------------------------------------------------------------- 1 | [{"id": 10017, "name": "Animate XL v1", "description": "Enhanced anime models", "cover": {"url": "https://imgo.domoai.app/ai-model/4b87efa6-a2b3-4fb2-8160-42d97a3a7e45.jpg", "width": 1024, "height": 1648}, "prompt_args": "--anixl v1"}, {"id": 10026, "name": "Animate XL v2", "description": "Detail anime model", "cover": {"url": "https://imgo.domoai.app/ai-model/a8668543-be38-4783-bb2d-c9643591529c.jpg", "width": 1024, "height": 1648}, "prompt_args": "--anixl v2"}, {"id": 10018, "name": "Realistic XL v1", "description": "Enhanced realistic model", "cover": {"url": "https://imgo.domoai.app/ai-model/40f18070-1dea-4c42-be5d-0cfd57fe13dd.jpg", "width": 1024, "height": 1648}, "prompt_args": "--realxl v1"}, {"id": 10027, "name": "Realistic XL v2", "description": "Dark gothic style", "cover": {"url": "https://imgo.domoai.app/ai-model/2cdd3d92-b009-4799-9c55-12ddebf91a02.jpg", "width": 1024, "height": 1648}, "prompt_args": "--realxl v2"}, {"id": 10019, "name": "Illustration XL v1", "description": "Enhanced illustration model", "cover": {"url": "https://imgo.domoai.app/ai-model/3d4e4935-d73c-4cb8-907e-c619f53a7690.jpg", "width": 1024, "height": 1648}, "prompt_args": "--illusxl v1"}, {"id": 10020, "name": "Illustration XL v2", "description": "Dark comic style", "cover": {"url": "https://imgo.domoai.app/ai-model/605a9509-7670-4c3e-a450-a14ca9748b42.jpg", "width": 1024, "height": 1648}, "prompt_args": "--illusxl v2"}, {"id": 10022, "name": "Animate v1", "description": "Dreamy japanese anime", "cover": {"url": "https://imgo.domoai.app/ai-model/18c4b38d-945c-46bd-b048-a4d1eacee19b.jpg", "width": 1024, "height": 1648}, "prompt_args": "--ani v1"}, {"id": 10011, "name": "Animate v2", "description": "Japanese anime style, more 3D", "cover": {"url": "https://imgo.domoai.app/ai-model/6a4d877e-2dc6-45c0-9435-b69426859b03.jpg", "width": 1024, "height": 1648}, "prompt_args": "--ani v2"}, {"id": 10012, "name": "Animate v3", "description": "American comics style", "cover": {"url": "https://imgo.domoai.app/ai-model/ef552c4f-365f-4446-a9e2-dd7692807f86.jpg", "width": 1024, "height": 1648}, "prompt_args": "--ani v3"}, {"id": 10006, "name": "Animate v4", "description": "CG style", "cover": {"url": "https://imgo.domoai.app/ai-model/9cd447c9-b141-42de-82d4-5e09e0e0d1bb.jpg", "width": 1024, "height": 1648}, "prompt_args": "--ani v4"}, {"id": 10023, "name": "Animate v5", "description": "Line comic style", "cover": {"url": "https://imgo.domoai.app/ai-model/9dbf7f9b-0d5d-40ba-993d-6d2a6222707f.jpg", "width": 1024, "height": 1648}, "prompt_args": "--ani v5"}, {"id": 10024, "name": "Animate v6", "description": "Watercolor anime", "cover": {"url": "https://imgo.domoai.app/ai-model/fc7eee20-aabe-42ef-b4e3-b9cc5fe0c06d.jpg", "width": 1024, "height": 1648}, "prompt_args": "--ani v6"}, {"id": 10025, "name": "Animate v7", "description": "Oilpainting anime", "cover": {"url": "https://imgo.domoai.app/ai-model/310443d1-6238-4a8e-be91-0130f0a0b865.jpg", "width": 1024, "height": 1648}, "prompt_args": "--ani v7"}, {"id": 10028, "name": "Illustration v1", "description": "3D cartoon style", "cover": {"url": "https://imgo.domoai.app/ai-model/885872cf-c546-4da0-8707-113280f85551.jpg", "width": 1024, "height": 1648}, "prompt_args": "--illus v1"}, {"id": 10029, "name": "Illustration v2", "description": "Storybook cartoon style", "cover": {"url": "https://imgo.domoai.app/ai-model/51822397-d1e4-4921-965d-21b587a84613.jpg", "width": 1024, "height": 1648}, "prompt_args": "--illus v2"}, {"id": 10030, "name": "Realistic v1", "description": "CG art", "cover": {"url": "https://imgo.domoai.app/ai-model/fc673afd-cb14-4a6c-a4ac-8736721ca98e.jpg", "width": 1024, "height": 1648}, "prompt_args": "--real v1"}, {"id": 10031, "name": "Realistic v2", "description": "Realistic portrait", "cover": {"url": "https://imgo.domoai.app/ai-model/cefd410a-4d79-44a5-9fa9-9218818a0e04.jpg", "width": 1024, "height": 1648}, "prompt_args": "--real v2"}, {"id": 10016, "name": "Realistic v3", "description": "Game character style", "cover": {"url": "https://imgo.domoai.app/ai-model/7c656b9a-5a4a-4f2a-9694-57fb6a00e23c.jpg", "width": 1024, "height": 1648}, "prompt_args": "--real v3"}] -------------------------------------------------------------------------------- /streamlit_demo/pages/Video.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Optional 3 | 4 | import httpx 5 | import streamlit as st 6 | from streamlit.runtime.uploaded_file_manager import UploadedFile 7 | 8 | from app.models import VideoModel, get_v2v_model_info_by_instructions 9 | from app.schema import VideoLength, VideoReferMode, Mode, VideoKey 10 | from streamlit_demo.auth import check_password 11 | from streamlit_demo.utils import polling_check_state, BASE_URL, BASE_HEADERS 12 | 13 | if not check_password(): 14 | st.stop() 15 | 16 | st.title("Video") 17 | 18 | 19 | async def run_video(prompt, refer_mode, model, length, video: UploadedFile, mode, 20 | video_key, subject_only, 21 | lip_sync, image: Optional[UploadedFile] = None): 22 | async with httpx.AsyncClient(base_url=BASE_URL, headers=BASE_HEADERS) as client: 23 | files = {'video': (video.name, video.read(), video.type)} 24 | if image: 25 | files['image'] = (image.name, image.read(), image.type) 26 | response = await client.post('/v1/video', data={ 27 | "prompt": prompt, 28 | "refer_mode": refer_mode, 29 | "model": model, 30 | "length": length, 31 | "mode": mode if mode != 'auto' else None, 32 | "video_key": video_key if video_key != 'None' and subject_only is not True else None, 33 | "subject_only": subject_only, 34 | "lip_sync": lip_sync, 35 | }, files=files, timeout=30) 36 | if not response.is_success: 37 | st.error(f"Generate Fail: {response}") 38 | return None 39 | 40 | task_id = response.json().get('task_id', None) 41 | if task_id is None: 42 | return False, response.json() 43 | 44 | result = await polling_check_state(task_id=task_id) 45 | return True, result['videos'][0]['proxy_url'] 46 | 47 | 48 | mode = st.radio(label="Mode(*)", options=['auto'] + list(map(lambda x: x.value, Mode)), horizontal=True) 49 | 50 | length = st.radio(label="Length(*)", options=list(map(lambda x: x.value, VideoLength)), horizontal=True) 51 | 52 | video_models_value = list(map(lambda x: x.value, VideoModel)) 53 | 54 | model = st.selectbox(label="Model(*)", options=video_models_value, 55 | index=video_models_value.index(VideoModel.ANIME_V1.value)) 56 | 57 | model_info = get_v2v_model_info_by_instructions(model) 58 | 59 | refer_mode = st.radio(label="Refer Mode(*)", 60 | options=list(map(lambda x: x.value, filter(lambda x: x in model_info.allowed_refer_modes, 61 | list(VideoReferMode)))), 62 | horizontal=True) 63 | 64 | lip_sync = st.checkbox(label="Lip Sync", key='lips', disabled=model_info.allowed_lip_sync is False) 65 | 66 | subject_only = st.checkbox(label="Subject Only", key='so') 67 | 68 | video_key = st.selectbox(label="Video Key", options=['None'] + list(map(lambda x: x.value, VideoKey)), 69 | disabled=subject_only) 70 | 71 | prompt = st.text_area(label="Prompt(*)") 72 | 73 | video = st.file_uploader(label="Source Video(*)", type=['mp4']) 74 | 75 | image = None 76 | if model_info.allowed_reference_image: 77 | image = st.file_uploader(label="Source Image(*)", type=['jpg', 'jpeg', 'png']) 78 | 79 | submitted = st.button("Submit") 80 | 81 | if submitted: 82 | source_col, result_col = st.columns(2) 83 | with source_col: 84 | st.text("Source Video") 85 | if video: 86 | st.video(video) 87 | with source_col: 88 | st.text("Source Image") 89 | if image: 90 | st.image(image) 91 | 92 | with result_col: 93 | st.text("Result Video") 94 | result_video = st.empty() 95 | 96 | with result_col: 97 | with st.spinner('Wait for completion...'): 98 | # asyncio.run(asyncio.sleep(5)) 99 | # result_video.video(video) 100 | status, result = asyncio.run( 101 | run_video( 102 | prompt=prompt, refer_mode=refer_mode, model=model, length=length, video=video, mode=mode, 103 | video_key=video_key, subject_only=subject_only, lip_sync=lip_sync, image=image, 104 | ) 105 | ) 106 | if not status: 107 | st.text(result) 108 | 109 | if status and result: 110 | result_video.video(result) 111 | st.success('Done!') 112 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 2 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 3 | 4 | # User-specific stuff 5 | .idea/**/workspace.xml 6 | .idea/**/tasks.xml 7 | .idea/**/usage.statistics.xml 8 | .idea/**/dictionaries 9 | .idea/**/shelf 10 | 11 | # AWS User-specific 12 | .idea/**/aws.xml 13 | 14 | # Generated files 15 | .idea/**/contentModel.xml 16 | 17 | # Sensitive or high-churn files 18 | .idea/**/dataSources/ 19 | .idea/**/dataSources.ids 20 | .idea/**/dataSources.local.xml 21 | .idea/**/sqlDataSources.xml 22 | .idea/**/dynamic.xml 23 | .idea/**/uiDesigner.xml 24 | .idea/**/dbnavigator.xml 25 | 26 | # Gradle 27 | .idea/**/gradle.xml 28 | .idea/**/libraries 29 | 30 | # Gradle and Maven with auto-import 31 | # When using Gradle or Maven with auto-import, you should exclude module files, 32 | # since they will be recreated, and may cause churn. Uncomment if using 33 | # auto-import. 34 | # .idea/artifacts 35 | # .idea/compiler.xml 36 | # .idea/jarRepositories.xml 37 | # .idea/modules.xml 38 | # .idea/*.iml 39 | # .idea/modules 40 | # *.iml 41 | # *.ipr 42 | 43 | # CMake 44 | cmake-build-*/ 45 | 46 | # Mongo Explorer plugin 47 | .idea/**/mongoSettings.xml 48 | 49 | # File-based project format 50 | *.iws 51 | 52 | # IntelliJ 53 | out/ 54 | 55 | # mpeltonen/sbt-idea plugin 56 | .idea_modules/ 57 | 58 | # JIRA plugin 59 | atlassian-ide-plugin.xml 60 | 61 | # Cursive Clojure plugin 62 | .idea/replstate.xml 63 | 64 | # Crashlytics plugin (for Android Studio and IntelliJ) 65 | com_crashlytics_export_strings.xml 66 | crashlytics.properties 67 | crashlytics-build.properties 68 | fabric.properties 69 | 70 | # Editor-based Rest Client 71 | .idea/httpRequests 72 | 73 | # Android studio 3.1+ serialized cache file 74 | .idea/caches/build_file_checksums.ser 75 | 76 | # Byte-compiled / optimized / DLL files 77 | __pycache__/ 78 | *.py[cod] 79 | *$py.class 80 | 81 | # C extensions 82 | *.so 83 | 84 | # Distribution / packaging 85 | .Python 86 | build/ 87 | develop-eggs/ 88 | dist/ 89 | downloads/ 90 | eggs/ 91 | .eggs/ 92 | lib/ 93 | lib64/ 94 | parts/ 95 | sdist/ 96 | var/ 97 | wheels/ 98 | share/python-wheels/ 99 | *.egg-info/ 100 | .installed.cfg 101 | *.egg 102 | MANIFEST 103 | 104 | # PyInstaller 105 | # Usually these files are written by a python script from a template 106 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 107 | *.manifest 108 | *.spec 109 | 110 | # Installer logs 111 | pip-log.txt 112 | pip-delete-this-directory.txt 113 | 114 | # Unit test / coverage reports 115 | htmlcov/ 116 | .tox/ 117 | .nox/ 118 | .coverage 119 | .coverage.* 120 | .cache 121 | nosetests.xml 122 | coverage.xml 123 | *.cover 124 | *.py,cover 125 | .hypothesis/ 126 | .pytest_cache/ 127 | cover/ 128 | 129 | # Translations 130 | *.mo 131 | *.pot 132 | 133 | # Django stuff: 134 | *.log 135 | local_settings.py 136 | db.sqlite3 137 | db.sqlite3-journal 138 | 139 | # Flask stuff: 140 | instance/ 141 | .webassets-cache 142 | 143 | # Scrapy stuff: 144 | .scrapy 145 | 146 | # Sphinx documentation 147 | docs/_build/ 148 | 149 | # PyBuilder 150 | .pybuilder/ 151 | target/ 152 | 153 | # Jupyter Notebook 154 | .ipynb_checkpoints 155 | 156 | # IPython 157 | profile_default/ 158 | ipython_config.py 159 | 160 | # pyenv 161 | # For a library or package, you might want to ignore these files since the code is 162 | # intended to run in multiple environments; otherwise, check them in: 163 | # .python-version 164 | 165 | # pipenv 166 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 167 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 168 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 169 | # install all needed dependencies. 170 | #Pipfile.lock 171 | 172 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 173 | __pypackages__/ 174 | 175 | # Celery stuff 176 | celerybeat-schedule 177 | celerybeat.pid 178 | 179 | # SageMath parsed files 180 | *.sage.py 181 | 182 | # Environments 183 | .env 184 | .venv 185 | env/ 186 | venv/ 187 | ENV/ 188 | env.bak/ 189 | venv.bak/ 190 | 191 | # Spyder project settings 192 | .spyderproject 193 | .spyproject 194 | 195 | # Rope project settings 196 | .ropeproject 197 | 198 | # mkdocs documentation 199 | /site 200 | 201 | # mypy 202 | .mypy_cache/ 203 | .dmypy.json 204 | dmypy.json 205 | 206 | # Pyre type checker 207 | .pyre/ 208 | 209 | # pytype static type analyzer 210 | .pytype/ 211 | 212 | # Cython debug symbols 213 | cython_debug/ 214 | 215 | .env.* 216 | !.env.template 217 | publish*.sh -------------------------------------------------------------------------------- /app/schema.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import enum 4 | from typing import Optional, List, Dict 5 | 6 | import discord 7 | from pydantic import BaseModel 8 | 9 | 10 | class Mode(enum.Enum): 11 | FAST = "fast" 12 | RELAX = "relax" 13 | 14 | 15 | class AnimateIntensity(enum.Enum): 16 | LOW = "low" 17 | MEDIUM = "mid" 18 | HIGH = "high" 19 | 20 | 21 | class AnimateLength(enum.Enum): 22 | LENGTH_3S = "3s" 23 | LENGTH_5S = "5s" 24 | 25 | 26 | class BaseModelInfo(BaseModel): 27 | name: str 28 | prompt_args: str 29 | 30 | 31 | class GenModelInfo(BaseModelInfo): 32 | pass 33 | 34 | 35 | class MoveModelInfo(BaseModelInfo): 36 | pass 37 | 38 | 39 | class VideoModelInfo(BaseModelInfo): 40 | allowed_refer_modes: List[VideoReferMode] 41 | allowed_lip_sync: bool 42 | allowed_reference_image: bool 43 | 44 | 45 | class VideoReferMode(enum.Enum): 46 | REFER_TO_SOURCE_VIDEO_MORE = "VIDEO_MORE" 47 | REFER_TO_MY_PROMPT_MORE = "PROMPT_MORE" 48 | 49 | 50 | class VideoApiError(enum.Enum): 51 | VIDEO_MODEL_ERROR: 10000 52 | NOT_ALLOW_REFER: 10001 53 | NOT_ALLOW_LIP_SYNC: 10002 54 | MODEL_NEED_REFERENCE_IMAGE: 10003 55 | 56 | 57 | class VideoKey(enum.Enum): 58 | WHITE = "WHITE" 59 | BLACK = "BLACK" 60 | RED = "RED" 61 | ORANGE = "ORANGE" 62 | YELLOW = "YELLOW" 63 | CYAN = "CYAN" 64 | GREEN = "GREEN" 65 | BLUE = "BLUE" 66 | PINK = "PINK" 67 | BROWN = "BROWN" 68 | PURPLE = "PURPLE" 69 | GREY = "GREY" 70 | GRAY = "GRAY" 71 | MAGENTA = "MAGENTA" 72 | NAVY = "NAVY" 73 | BEIGE = "BEIGE" 74 | GOLD = "GOLD" 75 | SILVER = "SILVER" 76 | 77 | 78 | class VideoLength(enum.Enum): 79 | LENGTH_3S = "3s" 80 | LENGTH_5S = "5s" 81 | LENGTH_10S = "10s" 82 | LENGTH_20S = "20s" 83 | 84 | 85 | class TaskAsset(BaseModel): 86 | size: Optional[int] = None 87 | width: Optional[int] = None 88 | height: Optional[int] = None 89 | url: str 90 | proxy_url: str 91 | content_type: Optional[str] = None 92 | 93 | @staticmethod 94 | def from_attachment(attachment: discord.Attachment) -> TaskAsset: 95 | return TaskAsset( 96 | size=attachment.size, 97 | width=attachment.width, 98 | height=attachment.height, 99 | url=attachment.url, 100 | proxy_url=attachment.proxy_url, 101 | content_type=attachment.content_type 102 | ) 103 | 104 | 105 | class TaskStatus(enum.Enum): 106 | RUNNING = "RUNNING" 107 | SUCCESS = "SUCCESS" 108 | 109 | 110 | class TaskCommand(enum.Enum): 111 | GEN = "GEN" 112 | REAL = "REAL" 113 | MOVE = "MOVE" 114 | VIDEO = "VIDEO" 115 | ANIMATE = "ANIMATE" 116 | 117 | 118 | class TaskCacheData(BaseModel): 119 | command: TaskCommand 120 | channel_id: str 121 | guild_id: Optional[str] 122 | message_id: str 123 | images: Optional[List[TaskAsset]] = None 124 | videos: Optional[List[TaskAsset]] = None 125 | status: TaskStatus 126 | upscale_custom_ids: Optional[Dict[str, str]] = None 127 | vary_custom_ids: Optional[Dict[str, str]] = None 128 | 129 | 130 | class CreateTaskOut(BaseModel): 131 | success: bool 132 | task_id: str 133 | message_id: str 134 | 135 | 136 | class TaskStateOut(BaseModel): 137 | command: TaskCommand 138 | channel_id: str 139 | guild_id: Optional[str] 140 | message_id: str 141 | images: Optional[List[TaskAsset]] = None 142 | videos: Optional[List[TaskAsset]] = None 143 | status: TaskStatus 144 | upscale_indices: Optional[List[int]] = None 145 | vary_indices: Optional[List[int]] = None 146 | 147 | @staticmethod 148 | def from_cache_data(data: TaskCacheData) -> TaskStateOut: 149 | upscale_index_map = { 150 | "U1": 1, 151 | "U2": 2, 152 | "U3": 3, 153 | "U4": 4 154 | } 155 | vary_index_map = { 156 | "V1": 1, 157 | "V2": 2, 158 | "V3": 3, 159 | "V4": 4 160 | } 161 | if data.upscale_custom_ids is not None: 162 | upscale_indices = [] 163 | for key, value in data.upscale_custom_ids.items(): 164 | index = upscale_index_map.get(key) 165 | if index is not None: 166 | upscale_indices.append(index) 167 | upscale_indices.sort() 168 | else: 169 | upscale_indices = None 170 | if data.vary_custom_ids is not None: 171 | vary_indices = [] 172 | for key, value in data.vary_custom_ids.items(): 173 | index = vary_index_map.get(key) 174 | if index is not None: 175 | vary_indices.append(index) 176 | vary_indices.sort() 177 | else: 178 | vary_indices = None 179 | 180 | return TaskStateOut( 181 | command=data.command, 182 | channel_id=data.channel_id, 183 | guild_id=data.guild_id, 184 | message_id=data.message_id, 185 | images=data.images, 186 | videos=data.videos, 187 | status=data.status, 188 | upscale_indices=upscale_indices, 189 | vary_indices=vary_indices 190 | ) 191 | -------------------------------------------------------------------------------- /streamlit_demo/pages/Real.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import httpx 4 | import streamlit as st 5 | 6 | from app.schema import Mode 7 | from streamlit_demo.auth import check_password 8 | from streamlit_demo.utils import polling_check_state, build_upscale_vary_buttons, UVResult, BASE_URL, \ 9 | BASE_HEADERS 10 | 11 | if not check_password(): 12 | st.stop() 13 | 14 | st.title("REAL") 15 | 16 | if 'real_result' not in st.session_state: 17 | st.session_state.real_result = None 18 | 19 | if 'real_uv_results' not in st.session_state: 20 | st.session_state.real_uv_results = [] 21 | 22 | 23 | async def real(prompt, image, mode): 24 | async with httpx.AsyncClient(base_url=BASE_URL, headers=BASE_HEADERS) as client: 25 | if prompt: 26 | data = { 27 | "prompt": prompt, 28 | "mode": mode if mode != 'auto' else None 29 | } 30 | else: 31 | data = { 32 | "mode": mode if mode != 'auto' else None 33 | } 34 | response = await client.post( 35 | '/v1/real', 36 | data=data, 37 | files={'image': (image.name, image.read(), image.type)}, 38 | timeout=30 39 | ) 40 | if not response.is_success: 41 | st.error(f"Generate Fail: {response}") 42 | return None 43 | 44 | task_id = response.json()['task_id'] 45 | 46 | result = await polling_check_state(task_id=task_id) 47 | return task_id, result['images'][0]['proxy_url'], result['upscale_indices'], result['vary_indices'] 48 | 49 | 50 | async def upscale(task_id, index): 51 | async with httpx.AsyncClient(base_url=BASE_URL, headers=BASE_HEADERS) as client: 52 | response = await client.post('/v1/upscale', data={ 53 | "task_id": task_id, 54 | "index": index 55 | }, timeout=30) 56 | if not response.is_success: 57 | st.error(f"Upscale Fail: {response}") 58 | return None 59 | 60 | task_id = response.json()['task_id'] 61 | 62 | result = await polling_check_state(task_id=task_id) 63 | return task_id, result['images'][0]['proxy_url'], result['upscale_indices'], result['vary_indices'] 64 | 65 | 66 | async def vary(task_id, index): 67 | async with httpx.AsyncClient(base_url=BASE_URL, headers=BASE_HEADERS) as client: 68 | response = await client.post('/v1/vary', data={ 69 | "task_id": task_id, 70 | "index": index 71 | }, timeout=30) 72 | if not response.is_success: 73 | st.error(f"Vary Fail: {response}") 74 | return None 75 | 76 | task_id = response.json()['task_id'] 77 | 78 | result = await polling_check_state(task_id=task_id) 79 | return task_id, result['images'][0]['proxy_url'], result['upscale_indices'], result['vary_indices'] 80 | 81 | 82 | with st.form("real_form", border=False): 83 | mode = st.radio(label="Mode(*)", options=['auto'] + list(map(lambda x: x.value, Mode)), horizontal=True) 84 | image = st.file_uploader(label="Reference Image(*)", type=['jpg', 'png']) 85 | prompt = st.text_area(label="Prompt") 86 | submitted = st.form_submit_button("Submit") 87 | 88 | 89 | def on_click_upscale(task_id: str, index: int): 90 | with st.spinner('Wait for completion...'): 91 | task_id, images_url, upscale_indices, vary_indices = asyncio.run(upscale(task_id=task_id, index=index)) 92 | st.session_state.real_uv_results.append(UVResult( 93 | task_id=task_id, 94 | image_url=images_url, 95 | upscale_indices=upscale_indices, 96 | vary_indices=vary_indices 97 | )) 98 | 99 | 100 | def on_click_vary(task_id: str, index: int): 101 | with st.spinner('Wait for completion...'): 102 | task_id, images_url, upscale_indices, vary_indices = asyncio.run(vary(task_id=task_id, index=index)) 103 | st.session_state.real_uv_results.append(UVResult( 104 | task_id=task_id, 105 | image_url=images_url, 106 | upscale_indices=upscale_indices, 107 | vary_indices=vary_indices 108 | )) 109 | 110 | 111 | if submitted: 112 | st.session_state.real_result = None 113 | 114 | if submitted or st.session_state.real_result: 115 | if image: 116 | source_col, result_col = st.columns(2) 117 | with source_col: 118 | st.text("Reference") 119 | if image: 120 | st.image(image) 121 | else: 122 | result_col = st.container() 123 | 124 | with result_col: 125 | st.text("Result") 126 | result_image = st.empty() 127 | 128 | with result_col: 129 | with st.spinner('Wait for completion...'): 130 | if not st.session_state.real_result: 131 | st.session_state.real_result = asyncio.run(real(prompt, image, mode)) 132 | task_id, image_url, upscale_indices, vary_indices = st.session_state.real_result 133 | result_image.image(image_url) 134 | build_upscale_vary_buttons( 135 | task_id=task_id, 136 | upscale_indices=upscale_indices, 137 | vary_indices=vary_indices, 138 | on_click_upscale=on_click_upscale, 139 | on_click_vary=on_click_vary 140 | ) 141 | 142 | for item in st.session_state.real_uv_results: 143 | result: UVResult = item 144 | st.image(result.image_url) 145 | build_upscale_vary_buttons( 146 | task_id=result.task_id, 147 | upscale_indices=result.upscale_indices, 148 | vary_indices=result.vary_indices, 149 | on_click_upscale=on_click_upscale, 150 | on_click_vary=on_click_vary 151 | ) 152 | -------------------------------------------------------------------------------- /streamlit_demo/pages/Gen.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Optional 3 | 4 | import httpx 5 | import streamlit as st 6 | from streamlit.runtime.uploaded_file_manager import UploadedFile 7 | 8 | from app.models import GenModel 9 | from app.schema import Mode 10 | from streamlit_demo.auth import check_password 11 | from streamlit_demo.utils import polling_check_state, build_upscale_vary_buttons, UVResult, BASE_URL, \ 12 | BASE_HEADERS 13 | 14 | if not check_password(): 15 | st.stop() 16 | 17 | st.title("Gen") 18 | 19 | if 'gen_result' not in st.session_state: 20 | st.session_state.gen_result = None 21 | 22 | if 'gen_uv_results' not in st.session_state: 23 | st.session_state.gen_uv_results = [] 24 | 25 | 26 | async def gen(prompt: str, image: Optional[UploadedFile], mode: str, model: Optional[str]): 27 | async with httpx.AsyncClient(base_url=BASE_URL, headers=BASE_HEADERS) as client: 28 | data = { 29 | "prompt": prompt, 30 | "mode": mode if mode != 'auto' else None 31 | } 32 | if model: 33 | data['model'] = model 34 | response = await client.post( 35 | '/v1/gen', 36 | data=data, 37 | files={'image': (image.name, image.read(), image.type)} if image else None, 38 | timeout=30 39 | ) 40 | if not response.is_success: 41 | st.error(f"Generate Fail: {response}") 42 | return None 43 | 44 | task_id = response.json()['task_id'] 45 | 46 | result = await polling_check_state(task_id=task_id) 47 | return task_id, result['images'][0]['proxy_url'], result['upscale_indices'], result['vary_indices'] 48 | 49 | 50 | async def upscale(task_id, index): 51 | async with httpx.AsyncClient(base_url=BASE_URL, headers=BASE_HEADERS) as client: 52 | response = await client.post('/v1/upscale', data={ 53 | "task_id": task_id, 54 | "index": index 55 | }, timeout=30) 56 | if not response.is_success: 57 | st.error(f"Upscale Fail: {response}") 58 | return None 59 | 60 | task_id = response.json()['task_id'] 61 | 62 | result = await polling_check_state(task_id=task_id) 63 | return task_id, result['images'][0]['proxy_url'], result['upscale_indices'], result['vary_indices'] 64 | 65 | 66 | async def vary(task_id, index): 67 | async with httpx.AsyncClient(base_url=BASE_URL, headers=BASE_HEADERS) as client: 68 | response = await client.post('/v1/vary', data={ 69 | "task_id": task_id, 70 | "index": index 71 | }, timeout=30) 72 | if not response.is_success: 73 | st.error(f"Vary Fail: {response}") 74 | return None 75 | 76 | task_id = response.json()['task_id'] 77 | 78 | result = await polling_check_state(task_id=task_id) 79 | return task_id, result['images'][0]['proxy_url'], result['upscale_indices'], result['vary_indices'] 80 | 81 | 82 | with st.form("gen_form", border=False): 83 | mode = st.radio(label="Mode(*)", options=['auto'] + list(map(lambda x: x.value, Mode)), horizontal=True) 84 | prompt = st.text_area(label="Prompt(*)") 85 | image = st.file_uploader(label="Reference Image", type=['jpg', 'png']) 86 | gen_models_value = [''] + list(map(lambda x: x.value, GenModel)) 87 | 88 | model = st.selectbox( 89 | label="Model", 90 | options=gen_models_value 91 | ) 92 | 93 | submitted = st.form_submit_button("Submit") 94 | 95 | 96 | def on_click_upscale(task_id: str, index: int): 97 | with st.spinner('Wait for completion...'): 98 | task_id, images_url, upscale_indices, vary_indices = asyncio.run(upscale(task_id=task_id, index=index)) 99 | st.session_state.gen_uv_results.append(UVResult( 100 | task_id=task_id, 101 | image_url=images_url, 102 | upscale_indices=upscale_indices, 103 | vary_indices=vary_indices 104 | )) 105 | 106 | 107 | def on_click_vary(task_id: str, index: int): 108 | with st.spinner('Wait for completion...'): 109 | task_id, images_url, upscale_indices, vary_indices = asyncio.run(vary(task_id=task_id, index=index)) 110 | st.session_state.gen_uv_results.append(UVResult( 111 | task_id=task_id, 112 | image_url=images_url, 113 | upscale_indices=upscale_indices, 114 | vary_indices=vary_indices 115 | )) 116 | 117 | 118 | if submitted: 119 | st.session_state.gen_result = None 120 | 121 | if submitted or st.session_state.gen_result: 122 | if image: 123 | source_col, result_col = st.columns(2) 124 | with source_col: 125 | st.text("Reference") 126 | if image: 127 | st.image(image) 128 | else: 129 | result_col = st.container() 130 | 131 | with result_col: 132 | st.text("Result") 133 | result_image = st.empty() 134 | 135 | with result_col: 136 | with st.spinner('Wait for completion...'): 137 | if not st.session_state.gen_result: 138 | st.session_state.gen_result = asyncio.run(gen(prompt, image, mode, model)) 139 | task_id, image_url, upscale_indices, vary_indices = st.session_state.gen_result 140 | result_image.image(image_url) 141 | build_upscale_vary_buttons( 142 | task_id=task_id, 143 | upscale_indices=upscale_indices, 144 | vary_indices=vary_indices, 145 | on_click_upscale=on_click_upscale, 146 | on_click_vary=on_click_vary 147 | ) 148 | 149 | for item in st.session_state.gen_uv_results: 150 | result: UVResult = item 151 | st.image(result.image_url) 152 | build_upscale_vary_buttons( 153 | task_id=result.task_id, 154 | upscale_indices=result.upscale_indices, 155 | vary_indices=result.vary_indices, 156 | on_click_upscale=on_click_upscale, 157 | on_click_vary=on_click_vary 158 | ) 159 | -------------------------------------------------------------------------------- /app/models/v2v-models.json: -------------------------------------------------------------------------------- 1 | [{"id": 15035, "name": "Fusion style v2", "description": "Any style by image", "cover": {"url": "https://imgo.domoai.app/ai-model/010623bf-2857-4679-a8cb-3bc60bdb9647.png", "width": 1024, "height": 1648}, "prompt_args": "--fs v2", "allowed_refer_modes": ["PROMPT_MORE"], "allowed_lip_sync": true, "tech_version": 300, "allowed_reference_image": true}, {"id": 15034, "name": "Illustration v16", "description": "Storybook style", "cover": {"url": "https://imgo.domoai.app/ai-model/9a7d4712-78ac-47f7-8826-e033ed8caea1.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v16", "allowed_refer_modes": ["VIDEO_MORE"], "allowed_lip_sync": false, "tech_version": 300, "allowed_reference_image": false}, {"id": 15033, "name": "Illustration v15", "description": "Renaissance art style", "cover": {"url": "https://imgo.domoai.app/ai-model/4aa3e571-72c8-4e18-a36f-420c7b6ccb41.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v15", "allowed_refer_modes": ["VIDEO_MORE"], "allowed_lip_sync": false, "tech_version": 300, "allowed_reference_image": false}, {"id": 15032, "name": "Anime v8", "description": "Sketch anime style", "cover": {"url": "https://imgo.domoai.app/ai-model/01f1a743-e180-4808-ac67-5fa865bd8029.png", "width": 1024, "height": 1648}, "prompt_args": "--ani v8", "allowed_refer_modes": ["VIDEO_MORE"], "allowed_lip_sync": false, "tech_version": 300, "allowed_reference_image": false}, {"id": 15031, "name": "Illustration v14", "description": "Ukiyo-e style", "cover": {"url": "https://imgo.domoai.app/ai-model/6651151b-6a5b-4d8d-8b6b-660db8c32d5f.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v14", "allowed_refer_modes": ["VIDEO_MORE"], "allowed_lip_sync": true, "tech_version": 300, "allowed_reference_image": false}, {"id": 15030, "name": "Illustration v1.2", "description": "3D cartoon style 3.0", "cover": {"url": "https://imgo.domoai.app/ai-model/6b8f9c99-60dd-4369-a3f5-8700f3f59704.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v1.2", "allowed_refer_modes": ["VIDEO_MORE"], "allowed_lip_sync": true, "tech_version": 300, "allowed_reference_image": false}, {"id": 15029, "name": "Illustration v13.1", "description": "Clay cartoon style 3.0", "cover": {"url": "https://imgo.domoai.app/ai-model/401d4cf6-fa12-40d7-a7cc-ec5982e6dd37.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v13.1", "allowed_refer_modes": ["VIDEO_MORE"], "allowed_lip_sync": true, "tech_version": 300, "allowed_reference_image": false}, {"id": 15028, "name": "Anime v5.2", "description": "Japanese anime 3.0", "cover": {"url": "https://imgo.domoai.app/ai-model/4249fe12-2456-4002-a39c-814c260e815c.png", "width": 1024, "height": 1648}, "prompt_args": "--ani v5.2", "allowed_refer_modes": ["VIDEO_MORE"], "allowed_lip_sync": true, "tech_version": 300, "allowed_reference_image": false}, {"id": 15027, "name": "Illustration v13", "description": "Clay cartoon style", "cover": {"url": "https://imgo.domoai.app/ai-model/073d2fa2-f678-415b-94fe-1cde9142d37b.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v13", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 200, "allowed_reference_image": false}, {"id": 15026, "name": "Illustration v12", "description": "Lego style", "cover": {"url": "https://imgo.domoai.app/ai-model/6c3064fd-f161-4f88-b1fa-726ab6cc29ec.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v12", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 200, "allowed_reference_image": false}, {"id": 15025, "name": "Illustration v11", "description": "American comicbook", "cover": {"url": "https://imgo.domoai.app/ai-model/c0223c6b-db06-4c09-950c-d0a5019f1c45.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v11", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 200, "allowed_reference_image": false}, {"id": 15024, "name": "Anime v7", "description": "Color pencil style", "cover": {"url": "https://imgo.domoai.app/ai-model/3b2629c6-d79e-4b21-b7bf-a5c454476007.png", "width": 1024, "height": 1648}, "prompt_args": "--ani v7", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 200, "allowed_reference_image": false}, {"id": 15023, "name": "Illustration v10", "description": "PS2 game style", "cover": {"url": "https://imgo.domoai.app/ai-model/e5394ee2-a0b7-437e-9e3d-1795d5de5125.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v10", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 200, "allowed_reference_image": false}, {"id": 15022, "name": "Illustration v9", "description": "2.5D illustration style", "cover": {"url": "https://imgo.domoai.app/ai-model/8e25862d-428f-4c7c-956e-a70eed5548c8.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v9", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 200, "allowed_reference_image": false}, {"id": 15021, "name": "Illustration v7.1", "description": "Paper art style 2.0", "cover": {"url": "https://imgo.domoai.app/ai-model/475bf829-96b0-4ceb-b3f3-ca555fb12f22.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v7.1", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 200, "allowed_reference_image": false}, {"id": 15020, "name": "Illustration v3.1", "description": "Pixel style 2.0", "cover": {"url": "https://imgo.domoai.app/ai-model/1b3b7497-4a89-4eeb-86ec-288f5c2743e0.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v3.1", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 200, "allowed_reference_image": false}, {"id": 15019, "name": "Illustration v1.1", "description": "3D cartoon style 2.0", "cover": {"url": "https://imgo.domoai.app/ai-model/d6d095be-dc29-417e-b5d8-22660fc3df3d.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v1.1", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 200, "allowed_reference_image": false}, {"id": 15018, "name": "Anime v4.1", "description": "Chinese ink painting 2.0", "cover": {"url": "https://imgo.domoai.app/ai-model/6db5f48d-8c53-4e88-abac-e4d550b641eb.png", "width": 1024, "height": 1648}, "prompt_args": "--ani v4.1", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 200, "allowed_reference_image": false}, {"id": 15017, "name": "Anime v1.1", "description": "Flat color anime style 2.0", "cover": {"url": "https://imgo.domoai.app/ai-model/5afd5dac-6843-4901-b67f-46f728261070.png", "width": 1024, "height": 1648}, "prompt_args": "--ani v1.1", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 200, "allowed_reference_image": false}, {"id": 15016, "name": "Anime v5.1", "description": "Japanese anime 2.1", "cover": {"url": "https://imgo.domoai.app/ai-model/3843c175-b640-4697-a89e-a785583c4a5c.png", "width": 1024, "height": 1648}, "prompt_args": "--ani v5.1", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 200, "allowed_reference_image": false}, {"id": 15015, "name": "Anime v6", "description": "Detail anime style 2.0", "cover": {"url": "https://imgo.domoai.app/ai-model/f59317bc-0417-46ab-8fee-72c8fbd3c8eb.png", "width": 1024, "height": 1648}, "prompt_args": "--ani v6", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 200, "allowed_reference_image": false}, {"id": 15014, "name": "Fusion Style v1", "description": "Any style by prompt", "cover": {"url": "https://imgo.domoai.app/ai-model/bcb0d19e-f88f-4fc2-85a7-9abf4624525a.png", "width": 1024, "height": 1648}, "prompt_args": "--fs v1", "allowed_refer_modes": ["PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 200, "allowed_reference_image": false}, {"id": 15012, "name": "Illustration v8", "description": "Van gogh style", "cover": {"url": "https://imgo.domoai.app/ai-model/cdd10dd9-1cf3-4b00-af4e-c8bd77362f46.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v8", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 100, "allowed_reference_image": false}, {"id": 15011, "name": "Illustration v7", "description": "Paper art style", "cover": {"url": "https://imgo.domoai.app/ai-model/ae4f9929-f3c5-44a0-a95c-05c524afa15d.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v7", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 100, "allowed_reference_image": false}, {"id": 15010, "name": "Anime v4", "description": "Chinese ink painting", "cover": {"url": "https://imgo.domoai.app/ai-model/d2a71189-1e1c-4b0c-ba52-31fc79d6bfdc.png", "width": 1024, "height": 1648}, "prompt_args": "--ani v4", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 100, "allowed_reference_image": false}, {"id": 15009, "name": "Illustration v6", "description": "Grand theft game", "cover": {"url": "https://imgo.domoai.app/ai-model/31f82d00-7ef6-4287-8544-1ec4854cbdc9.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v6", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 100, "allowed_reference_image": false}, {"id": 15008, "name": "Illustration v5", "description": "Color illustration", "cover": {"url": "https://imgo.domoai.app/ai-model/b7fd23a2-459b-47cf-b012-96899cf38f61.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v5", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 100, "allowed_reference_image": false}, {"id": 15007, "name": "Illustration v4", "description": "Storybook cartoon", "cover": {"url": "https://imgo.domoai.app/ai-model/0e3ca4ad-4400-4977-b828-41913e43759a.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v4", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 100, "allowed_reference_image": false}, {"id": 15006, "name": "Anime v3", "description": "Live anime style", "cover": {"url": "https://imgo.domoai.app/ai-model/c82a7ce4-ade0-4c7d-bfb5-4546d636e12f.png", "width": 1024, "height": 1648}, "prompt_args": "--ani v3", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 100, "allowed_reference_image": false}, {"id": 15005, "name": "Anime v2", "description": "Japanese anime style", "cover": {"url": "https://imgo.domoai.app/ai-model/5a11a3ea-2884-4630-b504-edd0b67ad239.png", "width": 1024, "height": 1648}, "prompt_args": "--ani v2", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 100, "allowed_reference_image": false}, {"id": 15004, "name": "Illustration v3", "description": "Pixel style", "cover": {"url": "https://imgo.domoai.app/ai-model/0c3c8928-e376-4462-8325-abd8f971f650.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v3", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 100, "allowed_reference_image": false}, {"id": 15003, "name": "Illustration v2", "description": "Comic style", "cover": {"url": "https://imgo.domoai.app/ai-model/8420802c-f777-4047-a65c-25ea1d4181a7.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v2", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 100, "allowed_reference_image": false}, {"id": 15002, "name": "Illustration v1", "description": "3D cartoon style", "cover": {"url": "https://imgo.domoai.app/ai-model/49ab8034-c804-47cd-8125-3810dd65310a.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v1", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 100, "allowed_reference_image": false}, {"id": 15001, "name": "Anime v1", "description": "Flat color anime style", "cover": {"url": "https://imgo.domoai.app/ai-model/2accc87e-0bbd-43c3-8d64-b207747761fb.png", "width": 1024, "height": 1648}, "prompt_args": "--ani v1", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 100, "allowed_reference_image": false}] -------------------------------------------------------------------------------- /app/main.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import io 3 | import uuid 4 | from typing import Optional 5 | 6 | import discord 7 | from fastapi import FastAPI, Request, UploadFile, Form, HTTPException, Depends 8 | from starlette import status 9 | from starlette.responses import JSONResponse 10 | 11 | from app.cache import RedisCache, MemoryCache, Cache 12 | from app.dependencies import api_auth 13 | from app.models import GenModel, MoveModel, VideoModel, get_v2v_model_info_by_instructions 14 | from app.schema import VideoReferMode, VideoLength, TaskCacheData, TaskStatus, CreateTaskOut, \ 15 | TaskCommand, TaskStateOut, AnimateLength, AnimateIntensity, Mode, VideoKey, VideoApiError 16 | from app.settings import get_settings 17 | from app.user_client import DiscordUserClient 18 | 19 | app = FastAPI() 20 | 21 | settings = get_settings() 22 | 23 | 24 | async def __did_send_interaction( 25 | wait_message_desc_keyword: str, 26 | discord_user_client: DiscordUserClient, 27 | cache: Cache, 28 | command: TaskCommand 29 | ) -> CreateTaskOut: 30 | message: discord.Message = await discord_user_client.wait_for_generating_message( 31 | embeds_desc_keyword=wait_message_desc_keyword 32 | ) 33 | task_id = str(uuid.uuid4()) 34 | await cache.set_message_id2task_id(message_id=str(message.id), task_id=task_id) 35 | await cache.set_task_id2data(task_id=task_id, data=TaskCacheData( 36 | command=command, 37 | status=TaskStatus.RUNNING, 38 | channel_id=str(discord_user_client.channel_id), 39 | guild_id=str(discord_user_client.guild_id), 40 | message_id=str(message.id) 41 | )) 42 | return CreateTaskOut( 43 | success=True, 44 | task_id=task_id, 45 | message_id=str(message.id) 46 | ) 47 | 48 | 49 | @app.post("/v1/gen") 50 | async def gen_api( 51 | request: Request, 52 | auth=Depends(api_auth), 53 | image: Optional[UploadFile] = None, 54 | prompt: str = Form(...), 55 | mode: Optional[Mode] = Form(default=None), 56 | model: Optional[GenModel] = Form(default=None) 57 | ): 58 | discord_user_client: DiscordUserClient = request.app.state.discord_user_client 59 | if image: 60 | image_bytes = await image.read() 61 | image_file = discord.File(io.BytesIO(image_bytes), filename=image.filename) 62 | else: 63 | image_file = None 64 | interaction = await discord_user_client.gen(prompt=prompt, image=image_file, mode=mode, model=model) 65 | print(f"gen, interaction_id: {interaction.id}, interaction.nonce: {interaction.nonce}") 66 | 67 | if not interaction.successful: 68 | # TODO: 69 | return {"success": interaction.successful} 70 | 71 | result = await __did_send_interaction( 72 | wait_message_desc_keyword='Waiting to start', 73 | command=TaskCommand.GEN, 74 | cache=request.app.state.cache, 75 | discord_user_client=discord_user_client 76 | ) 77 | return result 78 | 79 | 80 | @app.post("/v1/real") 81 | async def real_api( 82 | request: Request, 83 | auth=Depends(api_auth), 84 | image: UploadFile = Form(...), 85 | prompt: Optional[str] = Form(default=None), 86 | mode: Optional[Mode] = Form(default=None) 87 | ): 88 | discord_user_client: DiscordUserClient = request.app.state.discord_user_client 89 | image_bytes = await image.read() 90 | image_file = discord.File(io.BytesIO(image_bytes), filename=image.filename) 91 | interaction = await discord_user_client.real(prompt=prompt, image=image_file, mode=mode) 92 | print(f"real, interaction_id: {interaction.id}, interaction.nonce: {interaction.nonce}") 93 | 94 | if not interaction.successful: 95 | # TODO: 96 | return {"success": interaction.successful} 97 | 98 | result = await __did_send_interaction( 99 | wait_message_desc_keyword='Waiting to start', 100 | command=TaskCommand.REAL, 101 | cache=request.app.state.cache, 102 | discord_user_client=discord_user_client 103 | ) 104 | return result 105 | 106 | 107 | @app.post("/v1/animate") 108 | async def animate_api( 109 | request: Request, 110 | auth=Depends(api_auth), 111 | image: UploadFile = Form(...), 112 | length: AnimateLength = Form(...), 113 | intensity: AnimateIntensity = Form(...), 114 | prompt: Optional[str] = Form(default=None), 115 | mode: Optional[Mode] = Form(default=None) 116 | ): 117 | discord_user_client: DiscordUserClient = request.app.state.discord_user_client 118 | image_bytes = await image.read() 119 | image_file = discord.File(io.BytesIO(image_bytes), filename=image.filename) 120 | interaction = await discord_user_client.animate( 121 | prompt=prompt, 122 | image=image_file, 123 | length=length, 124 | intensity=intensity, 125 | mode=mode 126 | ) 127 | print(f"animate, interaction_id: {interaction.id}, interaction.nonce: {interaction.nonce}") 128 | 129 | if not interaction.successful: 130 | # TODO: 131 | return {"success": interaction.successful} 132 | 133 | result = await __did_send_interaction( 134 | wait_message_desc_keyword='Waiting to start', 135 | command=TaskCommand.ANIMATE, 136 | cache=request.app.state.cache, 137 | discord_user_client=discord_user_client 138 | ) 139 | return result 140 | 141 | 142 | @app.post("/v1/upscale") 143 | async def upscale_api( 144 | request: Request, 145 | auth=Depends(api_auth), 146 | task_id: str = Form(...), 147 | index: int = Form(..., ge=1, le=4) 148 | ): 149 | discord_user_client: DiscordUserClient = request.app.state.discord_user_client 150 | 151 | cache: Cache = request.app.state.cache 152 | data = await cache.get_task_data_by_id(task_id=task_id) 153 | if not data: 154 | raise HTTPException( 155 | status_code=status.HTTP_404_NOT_FOUND, 156 | ) 157 | label = f"U{index}" 158 | custom_id = data.upscale_custom_ids.get(label) 159 | if not custom_id: 160 | raise HTTPException( 161 | status_code=status.HTTP_404_NOT_FOUND, 162 | ) 163 | interaction = await discord_user_client.click_button(custom_id=custom_id, message_id=int(data.message_id)) 164 | print(f"upscale, interaction_id: {interaction.id}, interaction.nonce: {interaction.nonce}") 165 | 166 | if not interaction.successful: 167 | # TODO: 168 | return {"success": interaction.successful} 169 | 170 | result = await __did_send_interaction( 171 | wait_message_desc_keyword='Waiting to start', 172 | command=TaskCommand.GEN, 173 | cache=request.app.state.cache, 174 | discord_user_client=discord_user_client 175 | ) 176 | return result 177 | 178 | 179 | @app.post("/v1/vary") 180 | async def vary_api( 181 | request: Request, 182 | auth=Depends(api_auth), 183 | task_id: str = Form(...), 184 | index: int = Form(..., ge=1, le=4) 185 | ): 186 | discord_user_client: DiscordUserClient = request.app.state.discord_user_client 187 | 188 | cache: Cache = request.app.state.cache 189 | data = await cache.get_task_data_by_id(task_id=task_id) 190 | if not data: 191 | raise HTTPException( 192 | status_code=status.HTTP_404_NOT_FOUND, 193 | ) 194 | label = f"V{index}" 195 | custom_id = data.vary_custom_ids.get(label) 196 | if not custom_id: 197 | raise HTTPException( 198 | status_code=status.HTTP_404_NOT_FOUND, 199 | ) 200 | interaction = await discord_user_client.click_button(custom_id=custom_id, message_id=int(data.message_id)) 201 | print(f"vary, interaction_id: {interaction.id}, interaction.nonce: {interaction.nonce}") 202 | 203 | if not interaction.successful: 204 | # TODO: 205 | return {"success": interaction.successful} 206 | 207 | result = await __did_send_interaction( 208 | wait_message_desc_keyword='Waiting to start', 209 | command=TaskCommand.GEN, 210 | cache=request.app.state.cache, 211 | discord_user_client=discord_user_client 212 | ) 213 | return result 214 | 215 | 216 | @app.post("/v1/video") 217 | async def video_api( 218 | request: Request, 219 | video: UploadFile, 220 | image: Optional[UploadFile] = None, 221 | auth=Depends(api_auth), 222 | model: VideoModel = Form(...), 223 | refer_mode: VideoReferMode = Form(...), 224 | length: VideoLength = Form(...), 225 | prompt: str = Form(...), 226 | video_key: VideoKey = Form(default=None), 227 | subject_only: bool = Form(default=None), 228 | lip_sync: bool = Form(default=None), 229 | mode: Optional[Mode] = Form(default=None), 230 | ): 231 | # size_mb = video.size / 1024.0 / 1024.0 232 | discord_user_client: DiscordUserClient = request.app.state.discord_user_client 233 | video_bytes = await video.read() 234 | video_file = discord.File(io.BytesIO(video_bytes), filename=video.filename) 235 | image_file = None 236 | if image: 237 | image_bytes = await image.read() 238 | image_file = discord.File(io.BytesIO(image_bytes), filename=image.filename) 239 | 240 | model_info = get_v2v_model_info_by_instructions(model.value) 241 | if model_info is None: 242 | return JSONResponse( 243 | status_code=status.HTTP_400_BAD_REQUEST, 244 | content={"code": VideoApiError.VIDEO_MODEL_ERROR} 245 | ) 246 | 247 | if refer_mode not in model_info.allowed_refer_modes: 248 | return JSONResponse( 249 | status_code=status.HTTP_400_BAD_REQUEST, 250 | content={"error": VideoApiError.NOT_ALLOW_REFER} 251 | ) 252 | 253 | if not model_info.allowed_lip_sync and lip_sync: 254 | return JSONResponse( 255 | status_code=status.HTTP_400_BAD_REQUEST, 256 | content={"error": VideoApiError.NOT_ALLOW_LIP_SYNC} 257 | ) 258 | 259 | if model_info.allowed_reference_image and image is None: 260 | return JSONResponse( 261 | status_code=status.HTTP_400_BAD_REQUEST, 262 | content={"error": VideoApiError.MODEL_NEED_REFERENCE_IMAGE} 263 | ) 264 | 265 | interaction = await discord_user_client.video( 266 | prompt=prompt, 267 | video=video_file, 268 | image=image_file, 269 | model=model, 270 | refer_mode=refer_mode, 271 | length=length, 272 | mode=mode, 273 | video_key=video_key, 274 | subject_only=subject_only, 275 | lip_sync=lip_sync, 276 | ) 277 | print(f"video, interaction_id: {interaction.id}, interaction.nonce: {interaction.nonce}") 278 | 279 | if not interaction.successful: 280 | # TODO: 281 | return {"success": interaction.successful} 282 | 283 | result = await __did_send_interaction( 284 | wait_message_desc_keyword='Generating', 285 | command=TaskCommand.VIDEO, 286 | cache=request.app.state.cache, 287 | discord_user_client=discord_user_client 288 | ) 289 | return result 290 | 291 | 292 | @app.post("/v1/move") 293 | async def move_api( 294 | request: Request, 295 | image: UploadFile, 296 | video: UploadFile, 297 | auth=Depends(api_auth), 298 | model: MoveModel = Form(...), 299 | length: VideoLength = Form(...), 300 | prompt: str = Form(...), 301 | video_key: VideoKey = Form(default=None), 302 | mode: Optional[Mode] = Form(default=None), 303 | ): 304 | # size_mb = video.size / 1024.0 / 1024.0 305 | discord_user_client: DiscordUserClient = request.app.state.discord_user_client 306 | image_bytes = await image.read() 307 | image_file = discord.File(io.BytesIO(image_bytes), filename=image.filename) 308 | video_bytes = await video.read() 309 | video_file = discord.File(io.BytesIO(video_bytes), filename=video.filename) 310 | interaction = await discord_user_client.move( 311 | prompt=prompt, 312 | image=image_file, 313 | video=video_file, 314 | model=model, 315 | length=length, 316 | mode=mode, 317 | video_key=video_key, 318 | ) 319 | print(f"move, interaction_id: {interaction.id}, interaction.nonce: {interaction.nonce}") 320 | 321 | if not interaction.successful: 322 | # TODO: 323 | return {"success": interaction.successful} 324 | 325 | result = await __did_send_interaction( 326 | wait_message_desc_keyword='Generating', 327 | command=TaskCommand.MOVE, 328 | cache=request.app.state.cache, 329 | discord_user_client=discord_user_client 330 | ) 331 | return result 332 | 333 | 334 | @app.get("/v1/task-data/{task_id}") 335 | async def task_data( 336 | request: Request, 337 | task_id: str, 338 | auth=Depends(api_auth) 339 | ): 340 | cache: Cache = request.app.state.cache 341 | data = await cache.get_task_data_by_id(task_id=task_id) 342 | if not data: 343 | raise HTTPException( 344 | status_code=status.HTTP_404_NOT_FOUND, 345 | ) 346 | return TaskStateOut.from_cache_data(data=data) 347 | 348 | 349 | @app.on_event("startup") 350 | async def startup_event(): 351 | if settings.redis_uri: 352 | redis = await RedisCache.init_redis_pool(redis_uri=settings.redis_uri) 353 | app.state.cache = RedisCache(redis=redis, prefix=settings.cache_prefix) 354 | else: 355 | app.state.cache = MemoryCache(prefix=settings.cache_prefix) 356 | 357 | discord_user_client = DiscordUserClient( 358 | guild_id=settings.discord_guild_id, 359 | channel_id=settings.discord_channel_id, 360 | application_id=settings.domoai_application_id, 361 | cache=app.state.cache, 362 | event_callback_url=settings.event_callback_url 363 | ) 364 | 365 | app.state.discord_user_client = discord_user_client 366 | await discord_user_client.login(settings.discord_token) 367 | app.state.discord_start_task = asyncio.create_task(discord_user_client.connect(reconnect=True)) 368 | await discord_user_client.wait_until_ready() 369 | 370 | 371 | @app.on_event("shutdown") 372 | async def shutdown_event(): 373 | if app.state.discord_start_task: 374 | app.state.discord_start_task.cancel() 375 | 376 | discord_user_client: DiscordUserClient = app.state.discord_user_client 377 | if not discord_user_client.is_closed(): 378 | await discord_user_client.close() 379 | 380 | cache: Cache = app.state.cache 381 | await cache.close() 382 | -------------------------------------------------------------------------------- /app/user_client.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import re 3 | from typing import Dict, List, Optional 4 | 5 | import discord 6 | from discord import ComponentType, InteractionType, InvalidData 7 | from discord.http import Route 8 | from discord.utils import _generate_nonce 9 | 10 | from app.cache import Cache 11 | from app.event_callback import EventCallback 12 | from app.models import GenModel, MoveModel, VideoModel 13 | from app.schema import VideoReferMode, VideoLength, TaskCacheData, TaskAsset, TaskStatus, \ 14 | TaskCommand, AnimateIntensity, AnimateLength, Mode, VideoKey 15 | 16 | 17 | class DiscordUserClient(discord.Client): 18 | 19 | def __init__( 20 | self, 21 | channel_id: int, 22 | guild_id: int, 23 | application_id: int, 24 | cache: Cache, 25 | event_callback_url: Optional[str], 26 | **options 27 | ): 28 | super().__init__(**options) 29 | self.event_callback = EventCallback(callback_url=event_callback_url) 30 | self.application_id = application_id 31 | self.commands: Dict[str, discord.SlashCommand] = {} 32 | 33 | self.guild_id = guild_id 34 | self.guild = None 35 | 36 | self.channel_id = channel_id 37 | self.channel = None 38 | self.cache = cache 39 | 40 | self.bot_user_id = None 41 | 42 | async def setup_hook(self): 43 | self.bot_user_id = self.user.id 44 | self.guild = await self.fetch_guild(self.guild_id) 45 | self.channel = await self.fetch_channel(self.channel_id) 46 | await self.__init_slash_commands() 47 | 48 | print(f'guild: {self.guild}') 49 | print(f'channel: {self.channel}') 50 | print(f'slash commands: {self.commands.keys()}') 51 | 52 | async def on_ready(self): 53 | print(f'Logged on as {self.user}') 54 | 55 | async def __init_slash_commands(self): 56 | commands: List[discord.SlashCommand] = [ 57 | x for x in await self.guild.application_commands() if 58 | isinstance(x, discord.SlashCommand) and str(x.application_id) == str(self.application_id) 59 | ] 60 | for command in commands: 61 | self.commands[command.name] = command 62 | 63 | async def on_message(self, message: discord.Message): 64 | # if message.application_id != self.application_id: 65 | # return 66 | 67 | print( 68 | f"message: {message}, interaction_id: {message.interaction.id if message.interaction else None}, interaction.nonce: {message.interaction.nonce if message.interaction else None}") 69 | 70 | if message.content == 'ping': 71 | await message.channel.send('pong') 72 | 73 | async def on_message_edit(self, before: discord.Message, after: discord.Message): 74 | if after.author.id != self.application_id: 75 | return 76 | print( 77 | f"message edit, before {before}, after: {after}") 78 | if after.embeds: 79 | embed: discord.Embed = after.embeds[0] 80 | if embed.title.startswith('/gen'): 81 | await self.handle_gen_result(message=after) 82 | elif embed.title.startswith('/real'): 83 | await self.handle_real_result(message=after) 84 | elif embed.title == '/animate': 85 | await self.handle_animate_result(message=after) 86 | elif embed.title == '/video': 87 | await self.handle_video_result(message=after) 88 | elif embed.title == '/move': 89 | await self.handle_move_result(message=after) 90 | elif 'After:' in after.content and 'Before:' in after.content: 91 | await self.handle_video_result(message=after) 92 | elif 'Result:' in after.content and 'Image:' in after.content and 'Video:' in after.content: 93 | await self.handle_move_result(message=after) 94 | 95 | async def wait_for_generating_message(self, embeds_desc_keyword: str) -> discord.Message: 96 | def check(message: discord.Message): 97 | if not message.embeds or not message.mentions: 98 | return False 99 | if message.mentions[0].id != self.bot_user_id: 100 | return False 101 | return embeds_desc_keyword in message.embeds[0].description 102 | 103 | return await self.wait_for( 104 | 'message', 105 | check=check, 106 | timeout=20 107 | ) 108 | 109 | async def handle_gen_result( 110 | self, 111 | message: discord.Message 112 | ): 113 | if not message.attachments: 114 | return 115 | attachment = message.attachments[0] 116 | 117 | task_id = await self.cache.get_task_id_by_message_id(message_id=str(message.id)) 118 | if not task_id: 119 | return 120 | upscale_custom_ids = {} 121 | vary_custom_ids = {} 122 | for row in message.components: 123 | for component in row.children: 124 | if component.disabled or not component.label or not component.custom_id: 125 | continue 126 | if component.label.startswith("U"): 127 | upscale_custom_ids[component.label] = component.custom_id 128 | elif component.label.startswith("Vary"): 129 | vary_custom_ids["V1"] = component.custom_id 130 | elif component.label.startswith("V"): 131 | vary_custom_ids[component.label] = component.custom_id 132 | 133 | data = TaskCacheData( 134 | command=TaskCommand.GEN, 135 | channel_id=str(message.channel.id), 136 | guild_id=str(message.guild.id) if message.guild else None, 137 | message_id=str(message.id), 138 | images=[TaskAsset.from_attachment(attachment)], 139 | status=TaskStatus.SUCCESS, 140 | upscale_custom_ids=upscale_custom_ids, 141 | vary_custom_ids=vary_custom_ids 142 | ) 143 | await self.cache.set_task_id2data(task_id=task_id, data=data) 144 | await self.event_callback.send_task_success(task_id=task_id, data=data) 145 | 146 | async def handle_real_result( 147 | self, 148 | message: discord.Message 149 | ): 150 | if not message.attachments: 151 | return 152 | attachment = message.attachments[0] 153 | 154 | task_id = await self.cache.get_task_id_by_message_id(message_id=str(message.id)) 155 | if not task_id: 156 | return 157 | upscale_custom_ids = {} 158 | vary_custom_ids = {} 159 | for row in message.components: 160 | for component in row.children: 161 | if component.disabled or not component.label or not component.custom_id: 162 | continue 163 | if component.label.startswith("U"): 164 | upscale_custom_ids[component.label] = component.custom_id 165 | elif component.label.startswith("Vary"): 166 | vary_custom_ids["V1"] = component.custom_id 167 | elif component.label.startswith("V"): 168 | vary_custom_ids[component.label] = component.custom_id 169 | 170 | data = TaskCacheData( 171 | command=TaskCommand.REAL, 172 | channel_id=str(message.channel.id), 173 | guild_id=str(message.guild.id) if message.guild else None, 174 | message_id=str(message.id), 175 | images=[TaskAsset.from_attachment(attachment)], 176 | status=TaskStatus.SUCCESS, 177 | upscale_custom_ids=upscale_custom_ids, 178 | vary_custom_ids=vary_custom_ids 179 | ) 180 | 181 | await self.cache.set_task_id2data(task_id=task_id, data=data) 182 | await self.event_callback.send_task_success(task_id=task_id, data=data) 183 | 184 | async def handle_video_result( 185 | self, 186 | message: discord.Message 187 | ): 188 | task_id = await self.cache.get_task_id_by_message_id(message_id=str(message.id)) 189 | if not task_id: 190 | return 191 | 192 | if message.attachments: 193 | attachment = message.attachments[0] 194 | asset = TaskAsset.from_attachment(attachment) 195 | else: 196 | match = re.search(r"After:.*?(https:.*)", message.content) 197 | if not match: 198 | return 199 | video_url = match.group(1) 200 | asset = TaskAsset( 201 | url=video_url, 202 | proxy_url=video_url 203 | ) 204 | 205 | data = TaskCacheData( 206 | command=TaskCommand.VIDEO, 207 | channel_id=str(message.channel.id), 208 | guild_id=str(message.guild.id) if message.guild else None, 209 | message_id=str(message.id), 210 | videos=[asset], 211 | status=TaskStatus.SUCCESS 212 | ) 213 | 214 | await self.cache.set_task_id2data(task_id=task_id, data=data) 215 | await self.event_callback.send_task_success(task_id=task_id, data=data) 216 | 217 | async def handle_animate_result( 218 | self, 219 | message: discord.Message 220 | ): 221 | task_id = await self.cache.get_task_id_by_message_id(message_id=str(message.id)) 222 | if not task_id: 223 | return 224 | 225 | if not message.attachments: 226 | return 227 | 228 | attachment = message.attachments[0] 229 | asset = TaskAsset.from_attachment(attachment) 230 | 231 | data = TaskCacheData( 232 | command=TaskCommand.ANIMATE, 233 | channel_id=str(message.channel.id), 234 | guild_id=str(message.guild.id) if message.guild else None, 235 | message_id=str(message.id), 236 | videos=[asset], 237 | status=TaskStatus.SUCCESS 238 | ) 239 | await self.cache.set_task_id2data(task_id=task_id, data=data) 240 | await self.event_callback.send_task_success(task_id=task_id, data=data) 241 | 242 | async def handle_move_result( 243 | self, 244 | message: discord.Message 245 | ): 246 | task_id = await self.cache.get_task_id_by_message_id(message_id=str(message.id)) 247 | if not task_id: 248 | return 249 | 250 | if message.attachments: 251 | attachment = message.attachments[0] 252 | asset = TaskAsset.from_attachment(attachment) 253 | else: 254 | match = re.search(r"Result:.*?(https:.*)", message.content) 255 | if not match: 256 | return 257 | video_url = match.group(1) 258 | asset = TaskAsset( 259 | url=video_url, 260 | proxy_url=video_url 261 | ) 262 | 263 | data = TaskCacheData( 264 | command=TaskCommand.MOVE, 265 | channel_id=str(message.channel.id), 266 | guild_id=str(message.guild.id) if message.guild else None, 267 | message_id=str(message.id), 268 | videos=[asset], 269 | status=TaskStatus.SUCCESS 270 | ) 271 | await self.cache.set_task_id2data(task_id=task_id, data=data) 272 | await self.event_callback.send_task_success(task_id=task_id, data=data) 273 | 274 | async def gen( 275 | self, 276 | prompt: str, 277 | image: Optional[discord.File] = None, 278 | mode: Optional[Mode] = None, 279 | model: Optional[GenModel] = None 280 | ) -> Optional[discord.Interaction]: 281 | command = self.commands.get('gen') 282 | if not command: 283 | return None 284 | request_prompt = prompt 285 | if mode: 286 | request_prompt += f" --{mode.value}" 287 | if model: 288 | request_prompt += f" --{model.value}" 289 | options = dict( 290 | prompt=request_prompt 291 | ) 292 | if image: 293 | uploaded_image_files = await self.channel.upload_files(image) 294 | image.close() 295 | options['img2img'] = uploaded_image_files[0] 296 | 297 | interaction = await command(self.channel, **options) 298 | return interaction 299 | 300 | async def real( 301 | self, 302 | image: discord.File, 303 | prompt: Optional[str] = None, 304 | mode: Optional[Mode] = None 305 | ) -> Optional[discord.Interaction]: 306 | command = self.commands.get('real') 307 | if not command: 308 | return None 309 | uploaded_image_files = await self.channel.upload_files(image) 310 | image.close() 311 | options = dict( 312 | image=uploaded_image_files[0] 313 | ) 314 | request_prompt_parts = [] 315 | if prompt: 316 | request_prompt_parts.append(prompt) 317 | 318 | if mode: 319 | request_prompt_parts.append(f'--{mode.value}') 320 | 321 | if request_prompt_parts: 322 | options['prompt'] = ' '.join(request_prompt_parts) 323 | 324 | interaction = await command(self.channel, **options) 325 | return interaction 326 | 327 | async def click_button( 328 | self, 329 | custom_id: str, 330 | message_id: int 331 | ) -> Optional[discord.Interaction]: 332 | data = { 333 | "component_type": ComponentType.button.value, 334 | "custom_id": custom_id 335 | } 336 | 337 | nonce = _generate_nonce() 338 | 339 | try: 340 | payload = { 341 | 'application_id': self.application_id, 342 | 'channel_id': self.channel_id, 343 | 'data': data, 344 | 'nonce': nonce, 345 | 'session_id': '44efec80d647a97968b4c60d26d3c032', 346 | 'type': InteractionType.component.value, 347 | 'guild_id': self.guild_id, 348 | 'message_flags': 0, 349 | 'message_id': message_id 350 | } 351 | await self.http.request(Route('POST', '/interactions'), json=payload, form=[], files=None) 352 | # await self.http.interact( 353 | # type=InteractionType.component, 354 | # nonce=nonce, 355 | # data=data, 356 | # channel=self.channel, 357 | # application_id=self.application_id 358 | # ) 359 | # The maximum possible time a response can take is 3 seconds, 360 | # +/- a few milliseconds for network latency 361 | # However, people have been getting errors because their gateway 362 | # disconnects while waiting for the interaction, causing the 363 | # response to be delayed until the gateway is reconnected 364 | # 12 seconds should be enough to account for this 365 | i = await self.wait_for( 366 | 'interaction_finish', 367 | check=lambda d: d.nonce == nonce, 368 | timeout=12, 369 | ) 370 | return i 371 | except (asyncio.TimeoutError, asyncio.CancelledError) as exc: 372 | raise InvalidData('Did not receive a response from Discord') from exc 373 | 374 | async def move( 375 | self, 376 | image: discord.File, 377 | video: discord.File, 378 | prompt: str, 379 | model: MoveModel, 380 | length: VideoLength, 381 | mode: Optional[Mode] = None, 382 | video_key: Optional[VideoKey] = None, 383 | ) -> Optional[discord.Interaction]: 384 | command = self.commands.get('move') 385 | if not command: 386 | return None 387 | uploaded_images = await self.channel.upload_files(image) 388 | image.close() 389 | 390 | uploaded_videos = await self.channel.upload_files(video) 391 | video.close() 392 | request_prompt = f"{prompt} --{model.value} --length {length.value}" 393 | if mode: 394 | request_prompt += f' --{mode.value}' 395 | if video_key: 396 | request_prompt += f' --key {video_key.value.lower()}' 397 | options = dict( 398 | prompt=request_prompt, 399 | video=uploaded_videos[0], 400 | image=uploaded_images[0] 401 | ) 402 | return await command(self.channel, **options) 403 | 404 | async def video( 405 | self, 406 | video: discord.File, 407 | image: Optional[discord.File], 408 | prompt: str, 409 | model: VideoModel, 410 | refer_mode: VideoReferMode, 411 | length: VideoLength, 412 | mode: Optional[Mode] = None, 413 | video_key: Optional[VideoKey] = None, 414 | subject_only: Optional[bool] = None, 415 | lip_sync: Optional[bool] = None, 416 | ) -> Optional[discord.Interaction]: 417 | command = self.commands.get('video') 418 | if not command: 419 | return None 420 | uploaded_videos = await self.channel.upload_files(video) 421 | uploaded_image = None 422 | if image: 423 | uploaded_image = await self.channel.upload_files(image) 424 | video.close() 425 | if refer_mode == VideoReferMode.REFER_TO_MY_PROMPT_MORE: 426 | refer_mode_value = 'p' 427 | else: 428 | refer_mode_value = 'v' 429 | request_prompt = f"{prompt} --{model.value} --refer {refer_mode_value} --length {length.value}" 430 | if mode: 431 | request_prompt += f' --{mode.value}' 432 | if video_key: 433 | request_prompt += f' --key {video_key.value.lower()}' 434 | if subject_only: 435 | request_prompt += f' --so' 436 | if lip_sync: 437 | request_prompt += f' --lips' 438 | options = dict( 439 | prompt=request_prompt, 440 | video=uploaded_videos[0] 441 | ) 442 | if uploaded_image: 443 | options['image'] = uploaded_image[0] 444 | return await command(self.channel, **options) 445 | 446 | async def animate( 447 | self, 448 | image: discord.File, 449 | intensity: AnimateIntensity, 450 | length: AnimateLength, 451 | prompt: Optional[str] = None, 452 | mode: Optional[Mode] = None 453 | ) -> Optional[discord.Interaction]: 454 | command = self.commands.get('animate') 455 | if not command: 456 | return None 457 | uploaded_images = await self.channel.upload_files(image) 458 | image.close() 459 | request_prompt = f"--intensity {intensity.value} --length {length.value}" 460 | if prompt: 461 | request_prompt = f"{prompt} " + request_prompt 462 | if mode: 463 | request_prompt += f' --{mode.value}' 464 | options = dict( 465 | prompt=request_prompt, 466 | image=uploaded_images[0] 467 | ) 468 | return await command(self.channel, **options) 469 | --------------------------------------------------------------------------------