├── .github
└── workflows
│ └── pypi-publish.yml
├── LICENSE
├── README.md
├── nonebot_plugin_aidraw
├── __init__.py
├── config.py
├── count.py
├── database.py
├── draw.py
├── limit.py
└── manage.py
└── pyproject.toml
/.github/workflows/pypi-publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish Python 🐍 distributions 📦 to PyPI
2 |
3 | on: push
4 |
5 | jobs:
6 | build-n-publish:
7 | name: Build and publish Python 🐍 distributions 📦 to PyPI
8 | runs-on: ubuntu-20.04
9 | steps:
10 | - uses: actions/checkout@master
11 | - name: Set up Python 3.8
12 | uses: actions/setup-python@v1
13 | with:
14 | python-version: 3.8
15 | - name: Install pypa/build
16 | run: >-
17 | python -m
18 | pip install
19 | build
20 | --user
21 | - name: Build a binary wheel and a source tarball
22 | run: >-
23 | python -m
24 | build
25 | --sdist
26 | --wheel
27 | --outdir dist/
28 | .
29 | - name: Publish distribution 📦 to PyPI
30 | if: startsWith(github.ref, 'refs/tags')
31 | uses: pypa/gh-action-pypi-publish@master
32 | with:
33 | password: ${{ secrets.PYPI_API_TOKEN }}
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Akirami
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |

3 |
4 |

5 |
6 |
7 |
8 |
9 | # nonebot-plugin-aidraw
10 |
11 | _✨ 使用人工智能来一起画画吧! ✨_
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |

