├── .gitignore
├── assets
├── strawberry.png
└── strawberry_preview.png
├── readme.md
├── requirements.txt
├── strawberry
├── __init__.py
├── connection
│ ├── __init__.py
│ ├── stream_connection.py
│ └── voice_connection.py
├── gateway.py
├── packetizers
│ ├── __init__.py
│ ├── audio_packetizer.py
│ ├── base_packetizer.py
│ └── h264_packetizer.py
├── sources
│ ├── __init__.py
│ ├── h264_source.py
│ └── opus_source.py
├── streamer.py
└── utils.py
├── strawberry_config.toml.example
└── strawberry_yum.py
/.gitignore:
--------------------------------------------------------------------------------
1 | strawberry_config.toml
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | share/python-wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .nox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | *.py,cover
52 | .hypothesis/
53 | .pytest_cache/
54 | cover/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | .pybuilder/
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | # For a library or package, you might want to ignore these files since the code is
89 | # intended to run in multiple environments; otherwise, check them in:
90 | # .python-version
91 |
92 | # pipenv
93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
96 | # install all needed dependencies.
97 | #Pipfile.lock
98 |
99 | # poetry
100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
101 | # This is especially recommended for binary packages to ensure reproducibility, and is more
102 | # commonly ignored for libraries.
103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
104 | #poetry.lock
105 |
106 | # pdm
107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
108 | #pdm.lock
109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
110 | # in version control.
111 | # https://pdm.fming.dev/#use-with-ide
112 | .pdm.toml
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
--------------------------------------------------------------------------------
/assets/strawberry.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/justfoolingaround/strawberry/dcbb7a36adff520c6ac7b4669d137377e338c74a/assets/strawberry.png
--------------------------------------------------------------------------------
/assets/strawberry_preview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/justfoolingaround/strawberry/dcbb7a36adff520c6ac7b4669d137377e338c74a/assets/strawberry_preview.png
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 |
Strawberry
2 | A Discord video and audio streaming client designed for user-bots (aka self-bots.)
3 |
4 |
5 | You will most likely get banned for using this with your user token.
6 |
7 | Requirements
8 |
9 | - Python 3.11+
10 | - Mostly because the `match-case` syntax is lovely.
11 | - `aiohttp`, `PyNaCl` (`libsodium`), `toml` (install from requirements.txt)
12 | - `ffmpeg` and `ffprobe` in PATH.
13 |
14 | Usage
15 |
16 | ```sh
17 | # Configure the strawberry_config.toml first and then.
18 | $ py strawberry_yum.py "{file path or url to stream}"
19 | # or just stream via yt-dlp
20 | $ py strawberry_yum.py "{yt-dlp supported url to stream}" --yt-dlp
21 | ```
22 |
23 |
24 | What Strawberry can do?
25 |
26 | - Auto-infer audio and video information from the source (through a layer of abstraction.)
27 | - If you try a video source, and it contains audio and subtitles, the client will try to embed both in the stream.
28 | - If only audio, and given that the connection is not a stream connection, the client will just stream the audio.
29 | - Stream both audio and video to the said voice channel.
30 | - Streaming can be done at **any** video quality without any Nitro necessary. (Audio quality has not yet been tested.)
31 | - If a stream is not initiated, the video stream will open the user's video, not stream.
32 | - Pause can be achieved by using a `threading.Event` which will then hold the stream in place.
33 | - Listen in on your conversations and watch your streams.
34 | - The max transmission unit (mtu) for Discord voice is `1200 Bytes`. This means that you can do the corresponding UDP read to get the packet.
35 | - Each of these packets have a header containing the `SSRC`, timestamp and the sequence that they belong to.
36 | - This means that this client may also be used to mirror your streams.
37 | - This hurts the connection so it is suggested to keep the user-bot server-deafened.
38 | - Your favorite Discord music bots also do this because it hurts the connection. They don't care about your privacy, it is a bug labelled as a feature.
39 | - Be hosted, away from your home.
40 | - The source encoding and packetization code can be extracted and be hosted as one would host `Lavalink`.
41 | - After all, all you would need would be the UDP server's IP address and the secret key.
42 | - Make sure you identify with your VPS' IP address first and then send it the secret key.
43 | - This means that a single VPS can be used to effectively control a series of user-bots' streams.
44 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | aiohttp==3.9.0
2 | PyNaCl==1.3.0
3 | toml==0.10.2
--------------------------------------------------------------------------------
/strawberry/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/justfoolingaround/strawberry/dcbb7a36adff520c6ac7b4669d137377e338c74a/strawberry/__init__.py
--------------------------------------------------------------------------------
/strawberry/connection/__init__.py:
--------------------------------------------------------------------------------
1 | from .stream_connection import StreamConnection
2 | from .voice_connection import UDPConnection, VoiceConnection
3 |
4 | __all__ = ["StreamConnection", "VoiceConnection", "UDPConnection"]
5 |
--------------------------------------------------------------------------------
/strawberry/connection/stream_connection.py:
--------------------------------------------------------------------------------
1 | """
2 | Strawberry Stream Connection
3 | ============================
4 |
5 | A stream connection is practically a voice connection
6 | that depends upon two other gateway events.
7 |
8 | - STREAM_CREATE
9 |
10 | Provides all the necessary information to start a stream connection.
11 | The session_id must be derived from the voice connection where
12 | the stream connection takes place.
13 |
14 | - STREAM_SERVER_UPDATE
15 |
16 | Provides the endpoint and the token to connect to the voice server.
17 |
18 | Similar to the voice connection, the stream connection also
19 | has an underlying UDP connection to send and receive stream packets.
20 |
21 | The audio sent in the stream is not the same as the audio sent
22 | in the voice connection.
23 | """
24 |
25 | import base64
26 | import typing
27 |
28 | from .voice_connection import VoiceConnection, VoiceOpCodes
29 |
30 | if typing.TYPE_CHECKING:
31 | from strawberry.gateway import DiscordGateway
32 |
33 |
34 | class StreamConnection(VoiceConnection):
35 | def __init__(self, *args, stream_key: str, rtc_server_id: str, **kwargs):
36 | super().__init__(*args, **kwargs)
37 |
38 | self.logger = self.logger.getChild("stream")
39 | self.stream_key = stream_key
40 | self.server_id = rtc_server_id
41 |
42 | async def set_speaking(self, speaking: bool):
43 | self.ensure_ready()
44 | return await self.ws.send_json(
45 | {
46 | "op": VoiceOpCodes.SPEAKING,
47 | "d": {
48 | "speaking": 2 if speaking else 0,
49 | "delay": 0,
50 | "ssrc": self.ssrc,
51 | },
52 | }
53 | )
54 |
55 | async def set_preview(
56 | self,
57 | gateway: "DiscordGateway",
58 | preview: bytes,
59 | preview_type: str = "image/jpeg",
60 | ):
61 | if self.stream_key is None:
62 | raise ValueError("Stream key for the stream connection is not set")
63 |
64 | async with self.session.post(
65 | f"{gateway.DISCORD_API_ENDPOINT}/streams/{self.stream_key}/preview",
66 | headers={
67 | "Authorization": gateway.token,
68 | },
69 | json={
70 | "thumbnail": f"data:{preview_type};base64,{base64.b64encode(preview).decode('utf-8')}"
71 | },
72 | ) as response:
73 | return response.status == 204
74 |
75 | @classmethod
76 | def from_voice_connection(
77 | cls,
78 | voice_conn: VoiceConnection,
79 | *,
80 | stream_key: str,
81 | rtc_server_id: str,
82 | rtc_server_endpoint: str,
83 | rtc_server_token: str,
84 | ):
85 | return cls(
86 | voice_conn.session,
87 | channel_id=voice_conn.channel_id,
88 | user_id=voice_conn.user_id,
89 | session_id=voice_conn.session_id,
90 | guild_id=voice_conn.guild_id,
91 | stream_key=stream_key,
92 | endpoint=rtc_server_endpoint,
93 | token=rtc_server_token,
94 | rtc_server_id=rtc_server_id,
95 | )
96 |
--------------------------------------------------------------------------------
/strawberry/connection/voice_connection.py:
--------------------------------------------------------------------------------
1 | """
2 | Strawberry Voice Connection
3 | ===========================
4 |
5 | Establishes a proper websocket connection to Discord
6 | according to the information provided by a client's
7 | gateway.
8 |
9 | The gateway provides such information whenever a
10 | voice state update is requested by our client.
11 |
12 | SVC depends upon two gateway events:
13 |
14 | - VOICE_STATE_UPDATE
15 |
16 | This event provides necessary session information.
17 | This will be used to create a class instance of SVC.
18 |
19 | - VOICE_SERVER_UPDATE
20 |
21 | This event provides necessary server information.
22 | This will be used to "prepare" the SVC instance.
23 |
24 | Only after the preparation, the SVC instance can
25 | be started (i.e. the websocket connection can be
26 | established).
27 |
28 | The underlying UDP connection is responsible
29 | for the voice channel audio and video transmission.
30 | """
31 |
32 |
33 | import asyncio
34 | import enum
35 | import logging
36 | import socket
37 | import struct
38 | import typing
39 |
40 | import aiohttp
41 | import nacl.secret
42 | import nacl.utils
43 |
44 | from strawberry.utils import checked_add
45 |
46 | from ..packetizers import audio_packetizer, h264_packetizer
47 |
48 |
49 | class VoiceOpCodes(enum.IntEnum):
50 | IDENTIFY = 0
51 | SELECT_PROTOCOL = 1
52 | READY = 2
53 | HEARTBEAT = 3
54 | SELECT_PROTOCOL_ACK = 4
55 | SPEAKING = 5
56 | HEARTBEAT_ACK = 6
57 | RESUME = 7
58 | HELLO = 8
59 | RESUMED = 9
60 | VIDEO = 12
61 | CLIENT_DISCONNECT = 13
62 | SESSION_UPDATE = 14
63 | MEDIA_SINK_WANTS = 15
64 | VOICE_BACKEND_VERSION = 16
65 | CHANNEL_OPTIONS_UPDATE = 17
66 | FLAGS = 18
67 | SPEED_TEST = 19
68 | PLATFORM = 20
69 |
70 |
71 | class VoiceConnection:
72 | def __init__(
73 | self,
74 | session: aiohttp.ClientSession,
75 | *,
76 | channel_id: str,
77 | user_id: str,
78 | session_id: str,
79 | endpoint: str,
80 | token: str,
81 | guild_id: "str | None" = None,
82 | encryption_mode: str = "xsalsa20_poly1305_lite",
83 | audio_packetizer=audio_packetizer.AudioPacketizer,
84 | video_packetizer=h264_packetizer.H264Packetizer,
85 | ):
86 | self.logger = logging.getLogger("voice_connection")
87 |
88 | self.session: aiohttp.ClientSession = session
89 | self.loop = asyncio.get_event_loop()
90 | self.encryption_mode = encryption_mode
91 |
92 | self.last_heartbeat_at: int = 0
93 |
94 | self.udp_connection = UDPConnection(
95 | self, audio_packetizer=audio_packetizer, video_packetizer=video_packetizer
96 | )
97 |
98 | self.guild_id = guild_id
99 | self.channel_id = channel_id
100 |
101 | self.server_id = self.guild_id or self.channel_id
102 |
103 | self.user_id = user_id
104 | self.session_id = session_id
105 | self.endpoint = endpoint
106 | self.token = token
107 |
108 | self.ws: typing.Optional[aiohttp.ClientWebSocketResponse] = None
109 |
110 | self.our_ip: typing.Optional[str] = None
111 | self.our_port: typing.Optional[int] = None
112 |
113 | self.ip: typing.Optional[str] = None
114 | self.port: typing.Optional[int] = None
115 |
116 | self.ssrc: typing.Optional[int] = None
117 | self.video_ssrc: typing.Optional[int] = None
118 | self.rtx_ssrc: typing.Optional[int] = None
119 |
120 | self.secret_key: typing.Optional[str] = None
121 | self.ws_handler_task: typing.Optional[asyncio.Task] = None
122 |
123 | @property
124 | def own_identity(self):
125 | if not (self.our_ip or self.our_port):
126 | return None
127 |
128 | return f"{self.our_ip}:{self.our_port}"
129 |
130 | @own_identity.setter
131 | def own_identity(self, value: tuple[str, int]):
132 | self.our_ip, self.our_port = value
133 |
134 | def set_server_address(self, ip: str, port: int):
135 | self.ip = ip
136 | self.port = port
137 |
138 | def set_ssrc(self, ssrc: int):
139 | self.ssrc = ssrc
140 | self.video_ssrc = ssrc + 1
141 | self.rtx_ssrc = ssrc + 2
142 |
143 | self.udp_connection.set_ssrc(self.ssrc, self.video_ssrc)
144 |
145 | @property
146 | def is_ready(self):
147 | return all((self.endpoint, self.token, self.ip, self.port, self.ssrc))
148 |
149 | def ensure_ready(self):
150 | if not self.is_ready:
151 | raise RuntimeError("Voice connection is not ready yet.")
152 |
153 | async def setup_heartbeat(self, interval):
154 | self.logger.debug(f"Setting up heartbeat with interval {interval}ms.")
155 |
156 | while self.ws is not None and not self.ws.closed:
157 | await asyncio.sleep(interval / 1000)
158 |
159 | try:
160 | self.last_heartbeat_at = self.loop.time()
161 | self.logger.debug("Sending heartbeat.")
162 | await self.ws.send_json({"op": VoiceOpCodes.HEARTBEAT, "d": 1337})
163 | except ConnectionResetError:
164 | return await self.ws.close()
165 |
166 | async def set_video_state(
167 | self,
168 | state: bool,
169 | width: int = 1280,
170 | height: int = 720,
171 | framerate: int = 30,
172 | bitrate: int = 25 * 1024,
173 | ):
174 | self.ensure_ready()
175 |
176 | return await self.ws.send_json(
177 | {
178 | "op": VoiceOpCodes.VIDEO,
179 | "d": {
180 | "audio_ssrc": self.ssrc,
181 | "video_ssrc": self.video_ssrc,
182 | "rtx_ssrc": self.rtx_ssrc,
183 | "streams": [
184 | {
185 | "type": "video",
186 | "rid": "100",
187 | "ssrc": self.video_ssrc,
188 | "active": state,
189 | "quality": 100,
190 | "rtx_ssrc": self.rtx_ssrc,
191 | "max_bitrate": bitrate,
192 | "max_framerate": framerate,
193 | "max_resolution": {
194 | "type": "fixed",
195 | "width": width,
196 | "height": height,
197 | },
198 | }
199 | ],
200 | },
201 | }
202 | )
203 |
204 | async def handle_ws_events(
205 | self,
206 | ws: aiohttp.ClientWebSocketResponse,
207 | *,
208 | udp_socket_preparation_event: typing.Optional[asyncio.Event] = None,
209 | ):
210 | async for msg in ws:
211 | payload = msg.json()
212 | data = payload["d"]
213 |
214 | match payload["op"]:
215 | case VoiceOpCodes.READY:
216 | self.set_ssrc(data["ssrc"])
217 | self.set_server_address(data["ip"], data["port"])
218 |
219 | self.udp_connection.create_udp_socket()
220 | await self.set_video_state(False)
221 |
222 | case VoiceOpCodes.HELLO:
223 | self.loop.create_task(
224 | self.setup_heartbeat(data["heartbeat_interval"])
225 | )
226 |
227 | case VoiceOpCodes.HEARTBEAT_ACK:
228 | latency = (self.loop.time() - self.last_heartbeat_at) * 1000
229 |
230 | if self.ip and self.port:
231 | addr = f"{self.ip}:{self.port}"
232 | else:
233 | addr = "Unknown"
234 |
235 | self.logger.debug(
236 | f"Heartbeat ACK was received. Latency: {latency:.2f}ms. Address: {addr}"
237 | )
238 |
239 | case VoiceOpCodes.SPEAKING:
240 | ...
241 |
242 | case VoiceOpCodes.SELECT_PROTOCOL_ACK:
243 | self.secret_key = bytes(data["secret_key"])
244 | udp_socket_preparation_event.set()
245 |
246 | case VoiceOpCodes.RESUMED:
247 | ...
248 |
249 | async def set_speaking(self, speaking: bool):
250 | self.ensure_ready()
251 |
252 | return await self.ws.send_json(
253 | {
254 | "op": VoiceOpCodes.SPEAKING,
255 | "d": {
256 | "speaking": int(speaking),
257 | "delay": 0,
258 | "ssrc": self.ssrc,
259 | },
260 | }
261 | )
262 |
263 | async def set_protocols(self):
264 | self.ensure_ready()
265 |
266 | return await self.ws.send_json(
267 | {
268 | "op": VoiceOpCodes.SELECT_PROTOCOL,
269 | "d": {
270 | "protocol": "udp",
271 | "codecs": [
272 | {
273 | "name": self.udp_connection.audio_packetizer.codec,
274 | "type": "audio",
275 | "priority": 1000,
276 | "payload_type": 120,
277 | },
278 | {
279 | "name": self.udp_connection.video_packetizer.codec,
280 | "type": "video",
281 | "priority": 1000,
282 | "payload_type": 101,
283 | "rtx_payload_type": 102,
284 | "encode": True,
285 | "decode": True,
286 | },
287 | ],
288 | "data": {
289 | "address": self.our_ip,
290 | "port": self.our_port,
291 | "mode": self.encryption_mode,
292 | },
293 | },
294 | }
295 | )
296 |
297 | async def start(self):
298 | """
299 | Start the SVC websocket connection and wait until
300 | the internal UDP connection receives protocol acknowledgement.
301 | """
302 |
303 | if self.is_ready:
304 | raise RuntimeError("Media connection has already started.")
305 |
306 | self.ws = await self.session.ws_connect(
307 | f"wss://{self.endpoint}/", params={"v": 7}
308 | )
309 | # Do resume here in the future if it is that crucial.
310 |
311 | await self.ws.send_json(
312 | {
313 | "op": VoiceOpCodes.IDENTIFY,
314 | "d": {
315 | "server_id": self.server_id or self.guild_id or self.channel_id,
316 | "user_id": self.user_id,
317 | "session_id": self.session_id,
318 | "token": self.token,
319 | "video": True,
320 | "streams": [{"type": "screen", "rid": "100", "quality": 100}],
321 | },
322 | }
323 | )
324 |
325 | udp_socket_preparation_event = asyncio.Event()
326 |
327 | self.ws_handler_task = self.loop.create_task(
328 | self.handle_ws_events(
329 | self.ws, udp_socket_preparation_event=udp_socket_preparation_event
330 | )
331 | )
332 | return await udp_socket_preparation_event.wait()
333 |
334 |
335 | class UDPConnection:
336 | MAX_INT_16 = 1 << 16
337 | MAX_INT_32 = 1 << 32
338 |
339 | def __init__(
340 | self,
341 | conn: VoiceConnection,
342 | *,
343 | audio_packetizer=audio_packetizer.AudioPacketizer,
344 | video_packetizer=h264_packetizer.H264Packetizer,
345 | ) -> None:
346 | self.nonce = 0
347 |
348 | self.conn = conn
349 |
350 | self.logger = logging.getLogger("udp_connection")
351 |
352 | self.audio_packetizer = audio_packetizer(self)
353 | self.video_packetizer = video_packetizer(self)
354 |
355 | self.udp_socket = None
356 |
357 | def set_ssrc(self, audio: int, video: int):
358 | self.audio_packetizer.ssrc = audio
359 | self.video_packetizer.ssrc = video
360 |
361 | def send_audio_frame(self, frame: bytearray):
362 | return self.audio_packetizer.send_frame(frame)
363 |
364 | def send_video_frame(self, frame: bytearray):
365 | return self.video_packetizer.send_frame(frame)
366 |
367 | def create_udp_socket(self):
368 | self.udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
369 |
370 | payload = bytearray(74)
371 |
372 | struct.pack_into(">H", payload, 0, 0x1)
373 | struct.pack_into(">H", payload, 2, 0x46)
374 | struct.pack_into(">I", payload, 4, self.conn.ssrc)
375 | self.udp_socket.sendto(payload, (self.conn.ip, self.conn.port))
376 |
377 | data = self.udp_socket.recv(74)
378 |
379 | (handshake,) = struct.unpack(">H", data[:2])
380 |
381 | if handshake != 2:
382 | raise ValueError("Invalid handshake payload received from the server")
383 |
384 | our_ip = data[8 : data.find(0, 8)].decode("utf-8")
385 | (our_port,) = struct.unpack(">H", data[-2:])
386 |
387 | self.conn.own_identity = our_ip, our_port
388 |
389 | self.conn.loop.create_task(self.conn.set_protocols())
390 |
391 | def send_packet(self, packet: bytearray):
392 | if self.udp_socket is None:
393 | raise ValueError("UDP socket not created")
394 |
395 | self.udp_socket.sendto(packet, (self.conn.ip, self.conn.port))
396 |
397 | def close(self):
398 | if self.udp_socket is not None:
399 | self.udp_socket.close()
400 | self.udp_socket = None
401 |
402 | def encrypt_data_xsalsa20_poly1305(self, header: bytes, data) -> bytes:
403 | box = nacl.secret.SecretBox(bytes(self.conn.secret_key))
404 | nonce = bytearray(24)
405 | nonce[: len(header)] = header
406 |
407 | return header + box.encrypt(bytes(data), bytes(nonce)).ciphertext
408 |
409 | def encrypt_data_xsalsa20_poly1305_suffix(self, header: bytes, data) -> bytes:
410 | box = nacl.secret.SecretBox(bytes(self.conn.secret_key))
411 | nonce = nacl.utils.random(nacl.secret.SecretBox.NONCE_SIZE)
412 |
413 | return header + box.encrypt(bytes(data), nonce).ciphertext + nonce
414 |
415 | def encrypt_data_xsalsa20_poly1305_lite(self, header: bytes, data) -> bytes:
416 | self.nonce = checked_add(self.nonce, 1, self.MAX_INT_32)
417 |
418 | box = nacl.secret.SecretBox(bytes(self.conn.secret_key))
419 | nonce = bytearray(24)
420 | nonce[:4] = struct.pack(">I", self.nonce)
421 |
422 | return header + box.encrypt(bytes(data), bytes(nonce)).ciphertext + nonce[:4]
423 |
424 | encryptors = {
425 | "xsalsa20_poly1305": encrypt_data_xsalsa20_poly1305,
426 | "xsalsa20_poly1305_suffix": encrypt_data_xsalsa20_poly1305_suffix,
427 | "xsalsa20_poly1305_lite": encrypt_data_xsalsa20_poly1305_lite,
428 | }
429 |
430 | def encrypt_data(self, header: bytes, data: bytes) -> bytes:
431 | if self.conn.encryption_mode not in self.encryptors:
432 | raise ValueError(
433 | f"Unsupported encryption mode: {self.conn.encryption_mode}"
434 | )
435 |
436 | return self.encryptors[self.conn.encryption_mode](self, header, data)
437 |
--------------------------------------------------------------------------------
/strawberry/gateway.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import base64
3 | import enum
4 |
5 | import aiohttp
6 |
7 | from .connection import StreamConnection, VoiceConnection
8 |
9 | voice_capabilities = 1 << 7
10 |
11 |
12 | class DiscordGatewayOPCodes(enum.IntEnum):
13 | DISPATCH = 0
14 | HEARTBEAT = 1
15 | IDENTIFY = 2
16 | PRESENCE_UPDATE = 3
17 | VOICE_STATE_UPDATE = 4
18 | VOICE_SERVER_PING = 5
19 | RESUME = 6
20 | RECONNECT = 7
21 | REQUEST_GUILD_MEMBERS = 8
22 | INVALID_SESSION = 9
23 | HELLO = 10
24 | HEARTBEAT_ACK = 11
25 | CALL_CONNECT = 13
26 | GUILD_SUBSCRIPTIONS = 14
27 | LOBBY_CONNECT = 15
28 | LOBBY_DISCONNECT = 16
29 | LOBBY_VOICE_STATES_UPDATE = 17
30 | STREAM_CREATE = 18
31 | STREAM_DELETE = 19
32 | STREAM_WATCH = 20
33 | STREAM_PING = 21
34 | STREAM_SET_PAUSED = 22
35 | REQUEST_GUILD_APPLICATION_COMMANDS = 24
36 | EMBEDDED_ACTIVITY_LAUNCH = 25
37 | EMBEDDED_ACTIVITY_CLOSE = 26
38 | EMBEDDED_ACTIVITY_UPDATE = 27
39 | REQUEST_FORUM_UNREADS = 28
40 | REMOTE_COMMAND = 29
41 | GET_DELETED_ENTITY_IDS_NOT_MATCHING_HASH = 30
42 | REQUEST_SOUNDBOARD_SOUNDS = 31
43 | SPEED_TEST_CREATE = 32
44 | SPEED_TEST_DELETE = 33
45 | REQUEST_LAST_MESSAGES = 34
46 | SEARCH_RECENT_MEMBERS = 35
47 |
48 |
49 | class DiscordGateway:
50 | VOICE_CAPABILITIES = 1 << 7
51 | DISCORD_API_ENDPOINT = "https://discord.com/api/v9"
52 |
53 | GATEWAY_VERSION = 9
54 |
55 | def __init__(self, token: str, *, session=None):
56 | self.loop = asyncio.get_event_loop()
57 |
58 | if token[:4] == "Bot ":
59 | raise ValueError("Invalid token: Bot tokens are not supported.")
60 |
61 | uid_payload, _ = token.split(".", 1)
62 |
63 | self.token = token
64 |
65 | self.session = session or aiohttp.ClientSession()
66 | self.user_id = base64.b64decode(uid_payload + "===").decode("utf-8")
67 |
68 | self.sequence = None
69 | self.ws_handler_task = None
70 |
71 | self.interceptors = []
72 |
73 | self.voice_connection: asyncio.Future[VoiceConnection] = asyncio.Future()
74 | self.stream_connection: asyncio.Future[StreamConnection] = asyncio.Future()
75 |
76 | self.pending_joins = {}
77 | self.ws = None
78 |
79 | self.latency = 0
80 | self.last_heartbeat_sent = 0
81 |
82 | async def join_voice_channel(self, channel_id: str, guild_id=None, region=None):
83 | await self.ws.send_json(
84 | {
85 | "op": DiscordGatewayOPCodes.VOICE_STATE_UPDATE,
86 | "d": {
87 | "guild_id": guild_id,
88 | "channel_id": channel_id,
89 | "self_mute": False,
90 | "self_deaf": False,
91 | "self_video": False,
92 | "preferred_region": region,
93 | },
94 | }
95 | )
96 |
97 | state_update, server_update = await self.create_ws_interceptor(
98 | (
99 | lambda data: data["t"] == "VOICE_STATE_UPDATE"
100 | and data["d"]["channel_id"] == channel_id
101 | and data["d"]["user_id"] == self.user_id
102 | ),
103 | (lambda data: data["t"] == "VOICE_SERVER_UPDATE"),
104 | )
105 |
106 | voice_conn = VoiceConnection(
107 | self.session,
108 | guild_id=guild_id,
109 | channel_id=channel_id,
110 | user_id=self.user_id,
111 | session_id=state_update["d"]["session_id"],
112 | endpoint=server_update["d"]["endpoint"],
113 | token=server_update["d"]["token"],
114 | )
115 |
116 | await voice_conn.start()
117 | return voice_conn
118 |
119 | async def heartbeat(self, interval):
120 | while self.ws is not None and not self.ws.closed:
121 | await asyncio.sleep(interval / 1000)
122 | await self.ws.send_json({"op": DiscordGatewayOPCodes.HEARTBEAT, "d": 1337})
123 | self.last_heartbeat_sent = self.loop.time()
124 |
125 | async def create_ws_interceptor(self, *predicates):
126 | """
127 | Intercepts the next message that satisfies the predicate
128 | if the predicate fails, the regular handling is done.
129 | """
130 | interception = asyncio.Future()
131 | unmatched_predicates = list(predicates)
132 | predicate_mapping = dict.fromkeys(predicates)
133 |
134 | async def interceptor(data):
135 | for predicate in unmatched_predicates:
136 | if predicate(data):
137 | unmatched_predicates.remove(predicate)
138 | predicate_mapping[predicate] = data
139 |
140 | if not unmatched_predicates:
141 | interception.set_result(list(predicate_mapping.values()))
142 | self.interceptors.remove(interceptor)
143 |
144 | self.interceptors.append(interceptor)
145 | return await interception
146 |
147 | async def handle_incoming(self):
148 | if self.ws is None:
149 | return
150 |
151 | async for message in self.ws:
152 | data = message.json()
153 |
154 | for interceptor in self.interceptors:
155 | await interceptor(data)
156 |
157 | match data["op"]:
158 | case DiscordGatewayOPCodes.HELLO:
159 | self.loop.create_task(
160 | self.heartbeat(data["d"]["heartbeat_interval"])
161 | )
162 |
163 | case DiscordGatewayOPCodes.DISPATCH:
164 | match data["t"]:
165 | case "READY":
166 | ...
167 |
168 | case DiscordGatewayOPCodes.HEARTBEAT_ACK:
169 | self.latency = (self.loop.time() - self.last_heartbeat_sent) * 1000
170 |
171 | async def ws_connect(self):
172 | async with self.session.get(f"{self.DISCORD_API_ENDPOINT}/gateway") as response:
173 | gateway_endpoint = (await response.json())["url"]
174 |
175 | self.ws = await self.session.ws_connect(
176 | gateway_endpoint,
177 | params={
178 | "v": self.GATEWAY_VERSION,
179 | "encoding": "json",
180 | },
181 | )
182 | self.ws_handler_task = self.loop.create_task(self.handle_incoming())
183 |
184 | await self.ws.send_json(
185 | {
186 | "op": DiscordGatewayOPCodes.IDENTIFY,
187 | "d": {
188 | "token": self.token,
189 | "capabilities": voice_capabilities,
190 | "properties": {},
191 | "compress": False,
192 | },
193 | }
194 | )
195 |
196 | async def wait(self):
197 | await self.ws_handler_task
198 |
199 | async def update_voice_state(self, muted=False, deafened=False, video=False):
200 | voice_conn = await self.voice_connection
201 |
202 | await self.ws.send_json(
203 | {
204 | "op": DiscordGatewayOPCodes.VOICE_STATE_UPDATE,
205 | "d": {
206 | "guild_id": voice_conn.guild_id,
207 | "channel_id": voice_conn.channel_id,
208 | "self_mute": muted,
209 | "self_deaf": deafened,
210 | "self_video": video,
211 | },
212 | }
213 | )
214 |
215 | async def create_stream(self, voice_conn: VoiceConnection, preferred_region=None):
216 | payload = {
217 | "op": DiscordGatewayOPCodes.STREAM_CREATE,
218 | "d": {
219 | "type": "guild",
220 | "guild_id": voice_conn.guild_id,
221 | "channel_id": voice_conn.channel_id,
222 | "preferred_region": preferred_region,
223 | },
224 | }
225 |
226 | if voice_conn.guild_id is None:
227 | payload["d"]["type"] = "call"
228 |
229 | await self.ws.send_json(payload)
230 |
231 | (
232 | stream_create_data,
233 | stream_server_update_data,
234 | ) = await self.create_ws_interceptor(
235 | (lambda data: data["t"] == "STREAM_CREATE"),
236 | (lambda data: data["t"] == "STREAM_SERVER_UPDATE"),
237 | )
238 |
239 | stream_conn = StreamConnection.from_voice_connection(
240 | voice_conn,
241 | stream_key=stream_create_data["d"]["stream_key"],
242 | rtc_server_id=stream_create_data["d"]["rtc_server_id"],
243 | rtc_server_endpoint=stream_server_update_data["d"]["endpoint"],
244 | rtc_server_token=stream_server_update_data["d"]["token"],
245 | )
246 |
247 | await self.set_stream_pause(stream_conn, False)
248 | await stream_conn.start()
249 | return stream_conn
250 |
251 | async def set_stream_pause(self, stream_conn: StreamConnection, paused: bool):
252 | await self.ws.send_json(
253 | {
254 | "op": DiscordGatewayOPCodes.STREAM_SET_PAUSED,
255 | "d": {
256 | "stream_key": stream_conn.stream_key,
257 | "paused": paused,
258 | },
259 | }
260 | )
261 |
262 | async def delete_stream(self, stream_conn: StreamConnection):
263 | await self.ws.send_json(
264 | {
265 | "op": DiscordGatewayOPCodes.STREAM_DELETE,
266 | "d": {"stream_key": stream_conn.stream_key},
267 | }
268 | )
269 |
--------------------------------------------------------------------------------
/strawberry/packetizers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/justfoolingaround/strawberry/dcbb7a36adff520c6ac7b4669d137377e338c74a/strawberry/packetizers/__init__.py
--------------------------------------------------------------------------------
/strawberry/packetizers/audio_packetizer.py:
--------------------------------------------------------------------------------
1 | from .base_packetizer import BaseMediaPacketizer
2 |
3 |
4 | class AudioPacketizer(BaseMediaPacketizer):
5 | codec = "opus"
6 |
7 | def __init__(self, conn):
8 | super().__init__(conn, 0x78, False)
9 | self.frame_size = 48000 // 1000 * 20
10 |
11 | def send_frame(self, frame: bytearray):
12 | self.conn.send_packet(self.conn.encrypt_data(self.get_rtp_header(), frame))
13 | self.increment_timestamp(self.frame_size)
14 |
--------------------------------------------------------------------------------
/strawberry/packetizers/base_packetizer.py:
--------------------------------------------------------------------------------
1 | import struct
2 | import typing
3 |
4 | from strawberry.utils import checked_add
5 |
6 | if typing.TYPE_CHECKING:
7 | from strawberry.connection import UDPConnection
8 |
9 |
10 | class BaseMediaPacketizer:
11 | codec: str
12 |
13 | MAX_INT_16 = 1 << 16
14 | MAX_INT_32 = 1 << 32
15 |
16 | def __init__(
17 | self,
18 | conn: "UDPConnection",
19 | payload_type: int,
20 | extensions_enabled: bool = False,
21 | ):
22 | self.conn = conn
23 | self.payload_type = payload_type
24 | self.sequence = 0
25 | self.mtu = 1200
26 | self.extensions_enabled = extensions_enabled
27 | self.timestamp = 0
28 |
29 | self.ssrc = 0
30 |
31 | def send_frame(self, _: bytearray):
32 | raise NotImplementedError
33 |
34 | def get_new_sequence(self):
35 | self.sequence = checked_add(self.sequence, 1, self.MAX_INT_16)
36 | return self.sequence
37 |
38 | def increment_timestamp(self, increment):
39 | self.timestamp = checked_add(self.timestamp, int(increment), self.MAX_INT_32)
40 |
41 | def get_rtp_header(self, is_last: bool = True):
42 | return bytes(
43 | (
44 | 0x80 | (0x10 if self.extensions_enabled else 0x00),
45 | self.payload_type | (0x80 if is_last else 0x00),
46 | )
47 | ) + struct.pack(">HII", self.get_new_sequence(), self.timestamp, self.ssrc)
48 |
49 | def get_header_extension(self):
50 | profile = bytearray(4)
51 |
52 | extensions_enabled = [
53 | {
54 | "id": 5,
55 | "len": 2,
56 | "val": 0,
57 | }
58 | ]
59 |
60 | profile[0] = 0xBE
61 | profile[1] = 0xDE
62 |
63 | struct.pack_into(">H", profile, 2, len(extensions_enabled))
64 |
65 | for extension in extensions_enabled:
66 | data = bytearray(4)
67 |
68 | data[0] = (extension["id"] & 0b00001111) << 4
69 | data[0] |= (extension["len"] - 1) & 0b00001111
70 |
71 | struct.pack_into(">H", data, 1, extension["val"])
72 | profile += data
73 |
74 | return profile
75 |
--------------------------------------------------------------------------------
/strawberry/packetizers/h264_packetizer.py:
--------------------------------------------------------------------------------
1 | import math
2 | import typing
3 |
4 | from strawberry.utils import partition_chunks
5 |
6 | from .base_packetizer import BaseMediaPacketizer
7 |
8 | if typing.TYPE_CHECKING:
9 | from strawberry.connection import UDPConnection
10 |
11 |
12 | class H264Packetizer(BaseMediaPacketizer):
13 | codec = "H264"
14 |
15 | def __init__(self, conn: "UDPConnection"):
16 | super().__init__(conn, 0x65, True)
17 |
18 | self.fps = 30
19 |
20 | def send_frame(self, nalus: list[bytes]):
21 | for i, nalu in enumerate(nalus):
22 | is_last = i == len(nalus) - 1
23 |
24 | if len(nalu) <= self.mtu:
25 | self.conn.send_packet(
26 | self.conn.encrypt_data(
27 | self.get_rtp_header(is_last),
28 | self.get_header_extension() + nalu,
29 | )
30 | )
31 |
32 | else:
33 | nal0 = nalu[0]
34 | chunks_count = math.ceil((len(nalu) - 1) / self.mtu)
35 |
36 | nal_type = nal0 & 0x1F
37 | fnri = nal0 & 0xE0
38 |
39 | default_header = bytes((0x1C | fnri,))
40 |
41 | for j, nal_fragment in enumerate(partition_chunks(nalu[1:], self.mtu)):
42 | chunk_header = default_header
43 | is_final_chunk = j == chunks_count - 1
44 |
45 | if j == 0:
46 | chunk_header += bytes((0x80 | nal_type,))
47 | else:
48 | if is_final_chunk:
49 | chunk_header += bytes((0x40 | nal_type,))
50 | else:
51 | chunk_header += bytes((nal_type,))
52 |
53 | self.conn.send_packet(
54 | self.conn.encrypt_data(
55 | self.get_rtp_header(is_final_chunk and is_last),
56 | self.get_header_extension() + chunk_header + nal_fragment,
57 | )
58 | )
59 |
60 | self.increment_timestamp(90000 / self.fps)
61 |
--------------------------------------------------------------------------------
/strawberry/sources/__init__.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import typing
3 |
4 | from .h264_source import VideoSource
5 | from .opus_source import AudioSource
6 |
7 | if typing.TYPE_CHECKING:
8 | pass
9 |
10 | __all__ = [
11 | "VideoSource",
12 | "AudioSource",
13 | "try_probe_source",
14 | "create_av_sources_from_single_process",
15 | ]
16 |
17 |
18 | def create_av_sources_from_single_process(
19 | source: str,
20 | width: int = 1280,
21 | height: int = 720,
22 | has_burned_in_subtitles: bool = False,
23 | *,
24 | framerate: "int | None" = None,
25 | crf: "int | None" = None,
26 | audio_source: "str | None" = None,
27 | ):
28 | subprocess_kwargs = {
29 | "stdout": subprocess.PIPE,
30 | "stderr": subprocess.PIPE,
31 | }
32 |
33 | args = ("ffmpeg", "-hide_banner", "-loglevel", "quiet", "-i", source)
34 |
35 | if crf is not None:
36 | args += ("-crf", str(crf))
37 |
38 | if framerate is not None:
39 | args += ("-r", str(framerate))
40 |
41 | vf = f"scale={width}:{height}"
42 |
43 | if has_burned_in_subtitles:
44 | escaped_source = source.replace(":", "\\:").replace("'", "\\'")
45 |
46 | vf += ",subtitles=" + f"'{escaped_source}'" + ":si=0"
47 |
48 | args += (
49 | "-f",
50 | "h264",
51 | "-reconnect",
52 | "1",
53 | "-reconnect_streamed",
54 | "1",
55 | "-reconnect_delay_max",
56 | "5",
57 | "-vf",
58 | vf,
59 | "-tune",
60 | "zerolatency",
61 | "-pix_fmt",
62 | "yuv420p",
63 | "-preset",
64 | "ultrafast",
65 | "-profile:v",
66 | "baseline",
67 | "-bsf:v",
68 | "h264_metadata=aud=insert",
69 | "pipe:1",
70 | )
71 |
72 | if audio_source is not None:
73 | args += (
74 | "-i",
75 | audio_source,
76 | )
77 |
78 | args += (
79 | "-map_metadata",
80 | "-1",
81 | "-reconnect",
82 | "1",
83 | "-reconnect_streamed",
84 | "1",
85 | "-reconnect_delay_max",
86 | "5",
87 | "-f",
88 | "opus",
89 | "-c:a",
90 | "libopus",
91 | "-ar",
92 | "48000",
93 | "-ac",
94 | "2",
95 | "-b:a",
96 | f"{AudioSource.bitrate}k",
97 | "pipe:2",
98 | )
99 |
100 | process = subprocess.Popen(args, **subprocess_kwargs)
101 |
102 | return VideoSource(process.stdout), AudioSource(process.stderr)
103 |
104 |
105 | def try_probe_source(source: str):
106 | ffprobe = subprocess.Popen(
107 | (
108 | "ffprobe",
109 | "-v",
110 | "error",
111 | "-show_entries",
112 | "stream=width,height,avg_frame_rate,duration,codec_type,bit_rate",
113 | "-of",
114 | "default=noprint_wrappers=1",
115 | source,
116 | ),
117 | stdout=subprocess.PIPE,
118 | stderr=subprocess.PIPE,
119 | )
120 |
121 | stdout, stderr = ffprobe.communicate()
122 |
123 | assert ffprobe.returncode == 0, stderr.decode("utf-8")
124 | stdout_text = stdout.decode("utf-8")
125 |
126 | video_probes = []
127 | audio_probes = []
128 | subtitle_probes = []
129 | attachment_probes = []
130 |
131 | curr_probe = None
132 |
133 | for line in stdout_text.splitlines():
134 | key, value = line.split("=")
135 |
136 | if key == "codec_type":
137 | curr_probe = {}
138 |
139 | match value:
140 | case "video":
141 | video_probes.append(curr_probe)
142 | case "audio":
143 | audio_probes.append(curr_probe)
144 | case "subtitle":
145 | subtitle_probes.append(curr_probe)
146 | case "attachment":
147 | attachment_probes.append(curr_probe)
148 | case _:
149 | curr_probe = None
150 | else:
151 | if curr_probe is not None:
152 | curr_probe[key] = value if value != "N/A" else None
153 |
154 | return {
155 | "video": video_probes,
156 | "audio": audio_probes,
157 | "subtitle": subtitle_probes,
158 | "attachment": attachment_probes,
159 | }
160 |
--------------------------------------------------------------------------------
/strawberry/sources/h264_source.py:
--------------------------------------------------------------------------------
1 | import enum
2 | import functools
3 | import io
4 | import struct
5 | import subprocess
6 |
7 | EPB_PREFIX = b"\x00\x00\x03"
8 | NAL_SUFFIX = b"\x00\x00\x01"
9 |
10 |
11 | class NalUnitTypes(enum.IntEnum):
12 | Unspecified = 0
13 | CodedSliceNonIDR = enum.auto()
14 | CodedSlicePartitionA = enum.auto()
15 | CodedSlicePartitionB = enum.auto()
16 | CodedSlicePartitionC = enum.auto()
17 | CodedSliceIdr = enum.auto()
18 | SEI = enum.auto()
19 | SPS = enum.auto()
20 | PPS = enum.auto()
21 | AccessUnitDelimiter = enum.auto()
22 | EndOfSequence = enum.auto()
23 | EndOfStream = enum.auto()
24 | FillerData = enum.auto()
25 | SEIExtenstion = enum.auto()
26 | PrefixNalUnit = enum.auto()
27 | SubsetSPS = enum.auto()
28 |
29 |
30 | @functools.lru_cache()
31 | def get_raw_byte_sequence_payload(frame: bytes):
32 | raw = b""
33 |
34 | while (epbs_pos := frame.find(EPB_PREFIX)) != -1:
35 | size = 3
36 |
37 | if frame[epbs_pos + 3] <= 0x03:
38 | size -= 1
39 |
40 | raw += frame[: epbs_pos + size]
41 | frame = frame[epbs_pos + 3 :]
42 |
43 | return raw + frame
44 |
45 |
46 | class H264NalPacketIterator:
47 | def __init__(self):
48 | self.buffer = b""
49 | self.access_unit = []
50 |
51 | def iter_access_units(self, chunk: bytes):
52 | self.buffer += chunk
53 |
54 | *frames, self.buffer = self.buffer.split(NAL_SUFFIX)
55 |
56 | for frame in frames:
57 | if frame[-1] == 0:
58 | frame = frame[:-1]
59 |
60 | if not frame:
61 | continue
62 |
63 | unit_type = frame[0] & 0x1F
64 |
65 | if unit_type == NalUnitTypes.AccessUnitDelimiter:
66 | if self.access_unit:
67 | yield self.access_unit
68 | self.access_unit.clear()
69 | else:
70 | if unit_type in (NalUnitTypes.SPS, NalUnitTypes.SEI):
71 | self.access_unit.append(get_raw_byte_sequence_payload(frame))
72 | else:
73 | self.access_unit.append(frame)
74 |
75 | def iter_packets(self, chunk: bytes):
76 | for access_unit in self.iter_access_units(chunk):
77 | yield b"".join(struct.pack(">I", len(nalu)) + nalu for nalu in access_unit)
78 |
79 |
80 | class VideoSource:
81 | def __init__(
82 | self,
83 | input_stream: "io.BufferedIOBase",
84 | ):
85 | self.input = input_stream
86 |
87 | self.packet_iter = H264NalPacketIterator()
88 |
89 | def iter_packets(self):
90 | for chunk in iter(lambda: self.input.read(8192), b""):
91 | yield from self.packet_iter.iter_access_units(chunk)
92 |
93 | @classmethod
94 | def from_source(
95 | cls,
96 | source: "str | io.BufferedIOBase",
97 | width: int = 1280,
98 | height: int = 720,
99 | has_burned_in_subtitles: bool = False,
100 | *,
101 | framerate: "int | None" = None,
102 | crf: "int | None" = None,
103 | ):
104 | subprocess_kwargs = {
105 | "stdout": subprocess.PIPE,
106 | }
107 |
108 | if isinstance(source, str):
109 | args = ("ffmpeg", "-i", source)
110 | else:
111 | subprocess_kwargs["stdin"] = source
112 | args = ("ffmpeg", "-i", "pipe:0")
113 |
114 | if crf is not None:
115 | args += ("-crf", str(crf))
116 |
117 | if framerate is not None:
118 | args += (
119 | "-r",
120 | str(framerate),
121 | "-x264opts",
122 | f"keyint={framerate}:min-keyint={framerate}",
123 | "-g",
124 | str(framerate),
125 | )
126 |
127 | vf = f"scale={width}:{height}"
128 |
129 | if has_burned_in_subtitles:
130 | if isinstance(source, str):
131 | escaped_source = (
132 | source.replace("\\", "/").replace(":", "\\:").replace("'", "\\'")
133 | )
134 |
135 | vf += ",subtitles=" + f"'{escaped_source}'" + ":si=0"
136 | else:
137 | vf += ",subtitles=pipe\\:0:si=0"
138 |
139 | args += (
140 | "-f",
141 | "h264",
142 | "-reconnect",
143 | "1",
144 | "-reconnect_streamed",
145 | "1",
146 | "-reconnect_delay_max",
147 | "5",
148 | "-vf",
149 | vf,
150 | "-tune",
151 | "zerolatency",
152 | "-pix_fmt",
153 | "yuv420p",
154 | "-preset",
155 | "ultrafast",
156 | "-profile:v",
157 | "baseline",
158 | "-bsf:v",
159 | "h264_metadata=aud=insert",
160 | "-an",
161 | "-loglevel",
162 | "warning",
163 | "pipe:1",
164 | )
165 |
166 | process = subprocess.Popen(args, **subprocess_kwargs)
167 |
168 | return cls(process.stdout)
169 |
--------------------------------------------------------------------------------
/strawberry/sources/opus_source.py:
--------------------------------------------------------------------------------
1 | import struct
2 | import subprocess
3 | from typing import IO
4 |
5 |
6 | class OggPage:
7 | flag: int
8 | gran_pos: int
9 | serial: int
10 | pagenum: int
11 | crc: int
12 | segnum: int
13 |
14 | def __init__(self, stream: IO[bytes]) -> None:
15 | header = stream.read(0x17)
16 |
17 | (
18 | self.flag,
19 | self.gran_pos,
20 | self.serial,
21 | self.pagenum,
22 | self.crc,
23 | self.segnum,
24 | ) = struct.unpack(" None:
50 | self.stream: IO[bytes] = stream
51 |
52 | def __iter__(self):
53 | buffer = b""
54 |
55 | for frame in iter(lambda: self.stream.read(4), b""):
56 | if frame == b"OggS":
57 | for data, is_complete in OggPage(self.stream).iter_packets():
58 | buffer += data
59 | if is_complete:
60 | yield buffer
61 | buffer = b""
62 |
63 |
64 | class AudioSource:
65 | bitrate = 128
66 |
67 | def __init__(self, source: "IO[bytes]"):
68 | self.packet_iter = OggStream(source)
69 |
70 | def iter_packets(self):
71 | yield from self.packet_iter
72 |
73 | @classmethod
74 | def from_source(
75 | cls,
76 | source: "str | IO[bytes]",
77 | ):
78 | subprocess_kwargs = {
79 | "stdout": subprocess.PIPE,
80 | }
81 |
82 | if isinstance(source, str):
83 | args = ("ffmpeg", "-i", source)
84 | else:
85 | subprocess_kwargs["stdin"] = source
86 | args = ("ffmpeg", "-i", "pipe:0")
87 |
88 | args += (
89 | "-map_metadata",
90 | "-1",
91 | "-reconnect",
92 | "1",
93 | "-reconnect_streamed",
94 | "1",
95 | "-reconnect_delay_max",
96 | "5",
97 | "-f",
98 | "opus",
99 | "-c:a",
100 | "libopus",
101 | "-ar",
102 | "48000",
103 | "-ac",
104 | "2",
105 | "-b:a",
106 | f"{AudioSource.bitrate}k",
107 | "-vn",
108 | "-loglevel",
109 | "warning",
110 | "pipe:1",
111 | )
112 |
113 | process = subprocess.Popen(args, **subprocess_kwargs)
114 |
115 | return cls(process.stdout)
116 |
--------------------------------------------------------------------------------
/strawberry/streamer.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import threading
3 | import time
4 |
5 | from .connection import StreamConnection, UDPConnection, VoiceConnection
6 | from .sources import (
7 | AudioSource,
8 | VideoSource,
9 | create_av_sources_from_single_process,
10 | try_probe_source,
11 | )
12 |
13 |
14 | def invoke_source_stream(
15 | source,
16 | udp: UDPConnection,
17 | in_between_delay: float,
18 | pause_event: "threading.Event | None" = None,
19 | ):
20 | logger = logging.getLogger("streamer")
21 |
22 | if isinstance(source, AudioSource):
23 | sender = udp.send_audio_frame
24 | logger = logger.getChild("audio")
25 | else:
26 | sender = udp.send_video_frame
27 | logger = logger.getChild("video")
28 |
29 | paused_duration = 0
30 | start = None
31 |
32 | for loops, packet in enumerate(source.iter_packets(), 1):
33 | if start is None:
34 | start = time.perf_counter()
35 |
36 | if pause_event is not None and pause_event.is_set():
37 | paused_start = time.perf_counter()
38 | pause_event.wait()
39 | paused_duration += time.perf_counter() - paused_start
40 |
41 | sender(packet)
42 |
43 | delay = start + in_between_delay * loops - time.perf_counter() - paused_duration
44 |
45 | if delay < 0:
46 | behind_by = -delay * 1000
47 | if behind_by > 1000:
48 | logger.warning(
49 | "Stream is lagging by %.2f ms, experiencing poor connection.",
50 | behind_by,
51 | )
52 |
53 | time.sleep(max(0, delay))
54 |
55 |
56 | def ffmpeg_fps_eval(fps: str):
57 | numerator, denominator = map(int, fps.split("/", 1))
58 |
59 | if denominator == 0:
60 | return None
61 |
62 | return numerator / denominator
63 |
64 |
65 | async def stream(
66 | conn: VoiceConnection,
67 | source: str,
68 | *,
69 | audio_source: "str | None" = None,
70 | forced_width: int = 0,
71 | forced_height: int = 0,
72 | pause_event: "threading.Event | None" = None,
73 | ):
74 | probes = try_probe_source(source)
75 |
76 | has_audio_in_source = probes["audio"]
77 | has_audio = audio_source or has_audio_in_source
78 |
79 | sources = []
80 | is_udp_source = source[:6] == "udp://"
81 |
82 | if probes["video"]:
83 | max_video_res = max(
84 | probes["video"], key=lambda x: int(x["width"]) * int(x["height"])
85 | )
86 |
87 | width = forced_width or int(max_video_res["width"])
88 | height = forced_height or int(max_video_res["height"])
89 |
90 | fps = round(ffmpeg_fps_eval(max_video_res["avg_frame_rate"]) or 30)
91 |
92 | await conn.set_video_state(
93 | True,
94 | width,
95 | height,
96 | fps,
97 | )
98 | conn.udp_connection.video_packetizer.fps = fps
99 |
100 | if has_audio_in_source and is_udp_source:
101 | # You can only open 1 udp server at a time.
102 | # For some reason the stderr source has serious
103 | # latency (>1s).
104 | video, audio = create_av_sources_from_single_process(
105 | source,
106 | has_burned_in_subtitles=bool(probes["subtitle"]),
107 | width=width,
108 | height=height,
109 | audio_source=audio_source,
110 | framerate=fps,
111 | )
112 |
113 | sources.extend(
114 | (
115 | (
116 | video,
117 | 1 / fps,
118 | ),
119 | (
120 | audio,
121 | 20 / 1000,
122 | ),
123 | )
124 | )
125 | else:
126 | sources.append(
127 | (
128 | VideoSource.from_source(
129 | source,
130 | has_burned_in_subtitles=bool(probes["subtitle"]),
131 | width=width,
132 | height=height,
133 | framerate=int(fps),
134 | ),
135 | 1 / fps,
136 | )
137 | )
138 |
139 | else:
140 | await conn.set_video_state(False)
141 |
142 | if isinstance(source, StreamConnection):
143 | raise ValueError("StreamConnection requires a video source")
144 |
145 | if not (has_audio_in_source and is_udp_source) and (
146 | audio_source is not None or has_audio
147 | ):
148 | if audio_source is not None:
149 | asrc = AudioSource.from_source(audio_source)
150 | else:
151 | asrc = AudioSource.from_source(source)
152 |
153 | sources.append(
154 | (
155 | asrc,
156 | 1 / 50,
157 | )
158 | )
159 |
160 | threads = [
161 | threading.Thread(
162 | target=invoke_source_stream,
163 | args=(
164 | src,
165 | conn.udp_connection,
166 | delay,
167 | pause_event,
168 | ),
169 | )
170 | for src, delay in sources
171 | ]
172 |
173 | for thread in threads:
174 | thread.start()
175 |
176 | return threads
177 |
--------------------------------------------------------------------------------
/strawberry/utils.py:
--------------------------------------------------------------------------------
1 | def checked_add(integer: int, quantity: int, limit: int):
2 | return (integer + quantity) % limit
3 |
4 |
5 | def partition_chunks(data: bytes, chunk_size: int):
6 | for i in range(0, len(data), chunk_size):
7 | yield data[i : i + chunk_size]
8 |
--------------------------------------------------------------------------------
/strawberry_config.toml.example:
--------------------------------------------------------------------------------
1 | [user]
2 | token = "..."
3 |
4 | [voice]
5 |
6 | # Comment out the guild part if the channel is a dm channel.
7 | guild_id = "..."
8 | channel_id = "..."
9 |
10 | # Optional
11 | preferred_region = ""
12 |
--------------------------------------------------------------------------------
/strawberry_yum.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import shutil
3 | import subprocess
4 | import sys
5 |
6 | import toml
7 |
8 | from strawberry.gateway import DiscordGateway
9 | from strawberry.streamer import stream
10 |
11 | with open("strawberry_config.toml") as f:
12 | config = toml.load(f)
13 |
14 | with open("assets/strawberry_preview.png", "rb") as f:
15 | thumbnail = f.read()
16 |
17 |
18 | def invoke_ytdlp(query: str):
19 | if not shutil.which("yt-dlp"):
20 | return None
21 |
22 | proc = subprocess.Popen(
23 | [
24 | "yt-dlp",
25 | "-g",
26 | query,
27 | "--format",
28 | "bestvideo[height<=720]+bestaudio/best[height<=720]",
29 | ],
30 | stdout=subprocess.PIPE,
31 | stderr=subprocess.PIPE,
32 | )
33 | stdout, stderr = proc.communicate()
34 |
35 | if proc.returncode != 0 and stderr:
36 | raise ValueError(stderr.decode("utf-8"))
37 |
38 | streams = stdout.decode("utf-8").strip().split("\n")
39 | if not streams:
40 | return None
41 |
42 | mapped = {"video": streams[0]}
43 |
44 | if len(streams) > 1:
45 | mapped["audio"] = streams[1]
46 |
47 | return mapped
48 |
49 |
50 | async def main():
51 | args_copy = sys.argv.copy()
52 |
53 | stream_with_ytdlp = "--yt-dlp" in args_copy
54 |
55 | if stream_with_ytdlp:
56 | args_copy.remove("--yt-dlp")
57 |
58 | stream_what = sys.argv[1]
59 |
60 | guild_id, channel_id, region = (
61 | config["voice"].get("guild_id"),
62 | config["voice"]["channel_id"],
63 | config["voice"]["preferred_region"],
64 | )
65 |
66 | gateway_ws = DiscordGateway(
67 | config["user"]["token"],
68 | )
69 |
70 | if stream_with_ytdlp:
71 | media = invoke_ytdlp(stream_what)
72 |
73 | if media is None:
74 | kwargs = {
75 | "source": stream_what,
76 | }
77 | else:
78 | if "audio" in media:
79 | kwargs = {
80 | "source": media["video"],
81 | "audio_source": media["audio"],
82 | }
83 | else:
84 | kwargs = {
85 | "source": media["video"],
86 | }
87 | else:
88 | kwargs = {
89 | "source": stream_what,
90 | }
91 |
92 | await gateway_ws.ws_connect()
93 | conn = await gateway_ws.join_voice_channel(channel_id, guild_id, region or None)
94 | stream_conn = await gateway_ws.create_stream(conn)
95 | threads = await stream(stream_conn, **kwargs)
96 | # # Do something with the threads
97 | await stream_conn.set_preview(gateway_ws, thumbnail, "image/png")
98 | await gateway_ws.wait()
99 |
100 |
101 | asyncio.run(main())
102 |
--------------------------------------------------------------------------------