├── app
├── __init__.py
├── dependencies.py
├── settings.py
├── models
│ ├── move-models.json
│ ├── gen-models.json
│ └── v2v-models.json
├── event_callback.py
├── models.py
├── cache.py
├── schema.py
├── main.py
└── user_client.py
├── requirements-dev.txt
├── .env.template
├── .idea
├── .gitignore
├── vcs.xml
├── inspectionProfiles
│ ├── profiles_settings.xml
│ └── Project_Default.xml
├── modules.xml
├── domoai-api.iml
└── misc.xml
├── requirements.txt
├── streamlit_demo
├── 🏠_Home.py
├── auth.py
├── pages
│ ├── Animate.py
│ ├── Move.py
│ ├── Video.py
│ ├── Real.py
│ └── Gen.py
└── utils.py
├── Dockerfile
├── README_CN.md
├── streamlit_demo.dockerfile
├── README.md
├── LICENSE
├── scripts
└── update_models.py
└── .gitignore
/app/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/requirements-dev.txt:
--------------------------------------------------------------------------------
1 | streamlit
2 | watchdog
3 | streamlit-authenticator
--------------------------------------------------------------------------------
/.env.template:
--------------------------------------------------------------------------------
1 | DISCORD_TOKEN=
2 | DISCORD_GUILD_ID=
3 | DISCORD_CHANNEL_ID=
4 |
5 | # Optional
6 | # REDIS_URI=
7 |
8 | # EVENT_CALLBACK_URL=
9 |
10 | # AUTH_TOKEN=
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # 默认忽略的文件
2 | /shelf/
3 | /workspace.xml
4 | # 基于编辑器的 HTTP 客户端请求
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | fastapi
2 | uvicorn[standard]
3 | python-multipart
4 | pydantic-settings
5 | discord.py-self @ git+https://github.com/dolfies/discord.py-self.git
6 |
7 | redis
8 | httpx
9 | tenacity
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/streamlit_demo/🏠_Home.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 |
3 | from streamlit_demo.auth import check_password
4 |
5 | st.set_page_config(page_title="DomoAI API")
6 |
7 |
8 | if not check_password():
9 | st.stop()
10 |
11 | st.title("Unofficial DomoAI API Demo")
12 | st.link_button(label='GitHub', url='https://github.com/T0nYie1dYaw/domoai-api')
13 |
--------------------------------------------------------------------------------
/.idea/domoai-api.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.10
2 |
3 | WORKDIR /code
4 |
5 | COPY ./requirements.txt /code/requirements.txt
6 |
7 | RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
8 |
9 | COPY ./app /code/app
10 |
11 | ENV PYTHONUNBUFFERED 1
12 |
13 | ENV DISCORD_TOKEN ''
14 | ENV DISCORD_GUILD_ID ''
15 | ENV DISCORD_CHANNEL_ID ''
16 |
17 | ENV EVENT_CALLBACK_URL ''
18 |
19 | ENV AUTH_TOKEN ''
20 |
21 | ENV REDIS_URI ''
22 |
23 | EXPOSE 80
24 |
25 | CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "80"]
26 |
--------------------------------------------------------------------------------
/README_CN.md:
--------------------------------------------------------------------------------
1 | DomoAI API
2 | ===
3 | **非官方** DomoAI API.
4 |
5 | [English Documentation](README.md)
6 |
7 | 特性
8 | ---
9 |
10 | 1. 支持全部 [DomoAI](https://domoai.app/) 的AI指令.
11 | + `/video`
12 | + `/move`
13 | + `/animate`
14 | + `/gen`
15 | + `/real`
16 | 2. 支持 `upscale`(放大) 和 `vary`(变体)
17 | 3. 支持使用 `FAST`(高速)模式 和 `RELAX`(休闲)模式
18 | 4. 支持查询任务状态
19 | 5. 支持事件回调
20 | 6. 支持 Docker 部署
21 | 7. 提供基于 Streamlit 的示例
22 |
23 | TODO
24 | ---
25 |
26 | - [ ] 多账号/账号池
27 | - [ ] 操作队列
28 | - [ ] 标准化错误的响应
29 | - [ ] 完善文档
30 |
31 | 参考链接
32 | ---
33 |
34 | - [DomoAI Official Website](https://domoai.app/)
35 | - [Github - discord.py-self](https://github.com/dolfies/discord.py-self)
36 | - [Github - FastAPI](https://github.com/tiangolo/fastapi)
--------------------------------------------------------------------------------
/app/dependencies.py:
--------------------------------------------------------------------------------
1 | from fastapi import Depends, HTTPException
2 | from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
3 | from starlette import status
4 |
5 | from app.settings import Settings, get_settings
6 |
7 | security = HTTPBearer(scheme_name="DomoAI")
8 |
9 |
10 | async def api_auth(
11 | token: HTTPAuthorizationCredentials = Depends(security),
12 | settings: Settings = Depends(get_settings)
13 | ):
14 | if settings.api_auth_token and token.credentials != settings.api_auth_token:
15 | raise HTTPException(
16 | status_code=status.HTTP_401_UNAUTHORIZED,
17 | detail="Could not validate credentials",
18 | headers={"WWW-Authenticate": "Bearer"},
19 | )
20 |
--------------------------------------------------------------------------------
/streamlit_demo.dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.10
2 |
3 | WORKDIR /code
4 |
5 | COPY ./requirements.txt /code/requirements.txt
6 | COPY ./requirements-dev.txt /code/requirements-dev.txt
7 |
8 | RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
9 | RUN pip install --no-cache-dir --upgrade -r /code/requirements-dev.txt
10 |
11 | COPY ./app /code/app
12 | COPY ./streamlit_demo /code/streamlit_demo
13 |
14 | ENV PYTHONUNBUFFERED 1
15 |
16 | ENV STREAMLIT_BASE_URL ''
17 | ENV STREAMLIT_AUTH ''
18 | ENV API_AUTH_TOKEN ''
19 |
20 | EXPOSE 8501
21 |
22 | HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
23 |
24 | ENTRYPOINT ["python", "-m", "streamlit", "run", "/code/streamlit_demo/🏠_Home.py", "--server.port=8501", "--server.address=0.0.0.0"]
--------------------------------------------------------------------------------
/app/settings.py:
--------------------------------------------------------------------------------
1 | import os
2 | from functools import lru_cache
3 | from typing import Optional
4 |
5 | from pydantic_settings import BaseSettings, SettingsConfigDict
6 |
7 |
8 | class Settings(BaseSettings):
9 | model_config = SettingsConfigDict(
10 | env_file=os.environ.get('ENV_FILE', '.env'),
11 | env_file_encoding='utf-8',
12 | extra='ignore'
13 | )
14 | discord_token: str
15 | discord_guild_id: int
16 | discord_channel_id: int
17 |
18 | domoai_application_id: int = 1153984868804468756
19 |
20 | redis_uri: Optional[str] = None
21 |
22 | event_callback_url: Optional[str] = None
23 |
24 | cache_prefix: str = 'domoai:'
25 |
26 | api_auth_token: Optional[str] = None
27 |
28 |
29 | @lru_cache()
30 | def get_settings() -> Settings:
31 | return Settings()
32 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | DomoAI API
2 | ===
3 | **Unofficial** DomoAI API.
4 |
5 | [中文说明/Chinese Documentation](README_CN.md)
6 |
7 | Feature
8 | ---
9 |
10 | 1. Support all AI commands of [DomoAI](https://domoai.app/).
11 | + `/video`
12 | + `/move`
13 | + `/animate`
14 | + `/gen`
15 | + `/real`
16 | 2. Support `upscale` and `vary`
17 | 3. Support `FAST` mode and `RELAX` mode
18 | 4. Support query task status
19 | 5. Support event callback
20 | 6. Support Docker
21 | 7. Streamlit Demo provided
22 |
23 | TODO
24 | ---
25 |
26 | - [ ] Multi-Account and Account Pool
27 | - [ ] Action Queue
28 | - [ ] Standardized error response format
29 | - [ ] Usage documentation
30 |
31 | Reference
32 | ---
33 |
34 | - [DomoAI Official Website](https://domoai.app/)
35 | - [Github - discord.py-self](https://github.com/dolfies/discord.py-self)
36 | - [Github - FastAPI](https://github.com/tiangolo/fastapi)
--------------------------------------------------------------------------------
/app/models/move-models.json:
--------------------------------------------------------------------------------
1 | [{"id": 16004, "name": "Realistic v1", "description": "Realistic model", "cover": {"url": "https://imgo.domoai.app/ai-model/22a058ef-ca13-4e33-b7b6-61577674f84a.png", "width": 1024, "height": 1648}, "prompt_args": "--real v1"}, {"id": 16001, "name": "Anime v6", "description": "Detail anime style 2.0", "cover": {"url": "https://imgo.domoai.app/ai-model/9657fb98-93ef-4052-9df9-4669f34451d6.png", "width": 1024, "height": 1648}, "prompt_args": "--ani v6"}, {"id": 16002, "name": "Anime v2", "description": "Japanese anime style", "cover": {"url": "https://imgo.domoai.app/ai-model/2b2da758-b341-4738-8812-c1163338e14e.png", "width": 1024, "height": 1648}, "prompt_args": "--ani v2"}, {"id": 16003, "name": "Anime v1.1", "description": "Flat color anime style 2.0", "cover": {"url": "https://imgo.domoai.app/ai-model/dbe689f5-795d-424d-8c54-2503c9abc316.png", "width": 1024, "height": 1648}, "prompt_args": "--ani v1.1"}]
--------------------------------------------------------------------------------
/app/event_callback.py:
--------------------------------------------------------------------------------
1 | import enum
2 | from typing import Optional
3 |
4 | import httpx
5 | from tenacity import retry, wait_fixed, stop_after_attempt
6 |
7 | from app.schema import TaskCacheData, TaskStateOut
8 |
9 |
10 | class EventType(enum.Enum):
11 | TASK_SUCCESS = "TASK_SUCCESS"
12 |
13 |
14 | class EventCallback:
15 |
16 | def __init__(self, callback_url: Optional[str]):
17 | self.callback_url = callback_url
18 |
19 | @retry(wait=wait_fixed(2), stop=stop_after_attempt(3), reraise=False)
20 | async def send_task_success(self, task_id: str, data: TaskCacheData):
21 | if not self.callback_url:
22 | return
23 | out = TaskStateOut.from_cache_data(data)
24 | async with httpx.AsyncClient() as client:
25 | await client.post(self.callback_url, json={
26 | 'event': EventType.TASK_SUCCESS.value,
27 | 'task_id': task_id,
28 | 'data': out.model_dump_json()
29 | })
30 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 T0nYie1dYaw
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/scripts/update_models.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | import httpx
5 | from pydantic_settings import BaseSettings, SettingsConfigDict
6 |
7 | current_dir = os.path.dirname(os.path.realpath(__file__))
8 |
9 |
10 | class ScriptSettings(BaseSettings):
11 | model_config = SettingsConfigDict(
12 | env_file=os.path.join(current_dir, '../.env'),
13 | env_file_encoding='utf-8',
14 | extra='ignore'
15 | )
16 | domoai_web_token: str
17 |
18 |
19 | def update_models(req_end_path: str, filename: str, web_token: str):
20 | res = httpx.get(f'https://api.domoai.app/web-post/model/{req_end_path}?offset=0&limit=100&locale=en', headers={
21 | 'Authorization': f'Bearer {web_token}'
22 | })
23 |
24 | models = res.json()['data']['models']
25 | with open(os.path.join(current_dir, f'../app/models/{filename}.json'), 'w') as f:
26 | json.dump(models, f)
27 |
28 |
29 | if __name__ == '__main__':
30 | settings = ScriptSettings()
31 |
32 | update_models(req_end_path='video-models', filename='v2v-models', web_token=settings.domoai_web_token)
33 | update_models(req_end_path='gen-models', filename='gen-models', web_token=settings.domoai_web_token)
34 | update_models(req_end_path='move-models', filename='move-models', web_token=settings.domoai_web_token)
35 |
--------------------------------------------------------------------------------
/streamlit_demo/auth.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import streamlit as st
4 | from streamlit_authenticator import Authenticate
5 |
6 |
7 | def check_password():
8 | STREAMLIT_AUTH = os.environ.get("STREAMLIT_AUTH")
9 |
10 | if not STREAMLIT_AUTH:
11 | return True
12 |
13 | username, password = STREAMLIT_AUTH.split(":")
14 |
15 | if not username or not password:
16 | return True
17 |
18 | """Returns `True` if the user had a correct password."""
19 |
20 | authenticator = Authenticate(
21 | credentials={"usernames": {username: {"email": "demo@example.com", "name": "demo", "password": password}}},
22 | cookie_name="domo_ai_demo",
23 | cookie_key="zgt!uay.aug4juv*FAF"
24 | )
25 |
26 | name, authentication_status, username = authenticator.login('main')
27 |
28 | if authentication_status:
29 | st.markdown("""
30 |
32 | """, unsafe_allow_html=True)
33 | with st.sidebar.container():
34 | authenticator.logout('Logout', 'main')
35 | return True
36 | elif authentication_status == False:
37 | st.error('Username/password is incorrect')
38 | return False
39 | elif authentication_status == None:
40 | st.warning('Please enter your username and password')
41 | return False
42 |
--------------------------------------------------------------------------------
/app/models.py:
--------------------------------------------------------------------------------
1 | import enum
2 | import os
3 | from functools import lru_cache
4 | from typing import List, Optional
5 |
6 | from pydantic import TypeAdapter
7 |
8 | from app.schema import GenModelInfo, MoveModelInfo, VideoModelInfo
9 |
10 | current_dir = os.path.dirname(os.path.realpath(__file__))
11 |
12 |
13 | @lru_cache()
14 | def get_v2v_models() -> List[VideoModelInfo]:
15 | ta = TypeAdapter(List[VideoModelInfo])
16 | v2v_models_json_path = os.path.join(current_dir, "models/v2v-models.json")
17 | with open(v2v_models_json_path, 'r') as f:
18 | models = ta.validate_json(f.read())
19 | return models
20 |
21 |
22 | @lru_cache()
23 | def get_move_models() -> List[MoveModelInfo]:
24 | ta = TypeAdapter(List[MoveModelInfo])
25 | v2v_models_json_path = os.path.join(current_dir, "models/move-models.json")
26 | with open(v2v_models_json_path, 'r') as f:
27 | models = ta.validate_json(f.read())
28 | return models
29 |
30 |
31 | @lru_cache()
32 | def get_gen_models() -> List[GenModelInfo]:
33 | ta = TypeAdapter(List[GenModelInfo])
34 | v2v_models_json_path = os.path.join(current_dir, "models/gen-models.json")
35 | with open(v2v_models_json_path, 'r') as f:
36 | models = ta.validate_json(f.read())
37 | return models
38 |
39 |
40 | @lru_cache()
41 | def get_v2v_model_info_by_instructions(instructions: str) -> Optional[VideoModelInfo]:
42 | all_model_info = get_v2v_models()
43 | for model in all_model_info:
44 | if instructions in model.prompt_args:
45 | return model
46 | return None
47 |
48 |
49 | GenModel = enum.Enum(
50 | 'GenModel',
51 | {x.name.replace(' ', '_').upper(): x.prompt_args.removeprefix('--') for x in get_gen_models()}
52 | )
53 |
54 | MoveModel = enum.Enum(
55 | 'MoveModel',
56 | {x.name.replace(' ', '_').upper(): x.prompt_args.removeprefix('--') for x in get_move_models()}
57 | )
58 |
59 | VideoModel = enum.Enum(
60 | 'VideoModel',
61 | {x.name.replace(' ', '_').upper(): x.prompt_args.removeprefix('--') for x in get_v2v_models()}
62 | )
63 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
19 |
20 |
21 |
29 |
30 |
31 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
--------------------------------------------------------------------------------
/streamlit_demo/pages/Animate.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 |
3 | import httpx
4 | import streamlit as st
5 | from streamlit.runtime.uploaded_file_manager import UploadedFile
6 |
7 | from app.schema import AnimateLength, AnimateIntensity, Mode
8 | from streamlit_demo.auth import check_password
9 | from streamlit_demo.utils import polling_check_state, BASE_URL, BASE_HEADERS
10 |
11 | if not check_password():
12 | st.stop()
13 |
14 | st.title("Animate")
15 |
16 |
17 | async def run_animate(prompt, intensity, length, image: UploadedFile, mode):
18 | async with httpx.AsyncClient(base_url=BASE_URL, headers=BASE_HEADERS) as client:
19 | response = await client.post('/v1/animate', data={
20 | "prompt": prompt,
21 | "intensity": intensity,
22 | "length": length,
23 | "mode": mode if mode != 'auto' else None
24 | }, files={'image': (image.name, image.read(), image.type)}, timeout=30)
25 | if not response.is_success:
26 | st.error(f"Generate Fail: {response}")
27 | return None
28 |
29 | task_id = response.json()['task_id']
30 |
31 | result = await polling_check_state(task_id=task_id)
32 | return result['videos'][0]['proxy_url']
33 |
34 |
35 | with st.form("real_form", border=True):
36 | mode = st.radio(label="Mode(*)", options=['auto'] + list(map(lambda x: x.value, Mode)), horizontal=True)
37 |
38 | length = st.radio(label="Length(*)", options=list(map(lambda x: x.value, AnimateLength)), horizontal=True)
39 |
40 | intensity = st.radio(label="Intensity(*)", options=list(map(lambda x: x.value, AnimateIntensity)), horizontal=True)
41 |
42 | image = st.file_uploader(label="Source Image(*)", type=['jpg', 'png'])
43 |
44 | prompt = st.text_area(label="Prompt")
45 |
46 | submitted = st.form_submit_button("Submit")
47 |
48 | if submitted:
49 | source_col, result_col = st.columns(2)
50 | with source_col:
51 | st.text("Source Image")
52 | if image:
53 | st.image(image)
54 |
55 | with result_col:
56 | st.text("Result Video")
57 | result_video = st.empty()
58 |
59 | with result_col:
60 | with st.spinner('Wait for completion...'):
61 | videos_url = asyncio.run(
62 | run_animate(prompt=prompt, intensity=intensity, length=length, image=image, mode=mode))
63 | if videos_url:
64 | result_video.video(videos_url)
65 | st.success('Done!')
66 |
--------------------------------------------------------------------------------
/streamlit_demo/pages/Move.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 |
3 | import httpx
4 | import streamlit as st
5 | from streamlit.runtime.uploaded_file_manager import UploadedFile
6 |
7 | from app.models import MoveModel
8 | from app.schema import VideoLength, Mode, VideoKey
9 | from streamlit_demo.auth import check_password
10 | from streamlit_demo.utils import polling_check_state, BASE_URL, BASE_HEADERS
11 |
12 | if not check_password():
13 | st.stop()
14 |
15 | st.title("Move")
16 |
17 |
18 | async def run_move(prompt, model, length, video: UploadedFile, image: UploadedFile, mode, video_key):
19 | async with httpx.AsyncClient(base_url=BASE_URL, headers=BASE_HEADERS) as client:
20 | response = await client.post('/v1/move', data={
21 | "prompt": prompt,
22 | "model": model,
23 | "length": length,
24 | "mode": mode if mode != 'auto' else None,
25 | "video_key": video_key if video_key != 'None' else None,
26 | }, files={'video': (video.name, video.read(), video.type), 'image': (image.name, image.read(), image.type)},
27 | timeout=30)
28 | if not response.is_success:
29 | st.error(f"Generate Fail: {response}")
30 | return None
31 |
32 | task_id = response.json()['task_id']
33 |
34 | result = await polling_check_state(task_id=task_id)
35 | return result['videos'][0]['proxy_url']
36 |
37 |
38 | with st.form("move_form", border=True):
39 | mode = st.radio(label="Mode(*)", options=['auto'] + list(map(lambda x: x.value, Mode)), horizontal=True)
40 |
41 | length = st.radio(label="Length(*)", options=list(map(lambda x: x.value, VideoLength)), horizontal=True)
42 |
43 | model = st.selectbox(label="Model(*)", options=list(map(lambda x: x.value, MoveModel)))
44 |
45 | video_key = st.radio(label="Video Key", options=['None'] + list(map(lambda x: x.value, VideoKey)),
46 | horizontal=True)
47 |
48 | prompt = st.text_area(label="Prompt(*)")
49 |
50 | image = st.file_uploader(label="Source Image(*)", type=['jpg', 'png'])
51 | video = st.file_uploader(label="Source Video(*)", type=['mp4'])
52 |
53 | submitted = st.form_submit_button("Submit")
54 |
55 | if submitted:
56 | source_image_col, source_video_col, result_col = st.columns(3)
57 | with source_image_col:
58 | st.text("Source Image")
59 | if image:
60 | st.image(image)
61 |
62 | with source_video_col:
63 | st.text("Source Video")
64 | if video:
65 | st.video(video)
66 |
67 | with result_col:
68 | st.text("Result Video")
69 | result_video = st.empty()
70 |
71 | with result_col:
72 | with st.spinner('Wait for completion...'):
73 | # asyncio.run(asyncio.sleep(5))
74 | # result_video.video(video)
75 | video_url = asyncio.run(
76 | run_move(prompt=prompt, model=model, length=length, video=video, image=image, mode=mode, video_key=video_key))
77 | if video_url:
78 | result_video.video(video_url)
79 | st.success('Done!')
80 |
--------------------------------------------------------------------------------
/app/cache.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Any, Optional
4 |
5 | from redis.asyncio import Redis, from_url
6 |
7 | from app.schema import TaskCacheData
8 |
9 |
10 | class Cache:
11 | def __init__(self, prefix: str = ''):
12 | self.prefix = prefix
13 |
14 | async def set_value(self, key: str, value: Any, ex: Optional[int] = None):
15 | pass
16 |
17 | async def get_value(self, key: str):
18 | pass
19 |
20 | async def close(self):
21 | pass
22 |
23 | @staticmethod
24 | def __get_message_id2task_id_key(message_id: str) -> str:
25 | return f'message_id2task_id:{message_id}'
26 |
27 | async def set_message_id2task_id(
28 | self,
29 | message_id: str,
30 | task_id: str
31 | ):
32 | await self.set_value(key=self.__get_message_id2task_id_key(message_id), value=task_id)
33 |
34 | async def get_task_id_by_message_id(self, message_id: str) -> Optional[str]:
35 | value = await self.get_value(key=self.__get_message_id2task_id_key(message_id))
36 | if value:
37 | try:
38 | return str(value)
39 | except Exception as e:
40 | # TODO:
41 | pass
42 | return None
43 |
44 | @staticmethod
45 | def __get_task_id2data_key(task_id: str) -> str:
46 | return f'task_id2data:{task_id}'
47 |
48 | async def set_task_id2data(
49 | self,
50 | task_id: str,
51 | data: TaskCacheData
52 | ):
53 | await self.set_value(key=self.__get_task_id2data_key(task_id), value=data.model_dump_json())
54 |
55 | async def get_task_data_by_id(self, task_id: str) -> Optional[TaskCacheData]:
56 | value = await self.get_value(key=self.__get_task_id2data_key(task_id))
57 | if value:
58 | try:
59 | return TaskCacheData.model_validate_json(value)
60 | except Exception as e:
61 | # TODO:
62 | pass
63 | return None
64 |
65 |
66 | class MemoryCache(Cache):
67 | def __init__(self, prefix: str = ''):
68 | super().__init__(prefix=prefix)
69 | self.data = {}
70 |
71 | async def set_value(self, key: str, value: Any, ex: Optional[int] = None):
72 | self.data[f"{self.prefix}{key}"] = value
73 |
74 | async def get_value(self, key: str):
75 | return self.data.get(f"{self.prefix}{key}")
76 |
77 |
78 | class RedisCache(Cache):
79 | def __init__(self, redis: Redis, prefix: str = ''):
80 | super().__init__(prefix=prefix)
81 | self.redis = redis
82 |
83 | async def set_value(self, key: str, value: Any, ex: Optional[int] = None):
84 | await self.redis.set(name=f"{self.prefix}{key}", value=value, ex=ex)
85 |
86 | async def get_value(self, key: str):
87 | return await self.redis.get(name=f"{self.prefix}{key}")
88 |
89 | async def close(self):
90 | await self.redis.close()
91 |
92 | @staticmethod
93 | async def init_redis_pool(redis_uri: str) -> Redis:
94 | redis = await from_url(
95 | redis_uri,
96 | encoding="utf-8",
97 | decode_responses=True,
98 | )
99 | return redis
100 |
--------------------------------------------------------------------------------
/streamlit_demo/utils.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import os
3 | from typing import Optional, List, Callable
4 |
5 | import httpx
6 | import streamlit as st
7 | from pydantic import BaseModel
8 |
9 | BASE_URL = os.environ.get('STREAMLIT_BASE_URL')
10 |
11 | API_AUTH_TOKEN = os.environ.get('API_AUTH_TOKEN')
12 |
13 | if not BASE_URL:
14 | BASE_URL = 'http://127.0.0.1:8000'
15 |
16 | if API_AUTH_TOKEN:
17 | BASE_HEADERS = {
18 | 'Authorization': f'Bearer {API_AUTH_TOKEN}'
19 | }
20 | else:
21 | BASE_HEADERS = {
22 |
23 | }
24 |
25 |
26 | async def polling_check_state(task_id: str) -> Optional[dict]:
27 | async with httpx.AsyncClient(base_url=BASE_URL, headers=BASE_HEADERS) as client:
28 | while True:
29 | response = await client.get(f'/v1/task-data/{task_id}')
30 | if response.status_code == 404:
31 | return None
32 | response_json = response.json()
33 | if response_json['status'] == 'SUCCESS':
34 | return response_json
35 | await asyncio.sleep(1)
36 |
37 |
38 | def build_upscale_vary_buttons(
39 | task_id: str,
40 | upscale_indices: List[int],
41 | vary_indices: List[int],
42 | on_click_upscale: Callable[[str, int], None],
43 | on_click_vary: Callable[[str, int], None]
44 | ):
45 | upscale_container, vary_container = st.columns(2)
46 | upscale_cols_1, upscale_cols_2 = upscale_container.columns(2)
47 | upscale_cols_3, upscale_cols_4 = upscale_container.columns(2)
48 | upscale_cols_1.button(
49 | label=f":mag: U1",
50 | use_container_width=True,
51 | disabled=1 not in upscale_indices,
52 | on_click=on_click_upscale,
53 | args=(task_id, 1),
54 | key=f'U1-{task_id}'
55 | )
56 | upscale_cols_2.button(
57 | label=f":mag: U2",
58 | use_container_width=True,
59 | disabled=2 not in upscale_indices,
60 | on_click=on_click_upscale,
61 | args=(task_id, 2),
62 | key=f'U2-{task_id}'
63 | )
64 | upscale_cols_3.button(
65 | label=f":mag: U3",
66 | use_container_width=True,
67 | disabled=3 not in upscale_indices,
68 | on_click=on_click_upscale,
69 | args=(task_id, 3),
70 | key=f'U3-{task_id}'
71 | )
72 | upscale_cols_4.button(
73 | label=f":mag: U4",
74 | use_container_width=True,
75 | disabled=4 not in upscale_indices,
76 | on_click=on_click_upscale,
77 | args=(task_id, 4),
78 | key=f'U4-{task_id}'
79 | )
80 |
81 | vary_cols_1, vary_cols_2 = vary_container.columns(2)
82 | vary_cols_3, vary_cols_4 = vary_container.columns(2)
83 | vary_cols_1.button(
84 | label=f":magic_wand: V1",
85 | use_container_width=True,
86 | disabled=1 not in vary_indices,
87 | on_click=on_click_vary,
88 | args=(task_id, 1),
89 | key=f'V1-{task_id}'
90 | )
91 | vary_cols_2.button(
92 | label=f":magic_wand: V2",
93 | use_container_width=True,
94 | disabled=2 not in vary_indices,
95 | on_click=on_click_vary,
96 | args=(task_id, 2),
97 | key=f'V2-{task_id}'
98 | )
99 | vary_cols_3.button(
100 | label=f":magic_wand: V3",
101 | use_container_width=True,
102 | disabled=3 not in vary_indices,
103 | on_click=on_click_vary,
104 | args=(task_id, 3),
105 | key=f'V3-{task_id}'
106 | )
107 | vary_cols_4.button(
108 | label=f":magic_wand: V4",
109 | use_container_width=True,
110 | disabled=4 not in vary_indices,
111 | on_click=on_click_vary,
112 | args=(task_id, 4),
113 | key=f'V4-{task_id}'
114 | )
115 |
116 |
117 | class UVResult(BaseModel):
118 | task_id: str
119 | image_url: str
120 | upscale_indices: List[int]
121 | vary_indices: List[int]
122 |
--------------------------------------------------------------------------------
/app/models/gen-models.json:
--------------------------------------------------------------------------------
1 | [{"id": 10017, "name": "Animate XL v1", "description": "Enhanced anime models", "cover": {"url": "https://imgo.domoai.app/ai-model/4b87efa6-a2b3-4fb2-8160-42d97a3a7e45.jpg", "width": 1024, "height": 1648}, "prompt_args": "--anixl v1"}, {"id": 10026, "name": "Animate XL v2", "description": "Detail anime model", "cover": {"url": "https://imgo.domoai.app/ai-model/a8668543-be38-4783-bb2d-c9643591529c.jpg", "width": 1024, "height": 1648}, "prompt_args": "--anixl v2"}, {"id": 10018, "name": "Realistic XL v1", "description": "Enhanced realistic model", "cover": {"url": "https://imgo.domoai.app/ai-model/40f18070-1dea-4c42-be5d-0cfd57fe13dd.jpg", "width": 1024, "height": 1648}, "prompt_args": "--realxl v1"}, {"id": 10027, "name": "Realistic XL v2", "description": "Dark gothic style", "cover": {"url": "https://imgo.domoai.app/ai-model/2cdd3d92-b009-4799-9c55-12ddebf91a02.jpg", "width": 1024, "height": 1648}, "prompt_args": "--realxl v2"}, {"id": 10019, "name": "Illustration XL v1", "description": "Enhanced illustration model", "cover": {"url": "https://imgo.domoai.app/ai-model/3d4e4935-d73c-4cb8-907e-c619f53a7690.jpg", "width": 1024, "height": 1648}, "prompt_args": "--illusxl v1"}, {"id": 10020, "name": "Illustration XL v2", "description": "Dark comic style", "cover": {"url": "https://imgo.domoai.app/ai-model/605a9509-7670-4c3e-a450-a14ca9748b42.jpg", "width": 1024, "height": 1648}, "prompt_args": "--illusxl v2"}, {"id": 10022, "name": "Animate v1", "description": "Dreamy japanese anime", "cover": {"url": "https://imgo.domoai.app/ai-model/18c4b38d-945c-46bd-b048-a4d1eacee19b.jpg", "width": 1024, "height": 1648}, "prompt_args": "--ani v1"}, {"id": 10011, "name": "Animate v2", "description": "Japanese anime style, more 3D", "cover": {"url": "https://imgo.domoai.app/ai-model/6a4d877e-2dc6-45c0-9435-b69426859b03.jpg", "width": 1024, "height": 1648}, "prompt_args": "--ani v2"}, {"id": 10012, "name": "Animate v3", "description": "American comics style", "cover": {"url": "https://imgo.domoai.app/ai-model/ef552c4f-365f-4446-a9e2-dd7692807f86.jpg", "width": 1024, "height": 1648}, "prompt_args": "--ani v3"}, {"id": 10006, "name": "Animate v4", "description": "CG style", "cover": {"url": "https://imgo.domoai.app/ai-model/9cd447c9-b141-42de-82d4-5e09e0e0d1bb.jpg", "width": 1024, "height": 1648}, "prompt_args": "--ani v4"}, {"id": 10023, "name": "Animate v5", "description": "Line comic style", "cover": {"url": "https://imgo.domoai.app/ai-model/9dbf7f9b-0d5d-40ba-993d-6d2a6222707f.jpg", "width": 1024, "height": 1648}, "prompt_args": "--ani v5"}, {"id": 10024, "name": "Animate v6", "description": "Watercolor anime", "cover": {"url": "https://imgo.domoai.app/ai-model/fc7eee20-aabe-42ef-b4e3-b9cc5fe0c06d.jpg", "width": 1024, "height": 1648}, "prompt_args": "--ani v6"}, {"id": 10025, "name": "Animate v7", "description": "Oilpainting anime", "cover": {"url": "https://imgo.domoai.app/ai-model/310443d1-6238-4a8e-be91-0130f0a0b865.jpg", "width": 1024, "height": 1648}, "prompt_args": "--ani v7"}, {"id": 10028, "name": "Illustration v1", "description": "3D cartoon style", "cover": {"url": "https://imgo.domoai.app/ai-model/885872cf-c546-4da0-8707-113280f85551.jpg", "width": 1024, "height": 1648}, "prompt_args": "--illus v1"}, {"id": 10029, "name": "Illustration v2", "description": "Storybook cartoon style", "cover": {"url": "https://imgo.domoai.app/ai-model/51822397-d1e4-4921-965d-21b587a84613.jpg", "width": 1024, "height": 1648}, "prompt_args": "--illus v2"}, {"id": 10030, "name": "Realistic v1", "description": "CG art", "cover": {"url": "https://imgo.domoai.app/ai-model/fc673afd-cb14-4a6c-a4ac-8736721ca98e.jpg", "width": 1024, "height": 1648}, "prompt_args": "--real v1"}, {"id": 10031, "name": "Realistic v2", "description": "Realistic portrait", "cover": {"url": "https://imgo.domoai.app/ai-model/cefd410a-4d79-44a5-9fa9-9218818a0e04.jpg", "width": 1024, "height": 1648}, "prompt_args": "--real v2"}, {"id": 10016, "name": "Realistic v3", "description": "Game character style", "cover": {"url": "https://imgo.domoai.app/ai-model/7c656b9a-5a4a-4f2a-9694-57fb6a00e23c.jpg", "width": 1024, "height": 1648}, "prompt_args": "--real v3"}]
--------------------------------------------------------------------------------
/streamlit_demo/pages/Video.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from typing import Optional
3 |
4 | import httpx
5 | import streamlit as st
6 | from streamlit.runtime.uploaded_file_manager import UploadedFile
7 |
8 | from app.models import VideoModel, get_v2v_model_info_by_instructions
9 | from app.schema import VideoLength, VideoReferMode, Mode, VideoKey
10 | from streamlit_demo.auth import check_password
11 | from streamlit_demo.utils import polling_check_state, BASE_URL, BASE_HEADERS
12 |
13 | if not check_password():
14 | st.stop()
15 |
16 | st.title("Video")
17 |
18 |
19 | async def run_video(prompt, refer_mode, model, length, video: UploadedFile, mode,
20 | video_key, subject_only,
21 | lip_sync, image: Optional[UploadedFile] = None):
22 | async with httpx.AsyncClient(base_url=BASE_URL, headers=BASE_HEADERS) as client:
23 | files = {'video': (video.name, video.read(), video.type)}
24 | if image:
25 | files['image'] = (image.name, image.read(), image.type)
26 | response = await client.post('/v1/video', data={
27 | "prompt": prompt,
28 | "refer_mode": refer_mode,
29 | "model": model,
30 | "length": length,
31 | "mode": mode if mode != 'auto' else None,
32 | "video_key": video_key if video_key != 'None' and subject_only is not True else None,
33 | "subject_only": subject_only,
34 | "lip_sync": lip_sync,
35 | }, files=files, timeout=30)
36 | if not response.is_success:
37 | st.error(f"Generate Fail: {response}")
38 | return None
39 |
40 | task_id = response.json().get('task_id', None)
41 | if task_id is None:
42 | return False, response.json()
43 |
44 | result = await polling_check_state(task_id=task_id)
45 | return True, result['videos'][0]['proxy_url']
46 |
47 |
48 | mode = st.radio(label="Mode(*)", options=['auto'] + list(map(lambda x: x.value, Mode)), horizontal=True)
49 |
50 | length = st.radio(label="Length(*)", options=list(map(lambda x: x.value, VideoLength)), horizontal=True)
51 |
52 | video_models_value = list(map(lambda x: x.value, VideoModel))
53 |
54 | model = st.selectbox(label="Model(*)", options=video_models_value,
55 | index=video_models_value.index(VideoModel.ANIME_V1.value))
56 |
57 | model_info = get_v2v_model_info_by_instructions(model)
58 |
59 | refer_mode = st.radio(label="Refer Mode(*)",
60 | options=list(map(lambda x: x.value, filter(lambda x: x in model_info.allowed_refer_modes,
61 | list(VideoReferMode)))),
62 | horizontal=True)
63 |
64 | lip_sync = st.checkbox(label="Lip Sync", key='lips', disabled=model_info.allowed_lip_sync is False)
65 |
66 | subject_only = st.checkbox(label="Subject Only", key='so')
67 |
68 | video_key = st.selectbox(label="Video Key", options=['None'] + list(map(lambda x: x.value, VideoKey)),
69 | disabled=subject_only)
70 |
71 | prompt = st.text_area(label="Prompt(*)")
72 |
73 | video = st.file_uploader(label="Source Video(*)", type=['mp4'])
74 |
75 | image = None
76 | if model_info.allowed_reference_image:
77 | image = st.file_uploader(label="Source Image(*)", type=['jpg', 'jpeg', 'png'])
78 |
79 | submitted = st.button("Submit")
80 |
81 | if submitted:
82 | source_col, result_col = st.columns(2)
83 | with source_col:
84 | st.text("Source Video")
85 | if video:
86 | st.video(video)
87 | with source_col:
88 | st.text("Source Image")
89 | if image:
90 | st.image(image)
91 |
92 | with result_col:
93 | st.text("Result Video")
94 | result_video = st.empty()
95 |
96 | with result_col:
97 | with st.spinner('Wait for completion...'):
98 | # asyncio.run(asyncio.sleep(5))
99 | # result_video.video(video)
100 | status, result = asyncio.run(
101 | run_video(
102 | prompt=prompt, refer_mode=refer_mode, model=model, length=length, video=video, mode=mode,
103 | video_key=video_key, subject_only=subject_only, lip_sync=lip_sync, image=image,
104 | )
105 | )
106 | if not status:
107 | st.text(result)
108 |
109 | if status and result:
110 | result_video.video(result)
111 | st.success('Done!')
112 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
2 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
3 |
4 | # User-specific stuff
5 | .idea/**/workspace.xml
6 | .idea/**/tasks.xml
7 | .idea/**/usage.statistics.xml
8 | .idea/**/dictionaries
9 | .idea/**/shelf
10 |
11 | # AWS User-specific
12 | .idea/**/aws.xml
13 |
14 | # Generated files
15 | .idea/**/contentModel.xml
16 |
17 | # Sensitive or high-churn files
18 | .idea/**/dataSources/
19 | .idea/**/dataSources.ids
20 | .idea/**/dataSources.local.xml
21 | .idea/**/sqlDataSources.xml
22 | .idea/**/dynamic.xml
23 | .idea/**/uiDesigner.xml
24 | .idea/**/dbnavigator.xml
25 |
26 | # Gradle
27 | .idea/**/gradle.xml
28 | .idea/**/libraries
29 |
30 | # Gradle and Maven with auto-import
31 | # When using Gradle or Maven with auto-import, you should exclude module files,
32 | # since they will be recreated, and may cause churn. Uncomment if using
33 | # auto-import.
34 | # .idea/artifacts
35 | # .idea/compiler.xml
36 | # .idea/jarRepositories.xml
37 | # .idea/modules.xml
38 | # .idea/*.iml
39 | # .idea/modules
40 | # *.iml
41 | # *.ipr
42 |
43 | # CMake
44 | cmake-build-*/
45 |
46 | # Mongo Explorer plugin
47 | .idea/**/mongoSettings.xml
48 |
49 | # File-based project format
50 | *.iws
51 |
52 | # IntelliJ
53 | out/
54 |
55 | # mpeltonen/sbt-idea plugin
56 | .idea_modules/
57 |
58 | # JIRA plugin
59 | atlassian-ide-plugin.xml
60 |
61 | # Cursive Clojure plugin
62 | .idea/replstate.xml
63 |
64 | # Crashlytics plugin (for Android Studio and IntelliJ)
65 | com_crashlytics_export_strings.xml
66 | crashlytics.properties
67 | crashlytics-build.properties
68 | fabric.properties
69 |
70 | # Editor-based Rest Client
71 | .idea/httpRequests
72 |
73 | # Android studio 3.1+ serialized cache file
74 | .idea/caches/build_file_checksums.ser
75 |
76 | # Byte-compiled / optimized / DLL files
77 | __pycache__/
78 | *.py[cod]
79 | *$py.class
80 |
81 | # C extensions
82 | *.so
83 |
84 | # Distribution / packaging
85 | .Python
86 | build/
87 | develop-eggs/
88 | dist/
89 | downloads/
90 | eggs/
91 | .eggs/
92 | lib/
93 | lib64/
94 | parts/
95 | sdist/
96 | var/
97 | wheels/
98 | share/python-wheels/
99 | *.egg-info/
100 | .installed.cfg
101 | *.egg
102 | MANIFEST
103 |
104 | # PyInstaller
105 | # Usually these files are written by a python script from a template
106 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
107 | *.manifest
108 | *.spec
109 |
110 | # Installer logs
111 | pip-log.txt
112 | pip-delete-this-directory.txt
113 |
114 | # Unit test / coverage reports
115 | htmlcov/
116 | .tox/
117 | .nox/
118 | .coverage
119 | .coverage.*
120 | .cache
121 | nosetests.xml
122 | coverage.xml
123 | *.cover
124 | *.py,cover
125 | .hypothesis/
126 | .pytest_cache/
127 | cover/
128 |
129 | # Translations
130 | *.mo
131 | *.pot
132 |
133 | # Django stuff:
134 | *.log
135 | local_settings.py
136 | db.sqlite3
137 | db.sqlite3-journal
138 |
139 | # Flask stuff:
140 | instance/
141 | .webassets-cache
142 |
143 | # Scrapy stuff:
144 | .scrapy
145 |
146 | # Sphinx documentation
147 | docs/_build/
148 |
149 | # PyBuilder
150 | .pybuilder/
151 | target/
152 |
153 | # Jupyter Notebook
154 | .ipynb_checkpoints
155 |
156 | # IPython
157 | profile_default/
158 | ipython_config.py
159 |
160 | # pyenv
161 | # For a library or package, you might want to ignore these files since the code is
162 | # intended to run in multiple environments; otherwise, check them in:
163 | # .python-version
164 |
165 | # pipenv
166 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
167 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
168 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
169 | # install all needed dependencies.
170 | #Pipfile.lock
171 |
172 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
173 | __pypackages__/
174 |
175 | # Celery stuff
176 | celerybeat-schedule
177 | celerybeat.pid
178 |
179 | # SageMath parsed files
180 | *.sage.py
181 |
182 | # Environments
183 | .env
184 | .venv
185 | env/
186 | venv/
187 | ENV/
188 | env.bak/
189 | venv.bak/
190 |
191 | # Spyder project settings
192 | .spyderproject
193 | .spyproject
194 |
195 | # Rope project settings
196 | .ropeproject
197 |
198 | # mkdocs documentation
199 | /site
200 |
201 | # mypy
202 | .mypy_cache/
203 | .dmypy.json
204 | dmypy.json
205 |
206 | # Pyre type checker
207 | .pyre/
208 |
209 | # pytype static type analyzer
210 | .pytype/
211 |
212 | # Cython debug symbols
213 | cython_debug/
214 |
215 | .env.*
216 | !.env.template
217 | publish*.sh
--------------------------------------------------------------------------------
/app/schema.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import enum
4 | from typing import Optional, List, Dict
5 |
6 | import discord
7 | from pydantic import BaseModel
8 |
9 |
10 | class Mode(enum.Enum):
11 | FAST = "fast"
12 | RELAX = "relax"
13 |
14 |
15 | class AnimateIntensity(enum.Enum):
16 | LOW = "low"
17 | MEDIUM = "mid"
18 | HIGH = "high"
19 |
20 |
21 | class AnimateLength(enum.Enum):
22 | LENGTH_3S = "3s"
23 | LENGTH_5S = "5s"
24 |
25 |
26 | class BaseModelInfo(BaseModel):
27 | name: str
28 | prompt_args: str
29 |
30 |
31 | class GenModelInfo(BaseModelInfo):
32 | pass
33 |
34 |
35 | class MoveModelInfo(BaseModelInfo):
36 | pass
37 |
38 |
39 | class VideoModelInfo(BaseModelInfo):
40 | allowed_refer_modes: List[VideoReferMode]
41 | allowed_lip_sync: bool
42 | allowed_reference_image: bool
43 |
44 |
45 | class VideoReferMode(enum.Enum):
46 | REFER_TO_SOURCE_VIDEO_MORE = "VIDEO_MORE"
47 | REFER_TO_MY_PROMPT_MORE = "PROMPT_MORE"
48 |
49 |
50 | class VideoApiError(enum.Enum):
51 | VIDEO_MODEL_ERROR: 10000
52 | NOT_ALLOW_REFER: 10001
53 | NOT_ALLOW_LIP_SYNC: 10002
54 | MODEL_NEED_REFERENCE_IMAGE: 10003
55 |
56 |
57 | class VideoKey(enum.Enum):
58 | WHITE = "WHITE"
59 | BLACK = "BLACK"
60 | RED = "RED"
61 | ORANGE = "ORANGE"
62 | YELLOW = "YELLOW"
63 | CYAN = "CYAN"
64 | GREEN = "GREEN"
65 | BLUE = "BLUE"
66 | PINK = "PINK"
67 | BROWN = "BROWN"
68 | PURPLE = "PURPLE"
69 | GREY = "GREY"
70 | GRAY = "GRAY"
71 | MAGENTA = "MAGENTA"
72 | NAVY = "NAVY"
73 | BEIGE = "BEIGE"
74 | GOLD = "GOLD"
75 | SILVER = "SILVER"
76 |
77 |
78 | class VideoLength(enum.Enum):
79 | LENGTH_3S = "3s"
80 | LENGTH_5S = "5s"
81 | LENGTH_10S = "10s"
82 | LENGTH_20S = "20s"
83 |
84 |
85 | class TaskAsset(BaseModel):
86 | size: Optional[int] = None
87 | width: Optional[int] = None
88 | height: Optional[int] = None
89 | url: str
90 | proxy_url: str
91 | content_type: Optional[str] = None
92 |
93 | @staticmethod
94 | def from_attachment(attachment: discord.Attachment) -> TaskAsset:
95 | return TaskAsset(
96 | size=attachment.size,
97 | width=attachment.width,
98 | height=attachment.height,
99 | url=attachment.url,
100 | proxy_url=attachment.proxy_url,
101 | content_type=attachment.content_type
102 | )
103 |
104 |
105 | class TaskStatus(enum.Enum):
106 | RUNNING = "RUNNING"
107 | SUCCESS = "SUCCESS"
108 |
109 |
110 | class TaskCommand(enum.Enum):
111 | GEN = "GEN"
112 | REAL = "REAL"
113 | MOVE = "MOVE"
114 | VIDEO = "VIDEO"
115 | ANIMATE = "ANIMATE"
116 |
117 |
118 | class TaskCacheData(BaseModel):
119 | command: TaskCommand
120 | channel_id: str
121 | guild_id: Optional[str]
122 | message_id: str
123 | images: Optional[List[TaskAsset]] = None
124 | videos: Optional[List[TaskAsset]] = None
125 | status: TaskStatus
126 | upscale_custom_ids: Optional[Dict[str, str]] = None
127 | vary_custom_ids: Optional[Dict[str, str]] = None
128 |
129 |
130 | class CreateTaskOut(BaseModel):
131 | success: bool
132 | task_id: str
133 | message_id: str
134 |
135 |
136 | class TaskStateOut(BaseModel):
137 | command: TaskCommand
138 | channel_id: str
139 | guild_id: Optional[str]
140 | message_id: str
141 | images: Optional[List[TaskAsset]] = None
142 | videos: Optional[List[TaskAsset]] = None
143 | status: TaskStatus
144 | upscale_indices: Optional[List[int]] = None
145 | vary_indices: Optional[List[int]] = None
146 |
147 | @staticmethod
148 | def from_cache_data(data: TaskCacheData) -> TaskStateOut:
149 | upscale_index_map = {
150 | "U1": 1,
151 | "U2": 2,
152 | "U3": 3,
153 | "U4": 4
154 | }
155 | vary_index_map = {
156 | "V1": 1,
157 | "V2": 2,
158 | "V3": 3,
159 | "V4": 4
160 | }
161 | if data.upscale_custom_ids is not None:
162 | upscale_indices = []
163 | for key, value in data.upscale_custom_ids.items():
164 | index = upscale_index_map.get(key)
165 | if index is not None:
166 | upscale_indices.append(index)
167 | upscale_indices.sort()
168 | else:
169 | upscale_indices = None
170 | if data.vary_custom_ids is not None:
171 | vary_indices = []
172 | for key, value in data.vary_custom_ids.items():
173 | index = vary_index_map.get(key)
174 | if index is not None:
175 | vary_indices.append(index)
176 | vary_indices.sort()
177 | else:
178 | vary_indices = None
179 |
180 | return TaskStateOut(
181 | command=data.command,
182 | channel_id=data.channel_id,
183 | guild_id=data.guild_id,
184 | message_id=data.message_id,
185 | images=data.images,
186 | videos=data.videos,
187 | status=data.status,
188 | upscale_indices=upscale_indices,
189 | vary_indices=vary_indices
190 | )
191 |
--------------------------------------------------------------------------------
/streamlit_demo/pages/Real.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 |
3 | import httpx
4 | import streamlit as st
5 |
6 | from app.schema import Mode
7 | from streamlit_demo.auth import check_password
8 | from streamlit_demo.utils import polling_check_state, build_upscale_vary_buttons, UVResult, BASE_URL, \
9 | BASE_HEADERS
10 |
11 | if not check_password():
12 | st.stop()
13 |
14 | st.title("REAL")
15 |
16 | if 'real_result' not in st.session_state:
17 | st.session_state.real_result = None
18 |
19 | if 'real_uv_results' not in st.session_state:
20 | st.session_state.real_uv_results = []
21 |
22 |
23 | async def real(prompt, image, mode):
24 | async with httpx.AsyncClient(base_url=BASE_URL, headers=BASE_HEADERS) as client:
25 | if prompt:
26 | data = {
27 | "prompt": prompt,
28 | "mode": mode if mode != 'auto' else None
29 | }
30 | else:
31 | data = {
32 | "mode": mode if mode != 'auto' else None
33 | }
34 | response = await client.post(
35 | '/v1/real',
36 | data=data,
37 | files={'image': (image.name, image.read(), image.type)},
38 | timeout=30
39 | )
40 | if not response.is_success:
41 | st.error(f"Generate Fail: {response}")
42 | return None
43 |
44 | task_id = response.json()['task_id']
45 |
46 | result = await polling_check_state(task_id=task_id)
47 | return task_id, result['images'][0]['proxy_url'], result['upscale_indices'], result['vary_indices']
48 |
49 |
50 | async def upscale(task_id, index):
51 | async with httpx.AsyncClient(base_url=BASE_URL, headers=BASE_HEADERS) as client:
52 | response = await client.post('/v1/upscale', data={
53 | "task_id": task_id,
54 | "index": index
55 | }, timeout=30)
56 | if not response.is_success:
57 | st.error(f"Upscale Fail: {response}")
58 | return None
59 |
60 | task_id = response.json()['task_id']
61 |
62 | result = await polling_check_state(task_id=task_id)
63 | return task_id, result['images'][0]['proxy_url'], result['upscale_indices'], result['vary_indices']
64 |
65 |
66 | async def vary(task_id, index):
67 | async with httpx.AsyncClient(base_url=BASE_URL, headers=BASE_HEADERS) as client:
68 | response = await client.post('/v1/vary', data={
69 | "task_id": task_id,
70 | "index": index
71 | }, timeout=30)
72 | if not response.is_success:
73 | st.error(f"Vary Fail: {response}")
74 | return None
75 |
76 | task_id = response.json()['task_id']
77 |
78 | result = await polling_check_state(task_id=task_id)
79 | return task_id, result['images'][0]['proxy_url'], result['upscale_indices'], result['vary_indices']
80 |
81 |
82 | with st.form("real_form", border=False):
83 | mode = st.radio(label="Mode(*)", options=['auto'] + list(map(lambda x: x.value, Mode)), horizontal=True)
84 | image = st.file_uploader(label="Reference Image(*)", type=['jpg', 'png'])
85 | prompt = st.text_area(label="Prompt")
86 | submitted = st.form_submit_button("Submit")
87 |
88 |
89 | def on_click_upscale(task_id: str, index: int):
90 | with st.spinner('Wait for completion...'):
91 | task_id, images_url, upscale_indices, vary_indices = asyncio.run(upscale(task_id=task_id, index=index))
92 | st.session_state.real_uv_results.append(UVResult(
93 | task_id=task_id,
94 | image_url=images_url,
95 | upscale_indices=upscale_indices,
96 | vary_indices=vary_indices
97 | ))
98 |
99 |
100 | def on_click_vary(task_id: str, index: int):
101 | with st.spinner('Wait for completion...'):
102 | task_id, images_url, upscale_indices, vary_indices = asyncio.run(vary(task_id=task_id, index=index))
103 | st.session_state.real_uv_results.append(UVResult(
104 | task_id=task_id,
105 | image_url=images_url,
106 | upscale_indices=upscale_indices,
107 | vary_indices=vary_indices
108 | ))
109 |
110 |
111 | if submitted:
112 | st.session_state.real_result = None
113 |
114 | if submitted or st.session_state.real_result:
115 | if image:
116 | source_col, result_col = st.columns(2)
117 | with source_col:
118 | st.text("Reference")
119 | if image:
120 | st.image(image)
121 | else:
122 | result_col = st.container()
123 |
124 | with result_col:
125 | st.text("Result")
126 | result_image = st.empty()
127 |
128 | with result_col:
129 | with st.spinner('Wait for completion...'):
130 | if not st.session_state.real_result:
131 | st.session_state.real_result = asyncio.run(real(prompt, image, mode))
132 | task_id, image_url, upscale_indices, vary_indices = st.session_state.real_result
133 | result_image.image(image_url)
134 | build_upscale_vary_buttons(
135 | task_id=task_id,
136 | upscale_indices=upscale_indices,
137 | vary_indices=vary_indices,
138 | on_click_upscale=on_click_upscale,
139 | on_click_vary=on_click_vary
140 | )
141 |
142 | for item in st.session_state.real_uv_results:
143 | result: UVResult = item
144 | st.image(result.image_url)
145 | build_upscale_vary_buttons(
146 | task_id=result.task_id,
147 | upscale_indices=result.upscale_indices,
148 | vary_indices=result.vary_indices,
149 | on_click_upscale=on_click_upscale,
150 | on_click_vary=on_click_vary
151 | )
152 |
--------------------------------------------------------------------------------
/streamlit_demo/pages/Gen.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from typing import Optional
3 |
4 | import httpx
5 | import streamlit as st
6 | from streamlit.runtime.uploaded_file_manager import UploadedFile
7 |
8 | from app.models import GenModel
9 | from app.schema import Mode
10 | from streamlit_demo.auth import check_password
11 | from streamlit_demo.utils import polling_check_state, build_upscale_vary_buttons, UVResult, BASE_URL, \
12 | BASE_HEADERS
13 |
14 | if not check_password():
15 | st.stop()
16 |
17 | st.title("Gen")
18 |
19 | if 'gen_result' not in st.session_state:
20 | st.session_state.gen_result = None
21 |
22 | if 'gen_uv_results' not in st.session_state:
23 | st.session_state.gen_uv_results = []
24 |
25 |
26 | async def gen(prompt: str, image: Optional[UploadedFile], mode: str, model: Optional[str]):
27 | async with httpx.AsyncClient(base_url=BASE_URL, headers=BASE_HEADERS) as client:
28 | data = {
29 | "prompt": prompt,
30 | "mode": mode if mode != 'auto' else None
31 | }
32 | if model:
33 | data['model'] = model
34 | response = await client.post(
35 | '/v1/gen',
36 | data=data,
37 | files={'image': (image.name, image.read(), image.type)} if image else None,
38 | timeout=30
39 | )
40 | if not response.is_success:
41 | st.error(f"Generate Fail: {response}")
42 | return None
43 |
44 | task_id = response.json()['task_id']
45 |
46 | result = await polling_check_state(task_id=task_id)
47 | return task_id, result['images'][0]['proxy_url'], result['upscale_indices'], result['vary_indices']
48 |
49 |
50 | async def upscale(task_id, index):
51 | async with httpx.AsyncClient(base_url=BASE_URL, headers=BASE_HEADERS) as client:
52 | response = await client.post('/v1/upscale', data={
53 | "task_id": task_id,
54 | "index": index
55 | }, timeout=30)
56 | if not response.is_success:
57 | st.error(f"Upscale Fail: {response}")
58 | return None
59 |
60 | task_id = response.json()['task_id']
61 |
62 | result = await polling_check_state(task_id=task_id)
63 | return task_id, result['images'][0]['proxy_url'], result['upscale_indices'], result['vary_indices']
64 |
65 |
66 | async def vary(task_id, index):
67 | async with httpx.AsyncClient(base_url=BASE_URL, headers=BASE_HEADERS) as client:
68 | response = await client.post('/v1/vary', data={
69 | "task_id": task_id,
70 | "index": index
71 | }, timeout=30)
72 | if not response.is_success:
73 | st.error(f"Vary Fail: {response}")
74 | return None
75 |
76 | task_id = response.json()['task_id']
77 |
78 | result = await polling_check_state(task_id=task_id)
79 | return task_id, result['images'][0]['proxy_url'], result['upscale_indices'], result['vary_indices']
80 |
81 |
82 | with st.form("gen_form", border=False):
83 | mode = st.radio(label="Mode(*)", options=['auto'] + list(map(lambda x: x.value, Mode)), horizontal=True)
84 | prompt = st.text_area(label="Prompt(*)")
85 | image = st.file_uploader(label="Reference Image", type=['jpg', 'png'])
86 | gen_models_value = [''] + list(map(lambda x: x.value, GenModel))
87 |
88 | model = st.selectbox(
89 | label="Model",
90 | options=gen_models_value
91 | )
92 |
93 | submitted = st.form_submit_button("Submit")
94 |
95 |
96 | def on_click_upscale(task_id: str, index: int):
97 | with st.spinner('Wait for completion...'):
98 | task_id, images_url, upscale_indices, vary_indices = asyncio.run(upscale(task_id=task_id, index=index))
99 | st.session_state.gen_uv_results.append(UVResult(
100 | task_id=task_id,
101 | image_url=images_url,
102 | upscale_indices=upscale_indices,
103 | vary_indices=vary_indices
104 | ))
105 |
106 |
107 | def on_click_vary(task_id: str, index: int):
108 | with st.spinner('Wait for completion...'):
109 | task_id, images_url, upscale_indices, vary_indices = asyncio.run(vary(task_id=task_id, index=index))
110 | st.session_state.gen_uv_results.append(UVResult(
111 | task_id=task_id,
112 | image_url=images_url,
113 | upscale_indices=upscale_indices,
114 | vary_indices=vary_indices
115 | ))
116 |
117 |
118 | if submitted:
119 | st.session_state.gen_result = None
120 |
121 | if submitted or st.session_state.gen_result:
122 | if image:
123 | source_col, result_col = st.columns(2)
124 | with source_col:
125 | st.text("Reference")
126 | if image:
127 | st.image(image)
128 | else:
129 | result_col = st.container()
130 |
131 | with result_col:
132 | st.text("Result")
133 | result_image = st.empty()
134 |
135 | with result_col:
136 | with st.spinner('Wait for completion...'):
137 | if not st.session_state.gen_result:
138 | st.session_state.gen_result = asyncio.run(gen(prompt, image, mode, model))
139 | task_id, image_url, upscale_indices, vary_indices = st.session_state.gen_result
140 | result_image.image(image_url)
141 | build_upscale_vary_buttons(
142 | task_id=task_id,
143 | upscale_indices=upscale_indices,
144 | vary_indices=vary_indices,
145 | on_click_upscale=on_click_upscale,
146 | on_click_vary=on_click_vary
147 | )
148 |
149 | for item in st.session_state.gen_uv_results:
150 | result: UVResult = item
151 | st.image(result.image_url)
152 | build_upscale_vary_buttons(
153 | task_id=result.task_id,
154 | upscale_indices=result.upscale_indices,
155 | vary_indices=result.vary_indices,
156 | on_click_upscale=on_click_upscale,
157 | on_click_vary=on_click_vary
158 | )
159 |
--------------------------------------------------------------------------------
/app/models/v2v-models.json:
--------------------------------------------------------------------------------
1 | [{"id": 15035, "name": "Fusion style v2", "description": "Any style by image", "cover": {"url": "https://imgo.domoai.app/ai-model/010623bf-2857-4679-a8cb-3bc60bdb9647.png", "width": 1024, "height": 1648}, "prompt_args": "--fs v2", "allowed_refer_modes": ["PROMPT_MORE"], "allowed_lip_sync": true, "tech_version": 300, "allowed_reference_image": true}, {"id": 15034, "name": "Illustration v16", "description": "Storybook style", "cover": {"url": "https://imgo.domoai.app/ai-model/9a7d4712-78ac-47f7-8826-e033ed8caea1.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v16", "allowed_refer_modes": ["VIDEO_MORE"], "allowed_lip_sync": false, "tech_version": 300, "allowed_reference_image": false}, {"id": 15033, "name": "Illustration v15", "description": "Renaissance art style", "cover": {"url": "https://imgo.domoai.app/ai-model/4aa3e571-72c8-4e18-a36f-420c7b6ccb41.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v15", "allowed_refer_modes": ["VIDEO_MORE"], "allowed_lip_sync": false, "tech_version": 300, "allowed_reference_image": false}, {"id": 15032, "name": "Anime v8", "description": "Sketch anime style", "cover": {"url": "https://imgo.domoai.app/ai-model/01f1a743-e180-4808-ac67-5fa865bd8029.png", "width": 1024, "height": 1648}, "prompt_args": "--ani v8", "allowed_refer_modes": ["VIDEO_MORE"], "allowed_lip_sync": false, "tech_version": 300, "allowed_reference_image": false}, {"id": 15031, "name": "Illustration v14", "description": "Ukiyo-e style", "cover": {"url": "https://imgo.domoai.app/ai-model/6651151b-6a5b-4d8d-8b6b-660db8c32d5f.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v14", "allowed_refer_modes": ["VIDEO_MORE"], "allowed_lip_sync": true, "tech_version": 300, "allowed_reference_image": false}, {"id": 15030, "name": "Illustration v1.2", "description": "3D cartoon style 3.0", "cover": {"url": "https://imgo.domoai.app/ai-model/6b8f9c99-60dd-4369-a3f5-8700f3f59704.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v1.2", "allowed_refer_modes": ["VIDEO_MORE"], "allowed_lip_sync": true, "tech_version": 300, "allowed_reference_image": false}, {"id": 15029, "name": "Illustration v13.1", "description": "Clay cartoon style 3.0", "cover": {"url": "https://imgo.domoai.app/ai-model/401d4cf6-fa12-40d7-a7cc-ec5982e6dd37.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v13.1", "allowed_refer_modes": ["VIDEO_MORE"], "allowed_lip_sync": true, "tech_version": 300, "allowed_reference_image": false}, {"id": 15028, "name": "Anime v5.2", "description": "Japanese anime 3.0", "cover": {"url": "https://imgo.domoai.app/ai-model/4249fe12-2456-4002-a39c-814c260e815c.png", "width": 1024, "height": 1648}, "prompt_args": "--ani v5.2", "allowed_refer_modes": ["VIDEO_MORE"], "allowed_lip_sync": true, "tech_version": 300, "allowed_reference_image": false}, {"id": 15027, "name": "Illustration v13", "description": "Clay cartoon style", "cover": {"url": "https://imgo.domoai.app/ai-model/073d2fa2-f678-415b-94fe-1cde9142d37b.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v13", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 200, "allowed_reference_image": false}, {"id": 15026, "name": "Illustration v12", "description": "Lego style", "cover": {"url": "https://imgo.domoai.app/ai-model/6c3064fd-f161-4f88-b1fa-726ab6cc29ec.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v12", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 200, "allowed_reference_image": false}, {"id": 15025, "name": "Illustration v11", "description": "American comicbook", "cover": {"url": "https://imgo.domoai.app/ai-model/c0223c6b-db06-4c09-950c-d0a5019f1c45.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v11", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 200, "allowed_reference_image": false}, {"id": 15024, "name": "Anime v7", "description": "Color pencil style", "cover": {"url": "https://imgo.domoai.app/ai-model/3b2629c6-d79e-4b21-b7bf-a5c454476007.png", "width": 1024, "height": 1648}, "prompt_args": "--ani v7", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 200, "allowed_reference_image": false}, {"id": 15023, "name": "Illustration v10", "description": "PS2 game style", "cover": {"url": "https://imgo.domoai.app/ai-model/e5394ee2-a0b7-437e-9e3d-1795d5de5125.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v10", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 200, "allowed_reference_image": false}, {"id": 15022, "name": "Illustration v9", "description": "2.5D illustration style", "cover": {"url": "https://imgo.domoai.app/ai-model/8e25862d-428f-4c7c-956e-a70eed5548c8.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v9", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 200, "allowed_reference_image": false}, {"id": 15021, "name": "Illustration v7.1", "description": "Paper art style 2.0", "cover": {"url": "https://imgo.domoai.app/ai-model/475bf829-96b0-4ceb-b3f3-ca555fb12f22.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v7.1", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 200, "allowed_reference_image": false}, {"id": 15020, "name": "Illustration v3.1", "description": "Pixel style 2.0", "cover": {"url": "https://imgo.domoai.app/ai-model/1b3b7497-4a89-4eeb-86ec-288f5c2743e0.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v3.1", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 200, "allowed_reference_image": false}, {"id": 15019, "name": "Illustration v1.1", "description": "3D cartoon style 2.0", "cover": {"url": "https://imgo.domoai.app/ai-model/d6d095be-dc29-417e-b5d8-22660fc3df3d.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v1.1", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 200, "allowed_reference_image": false}, {"id": 15018, "name": "Anime v4.1", "description": "Chinese ink painting 2.0", "cover": {"url": "https://imgo.domoai.app/ai-model/6db5f48d-8c53-4e88-abac-e4d550b641eb.png", "width": 1024, "height": 1648}, "prompt_args": "--ani v4.1", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 200, "allowed_reference_image": false}, {"id": 15017, "name": "Anime v1.1", "description": "Flat color anime style 2.0", "cover": {"url": "https://imgo.domoai.app/ai-model/5afd5dac-6843-4901-b67f-46f728261070.png", "width": 1024, "height": 1648}, "prompt_args": "--ani v1.1", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 200, "allowed_reference_image": false}, {"id": 15016, "name": "Anime v5.1", "description": "Japanese anime 2.1", "cover": {"url": "https://imgo.domoai.app/ai-model/3843c175-b640-4697-a89e-a785583c4a5c.png", "width": 1024, "height": 1648}, "prompt_args": "--ani v5.1", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 200, "allowed_reference_image": false}, {"id": 15015, "name": "Anime v6", "description": "Detail anime style 2.0", "cover": {"url": "https://imgo.domoai.app/ai-model/f59317bc-0417-46ab-8fee-72c8fbd3c8eb.png", "width": 1024, "height": 1648}, "prompt_args": "--ani v6", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 200, "allowed_reference_image": false}, {"id": 15014, "name": "Fusion Style v1", "description": "Any style by prompt", "cover": {"url": "https://imgo.domoai.app/ai-model/bcb0d19e-f88f-4fc2-85a7-9abf4624525a.png", "width": 1024, "height": 1648}, "prompt_args": "--fs v1", "allowed_refer_modes": ["PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 200, "allowed_reference_image": false}, {"id": 15012, "name": "Illustration v8", "description": "Van gogh style", "cover": {"url": "https://imgo.domoai.app/ai-model/cdd10dd9-1cf3-4b00-af4e-c8bd77362f46.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v8", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 100, "allowed_reference_image": false}, {"id": 15011, "name": "Illustration v7", "description": "Paper art style", "cover": {"url": "https://imgo.domoai.app/ai-model/ae4f9929-f3c5-44a0-a95c-05c524afa15d.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v7", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 100, "allowed_reference_image": false}, {"id": 15010, "name": "Anime v4", "description": "Chinese ink painting", "cover": {"url": "https://imgo.domoai.app/ai-model/d2a71189-1e1c-4b0c-ba52-31fc79d6bfdc.png", "width": 1024, "height": 1648}, "prompt_args": "--ani v4", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 100, "allowed_reference_image": false}, {"id": 15009, "name": "Illustration v6", "description": "Grand theft game", "cover": {"url": "https://imgo.domoai.app/ai-model/31f82d00-7ef6-4287-8544-1ec4854cbdc9.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v6", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 100, "allowed_reference_image": false}, {"id": 15008, "name": "Illustration v5", "description": "Color illustration", "cover": {"url": "https://imgo.domoai.app/ai-model/b7fd23a2-459b-47cf-b012-96899cf38f61.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v5", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 100, "allowed_reference_image": false}, {"id": 15007, "name": "Illustration v4", "description": "Storybook cartoon", "cover": {"url": "https://imgo.domoai.app/ai-model/0e3ca4ad-4400-4977-b828-41913e43759a.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v4", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 100, "allowed_reference_image": false}, {"id": 15006, "name": "Anime v3", "description": "Live anime style", "cover": {"url": "https://imgo.domoai.app/ai-model/c82a7ce4-ade0-4c7d-bfb5-4546d636e12f.png", "width": 1024, "height": 1648}, "prompt_args": "--ani v3", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 100, "allowed_reference_image": false}, {"id": 15005, "name": "Anime v2", "description": "Japanese anime style", "cover": {"url": "https://imgo.domoai.app/ai-model/5a11a3ea-2884-4630-b504-edd0b67ad239.png", "width": 1024, "height": 1648}, "prompt_args": "--ani v2", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 100, "allowed_reference_image": false}, {"id": 15004, "name": "Illustration v3", "description": "Pixel style", "cover": {"url": "https://imgo.domoai.app/ai-model/0c3c8928-e376-4462-8325-abd8f971f650.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v3", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 100, "allowed_reference_image": false}, {"id": 15003, "name": "Illustration v2", "description": "Comic style", "cover": {"url": "https://imgo.domoai.app/ai-model/8420802c-f777-4047-a65c-25ea1d4181a7.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v2", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 100, "allowed_reference_image": false}, {"id": 15002, "name": "Illustration v1", "description": "3D cartoon style", "cover": {"url": "https://imgo.domoai.app/ai-model/49ab8034-c804-47cd-8125-3810dd65310a.png", "width": 1024, "height": 1648}, "prompt_args": "--illus v1", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 100, "allowed_reference_image": false}, {"id": 15001, "name": "Anime v1", "description": "Flat color anime style", "cover": {"url": "https://imgo.domoai.app/ai-model/2accc87e-0bbd-43c3-8d64-b207747761fb.png", "width": 1024, "height": 1648}, "prompt_args": "--ani v1", "allowed_refer_modes": ["VIDEO_MORE", "PROMPT_MORE"], "allowed_lip_sync": false, "tech_version": 100, "allowed_reference_image": false}]
--------------------------------------------------------------------------------
/app/main.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import io
3 | import uuid
4 | from typing import Optional
5 |
6 | import discord
7 | from fastapi import FastAPI, Request, UploadFile, Form, HTTPException, Depends
8 | from starlette import status
9 | from starlette.responses import JSONResponse
10 |
11 | from app.cache import RedisCache, MemoryCache, Cache
12 | from app.dependencies import api_auth
13 | from app.models import GenModel, MoveModel, VideoModel, get_v2v_model_info_by_instructions
14 | from app.schema import VideoReferMode, VideoLength, TaskCacheData, TaskStatus, CreateTaskOut, \
15 | TaskCommand, TaskStateOut, AnimateLength, AnimateIntensity, Mode, VideoKey, VideoApiError
16 | from app.settings import get_settings
17 | from app.user_client import DiscordUserClient
18 |
19 | app = FastAPI()
20 |
21 | settings = get_settings()
22 |
23 |
24 | async def __did_send_interaction(
25 | wait_message_desc_keyword: str,
26 | discord_user_client: DiscordUserClient,
27 | cache: Cache,
28 | command: TaskCommand
29 | ) -> CreateTaskOut:
30 | message: discord.Message = await discord_user_client.wait_for_generating_message(
31 | embeds_desc_keyword=wait_message_desc_keyword
32 | )
33 | task_id = str(uuid.uuid4())
34 | await cache.set_message_id2task_id(message_id=str(message.id), task_id=task_id)
35 | await cache.set_task_id2data(task_id=task_id, data=TaskCacheData(
36 | command=command,
37 | status=TaskStatus.RUNNING,
38 | channel_id=str(discord_user_client.channel_id),
39 | guild_id=str(discord_user_client.guild_id),
40 | message_id=str(message.id)
41 | ))
42 | return CreateTaskOut(
43 | success=True,
44 | task_id=task_id,
45 | message_id=str(message.id)
46 | )
47 |
48 |
49 | @app.post("/v1/gen")
50 | async def gen_api(
51 | request: Request,
52 | auth=Depends(api_auth),
53 | image: Optional[UploadFile] = None,
54 | prompt: str = Form(...),
55 | mode: Optional[Mode] = Form(default=None),
56 | model: Optional[GenModel] = Form(default=None)
57 | ):
58 | discord_user_client: DiscordUserClient = request.app.state.discord_user_client
59 | if image:
60 | image_bytes = await image.read()
61 | image_file = discord.File(io.BytesIO(image_bytes), filename=image.filename)
62 | else:
63 | image_file = None
64 | interaction = await discord_user_client.gen(prompt=prompt, image=image_file, mode=mode, model=model)
65 | print(f"gen, interaction_id: {interaction.id}, interaction.nonce: {interaction.nonce}")
66 |
67 | if not interaction.successful:
68 | # TODO:
69 | return {"success": interaction.successful}
70 |
71 | result = await __did_send_interaction(
72 | wait_message_desc_keyword='Waiting to start',
73 | command=TaskCommand.GEN,
74 | cache=request.app.state.cache,
75 | discord_user_client=discord_user_client
76 | )
77 | return result
78 |
79 |
80 | @app.post("/v1/real")
81 | async def real_api(
82 | request: Request,
83 | auth=Depends(api_auth),
84 | image: UploadFile = Form(...),
85 | prompt: Optional[str] = Form(default=None),
86 | mode: Optional[Mode] = Form(default=None)
87 | ):
88 | discord_user_client: DiscordUserClient = request.app.state.discord_user_client
89 | image_bytes = await image.read()
90 | image_file = discord.File(io.BytesIO(image_bytes), filename=image.filename)
91 | interaction = await discord_user_client.real(prompt=prompt, image=image_file, mode=mode)
92 | print(f"real, interaction_id: {interaction.id}, interaction.nonce: {interaction.nonce}")
93 |
94 | if not interaction.successful:
95 | # TODO:
96 | return {"success": interaction.successful}
97 |
98 | result = await __did_send_interaction(
99 | wait_message_desc_keyword='Waiting to start',
100 | command=TaskCommand.REAL,
101 | cache=request.app.state.cache,
102 | discord_user_client=discord_user_client
103 | )
104 | return result
105 |
106 |
107 | @app.post("/v1/animate")
108 | async def animate_api(
109 | request: Request,
110 | auth=Depends(api_auth),
111 | image: UploadFile = Form(...),
112 | length: AnimateLength = Form(...),
113 | intensity: AnimateIntensity = Form(...),
114 | prompt: Optional[str] = Form(default=None),
115 | mode: Optional[Mode] = Form(default=None)
116 | ):
117 | discord_user_client: DiscordUserClient = request.app.state.discord_user_client
118 | image_bytes = await image.read()
119 | image_file = discord.File(io.BytesIO(image_bytes), filename=image.filename)
120 | interaction = await discord_user_client.animate(
121 | prompt=prompt,
122 | image=image_file,
123 | length=length,
124 | intensity=intensity,
125 | mode=mode
126 | )
127 | print(f"animate, interaction_id: {interaction.id}, interaction.nonce: {interaction.nonce}")
128 |
129 | if not interaction.successful:
130 | # TODO:
131 | return {"success": interaction.successful}
132 |
133 | result = await __did_send_interaction(
134 | wait_message_desc_keyword='Waiting to start',
135 | command=TaskCommand.ANIMATE,
136 | cache=request.app.state.cache,
137 | discord_user_client=discord_user_client
138 | )
139 | return result
140 |
141 |
142 | @app.post("/v1/upscale")
143 | async def upscale_api(
144 | request: Request,
145 | auth=Depends(api_auth),
146 | task_id: str = Form(...),
147 | index: int = Form(..., ge=1, le=4)
148 | ):
149 | discord_user_client: DiscordUserClient = request.app.state.discord_user_client
150 |
151 | cache: Cache = request.app.state.cache
152 | data = await cache.get_task_data_by_id(task_id=task_id)
153 | if not data:
154 | raise HTTPException(
155 | status_code=status.HTTP_404_NOT_FOUND,
156 | )
157 | label = f"U{index}"
158 | custom_id = data.upscale_custom_ids.get(label)
159 | if not custom_id:
160 | raise HTTPException(
161 | status_code=status.HTTP_404_NOT_FOUND,
162 | )
163 | interaction = await discord_user_client.click_button(custom_id=custom_id, message_id=int(data.message_id))
164 | print(f"upscale, interaction_id: {interaction.id}, interaction.nonce: {interaction.nonce}")
165 |
166 | if not interaction.successful:
167 | # TODO:
168 | return {"success": interaction.successful}
169 |
170 | result = await __did_send_interaction(
171 | wait_message_desc_keyword='Waiting to start',
172 | command=TaskCommand.GEN,
173 | cache=request.app.state.cache,
174 | discord_user_client=discord_user_client
175 | )
176 | return result
177 |
178 |
179 | @app.post("/v1/vary")
180 | async def vary_api(
181 | request: Request,
182 | auth=Depends(api_auth),
183 | task_id: str = Form(...),
184 | index: int = Form(..., ge=1, le=4)
185 | ):
186 | discord_user_client: DiscordUserClient = request.app.state.discord_user_client
187 |
188 | cache: Cache = request.app.state.cache
189 | data = await cache.get_task_data_by_id(task_id=task_id)
190 | if not data:
191 | raise HTTPException(
192 | status_code=status.HTTP_404_NOT_FOUND,
193 | )
194 | label = f"V{index}"
195 | custom_id = data.vary_custom_ids.get(label)
196 | if not custom_id:
197 | raise HTTPException(
198 | status_code=status.HTTP_404_NOT_FOUND,
199 | )
200 | interaction = await discord_user_client.click_button(custom_id=custom_id, message_id=int(data.message_id))
201 | print(f"vary, interaction_id: {interaction.id}, interaction.nonce: {interaction.nonce}")
202 |
203 | if not interaction.successful:
204 | # TODO:
205 | return {"success": interaction.successful}
206 |
207 | result = await __did_send_interaction(
208 | wait_message_desc_keyword='Waiting to start',
209 | command=TaskCommand.GEN,
210 | cache=request.app.state.cache,
211 | discord_user_client=discord_user_client
212 | )
213 | return result
214 |
215 |
216 | @app.post("/v1/video")
217 | async def video_api(
218 | request: Request,
219 | video: UploadFile,
220 | image: Optional[UploadFile] = None,
221 | auth=Depends(api_auth),
222 | model: VideoModel = Form(...),
223 | refer_mode: VideoReferMode = Form(...),
224 | length: VideoLength = Form(...),
225 | prompt: str = Form(...),
226 | video_key: VideoKey = Form(default=None),
227 | subject_only: bool = Form(default=None),
228 | lip_sync: bool = Form(default=None),
229 | mode: Optional[Mode] = Form(default=None),
230 | ):
231 | # size_mb = video.size / 1024.0 / 1024.0
232 | discord_user_client: DiscordUserClient = request.app.state.discord_user_client
233 | video_bytes = await video.read()
234 | video_file = discord.File(io.BytesIO(video_bytes), filename=video.filename)
235 | image_file = None
236 | if image:
237 | image_bytes = await image.read()
238 | image_file = discord.File(io.BytesIO(image_bytes), filename=image.filename)
239 |
240 | model_info = get_v2v_model_info_by_instructions(model.value)
241 | if model_info is None:
242 | return JSONResponse(
243 | status_code=status.HTTP_400_BAD_REQUEST,
244 | content={"code": VideoApiError.VIDEO_MODEL_ERROR}
245 | )
246 |
247 | if refer_mode not in model_info.allowed_refer_modes:
248 | return JSONResponse(
249 | status_code=status.HTTP_400_BAD_REQUEST,
250 | content={"error": VideoApiError.NOT_ALLOW_REFER}
251 | )
252 |
253 | if not model_info.allowed_lip_sync and lip_sync:
254 | return JSONResponse(
255 | status_code=status.HTTP_400_BAD_REQUEST,
256 | content={"error": VideoApiError.NOT_ALLOW_LIP_SYNC}
257 | )
258 |
259 | if model_info.allowed_reference_image and image is None:
260 | return JSONResponse(
261 | status_code=status.HTTP_400_BAD_REQUEST,
262 | content={"error": VideoApiError.MODEL_NEED_REFERENCE_IMAGE}
263 | )
264 |
265 | interaction = await discord_user_client.video(
266 | prompt=prompt,
267 | video=video_file,
268 | image=image_file,
269 | model=model,
270 | refer_mode=refer_mode,
271 | length=length,
272 | mode=mode,
273 | video_key=video_key,
274 | subject_only=subject_only,
275 | lip_sync=lip_sync,
276 | )
277 | print(f"video, interaction_id: {interaction.id}, interaction.nonce: {interaction.nonce}")
278 |
279 | if not interaction.successful:
280 | # TODO:
281 | return {"success": interaction.successful}
282 |
283 | result = await __did_send_interaction(
284 | wait_message_desc_keyword='Generating',
285 | command=TaskCommand.VIDEO,
286 | cache=request.app.state.cache,
287 | discord_user_client=discord_user_client
288 | )
289 | return result
290 |
291 |
292 | @app.post("/v1/move")
293 | async def move_api(
294 | request: Request,
295 | image: UploadFile,
296 | video: UploadFile,
297 | auth=Depends(api_auth),
298 | model: MoveModel = Form(...),
299 | length: VideoLength = Form(...),
300 | prompt: str = Form(...),
301 | video_key: VideoKey = Form(default=None),
302 | mode: Optional[Mode] = Form(default=None),
303 | ):
304 | # size_mb = video.size / 1024.0 / 1024.0
305 | discord_user_client: DiscordUserClient = request.app.state.discord_user_client
306 | image_bytes = await image.read()
307 | image_file = discord.File(io.BytesIO(image_bytes), filename=image.filename)
308 | video_bytes = await video.read()
309 | video_file = discord.File(io.BytesIO(video_bytes), filename=video.filename)
310 | interaction = await discord_user_client.move(
311 | prompt=prompt,
312 | image=image_file,
313 | video=video_file,
314 | model=model,
315 | length=length,
316 | mode=mode,
317 | video_key=video_key,
318 | )
319 | print(f"move, interaction_id: {interaction.id}, interaction.nonce: {interaction.nonce}")
320 |
321 | if not interaction.successful:
322 | # TODO:
323 | return {"success": interaction.successful}
324 |
325 | result = await __did_send_interaction(
326 | wait_message_desc_keyword='Generating',
327 | command=TaskCommand.MOVE,
328 | cache=request.app.state.cache,
329 | discord_user_client=discord_user_client
330 | )
331 | return result
332 |
333 |
334 | @app.get("/v1/task-data/{task_id}")
335 | async def task_data(
336 | request: Request,
337 | task_id: str,
338 | auth=Depends(api_auth)
339 | ):
340 | cache: Cache = request.app.state.cache
341 | data = await cache.get_task_data_by_id(task_id=task_id)
342 | if not data:
343 | raise HTTPException(
344 | status_code=status.HTTP_404_NOT_FOUND,
345 | )
346 | return TaskStateOut.from_cache_data(data=data)
347 |
348 |
349 | @app.on_event("startup")
350 | async def startup_event():
351 | if settings.redis_uri:
352 | redis = await RedisCache.init_redis_pool(redis_uri=settings.redis_uri)
353 | app.state.cache = RedisCache(redis=redis, prefix=settings.cache_prefix)
354 | else:
355 | app.state.cache = MemoryCache(prefix=settings.cache_prefix)
356 |
357 | discord_user_client = DiscordUserClient(
358 | guild_id=settings.discord_guild_id,
359 | channel_id=settings.discord_channel_id,
360 | application_id=settings.domoai_application_id,
361 | cache=app.state.cache,
362 | event_callback_url=settings.event_callback_url
363 | )
364 |
365 | app.state.discord_user_client = discord_user_client
366 | await discord_user_client.login(settings.discord_token)
367 | app.state.discord_start_task = asyncio.create_task(discord_user_client.connect(reconnect=True))
368 | await discord_user_client.wait_until_ready()
369 |
370 |
371 | @app.on_event("shutdown")
372 | async def shutdown_event():
373 | if app.state.discord_start_task:
374 | app.state.discord_start_task.cancel()
375 |
376 | discord_user_client: DiscordUserClient = app.state.discord_user_client
377 | if not discord_user_client.is_closed():
378 | await discord_user_client.close()
379 |
380 | cache: Cache = app.state.cache
381 | await cache.close()
382 |
--------------------------------------------------------------------------------
/app/user_client.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import re
3 | from typing import Dict, List, Optional
4 |
5 | import discord
6 | from discord import ComponentType, InteractionType, InvalidData
7 | from discord.http import Route
8 | from discord.utils import _generate_nonce
9 |
10 | from app.cache import Cache
11 | from app.event_callback import EventCallback
12 | from app.models import GenModel, MoveModel, VideoModel
13 | from app.schema import VideoReferMode, VideoLength, TaskCacheData, TaskAsset, TaskStatus, \
14 | TaskCommand, AnimateIntensity, AnimateLength, Mode, VideoKey
15 |
16 |
17 | class DiscordUserClient(discord.Client):
18 |
19 | def __init__(
20 | self,
21 | channel_id: int,
22 | guild_id: int,
23 | application_id: int,
24 | cache: Cache,
25 | event_callback_url: Optional[str],
26 | **options
27 | ):
28 | super().__init__(**options)
29 | self.event_callback = EventCallback(callback_url=event_callback_url)
30 | self.application_id = application_id
31 | self.commands: Dict[str, discord.SlashCommand] = {}
32 |
33 | self.guild_id = guild_id
34 | self.guild = None
35 |
36 | self.channel_id = channel_id
37 | self.channel = None
38 | self.cache = cache
39 |
40 | self.bot_user_id = None
41 |
42 | async def setup_hook(self):
43 | self.bot_user_id = self.user.id
44 | self.guild = await self.fetch_guild(self.guild_id)
45 | self.channel = await self.fetch_channel(self.channel_id)
46 | await self.__init_slash_commands()
47 |
48 | print(f'guild: {self.guild}')
49 | print(f'channel: {self.channel}')
50 | print(f'slash commands: {self.commands.keys()}')
51 |
52 | async def on_ready(self):
53 | print(f'Logged on as {self.user}')
54 |
55 | async def __init_slash_commands(self):
56 | commands: List[discord.SlashCommand] = [
57 | x for x in await self.guild.application_commands() if
58 | isinstance(x, discord.SlashCommand) and str(x.application_id) == str(self.application_id)
59 | ]
60 | for command in commands:
61 | self.commands[command.name] = command
62 |
63 | async def on_message(self, message: discord.Message):
64 | # if message.application_id != self.application_id:
65 | # return
66 |
67 | print(
68 | f"message: {message}, interaction_id: {message.interaction.id if message.interaction else None}, interaction.nonce: {message.interaction.nonce if message.interaction else None}")
69 |
70 | if message.content == 'ping':
71 | await message.channel.send('pong')
72 |
73 | async def on_message_edit(self, before: discord.Message, after: discord.Message):
74 | if after.author.id != self.application_id:
75 | return
76 | print(
77 | f"message edit, before {before}, after: {after}")
78 | if after.embeds:
79 | embed: discord.Embed = after.embeds[0]
80 | if embed.title.startswith('/gen'):
81 | await self.handle_gen_result(message=after)
82 | elif embed.title.startswith('/real'):
83 | await self.handle_real_result(message=after)
84 | elif embed.title == '/animate':
85 | await self.handle_animate_result(message=after)
86 | elif embed.title == '/video':
87 | await self.handle_video_result(message=after)
88 | elif embed.title == '/move':
89 | await self.handle_move_result(message=after)
90 | elif 'After:' in after.content and 'Before:' in after.content:
91 | await self.handle_video_result(message=after)
92 | elif 'Result:' in after.content and 'Image:' in after.content and 'Video:' in after.content:
93 | await self.handle_move_result(message=after)
94 |
95 | async def wait_for_generating_message(self, embeds_desc_keyword: str) -> discord.Message:
96 | def check(message: discord.Message):
97 | if not message.embeds or not message.mentions:
98 | return False
99 | if message.mentions[0].id != self.bot_user_id:
100 | return False
101 | return embeds_desc_keyword in message.embeds[0].description
102 |
103 | return await self.wait_for(
104 | 'message',
105 | check=check,
106 | timeout=20
107 | )
108 |
109 | async def handle_gen_result(
110 | self,
111 | message: discord.Message
112 | ):
113 | if not message.attachments:
114 | return
115 | attachment = message.attachments[0]
116 |
117 | task_id = await self.cache.get_task_id_by_message_id(message_id=str(message.id))
118 | if not task_id:
119 | return
120 | upscale_custom_ids = {}
121 | vary_custom_ids = {}
122 | for row in message.components:
123 | for component in row.children:
124 | if component.disabled or not component.label or not component.custom_id:
125 | continue
126 | if component.label.startswith("U"):
127 | upscale_custom_ids[component.label] = component.custom_id
128 | elif component.label.startswith("Vary"):
129 | vary_custom_ids["V1"] = component.custom_id
130 | elif component.label.startswith("V"):
131 | vary_custom_ids[component.label] = component.custom_id
132 |
133 | data = TaskCacheData(
134 | command=TaskCommand.GEN,
135 | channel_id=str(message.channel.id),
136 | guild_id=str(message.guild.id) if message.guild else None,
137 | message_id=str(message.id),
138 | images=[TaskAsset.from_attachment(attachment)],
139 | status=TaskStatus.SUCCESS,
140 | upscale_custom_ids=upscale_custom_ids,
141 | vary_custom_ids=vary_custom_ids
142 | )
143 | await self.cache.set_task_id2data(task_id=task_id, data=data)
144 | await self.event_callback.send_task_success(task_id=task_id, data=data)
145 |
146 | async def handle_real_result(
147 | self,
148 | message: discord.Message
149 | ):
150 | if not message.attachments:
151 | return
152 | attachment = message.attachments[0]
153 |
154 | task_id = await self.cache.get_task_id_by_message_id(message_id=str(message.id))
155 | if not task_id:
156 | return
157 | upscale_custom_ids = {}
158 | vary_custom_ids = {}
159 | for row in message.components:
160 | for component in row.children:
161 | if component.disabled or not component.label or not component.custom_id:
162 | continue
163 | if component.label.startswith("U"):
164 | upscale_custom_ids[component.label] = component.custom_id
165 | elif component.label.startswith("Vary"):
166 | vary_custom_ids["V1"] = component.custom_id
167 | elif component.label.startswith("V"):
168 | vary_custom_ids[component.label] = component.custom_id
169 |
170 | data = TaskCacheData(
171 | command=TaskCommand.REAL,
172 | channel_id=str(message.channel.id),
173 | guild_id=str(message.guild.id) if message.guild else None,
174 | message_id=str(message.id),
175 | images=[TaskAsset.from_attachment(attachment)],
176 | status=TaskStatus.SUCCESS,
177 | upscale_custom_ids=upscale_custom_ids,
178 | vary_custom_ids=vary_custom_ids
179 | )
180 |
181 | await self.cache.set_task_id2data(task_id=task_id, data=data)
182 | await self.event_callback.send_task_success(task_id=task_id, data=data)
183 |
184 | async def handle_video_result(
185 | self,
186 | message: discord.Message
187 | ):
188 | task_id = await self.cache.get_task_id_by_message_id(message_id=str(message.id))
189 | if not task_id:
190 | return
191 |
192 | if message.attachments:
193 | attachment = message.attachments[0]
194 | asset = TaskAsset.from_attachment(attachment)
195 | else:
196 | match = re.search(r"After:.*?(https:.*)", message.content)
197 | if not match:
198 | return
199 | video_url = match.group(1)
200 | asset = TaskAsset(
201 | url=video_url,
202 | proxy_url=video_url
203 | )
204 |
205 | data = TaskCacheData(
206 | command=TaskCommand.VIDEO,
207 | channel_id=str(message.channel.id),
208 | guild_id=str(message.guild.id) if message.guild else None,
209 | message_id=str(message.id),
210 | videos=[asset],
211 | status=TaskStatus.SUCCESS
212 | )
213 |
214 | await self.cache.set_task_id2data(task_id=task_id, data=data)
215 | await self.event_callback.send_task_success(task_id=task_id, data=data)
216 |
217 | async def handle_animate_result(
218 | self,
219 | message: discord.Message
220 | ):
221 | task_id = await self.cache.get_task_id_by_message_id(message_id=str(message.id))
222 | if not task_id:
223 | return
224 |
225 | if not message.attachments:
226 | return
227 |
228 | attachment = message.attachments[0]
229 | asset = TaskAsset.from_attachment(attachment)
230 |
231 | data = TaskCacheData(
232 | command=TaskCommand.ANIMATE,
233 | channel_id=str(message.channel.id),
234 | guild_id=str(message.guild.id) if message.guild else None,
235 | message_id=str(message.id),
236 | videos=[asset],
237 | status=TaskStatus.SUCCESS
238 | )
239 | await self.cache.set_task_id2data(task_id=task_id, data=data)
240 | await self.event_callback.send_task_success(task_id=task_id, data=data)
241 |
242 | async def handle_move_result(
243 | self,
244 | message: discord.Message
245 | ):
246 | task_id = await self.cache.get_task_id_by_message_id(message_id=str(message.id))
247 | if not task_id:
248 | return
249 |
250 | if message.attachments:
251 | attachment = message.attachments[0]
252 | asset = TaskAsset.from_attachment(attachment)
253 | else:
254 | match = re.search(r"Result:.*?(https:.*)", message.content)
255 | if not match:
256 | return
257 | video_url = match.group(1)
258 | asset = TaskAsset(
259 | url=video_url,
260 | proxy_url=video_url
261 | )
262 |
263 | data = TaskCacheData(
264 | command=TaskCommand.MOVE,
265 | channel_id=str(message.channel.id),
266 | guild_id=str(message.guild.id) if message.guild else None,
267 | message_id=str(message.id),
268 | videos=[asset],
269 | status=TaskStatus.SUCCESS
270 | )
271 | await self.cache.set_task_id2data(task_id=task_id, data=data)
272 | await self.event_callback.send_task_success(task_id=task_id, data=data)
273 |
274 | async def gen(
275 | self,
276 | prompt: str,
277 | image: Optional[discord.File] = None,
278 | mode: Optional[Mode] = None,
279 | model: Optional[GenModel] = None
280 | ) -> Optional[discord.Interaction]:
281 | command = self.commands.get('gen')
282 | if not command:
283 | return None
284 | request_prompt = prompt
285 | if mode:
286 | request_prompt += f" --{mode.value}"
287 | if model:
288 | request_prompt += f" --{model.value}"
289 | options = dict(
290 | prompt=request_prompt
291 | )
292 | if image:
293 | uploaded_image_files = await self.channel.upload_files(image)
294 | image.close()
295 | options['img2img'] = uploaded_image_files[0]
296 |
297 | interaction = await command(self.channel, **options)
298 | return interaction
299 |
300 | async def real(
301 | self,
302 | image: discord.File,
303 | prompt: Optional[str] = None,
304 | mode: Optional[Mode] = None
305 | ) -> Optional[discord.Interaction]:
306 | command = self.commands.get('real')
307 | if not command:
308 | return None
309 | uploaded_image_files = await self.channel.upload_files(image)
310 | image.close()
311 | options = dict(
312 | image=uploaded_image_files[0]
313 | )
314 | request_prompt_parts = []
315 | if prompt:
316 | request_prompt_parts.append(prompt)
317 |
318 | if mode:
319 | request_prompt_parts.append(f'--{mode.value}')
320 |
321 | if request_prompt_parts:
322 | options['prompt'] = ' '.join(request_prompt_parts)
323 |
324 | interaction = await command(self.channel, **options)
325 | return interaction
326 |
327 | async def click_button(
328 | self,
329 | custom_id: str,
330 | message_id: int
331 | ) -> Optional[discord.Interaction]:
332 | data = {
333 | "component_type": ComponentType.button.value,
334 | "custom_id": custom_id
335 | }
336 |
337 | nonce = _generate_nonce()
338 |
339 | try:
340 | payload = {
341 | 'application_id': self.application_id,
342 | 'channel_id': self.channel_id,
343 | 'data': data,
344 | 'nonce': nonce,
345 | 'session_id': '44efec80d647a97968b4c60d26d3c032',
346 | 'type': InteractionType.component.value,
347 | 'guild_id': self.guild_id,
348 | 'message_flags': 0,
349 | 'message_id': message_id
350 | }
351 | await self.http.request(Route('POST', '/interactions'), json=payload, form=[], files=None)
352 | # await self.http.interact(
353 | # type=InteractionType.component,
354 | # nonce=nonce,
355 | # data=data,
356 | # channel=self.channel,
357 | # application_id=self.application_id
358 | # )
359 | # The maximum possible time a response can take is 3 seconds,
360 | # +/- a few milliseconds for network latency
361 | # However, people have been getting errors because their gateway
362 | # disconnects while waiting for the interaction, causing the
363 | # response to be delayed until the gateway is reconnected
364 | # 12 seconds should be enough to account for this
365 | i = await self.wait_for(
366 | 'interaction_finish',
367 | check=lambda d: d.nonce == nonce,
368 | timeout=12,
369 | )
370 | return i
371 | except (asyncio.TimeoutError, asyncio.CancelledError) as exc:
372 | raise InvalidData('Did not receive a response from Discord') from exc
373 |
374 | async def move(
375 | self,
376 | image: discord.File,
377 | video: discord.File,
378 | prompt: str,
379 | model: MoveModel,
380 | length: VideoLength,
381 | mode: Optional[Mode] = None,
382 | video_key: Optional[VideoKey] = None,
383 | ) -> Optional[discord.Interaction]:
384 | command = self.commands.get('move')
385 | if not command:
386 | return None
387 | uploaded_images = await self.channel.upload_files(image)
388 | image.close()
389 |
390 | uploaded_videos = await self.channel.upload_files(video)
391 | video.close()
392 | request_prompt = f"{prompt} --{model.value} --length {length.value}"
393 | if mode:
394 | request_prompt += f' --{mode.value}'
395 | if video_key:
396 | request_prompt += f' --key {video_key.value.lower()}'
397 | options = dict(
398 | prompt=request_prompt,
399 | video=uploaded_videos[0],
400 | image=uploaded_images[0]
401 | )
402 | return await command(self.channel, **options)
403 |
404 | async def video(
405 | self,
406 | video: discord.File,
407 | image: Optional[discord.File],
408 | prompt: str,
409 | model: VideoModel,
410 | refer_mode: VideoReferMode,
411 | length: VideoLength,
412 | mode: Optional[Mode] = None,
413 | video_key: Optional[VideoKey] = None,
414 | subject_only: Optional[bool] = None,
415 | lip_sync: Optional[bool] = None,
416 | ) -> Optional[discord.Interaction]:
417 | command = self.commands.get('video')
418 | if not command:
419 | return None
420 | uploaded_videos = await self.channel.upload_files(video)
421 | uploaded_image = None
422 | if image:
423 | uploaded_image = await self.channel.upload_files(image)
424 | video.close()
425 | if refer_mode == VideoReferMode.REFER_TO_MY_PROMPT_MORE:
426 | refer_mode_value = 'p'
427 | else:
428 | refer_mode_value = 'v'
429 | request_prompt = f"{prompt} --{model.value} --refer {refer_mode_value} --length {length.value}"
430 | if mode:
431 | request_prompt += f' --{mode.value}'
432 | if video_key:
433 | request_prompt += f' --key {video_key.value.lower()}'
434 | if subject_only:
435 | request_prompt += f' --so'
436 | if lip_sync:
437 | request_prompt += f' --lips'
438 | options = dict(
439 | prompt=request_prompt,
440 | video=uploaded_videos[0]
441 | )
442 | if uploaded_image:
443 | options['image'] = uploaded_image[0]
444 | return await command(self.channel, **options)
445 |
446 | async def animate(
447 | self,
448 | image: discord.File,
449 | intensity: AnimateIntensity,
450 | length: AnimateLength,
451 | prompt: Optional[str] = None,
452 | mode: Optional[Mode] = None
453 | ) -> Optional[discord.Interaction]:
454 | command = self.commands.get('animate')
455 | if not command:
456 | return None
457 | uploaded_images = await self.channel.upload_files(image)
458 | image.close()
459 | request_prompt = f"--intensity {intensity.value} --length {length.value}"
460 | if prompt:
461 | request_prompt = f"{prompt} " + request_prompt
462 | if mode:
463 | request_prompt += f' --{mode.value}'
464 | options = dict(
465 | prompt=request_prompt,
466 | image=uploaded_images[0]
467 | )
468 | return await command(self.channel, **options)
469 |
--------------------------------------------------------------------------------