21 |
22 |
23 |
24 |
25 | ## 📖 介绍
26 |
27 | 使用第三方 API 的 NovelAI 绘图插件
28 |
29 | ## 💿 安装
30 |
31 |
32 | 使用 nb-cli 安装
33 | 在 nonebot2 项目的根目录下打开命令行, 输入以下指令即可安装
34 |
35 | nb plugin install nonebot-plugin-aidraw
36 |
37 |
38 |
39 |
40 | 使用包管理器安装
41 | 在 nonebot2 项目的插件目录下, 打开命令行, 根据你使用的包管理器, 输入相应的安装命令
42 |
43 |
44 | pip
45 |
46 | pip install nonebot-plugin-aidraw
47 |
48 |
49 | pdm
50 |
51 | pdm add nonebot-plugin-aidraw
52 |
53 |
54 | poetry
55 |
56 | poetry add nonebot-plugin-aidraw
57 |
58 |
59 | conda
60 |
61 | conda install nonebot-plugin-aidraw
62 |
63 |
64 | 打开 nonebot2 项目的 `bot.py` 文件, 在其中写入
65 |
66 | nonebot.load_plugin('nonebot_plugin_aidraw')
67 |
68 |
69 |
70 |
71 | ## ⚙️ 配置
72 |
73 | 在 nonebot2 项目的`.env`文件中添加下表中的必填配置
74 |
75 | | 配置项 | 必填 | 默认值 | 说明 |
76 | |:-----:|:----:|:----:|:----:|
77 | | AI_DRAW_API | 否 | [API](https://lulu.uedbq.xyz)| 第三方 API 的地址 |
78 | | AI_DRAW_TOKEN | 是 | 空 | 第三方 API 的 token, [点击这里获取](https://lulu.uedbq.xyz/token) |
79 | | AI_DRAW_COOLDOWN | 否 | 60 | 使用后的冷却时间, 单位: 秒 |
80 | | AI_DRAW_DAILY | 否 | 30 | 每日使用次数, 单位: 次 |
81 | | AI_DRAW_TIMEOUT | 否 | 60 | 请求 API 的超时时间, 单位: 秒 |
82 | | AI_DRAW_REVOKE | 否 | 0 |图片的撤回时间, 默认不撤回, 单位: 秒 |
83 | | AI_DRAW_MESSAGE | 否 | mix | 消息发送方式
可选 mix(图文混合)、part(图文分离)、image(仅图片) |
84 | | AI_DRAW_RANK | 否 | 10 | 标签统计排行的最大显示数量, 设置为0表示显示全部, 单位: 位 |
85 | | AI_DRAW_DATA | 否 | 自身目录 | 插件保存数据文件夹的路径 |
86 | | AI_DRAW_TEXT | 否 | \n图像种子: {seed}\n提示标签: {tags} | 文本消息模板, 支持参数有:
tags(标签), steps(迭代步数), seed(图像种子), strength(强度), scale(自由度), ntags(负面标签), 参数需以{}包裹 |
87 | | AI_DRAW_DATABASE | 否 | True | 是否使用数据库, 如果为 False, 则不启用数据库, 标签统计功能将无法使用 |
88 | ## 🎉 使用
89 | ### 指令表
90 | | 指令 | 需要@ | 范围 | 说明 |
91 | |:-----:|:----:|:----:|:----:|
92 | | 绘画/画画/画图/作图/绘图/约稿 | 否 | 群聊/私聊 | 使用描述性文本生成图画, 可用参数见[文本生成参数](#文本生成参数), 管理参数见[绘图管理参数](#绘图管理参数) |
93 | | 以图绘图/以图生图/以图制图 | 否 | 群聊/私聊 | 在基准图像上使用描述性文本生成图画, 支持回复图片消息使用,
可用参数见[图像生成参数](#图像生成参数) |
94 | | 个人标签排行/我的标签排行 | 否 | 群聊/私聊 | 查看我的所有使用过的标签的排行 |
95 | | 群标签排行/本群标签排行 | 否 | 群聊 | 查看本期所有使用过的标签的排行 |
96 |
97 | 使用示例:
98 |
99 | /绘图 描述文本 -p l --scale 12
100 |
101 | **注意**
102 |
103 | 默认情况下, 您应该在指令前加上命令前缀, 通常是 /
104 |
105 | ### 文本生成参数
106 | | 参数名 | 简写 | 全写 | 默认值 | 说明 |
107 | |:-----:|:----:|:----:|:----:|:----:|
108 | | shape | -p | --shape | Portrait | 图像的形状, 可选 Portrait(纵向)、Landscape(横向)、Square(方形)
支持缩写为 p、l、s |
109 | | scale | -c | --scale | 11 | 指示 AI 对提示的遵守程度,较大的值可以帮助 AI 更接近文本提示的整体意图 |
110 | | seed | -s | --seed | 随机 | 随机种子。在其他条件不变的情况下,相同的种子代表生成相同的图 |
111 | | steps | -t | --steps | 28 | 定义 AI 从最初创建时应优化的迭代次数 |
112 | | ntags | -n | --ntags | 默认自带 | 不需要的内容,可以列出希望 AI 避免的任何内容 |
113 |
114 | ### 图像生成参数
115 | | 参数名 | 简写 | 全写 | 默认值 | 说明 |
116 | |:-----:|:----:|:----:|:----:|:----:|
117 | | strength | -e | --strength | 0.6 | 允许 AI 改变图像的构成, 降低该值会产生更接近原始图像的效果 |
118 |
119 | ### 绘图管理参数
120 |
121 | | 参数名 | 说明 |
122 | |:-----:|:----:|
123 | | 查看白名单 | 查看白名单模式下允许的群组 |
124 | | 查看黑名单 | 查看黑名单模式下禁止的群组 |
125 | | 添加白名单 + 群号 | 将群组添加到白名单中, 群号以逗号分隔 |
126 | | 添加黑名单 + 群号 | 将群组添加到黑名单中, 群号以逗号分隔 |
127 | | 删除白名单 + 群号 | 将群组从白名单中移除, 群号以逗号分隔 |
128 | | 删除黑名单 + 群号 | 将群组从黑名单中移除, 群号以逗号分隔 |
129 | | 切换白名单 | 切换到白名单模式, 只有白名单中的群组才允许使用 |
130 | | 切换黑名单 | 切换到黑名单模式, 只有黑名单中的群组才禁止使用 |
131 | | 添加屏蔽词 + 屏蔽内容 | 添加到屏蔽词过滤器中, 屏蔽词以逗号分隔 |
132 | | 删除屏蔽词 + 屏蔽内容 | 从屏蔽词过滤器中删除, 屏蔽词以逗号分隔 |
133 | | 查看屏蔽词 | 查看当前的屏蔽词 |
134 |
135 | 使用示例:
136 |
137 | /绘图添加黑名单 123456
--------------------------------------------------------------------------------
/nonebot_plugin_aidraw/__init__.py:
--------------------------------------------------------------------------------
1 | from . import draw, count
2 |
--------------------------------------------------------------------------------
/nonebot_plugin_aidraw/config.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Literal
3 |
4 | from nonebot import get_driver
5 | from pydantic import BaseModel, Extra, Field, validator
6 |
7 |
8 | class Config(BaseModel, extra=Extra.ignore):
9 | ai_draw_api: str = "https://lulu.uedbq.xyz"
10 | ai_draw_token: str = ""
11 | ai_draw_cooldown: int = 60
12 | ai_draw_daily: int = 30
13 | ai_draw_timeout: int = 60
14 | ai_draw_revoke: int = 0
15 | ai_draw_message: Literal["mix", "part", "image"] = "mix"
16 | ai_draw_rank: int = Field(default=10, ge=0)
17 | ai_draw_data: Path = Path(__file__).parent
18 | ai_draw_text: str = "\n图像种子: {seed}\n提示标签: {tags}"
19 | ai_draw_database: bool = True
20 |
21 | @validator("ai_draw_data")
22 | def check_path(cls, v: Path):
23 | if v.exists() and not v.is_dir():
24 | raise ValueError("必须是有效的文件目录")
25 | return v
26 |
27 |
28 | plugin_config = Config.parse_obj(get_driver().config)
29 |
30 | api_url = plugin_config.ai_draw_api
31 | token = plugin_config.ai_draw_token
32 | cooldown_time = plugin_config.ai_draw_cooldown
33 | daily_times = plugin_config.ai_draw_daily
34 | timeout = plugin_config.ai_draw_timeout
35 | revoke_time = plugin_config.ai_draw_revoke
36 | message_mode = plugin_config.ai_draw_message
37 | rank_number = plugin_config.ai_draw_rank
38 | data_path = plugin_config.ai_draw_data / "data"
39 | data_path.mkdir(parents=True, exist_ok=True)
40 | save_path = plugin_config.ai_draw_data / "save"
41 | save_path.mkdir(parents=True, exist_ok=True)
42 | text_templet = plugin_config.ai_draw_text.replace("\\\\", "\\")
43 | enable_database = plugin_config.ai_draw_database
44 |
--------------------------------------------------------------------------------
/nonebot_plugin_aidraw/count.py:
--------------------------------------------------------------------------------
1 | from collections import Counter
2 |
3 | from nonebot import on_command
4 | from nonebot.adapters.onebot.v11 import GroupMessageEvent, MessageEvent
5 |
6 | from .config import rank_number
7 | from .database import DrawCount
8 |
9 |
10 | def get_rank(count: Counter) -> str:
11 | counts = count.most_common(rank_number or None)
12 | msg = "".join(
13 | f"\n第{i}位: 『{tag}』 ({times}次)" for i, (tag, times) in enumerate(counts, 1)
14 | )
15 | return msg or "还没有任何记录哦"
16 |
17 |
18 | user_count = on_command("个人标签排行", aliases={"我的标签排行"})
19 |
20 |
21 | @user_count.handle()
22 | async def user_count_rank(event: MessageEvent):
23 | count = await DrawCount.get_user_count(event.user_id)
24 | msg = get_rank(count)
25 | await user_count.send(msg, at_sender=True)
26 |
27 |
28 | group_count = on_command("本群标签排行", aliases={"群标签排行"})
29 |
30 |
31 | @group_count.handle()
32 | async def group_count_rank(event: GroupMessageEvent):
33 | count = await DrawCount.get_group_count(event.group_id)
34 | msg = get_rank(count)
35 | await group_count.send(msg, at_sender=True)
36 |
--------------------------------------------------------------------------------
/nonebot_plugin_aidraw/database.py:
--------------------------------------------------------------------------------
1 | import re
2 | from collections import Counter
3 | from pathlib import Path
4 | from typing import Any, Dict, Literal, Set
5 |
6 | from nonebot import get_driver
7 | from pydantic import BaseModel, Field, root_validator
8 | from tortoise import Tortoise, fields
9 | from tortoise.models import Model
10 | from tortoise.queryset import QuerySet
11 |
12 | from .config import data_path
13 |
14 | try:
15 | import ujson as json
16 | except ModuleNotFoundError:
17 | import json
18 |
19 |
20 | class Setting(BaseModel):
21 | type: Literal["blacklist", "whitelist"] = "blacklist"
22 | """名单类型"""
23 | blacklist: Set[int] = Field(default_factory=set)
24 | """黑名单"""
25 | whitelist: Set[int] = Field(default_factory=set)
26 | """白名单"""
27 | shield: Set[str] = Field(default_factory=set)
28 | """过滤词"""
29 |
30 | __file_path: Path = data_path / "setting.json"
31 |
32 | @property
33 | def file_path(self) -> Path:
34 | return self.__class__.__file_path
35 |
36 | @root_validator(pre=True)
37 | def init(cls, values: Dict[str, Any]) -> Dict[str, Any]:
38 | if cls.__file_path.is_file():
39 | return json.loads(cls.__file_path.read_text("utf-8"))
40 | return values
41 |
42 | def save(self) -> None:
43 | self.file_path.write_text(self.json(), encoding="utf-8")
44 |
45 |
46 | setting = Setting()
47 |
48 |
49 | class DrawCount(Model):
50 | uid: int = fields.IntField(index=True)
51 | """用户QQ"""
52 | gid: int = fields.IntField(index=True)
53 | """群号"""
54 | count: Dict[str, int] = fields.JSONField(default=dict)
55 | """计数字典"""
56 |
57 | @classmethod
58 | async def count_tags(cls, uid: int, gid: int, tags: str) -> None:
59 | tag = re.sub(r"[{[\]}]", "", tags)
60 | tag_count = Counter(tag.split(","))
61 | counter, _ = await cls.get_or_create(uid=uid, gid=gid)
62 | counter.count = dict(Counter(counter.count) + Counter(tag_count))
63 | await counter.save()
64 |
65 | @classmethod
66 | async def _get_count(cls, queryset: QuerySet) -> Counter:
67 | counts = queryset.values_list("count", flat=True)
68 | counter = Counter()
69 | async for count in counts:
70 | counter += Counter(count)
71 | return counter
72 |
73 | @classmethod
74 | async def get_user_count(cls, uid: int) -> Counter:
75 | return await cls._get_count(cls.filter(uid=uid))
76 |
77 | @classmethod
78 | async def get_group_count(cls, gid: int) -> Counter:
79 | return await cls._get_count(cls.filter(gid=gid))
80 |
81 |
82 | driver = get_driver()
83 |
84 |
85 | @driver.on_startup
86 | async def init():
87 | sqlite_path = data_path / "aidraw.sqlite"
88 |
89 | config = {
90 | "connections": {
91 | "aidraw": {
92 | "engine": "tortoise.backends.sqlite",
93 | "credentials": {"file_path": sqlite_path},
94 | },
95 | },
96 | "apps": {
97 | "aidraw": {
98 | "models": [__name__],
99 | "default_connection": "aidraw",
100 | }
101 | },
102 | }
103 |
104 | await Tortoise.init(config)
105 | await Tortoise.generate_schemas()
106 |
107 |
108 | @driver.on_shutdown
109 | async def finish():
110 | await Tortoise.close_connections()
111 |
--------------------------------------------------------------------------------
/nonebot_plugin_aidraw/draw.py:
--------------------------------------------------------------------------------
1 | import base64
2 | from argparse import Namespace
3 | from io import BytesIO
4 | from typing import List, Union
5 | from urllib.parse import urljoin
6 |
7 | import httpx
8 | from httpx import TimeoutException
9 | from nonebot import on_shell_command
10 | from nonebot.adapters.onebot.v11 import (
11 | Bot,
12 | GroupMessageEvent,
13 | Message,
14 | MessageEvent,
15 | MessageSegment,
16 | )
17 | from nonebot.adapters.onebot.v11.helpers import (
18 | Cooldown,
19 | CooldownIsolateLevel,
20 | autorevoke_send,
21 | extract_image_urls,
22 | )
23 | from nonebot.exception import ParserExit
24 | from nonebot.log import logger
25 | from nonebot.matcher import Matcher
26 | from nonebot.params import Arg, ShellCommandArgs
27 | from nonebot.rule import ArgumentParser
28 | from nonebot.typing import T_State
29 | from PIL import Image, UnidentifiedImageError
30 |
31 | from .config import *
32 | from .database import DrawCount
33 | from .limit import daily_limiter, limiter
34 | from .manage import group_checker, group_manager, shield_filter, shield_manager
35 |
36 | try:
37 | import ujson as json
38 | except ImportError:
39 | import json
40 |
41 | TAGS_PROMPT = "请输入描述性的单词或短句"
42 |
43 |
44 | cooldown = Cooldown(
45 | cooldown=cooldown_time, prompt="AI绘图冷却中……", isolate_level=CooldownIsolateLevel.USER
46 | )
47 |
48 |
49 | async def get_tags(state: T_State, tags: str = Arg()):
50 | state["tags"] = tags
51 |
52 |
53 | async def filter_tags(event: MessageEvent, matcher: Matcher, state: T_State):
54 | filter_tags, state["tags"] = shield_filter(state["tags"])
55 | msg = f"正在努力绘图中……(今日剩余{limiter.last(event.user_id)}次)"
56 | if filter_tags:
57 | msg += f"\n已过滤屏蔽词: {filter_tags}"
58 | await matcher.send(msg, at_sender=True)
59 |
60 |
61 | async def count_tags(event: MessageEvent, state: T_State):
62 | if enable_database:
63 | await DrawCount.count_tags(
64 | event.user_id,
65 | event.group_id if isinstance(event, GroupMessageEvent) else 0,
66 | state["tags"],
67 | )
68 |
69 |
70 | async def send_msg(
71 | bot: Bot,
72 | event: MessageEvent,
73 | message: Union[List[Message], Message],
74 | ):
75 | if isinstance(message, Message):
76 | message = [message]
77 | for msg in message:
78 | if revoke_time:
79 | await autorevoke_send(
80 | bot, event, msg, revoke_interval=revoke_time, at_sender=True
81 | )
82 | else:
83 | await bot.send(event, msg, at_sender=True)
84 | limiter.increase(event.user_id)
85 |
86 |
87 | novel_parser = ArgumentParser()
88 | novel_parser.add_argument("tags", default="", nargs="*", help="描述标签")
89 | novel_parser.add_argument("-p", "--shape", default="", help="画布形状")
90 | novel_parser.add_argument("-c", "--scale", type=float, help="规模")
91 | novel_parser.add_argument("-s", "--seed", type=int, help="种子")
92 | novel_parser.add_argument("-t", "--steps", type=int, help="步骤")
93 | novel_parser.add_argument("-n", "--ntags", default="", nargs="*", help="负面标签")
94 |
95 |
96 | ai_novel = on_shell_command(
97 | "绘画",
98 | aliases={"画画", "画图", "作图", "绘图", "约稿"},
99 | parser=novel_parser,
100 | rule=group_checker,
101 | handlers=[shield_manager, group_manager],
102 | )
103 |
104 |
105 | @ai_novel.handle()
106 | async def _(args: ParserExit = ShellCommandArgs()):
107 | await ai_novel.finish(args.message)
108 |
109 |
110 | @ai_novel.handle([cooldown, daily_limiter()])
111 | async def novel_draw(
112 | matcher: Matcher,
113 | state: T_State,
114 | args: Namespace = ShellCommandArgs(),
115 | ):
116 | shape = args.shape.lower()
117 | shape_list = ["landscape", "portrait", "square"]
118 |
119 | for s in shape_list:
120 | if s.startswith(shape):
121 | shape = s
122 | break
123 |
124 | if shape not in shape_list:
125 | await ai_novel.finish("shape 的输入值不正确, 应为 landscape, portrait 或 square")
126 |
127 | args.shape = shape.capitalize()
128 | args.tags = " ".join(args.tags)
129 | args.ntags = " ".join(args.ntags)
130 | state["args"] = args
131 | if args.tags:
132 | matcher.set_arg("tags", Message(args.tags))
133 |
134 |
135 | ai_novel.got("tags", TAGS_PROMPT)(get_tags)
136 |
137 | ai_novel.handle()(filter_tags)
138 |
139 | ai_novel.handle()(count_tags)
140 |
141 |
142 | @ai_novel.handle()
143 | async def novel_draw_handle(bot: Bot, event: MessageEvent, state: T_State):
144 | args = state["args"]
145 | args.tags = state["tags"]
146 |
147 | try:
148 | async with httpx.AsyncClient() as client:
149 | res = await client.get(
150 | urljoin(api_url, "got_image"),
151 | params={"token": token, **{k: v for k, v in vars(args).items() if v}},
152 | timeout=timeout,
153 | )
154 | except TimeoutException:
155 | await ai_novel.finish("绘图请求超时, 请稍后重试")
156 |
157 | if res.is_error:
158 | logger.error(f"{res.url} {res.status_code}")
159 | await ai_novel.finish("出现意外的网络错误")
160 | try:
161 | info = Image.open(BytesIO(res.content)).info
162 | except UnidentifiedImageError:
163 | await ai_novel.finish("API 返回图像异常, 请稍后重试")
164 | if not info:
165 | await ai_novel.finish("token失效, 请更换token后重试")
166 | image = "\n" + MessageSegment.image(res.content)
167 | comment = json.loads(info["Comment"])
168 | text = Message(
169 | text_templet.format(
170 | **{
171 | "tags": info["Description"],
172 | "steps": comment["steps"],
173 | "seed": comment["seed"],
174 | "strength": comment["strength"],
175 | "scale": comment["scale"],
176 | "ntags": comment["uc"],
177 | }
178 | )
179 | )
180 | if message_mode == "image":
181 | msg = image
182 | elif message_mode == "part":
183 | msg = [image, text]
184 | else:
185 | msg = image + text
186 | await send_msg(bot, event, msg)
187 |
188 |
189 | image_parser = ArgumentParser()
190 | image_parser.add_argument("tags", default="", nargs="*", help="描述标签")
191 | image_parser.add_argument(
192 | "-e", "--strength", type=float, help="允许 AI 改变图像的构成, 降低该值会产生更接近原始图像的效果"
193 | )
194 |
195 | ai_image = on_shell_command(
196 | "以图绘图", aliases={"以图生图", "以图作图", "以图制图"}, parser=image_parser, rule=group_checker
197 | )
198 |
199 |
200 | @ai_image.handle([cooldown, daily_limiter()])
201 | async def image_draw(
202 | event: MessageEvent,
203 | matcher: Matcher,
204 | state: T_State,
205 | args: Namespace = ShellCommandArgs(),
206 | ):
207 | message = reply.message if (reply := event.reply) else event.message
208 |
209 | state["args"] = args
210 |
211 | if imgs := message["image"]:
212 | matcher.set_arg("imgs", Message(imgs))
213 |
214 | if args.tags:
215 | args.tags = " ".join(args.tags)
216 | matcher.set_arg("tags", Message(args.tags))
217 |
218 |
219 | @ai_image.got("imgs", prompt="请发送基准图片")
220 | async def get_image(state: T_State, imgs: Message = Arg()):
221 | urls = extract_image_urls(imgs)
222 | if not urls:
223 | await ai_image.reject("没有找到图片, 请重新发送")
224 | async with httpx.AsyncClient() as client:
225 | res = await client.get(urls[0])
226 | if res.is_error:
227 | await ai_image.finish("获取图片失败, 请更换图片重试")
228 | base_img = Image.open(BytesIO(res.content)).convert("RGB")
229 |
230 | if base_img.width > base_img.height:
231 | state["shape"] = "Landscape"
232 | elif base_img.width < base_img.height:
233 | state["shape"] = "Portrait"
234 | else:
235 | state["shape"] = "Square"
236 |
237 | state["image_data"] = BytesIO()
238 | base_img.save(state["image_data"], format="JPEG")
239 |
240 |
241 | ai_image.got("tags", TAGS_PROMPT)(get_tags)
242 |
243 | ai_image.handle()(filter_tags)
244 |
245 | ai_image.handle()(count_tags)
246 |
247 |
248 | @ai_image.handle()
249 | async def image_draw_handle(bot: Bot, event: MessageEvent, state: T_State):
250 | args = state["args"]
251 |
252 | try:
253 | async with httpx.AsyncClient() as client:
254 | res = await client.post(
255 | urljoin(api_url, "got_image2image"),
256 | data=base64.b64encode(state["image_data"].getvalue()), # type: ignore
257 | params={
258 | "token": token,
259 | "tags": state["tags"],
260 | "shape": state["shape"],
261 | "strength": args.strength or 0.6,
262 | },
263 | timeout=timeout,
264 | )
265 | except TimeoutException:
266 | await ai_novel.finish("绘图请求超时, 请稍后重试")
267 |
268 | if res.is_error:
269 | logger.error(f"{res.url} {res.status_code}")
270 | await ai_novel.finish("出现意外的网络错误")
271 | msg = "\n" + MessageSegment.image(res.content)
272 | await send_msg(bot, event, msg)
273 |
--------------------------------------------------------------------------------
/nonebot_plugin_aidraw/limit.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from datetime import datetime
3 |
4 | from nonebot.adapters.onebot.v11 import MessageEvent
5 | from nonebot.matcher import Matcher
6 | from nonebot.params import Depends
7 | from pydantic import BaseModel, Field
8 |
9 | from .config import daily_times
10 |
11 |
12 | class DailyLimiter(BaseModel):
13 | max: int
14 | day: int = Field(default=0, init=False)
15 | count: defaultdict = Field(defaultdict(int), init=False)
16 |
17 | def last(self, key: int) -> int:
18 | return self.max - self.count[key]
19 |
20 | def increase(self, key: int) -> None:
21 | self.count[key] += 1
22 |
23 | def check(self, key: int) -> bool:
24 | today = datetime.now().day
25 | if self.day != today:
26 | self.day = today
27 | self.count.clear()
28 | return self.count[key] < self.max
29 |
30 |
31 | limiter = DailyLimiter(max=daily_times)
32 |
33 |
34 | def daily_limiter():
35 | async def _daily_limiter(matcher: Matcher, event: MessageEvent):
36 | if not limiter.check(event.user_id):
37 | await matcher.finish("今日画图次数已用完, 明天再来吧~")
38 |
39 | return Depends(_daily_limiter)
40 |
--------------------------------------------------------------------------------
/nonebot_plugin_aidraw/manage.py:
--------------------------------------------------------------------------------
1 | from argparse import Namespace
2 | from typing import List, Literal, Set, Tuple, Union
3 |
4 | from nonebot.adapters.onebot.v11 import (
5 | Bot,
6 | GroupMessageEvent,
7 | Message,
8 | MessageEvent,
9 | PrivateMessageEvent,
10 | )
11 | from nonebot.matcher import Matcher
12 | from nonebot.params import ShellCommandArgs
13 | from nonebot.permission import SUPERUSER
14 |
15 | from .database import setting
16 |
17 | add_word = {"添加", "增加", "设置"}
18 | del_word = {"删除", "移除", "解除"}
19 | see_word = {"查看", "检查"}
20 | change_word = {"切换", "管理"}
21 | shield_word = {"屏蔽词", "过滤词"}
22 |
23 |
24 | def parm_trim(parm: List[str]) -> List[str]:
25 | return " ".join(parm).replace(",", ",").split(",")
26 |
27 |
28 | async def group_checker(
29 | bot: Bot, event: Union[GroupMessageEvent, PrivateMessageEvent]
30 | ) -> bool:
31 | if await SUPERUSER(bot, event) or isinstance(event, PrivateMessageEvent):
32 | return True
33 | group_list: Set[int] = getattr(setting, setting.type)
34 | check = event.group_id in group_list
35 | return check if setting.type == "whitelist" else not check
36 |
37 |
38 | def handle_namelist(
39 | action: Literal["add", "del"],
40 | type_: Literal["blacklist", "whitelist"],
41 | groups: Set[int],
42 | ) -> str:
43 | group_list: Set[int] = getattr(setting, type_)
44 | if action == "add":
45 | group_list.update(groups)
46 | _mode = "添加"
47 | elif action == "del":
48 | group_list.difference_update(groups)
49 | _mode = "删除"
50 | setattr(setting, type_, group_list)
51 | setting.save()
52 | _type = "黑" if type_ == "blacklist" else "白"
53 | return f"已{_mode} {len(groups)} 个{_type}名单: {','.join(map(str, groups))}"
54 |
55 |
56 | async def group_manager(
57 | bot: Bot,
58 | event: MessageEvent,
59 | matcher: Matcher,
60 | args: Namespace = ShellCommandArgs(),
61 | ):
62 | if not await SUPERUSER(bot, event):
63 | matcher.skip()
64 |
65 | manage_type, *groups = args.tags
66 | action, class_, group = manage_type[:2], manage_type[2:5], manage_type[5:]
67 |
68 | if class_ == "黑名单":
69 | type_ = "blacklist"
70 | elif class_ == "白名单":
71 | type_ = "whitelist"
72 | else:
73 | matcher.skip()
74 |
75 | if action in add_word:
76 | action = "add"
77 | elif action in del_word:
78 | action = "del"
79 | elif action in see_word:
80 | group_list = getattr(setting, type_)
81 | msg = (
82 | f"当前{class_}: {','.join(map(str, group_list))}"
83 | if group_list
84 | else f"当前没有{class_}"
85 | )
86 | await matcher.finish(msg)
87 | elif action in change_word:
88 | setting.type = type_
89 | setting.save()
90 | await matcher.finish(f"已切换为 {class_} 模式")
91 | else:
92 | matcher.skip()
93 |
94 | groups.insert(0, group)
95 | groups = parm_trim(groups)
96 |
97 | msg = handle_namelist(
98 | action, type_, {int(group) for group in groups if group.strip().isdigit()}
99 | )
100 | await matcher.finish(msg)
101 |
102 |
103 | def shield_filter(tags: Message) -> Tuple[str, str]:
104 | tag_list = str(tags).lower().replace(",", ",").split(",")
105 | tag_set = {tag.strip() for tag in tag_list}
106 | filter_tags = ",".join(tag_set & setting.shield)
107 | safe_tags = ",".join(tag_set - setting.shield)
108 | return filter_tags, safe_tags
109 |
110 |
111 | async def shield_manager(
112 | bot: Bot,
113 | event: MessageEvent,
114 | matcher: Matcher,
115 | args: Namespace = ShellCommandArgs(),
116 | ):
117 | if not await SUPERUSER(bot, event):
118 | matcher.skip()
119 |
120 | manage_type, *words = args.tags
121 | action, class_, word = manage_type[:2], manage_type[2:5], manage_type[5:]
122 |
123 | if class_ not in shield_word:
124 | matcher.skip()
125 |
126 | words.insert(0, word)
127 | words = parm_trim(words)
128 |
129 | if action in add_word:
130 | setting.shield.update(word.strip() for word in words)
131 | elif action in del_word:
132 | setting.shield.difference_update(word.strip() for word in words)
133 | elif action in see_word:
134 | msg = f"当前屏蔽词: {','.join(setting.shield)}" if setting.shield else "当前没有设置屏蔽词"
135 | await matcher.finish(msg)
136 |
137 | setting.save()
138 | await matcher.finish(f"已{action}屏蔽词: {','.join(words)}")
139 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "nonebot-plugin-aidraw"
3 | version = "0.7.1"
4 | description = "NoneBot2 plugin for use AI to draw image"
5 | authors = [
6 | {name = "Akirami", email = "Akiramiaya@outlook.com"},
7 | ]
8 | license = {text = "MIT"}
9 | dependencies = ["nonebot2>=2.0.0rc1", "nonebot-adapter-onebot>=2.1.3", "httpx>=0.23.0", "Pillow<10.0.0,>=9.2.0", "tortoise-orm[aiosqlite]>=0.19.2"]
10 | requires-python = ">=3.8"
11 | readme = "README.md"
12 |
13 | [project.urls]
14 | Homepage = "https://github.com/A-kirami/nonebot-plugin-aidraw"
15 | Repository = "https://github.com/A-kirami/nonebot-plugin-aidraw"
16 |
17 | [build-system]
18 | requires = ["pdm-pep517>=0.12.0"]
19 | build-backend = "pdm.pep517.api"
20 |
--------------------------------------------------------------------------------