├── requirements.txt
├── core
├── miscellaneous.py
├── exceptions.py
├── __init__.py
├── load_balancing.py
├── events.py
├── node__DEPRECATED.py
├── nodews__DEPRECATED.py
├── lavalink.py
├── nodeaio.py
└── player.py
├── README.md
├── LICENSE
├── examples
└── README.md
├── .gitignore
└── setup.py
/requirements.txt:
--------------------------------------------------------------------------------
1 | git+https://github.com/Rapptz/discord.py@rewrite#egg=discord.py[voice]
2 |
--------------------------------------------------------------------------------
/core/miscellaneous.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 |
4 | def format_time(millis):
5 | return time.strftime('%H:%M:%S', time.gmtime(millis/1000))
6 |
--------------------------------------------------------------------------------
/core/exceptions.py:
--------------------------------------------------------------------------------
1 |
2 | class LavalinkException(Exception):
3 | def __init__(self, msg):
4 | self.msg = msg
5 |
6 |
7 | class IllegalAction(LavalinkException):
8 | pass
9 |
10 |
11 | class NodeException(LavalinkException):
12 | pass
13 |
--------------------------------------------------------------------------------
/core/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | logger = logging.getLogger("magma")
4 | logger.addHandler(logging.NullHandler())
5 |
6 | from .events import *
7 | from .exceptions import *
8 | from .lavalink import *
9 | from .player import *
10 | from .miscellaneous import *
11 | from .nodeaio import *
12 |
13 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Magma
2 |
3 | A Lavalink connector written in Python, translated from Fre_d's Java Lavalink client.
4 | Magma is the best Python Lavalink connector right now, as it is flawlessly incorporated into Himebot, a music bot with over 160k servers.
5 |
6 | ### Requirements:
7 | * Lavalink v3
8 | * discord.py, rewrite (any other version should work just fine)
9 | * asyncio
10 | * websockets
11 | * aiohttp
12 |
13 | **Magma depends on discord.py rewrite**
14 |
15 | More info in requirements.txt
16 |
17 | ### Help and support:
18 | You can join the official Discord server for Magma here:
19 | https://discord.gg/JpPAMYD
20 | There is a very basic working example of a cog implementing Magma under examples which you can use for reference.
21 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Initzx
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/examples/README.md:
--------------------------------------------------------------------------------
1 | # Implementation
2 | ```python
3 | class MusicPlayer(AbstractPlayerEventAdapter):
4 | def __init__(self):
5 | # business rules
6 |
7 | async def track_pause(self, event: TrackPauseEvent):
8 | pass
9 |
10 | async def track_resume(self, event: TrackResumeEvent):
11 | pass
12 |
13 | async def track_start(self, event: TrackStartEvent):
14 | pass
15 |
16 | async def track_end(self, event: TrackEndEvent):
17 | pass
18 |
19 | async def track_exception(self, event: TrackExceptionEvent):
20 | pass
21 |
22 | async def track_stuck(self, event: TrackStuckEvent):
23 | pass
24 | ```
25 |
26 | ### Basic implementation
27 | You should subclass `AbstractPlayerEventAdapter` to handle all business logic and other components related to your bot. I also recommend using a manager that manages your inherited adapters in order to allow more control over the different adapters.
28 |
29 | ### Advanced implementation
30 | A more advanced implementation can be found in Himebot's code:
31 | [Player manager and player](https://github.com/initzx/rewrite/tree/multiprocessing/audio)
32 | [Commands and such](https://github.com/initzx/rewrite/blob/multiprocessing/commands/music.py)
33 |
34 | ### Logging
35 | The handler of Magma's logger is `logging.NullHandler` by default, though you can choose to receive logging messages by doing for example:
36 | ```python
37 | import logging
38 | logging.basicConfig(format="%(levelname)s -- %(name)s.%(funcName)s : %(message)s", level=logging.INFO)
39 | ```
40 | Place the code above somewhere where you initialize the bot.
41 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/core/load_balancing.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from .exceptions import IllegalAction
4 |
5 | logger = logging.getLogger("magma")
6 | big_number = 9e30
7 |
8 |
9 | class LoadBalancer:
10 |
11 | """
12 | The load balancer is copied from Fre_d's Java client, and works in somewhat the same way
13 | """
14 |
15 | def __init__(self, lavalink):
16 | self.lavalink = lavalink
17 |
18 | async def determine_best_node(self):
19 | nodes = self.lavalink.nodes.values()
20 | if not nodes:
21 | raise IllegalAction("No nodes found!")
22 | best_node = None
23 | record = big_number
24 | for node in nodes:
25 | penalties = Penalties(node, self.lavalink)
26 | total = await penalties.get_total()
27 | if total < record:
28 | best_node = node
29 | record = total
30 |
31 | if not best_node or not best_node.connected:
32 | raise IllegalAction(f"No available nodes! record: {record}")
33 | return best_node
34 |
35 | async def on_node_disconnect(self, node):
36 | logger.info(f"Node disconnected: {node.name}")
37 | new_node = await self.determine_best_node()
38 | for link in node.links.values():
39 | await link.change_node(new_node)
40 | node.links = {}
41 |
42 | async def on_node_connect(self, node):
43 | logger.info(f"Node connected: {node.name}")
44 | for link in self.lavalink.links.values():
45 | if link.node is node or not link.node or not link.node.connected:
46 | await link.change_node(node)
47 |
48 |
49 | class Penalties:
50 | def __init__(self, node, lavalink):
51 | self.node = node
52 | self.lavalink = lavalink
53 |
54 | self.player_penalty = 0
55 | self.cpu_penalty = 0
56 | self.deficit_frame_penalty = 0
57 | self.null_frame_penalty = 0
58 |
59 | async def get_total(self):
60 | # hard maths
61 | stats = self.node.stats
62 | if not self.node.connected or not stats:
63 | return big_number
64 |
65 | self.player_penalty = stats.playing_players
66 |
67 | self.cpu_penalty = 1.05 ** (100 * stats.system_load) * 10 - 10
68 | if stats.avg_frame_deficit != -1:
69 | self.deficit_frame_penalty = (1.03 ** (500 * (stats.avg_frame_deficit / 3000))) * 600 - 600
70 | self.null_frame_penalty = (1.03 ** (500 * (stats.avg_frame_nulled / 3000))) * 300 - 300
71 | self.null_frame_penalty *= 2
72 |
73 | return self.player_penalty + self.cpu_penalty + self.deficit_frame_penalty + self.null_frame_penalty
74 |
--------------------------------------------------------------------------------
/core/events.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 |
4 | class Event(ABC):
5 | """
6 | These event classes are similar to that of Lavaplayer's
7 | All business rules should be handled using these and an EventAdapter
8 | """
9 | @abstractmethod
10 | def __init__(self, player):
11 | self.player = player
12 |
13 |
14 | class TrackPauseEvent(Event):
15 | def __init__(self, player):
16 | super().__init__(player)
17 |
18 |
19 | class TrackResumeEvent(Event):
20 | def __init__(self, player):
21 | super().__init__(player)
22 |
23 |
24 | class TrackStartEvent(Event):
25 | def __init__(self, player, track):
26 | super().__init__(player)
27 | self.track = track
28 |
29 |
30 | class TrackEndEvent(Event):
31 | def __init__(self, player, track, reason):
32 | super().__init__(player)
33 | self.track = track
34 | self.reason = reason
35 |
36 |
37 | class TrackExceptionEvent(Event):
38 | def __init__(self, player, track, exception):
39 | super().__init__(player)
40 | self.track = track
41 | self.exception = exception
42 |
43 |
44 | class TrackStuckEvent(Event):
45 | def __init__(self, player, track, threshold_ms):
46 | super().__init__(player)
47 | self.track = track
48 | self.threshold_ms = threshold_ms
49 |
50 |
51 | class AbstractPlayerEventAdapter(ABC):
52 | """
53 | This is a base EventAdapter people can inherit to put on their players by doing:
54 |
55 | player.event_adapter = event_adapter
56 | """
57 | @abstractmethod
58 | async def track_pause(self, event: TrackPauseEvent):
59 | pass
60 |
61 | @abstractmethod
62 | async def track_resume(self, event: TrackResumeEvent):
63 | pass
64 |
65 | @abstractmethod
66 | async def track_start(self, event: TrackStartEvent):
67 | pass
68 |
69 | @abstractmethod
70 | async def track_end(self, event: TrackEndEvent):
71 | pass
72 |
73 | @abstractmethod
74 | async def track_exception(self, event: TrackExceptionEvent):
75 | pass
76 |
77 | @abstractmethod
78 | async def track_stuck(self, event: TrackStuckEvent):
79 | pass
80 |
81 | async def destroy(self):
82 | pass
83 |
84 | async def on_event(self, event):
85 | if not issubclass(event.__class__, Event):
86 | raise TypeError
87 | if isinstance(event, TrackPauseEvent):
88 | await self.track_pause(event)
89 | elif isinstance(event, TrackResumeEvent):
90 | await self.track_resume(event)
91 | elif isinstance(event, TrackStartEvent):
92 | await self.track_start(event)
93 | elif isinstance(event, TrackEndEvent):
94 | await self.track_end(event)
95 | elif isinstance(event, TrackExceptionEvent):
96 | await self.track_exception(event)
97 | elif isinstance(event, TrackStuckEvent):
98 | await self.track_stuck(event)
99 |
100 |
101 | class InternalEventAdapter(AbstractPlayerEventAdapter):
102 | """
103 | A default internal EventAdapter that only cares about track_end
104 | """
105 |
106 | async def track_pause(self, event: TrackPauseEvent):
107 | event.player.paused = True
108 |
109 | async def track_resume(self, event: TrackResumeEvent):
110 | event.player.paused = False
111 |
112 | async def track_start(self, event: TrackStartEvent):
113 | pass
114 |
115 | async def track_end(self, event: TrackEndEvent):
116 | event.player.reset()
117 |
118 | async def track_exception(self, event: TrackExceptionEvent):
119 | pass
120 |
121 | async def track_stuck(self, event: TrackStuckEvent):
122 | pass
123 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # Note: To use the 'upload' functionality of this file, you must:
5 | # $ pip install twine
6 |
7 | import io
8 | import os
9 | import sys
10 | from shutil import rmtree
11 |
12 | from setuptools import find_packages, setup, Command
13 |
14 | # Package meta-data.
15 | NAME = 'lavalink-magma'
16 | DESCRIPTION = 'The Python connector for Lavalink'
17 | URL = 'https://github.com/initzx/magma'
18 | AUTHOR = 'init0'
19 | REQUIRES_PYTHON = '3.6' # TODO
20 | VERSION = 'v3'
21 |
22 | # What packages are required for this module to be executed?
23 | REQUIRED = [
24 | "discord.py" # rewrite
25 | ]
26 |
27 | # The rest you shouldn't have to touch too much :)
28 | # ------------------------------------------------
29 | # Except, perhaps the License and Trove Classifiers!
30 | # If you do change the License, remember to change the Trove Classifier for that!
31 |
32 | here = os.path.abspath(os.path.dirname(__file__))
33 |
34 | # Import the README and use it as the long-description.
35 | # Note: this will only work if 'README.md' is present in your MANIFEST.in file!
36 | with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f:
37 | long_description = '\n' + f.read()
38 |
39 | # Load the package's __version__.py module as a dictionary.
40 | about = {}
41 | if not VERSION:
42 | with open(os.path.join(here, NAME, '__version__.py')) as f:
43 | exec(f.read(), about)
44 | else:
45 | about['__version__'] = VERSION
46 |
47 |
48 | class UploadCommand(Command):
49 | """Support setup.py upload."""
50 |
51 | description = 'Build and publish the package.'
52 | user_options = []
53 |
54 | @staticmethod
55 | def status(s):
56 | """Prints things in bold."""
57 | print('\033[1m{0}\033[0m'.format(s))
58 |
59 | def initialize_options(self):
60 | pass
61 |
62 | def finalize_options(self):
63 | pass
64 |
65 | def run(self):
66 | try:
67 | self.status('Removing previous builds…')
68 | rmtree(os.path.join(here, 'dist'))
69 | except OSError:
70 | pass
71 |
72 | self.status('Building Source and Wheel (universal) distribution…')
73 | os.system('{0} setup.py sdist bdist_wheel --universal'.format(sys.executable))
74 |
75 | self.status('Uploading the package to PyPi via Twine…')
76 | os.system('twine upload dist/*')
77 |
78 | self.status('Pushing git tags…')
79 | os.system('git tag v{0}'.format(about['__version__']))
80 | os.system('git push --tags')
81 |
82 | sys.exit()
83 |
84 |
85 | # Where the magic happens:
86 | setup(
87 | name=NAME,
88 | version=about['__version__'],
89 | description=DESCRIPTION,
90 | long_description=long_description,
91 | long_description_content_type='text/markdown',
92 | author=AUTHOR,
93 | # author_email=EMAIL,
94 | # python_requires=REQUIRES_PYTHON,
95 | url=URL,
96 | packages=find_packages(exclude=('tests',)),
97 | # If your package is a single module, use this instead of 'packages':
98 | # py_modules=['mypackage'],
99 |
100 | # entry_points={
101 | # 'console_scripts': ['mycli=mymodule:cli'],
102 | # },
103 | install_requires=REQUIRED,
104 | include_package_data=True,
105 | license='MIT',
106 | classifiers=[
107 | # Trove classifiers
108 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers
109 | 'License :: OSI Approved :: MIT License',
110 | 'Programming Language :: Python',
111 | 'Programming Language :: Python :: 3',
112 | 'Programming Language :: Python :: 3.6',
113 | # 'Programming Language :: Python :: Implementation :: CPython',
114 | # 'Programming Language :: Python :: Implementation :: PyPy'
115 | ],
116 | # $ setup.py publish support.
117 | cmdclass={
118 | 'upload': UploadCommand,
119 | },
120 | )
121 |
--------------------------------------------------------------------------------
/core/node__DEPRECATED.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import asyncio
4 | import threading
5 |
6 | import aiohttp
7 | import logging
8 | import websockets
9 | from discord.backoff import ExponentialBackoff
10 |
11 | from .events import TrackEndEvent, TrackStuckEvent, TrackExceptionEvent, TrackStartEvent
12 | from .exceptions import NodeException
13 |
14 | logger = logging.getLogger("magma")
15 |
16 |
17 | class NodeStats:
18 | def __init__(self, msg):
19 | self.msg = msg
20 |
21 | self.players = msg.get("players")
22 | self.playing_players = msg.get("playingPlayers")
23 | self.uptime = msg.get("uptime")
24 |
25 | mem = msg.get("memory")
26 | self.mem_free = mem.get("free")
27 | self.mem_used = mem.get("used")
28 | self.mem_allocated = mem.get("allocated")
29 | self.mem_reservable = mem.get("reserveable")
30 |
31 | cpu = msg.get("cpu")
32 | self.cpu_cores = cpu.get("cores")
33 | self.system_load = cpu.get("systemLoad")
34 | self.lavalink_load = cpu.get("lavalinkLoad")
35 |
36 | frames = msg.get("frameStats")
37 | if frames:
38 | # These are per minute
39 | self.avg_frame_sent = frames.get("sent")
40 | self.avg_frame_nulled = frames.get("nulled")
41 | self.avg_frame_deficit = frames.get("deficit")
42 | else:
43 | self.avg_frame_sent = -1
44 | self.avg_frame_nulled = -1
45 | self.avg_frame_deficit = -1
46 |
47 |
48 | class KeepAlive(threading.Thread):
49 | def __init__(self, node, interval, *args, **kwargs):
50 | super().__init__(*args, **kwargs)
51 | self.name = f"{node.name}-KeepAlive"
52 | self.daemon = True
53 | self.node = node
54 | self.ws = node.ws
55 | self.loop = node.ws.loop
56 | self.interval = interval
57 | self._stop_ev = threading.Event()
58 |
59 | def run(self):
60 | try:
61 | while not self._stop_ev.wait(self.interval):
62 | future = asyncio.run_coroutine_threadsafe(self.ws.ping(), loop=self.loop)
63 | future.result()
64 | except websockets.ConnectionClosed as e:
65 | logger.warning(f"Connection to `{self.node.name}` was closed`! Reason: {e.code}, {e.reason}")
66 | self.node.available = False
67 | if self.node.closing:
68 | asyncio.run_coroutine_threadsafe(self.node.on_close(e.code, e.reason), loop=self.loop)
69 | return
70 |
71 | logger.info(f"Attempting to reconnect `{self.node.name}`")
72 | future = asyncio.run_coroutine_threadsafe(self.node.connect(), loop=self.loop)
73 | future.result()
74 |
75 | def stop(self):
76 | self._stop_ev.set()
77 |
78 |
79 | class Node:
80 | def __init__(self, lavalink, name, host, port, headers):
81 | self.name = name
82 | self.lavalink = lavalink
83 | self.links = {}
84 | self.headers = headers
85 | self.keep_alive = None
86 | self.stats = None
87 | self.ws = None
88 | self.available = False
89 | self.closing = False
90 |
91 | self.uri = f"ws://{host}:{port}"
92 | self.rest_uri = f"http://{host}:{port}"
93 |
94 | async def _connect(self):
95 | backoff = ExponentialBackoff(2)
96 | while not (self.ws and self.ws.open):
97 | try:
98 | self.ws = await websockets.connect(self.uri, extra_headers=self.headers)
99 | task = asyncio.create_task(self.listen())
100 | # asyncio.create_task(hello())
101 | self.keep_alive = KeepAlive(self, 3)
102 | self.keep_alive.start()
103 | # await task
104 | except OSError:
105 | delay = backoff.delay()
106 | logger.error(f"Connection refused, trying again in {delay:.2f}s")
107 | await asyncio.sleep(delay)
108 |
109 | async def connect(self):
110 | await self._connect()
111 | await self.on_open()
112 |
113 | async def disconnect(self):
114 | logger.info(f"Closing websocket connection for node: {self.name}")
115 | self.closing = True
116 | await self.ws.close()
117 |
118 | async def _keep_alive(self):
119 | """
120 | **THIS IS VERY IMPORTANT**
121 |
122 | Lavalink will sometimes fail to recognize the client connection if
123 | a ping is not sent frequently. Websockets sends by default, a ping
124 | every 5-6 seconds, but this is not enough to maintain the connection.
125 |
126 | This is likely due to the deprecated ws draft: RFC 6455
127 | """
128 | try:
129 | while True:
130 | await self.ws.ping()
131 | await asyncio.sleep(2)
132 | except websockets.ConnectionClosed as e:
133 | logger.warning(f"Connection to `{self.name}` was closed! Reason: {e.code}, {e.reason}")
134 | self.available = False
135 | if self.closing:
136 | await self.on_close(e.code, e.reason)
137 | return
138 |
139 | try:
140 | logger.info(f"Attempting to reconnect `{self.name}`")
141 | await self.connect()
142 | except NodeException:
143 | await self.on_close(e.code, e.reason)
144 |
145 | async def listen(self):
146 | try:
147 | while True:
148 | msg = await self.ws.recv()
149 | logger.debug(f"Received websocket message from `{self.name}`: {msg}")
150 | await self.on_message(json.loads(msg))
151 | except websockets.ConnectionClosed:
152 | pass # ping() handles this for us, no need to hear it twice..
153 |
154 | async def on_open(self):
155 | self.available = True
156 | await self.lavalink.load_balancer.on_node_connect(self)
157 |
158 | async def on_close(self, code=None, reason=None):
159 | self.closing = False
160 | if self.keep_alive:
161 | self.keep_alive.stop()
162 |
163 | if not reason:
164 | reason = ""
165 |
166 | if code == 1000:
167 | logger.info(f"Connection to {self.name} closed gracefully with reason: {reason}")
168 | else:
169 | logger.warning(f"Connection to {self.name} closed unexpectedly with code: {code}, reason: {reason}")
170 |
171 | await self.lavalink.load_balancer.on_node_disconnect(self)
172 |
173 | async def on_message(self, msg):
174 | # We receive Lavalink responses here
175 | op = msg.get("op")
176 | if op == "playerUpdate":
177 | link = self.lavalink.get_link(msg.get("guildId"))
178 | if link:
179 | await link.player.provide_state(msg.get("state"))
180 | elif op == "stats":
181 | self.stats = NodeStats(msg)
182 | elif op == "event":
183 | await self.handle_event(msg)
184 | else:
185 | logger.info(f"Received unknown op: {op}")
186 |
187 | async def send(self, msg):
188 | if not self.ws or not self.ws.open:
189 | self.available = False
190 | raise NodeException("Websocket is not ready, cannot send message")
191 | logger.debug(f"Sending websocket message: {msg}")
192 | await self.ws.send(json.dumps(msg))
193 |
194 | async def get_tracks(self, query):
195 | # Fetch tracks from the Lavalink node using its REST API
196 | params = {"identifier": query}
197 | headers = {"Authorization": self.headers["Authorization"]}
198 | async with aiohttp.ClientSession(headers=headers) as session:
199 | async with session.get(self.rest_uri+"/loadtracks", params=params) as resp:
200 | return await resp.json()
201 |
202 | async def handle_event(self, msg):
203 | # Lavalink sends us track end event types
204 | link = self.lavalink.get_link(msg.get("guildId"))
205 | if not link:
206 | return # the link got destroyed
207 |
208 | player = link.player
209 | event = None
210 | event_type = msg.get("type")
211 |
212 | if event_type == "TrackEndEvent":
213 | event = TrackEndEvent(player, player.current, msg.get("reason"))
214 | elif event_type == "TrackStartEvent":
215 | event = TrackStartEvent(player, player.current)
216 | elif event_type == "TrackExceptionEvent":
217 | event = TrackExceptionEvent(player, player.current, msg.get("error"))
218 | elif event_type == "TrackStuckEvent":
219 | event = TrackStuckEvent(player, player.current, msg.get("thresholdMs"))
220 | elif event_type == "WebSocketClosedEvent":
221 | if msg.get("code") == 4006 and msg.get("byRemote"):
222 | await link.destroy()
223 |
224 | elif event_type:
225 | logger.info(f"Received unknown event: {event_type}")
226 |
227 | if event:
228 | await player.trigger_event(event)
229 |
--------------------------------------------------------------------------------
/core/nodews__DEPRECATED.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import asyncio
4 | import threading
5 | from collections import deque
6 |
7 | import aiohttp
8 | import logging
9 | import websockets
10 | from discord.backoff import ExponentialBackoff
11 |
12 | from .events import TrackEndEvent, TrackStuckEvent, TrackExceptionEvent, TrackStartEvent
13 | from .exceptions import NodeException
14 |
15 | logger = logging.getLogger("magma")
16 |
17 |
18 | class NodeStats:
19 | def __init__(self, msg):
20 | self.msg = msg
21 |
22 | self.players = msg.get("players")
23 | self.playing_players = msg.get("playingPlayers")
24 | self.uptime = msg.get("uptime")
25 |
26 | mem = msg.get("memory")
27 | self.mem_free = mem.get("free")
28 | self.mem_used = mem.get("used")
29 | self.mem_allocated = mem.get("allocated")
30 | self.mem_reservable = mem.get("reserveable")
31 |
32 | cpu = msg.get("cpu")
33 | self.cpu_cores = cpu.get("cores")
34 | self.system_load = cpu.get("systemLoad")
35 | self.lavalink_load = cpu.get("lavalinkLoad")
36 |
37 | frames = msg.get("frameStats")
38 | if frames:
39 | # These are per minute
40 | self.avg_frame_sent = frames.get("sent")
41 | self.avg_frame_nulled = frames.get("nulled")
42 | self.avg_frame_deficit = frames.get("deficit")
43 | else:
44 | self.avg_frame_sent = -1
45 | self.avg_frame_nulled = -1
46 | self.avg_frame_deficit = -1
47 |
48 |
49 | class ListenerThread(threading.Thread):
50 | def __init__(self, node, *args, **kwargs):
51 | super().__init__(*args, **kwargs)
52 | self.name = f"{node.name}-Listener"
53 | self.daemon = True
54 | self.node = node
55 | self.ws = node.ws
56 | self.loop = node.ws.loop
57 | self._stop_ev = threading.Event()
58 |
59 | def run(self):
60 | try:
61 | while not self._stop_ev.isSet():
62 | future = asyncio.run_coroutine_threadsafe(self.node.poll(), loop=self.loop)
63 | res = future.result()
64 | except websockets.ConnectionClosed as e:
65 | logger.warning(f"Connection to `{self.node.name}` was closed`! Reason: {e.code}, {e.reason}")
66 | self.node.available = False
67 | if self.node.closing:
68 | asyncio.run_coroutine_threadsafe(self.node.on_close(e.code, e.reason), loop=self.loop)
69 | return
70 |
71 | logger.info(f"Attempting to reconnect `{self.node.name}`")
72 | future = asyncio.run_coroutine_threadsafe(self.node.connect(), loop=self.loop)
73 | future.result()
74 |
75 | def stop(self):
76 | self._stop_ev.set()
77 |
78 |
79 | class Node:
80 | def __init__(self, lavalink, name, host, port, headers):
81 | self.name = name
82 | self.lavalink = lavalink
83 | self.links = {}
84 | self.headers = headers
85 | self.keep_alive = None
86 | self.stats = None
87 | self.ws = None
88 | self.send_queue = asyncio.Queue()
89 | self.listener_thread = None
90 | self.available = False
91 | self.closing = False
92 |
93 | self.uri = f"ws://{host}:{port}"
94 | self.rest_uri = f"http://{host}:{port}"
95 |
96 | async def _connect(self):
97 | backoff = ExponentialBackoff(2)
98 | while not (self.ws and self.ws.open):
99 | try:
100 | self.ws = await websockets.client.connect(self.uri,
101 | ping_interval=5,
102 | max_size=None,
103 | max_queue=None,
104 | extra_headers=self.headers)
105 | asyncio.create_task(self.poll())
106 | asyncio.create_task(self._send_task())
107 | # self.listener_thread = ListenerThread(self)
108 | # self.listener_thread.start()
109 | except OSError:
110 | delay = backoff.delay()
111 | logger.error(f"Connection refused, trying again in {delay:.2f}s")
112 | await asyncio.sleep(delay)
113 |
114 | async def connect(self):
115 | await self._connect()
116 | await self.on_open()
117 |
118 | async def disconnect(self):
119 | logger.info(f"Closing websocket connection for node: {self.name}")
120 | self.closing = True
121 | await self.ws.close()
122 |
123 | async def _keep_alive(self):
124 | """
125 | **THIS IS VERY IMPORTANT**
126 |
127 | Lavalink will sometimes fail to recognize the client connection if
128 | a ping is not sent frequently. Websockets sends by default, a ping
129 | every 5-6 seconds, but this is not enough to maintain the connection.
130 |
131 | This is likely due to the deprecated ws draft: RFC 6455
132 | """
133 | try:
134 | while True:
135 | await self.ws.ping()
136 | await asyncio.sleep(2)
137 | except websockets.ConnectionClosed as e:
138 | logger.warning(f"Connection to `{self.name}` was closed! Reason: {e.code}, {e.reason}")
139 | self.available = False
140 | if self.closing:
141 | await self.on_close(e.code, e.reason)
142 | return
143 |
144 | try:
145 | logger.info(f"Attempting to reconnect `{self.name}`")
146 | await self.connect()
147 | except NodeException:
148 | await self.on_close(e.code, e.reason)
149 |
150 | async def poll(self):
151 | try:
152 | while True:
153 | msg = await self.ws.recv()
154 | logger.debug(f"Received websocket message from `{self.name}`: {msg}")
155 | await self.on_message(json.loads(msg))
156 | except websockets.ConnectionClosed:
157 | logger.info(f"Connection closed to `{self.name}`")
158 |
159 | async def _send_task(self):
160 | while True:
161 | msg = await self.send_queue.get()
162 | logger.debug(f"Sending websocket message: {msg}")
163 | await self.ws.send(json.dumps(msg))
164 |
165 | async def send(self, msg):
166 | if not self.ws or not self.ws.open:
167 | self.available = False
168 | raise NodeException("Websocket is not ready, cannot send message")
169 | # logger.debug(f"Sending websocket message: {msg}")
170 | # await self.ws.send(json.dumps(msg))
171 | await self.send_queue.put(msg)
172 |
173 | async def on_open(self):
174 | self.available = True
175 | await self.lavalink.load_balancer.on_node_connect(self)
176 |
177 | async def on_close(self, code=None, reason=None):
178 | self.closing = False
179 | if self.keep_alive:
180 | self.keep_alive.stop()
181 |
182 | if not reason:
183 | reason = ""
184 |
185 | if code == 1000:
186 | logger.info(f"Connection to {self.name} closed gracefully with reason: {reason}")
187 | else:
188 | logger.warning(f"Connection to {self.name} closed unexpectedly with code: {code}, reason: {reason}")
189 |
190 | await self.lavalink.load_balancer.on_node_disconnect(self)
191 |
192 | async def on_message(self, msg):
193 | # We receive Lavalink responses here
194 | op = msg.get("op")
195 | if op == "playerUpdate":
196 | link = self.lavalink.get_link(msg.get("guildId"))
197 | if link:
198 | await link.player.provide_state(msg.get("state"))
199 | elif op == "stats":
200 | self.stats = NodeStats(msg)
201 | elif op == "event":
202 | await self.handle_event(msg)
203 | else:
204 | logger.info(f"Received unknown op: {op}")
205 |
206 | async def get_tracks(self, query):
207 | # Fetch tracks from the Lavalink node using its REST API
208 | params = {"identifier": query}
209 | headers = {"Authorization": self.headers["Authorization"]}
210 | async with aiohttp.ClientSession(headers=headers) as session:
211 | async with session.get(self.rest_uri + "/loadtracks", params=params) as resp:
212 | return await resp.json()
213 |
214 | async def handle_event(self, msg):
215 | # Lavalink sends us track end event types
216 | link = self.lavalink.get_link(msg.get("guildId"))
217 | if not link:
218 | return # the link got destroyed
219 |
220 | player = link.player
221 | event = None
222 | event_type = msg.get("type")
223 |
224 | if event_type == "TrackEndEvent":
225 | event = TrackEndEvent(player, player.current, msg.get("reason"))
226 | elif event_type == "TrackStartEvent":
227 | event = TrackStartEvent(player, player.current)
228 | elif event_type == "TrackExceptionEvent":
229 | event = TrackExceptionEvent(player, player.current, msg.get("error"))
230 | elif event_type == "TrackStuckEvent":
231 | event = TrackStuckEvent(player, player.current, msg.get("thresholdMs"))
232 | elif event_type == "WebSocketClosedEvent":
233 | if msg.get("code") == 4006 and msg.get("byRemote"):
234 | await link.destroy()
235 |
236 | elif event_type:
237 | logger.info(f"Received unknown event: {event_type}")
238 |
239 | if event:
240 | await player.trigger_event(event)
241 |
--------------------------------------------------------------------------------
/core/lavalink.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import logging
3 | import time
4 | from enum import Enum
5 | from typing import Optional
6 |
7 | from discord import InvalidArgument
8 | from discord.ext import commands
9 | from discord.ext.commands import BotMissingPermissions
10 | from discord.gateway import DiscordWebSocket
11 |
12 | from .exceptions import IllegalAction
13 | from .load_balancing import LoadBalancer
14 | from .nodeaio import Node
15 | from .player import Player, AudioTrackPlaylist
16 |
17 | logger = logging.getLogger("magma")
18 |
19 |
20 | class State(Enum):
21 | # States the Link can be in
22 | NOT_CONNECTED = 0
23 | CONNECTING = 1
24 | CONNECTED = 2
25 | DISCONNECTING = 3
26 | DESTROYING = 4
27 | DESTROYED = 5
28 |
29 |
30 | class Lavalink:
31 | def __init__(self, user_id, shard_count):
32 | self.user_id = user_id
33 | self.shard_count = shard_count
34 | self.loop = asyncio.get_event_loop()
35 | self.load_balancer = LoadBalancer(self)
36 | self.nodes = {}
37 | self.links = {}
38 |
39 | @property
40 | def playing_guilds(self):
41 | return {name: node.stats.playing_players for name, node in self.nodes.items() if node.stats}
42 |
43 | @property
44 | def total_playing_guilds(self):
45 | return sum(self.playing_guilds.values())
46 |
47 | async def on_socket_response(self, data):
48 | """
49 | YOU MUST ADD THIS WITH `bot.add_listener(lavalink.on_socket_response)`
50 | """
51 |
52 | if not data.get("t") in ("VOICE_SERVER_UPDATE", "VOICE_STATE_UPDATE"):
53 | return
54 | link = self.links.get(int(data['d']['guild_id']))
55 | if link:
56 | await link.update_voice(data)
57 |
58 | def get_link(self, guild_id: int, bot=None):
59 | """
60 | Return a Link for the specified guild
61 | :param guild_id: The guild id for the Link
62 | :param bot: The bot/shard where this was invoked
63 | :return: A Link
64 | """
65 | guild_id = int(guild_id)
66 |
67 | if guild_id in self.links or not bot:
68 | return self.links.get(guild_id)
69 |
70 | self.links[guild_id] = Link(self, guild_id, bot)
71 | return self.links[guild_id]
72 |
73 | async def add_node(self, name, host, port, password):
74 | """
75 | Add a Lavalink node
76 |
77 | :param name: The name of the node
78 | :param host: The web socket URI of the node, ("localhost")
79 | :param port: The REST URI of the node, ("2333")
80 | :param password: The password to connect to the node
81 | :return: A node
82 | """
83 | headers = {
84 | "Authorization": password,
85 | "Num-Shards": self.shard_count,
86 | "User-Id": self.user_id
87 | }
88 |
89 | node = Node(self, name, host, port, headers)
90 | await node.connect()
91 | self.nodes[name] = node
92 |
93 | async def get_best_node(self):
94 | """
95 | Determines the best Node based on penalty calculations
96 |
97 | :return: A Node
98 | """
99 | return await self.load_balancer.determine_best_node()
100 |
101 |
102 | class Link:
103 | def __init__(self, lavalink, guild_id, bot):
104 | self.lavalink = lavalink
105 | self.guild_id = guild_id
106 | self.bot = bot
107 | self.state = State.NOT_CONNECTED
108 | self.last_voice_update = {}
109 | self.last_session_id = None
110 | self._player = None
111 | self.node = None
112 |
113 | @property
114 | def player(self):
115 | if not self._player:
116 | self._player = Player(self)
117 | return self._player
118 |
119 | def set_state(self, state):
120 | if self.state.value > 3 and state.value != 5:
121 | raise IllegalAction(f"Cannot change the state to {state} when the state is {self.state}")
122 | self.state = state
123 |
124 | async def update_voice(self, data):
125 | logger.debug(f"Received voice update data: {data}")
126 | if not self.guild_id: # is this even necessary? :thinking:
127 | raise IllegalAction("Attempted to start audio connection with a guild that doesn't exist")
128 |
129 | if data["t"] == "VOICE_SERVER_UPDATE":
130 | self.last_voice_update.update({
131 | "op": "voiceUpdate",
132 | "event": data["d"],
133 | "guildId": data["d"]["guild_id"],
134 | "sessionId": self.last_session_id
135 | })
136 | node = await self.get_node(True)
137 | await node.send(self.last_voice_update)
138 | self.set_state(State.CONNECTED)
139 | else: # data["t"] == "VOICE_STATE_UPDATE"
140 |
141 | # We're selfish and only care about ourselves
142 | if int(data["d"]["user_id"]) != self.bot.user.id:
143 | return
144 |
145 | channel_id = data["d"]["channel_id"]
146 | self.last_session_id = data["d"]["session_id"]
147 | if not channel_id and self.state != State.DESTROYED:
148 | self.state = State.NOT_CONNECTED
149 | # if self.node:
150 | # print('detroying')
151 | # await self.destroy()
152 | # print('destroyed')
153 |
154 | def _get_shard_socket(self, shard_id: int) -> Optional[DiscordWebSocket]:
155 | if isinstance(self.bot, commands.AutoShardedBot):
156 | try:
157 | return self.bot.shards[shard_id].ws
158 | except AttributeError:
159 | return self.bot.shards[shard_id]._parent.ws
160 |
161 | if self.bot.shard_id is None or self.bot.shard_id == shard_id:
162 | return self.bot.ws
163 |
164 | async def get_tracks(self, query):
165 | """
166 | Get a list of AudioTracks from a query
167 |
168 | :param query: The query to pass to the Node
169 | :return:
170 | """
171 | node = await self.get_node(True)
172 | results = await node.get_tracks(query)
173 | return AudioTrackPlaylist(results)
174 |
175 | async def get_tracks_yt(self, query):
176 | return await self.get_tracks("ytsearch:" + query)
177 |
178 | async def get_tracks_sc(self, query):
179 | return await self.get_tracks("scsearch:" + query)
180 |
181 | async def get_node(self, select_if_absent=False):
182 | """
183 | Gets a Node for the link
184 |
185 | :param select_if_absent: A boolean that indicates if a Node should be created if there is none
186 | :return: A Node
187 | """
188 | if select_if_absent and not (self.node and self.node.connected):
189 | await self.change_node(await self.lavalink.get_best_node())
190 | return self.node
191 |
192 | async def change_node(self, node):
193 | """
194 | Change to another node
195 |
196 | :param node: The Node to change to
197 | :return:
198 | """
199 | self.node = node
200 | self.node.links[self.guild_id] = self
201 | if self.last_voice_update:
202 | await node.send(self.last_voice_update)
203 | if self._player:
204 | await self._player.node_changed()
205 |
206 | async def connect(self, channel):
207 | """
208 | Connect to a voice channel
209 |
210 | :param channel: The voice channel to connect to
211 | :return:
212 | """
213 | # We're using discord's websocket, not lavalink
214 | if channel.guild.id != self.guild_id:
215 | raise InvalidArgument("The guild of the channel isn't the the same as the link's!")
216 | if channel.guild.unavailable:
217 | raise IllegalAction("Cannot connect to guild that is unavailable!")
218 |
219 | me = channel.guild.me
220 | permissions = me.permissions_in(channel)
221 | if (not permissions.connect or len(channel.members) >= channel.user_limit >= 1) and not permissions.move_members:
222 | raise BotMissingPermissions(["connect"])
223 |
224 | self.set_state(State.CONNECTING)
225 | # payload = {
226 | # "op": 4,
227 | # "d": {
228 | # "guild_id": self.guild_id,
229 | # "channel_id": str(channel.id),
230 | # "self_mute": False,
231 | # "self_deaf": False
232 | # }
233 | # }
234 | # await self.bot._connection._get_websocket(self.guild_id).send_as_json(payload)
235 | await self._get_shard_socket(self.bot.shard_id).voice_state(self.guild_id, str(channel.id))
236 |
237 | start = time.monotonic()
238 | while not (me.voice and me.voice.channel):
239 | await asyncio.sleep(0.1)
240 | if start-time.monotonic() >= 10:
241 | raise IllegalAction("Couldn't connect to the channel within a reasonable timeframe!")
242 |
243 | async def disconnect(self):
244 | """
245 | Disconnect from the current voice channel
246 |
247 | :return:
248 | """
249 | # We're using discord's websocket, no lavalink
250 | # payload = {
251 | # "op": 4,
252 | # "d": {
253 | # "guild_id": self.guild_id,
254 | # "channel_id": None,
255 | # "self_mute": False,
256 | # "self_deaf": False
257 | # }
258 | # }
259 | #
260 | self.set_state(State.DISCONNECTING)
261 | await self._get_shard_socket(self.bot.shard_id).voice_state(self.guild_id, None)
262 |
263 | async def destroy(self):
264 | self.lavalink.links.pop(self.guild_id)
265 | if self._player and self.node:
266 | self.node.links.pop(self.guild_id)
267 | await self._player.destroy()
268 | self._player = None
269 |
--------------------------------------------------------------------------------
/core/nodeaio.py:
--------------------------------------------------------------------------------
1 | # PARTS OF THIS CODE IS TAKEN FROM: https://github.com/Devoxin/Lavalink.py/blob/master/lavalink/websocket.py
2 | # MIT License
3 | #
4 | # Copyright (c) 2019 Luke & William
5 | #
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy
7 | # of this software and associated documentation files (the "Software"), to deal
8 | # in the Software without restriction, including without limitation the rights
9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | # copies of the Software, and to permit persons to whom the Software is
11 | # furnished to do so, subject to the following conditions:
12 | #
13 | # The above copyright notice and this permission notice shall be included in all
14 | # copies or substantial portions of the Software.
15 | #
16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | # SOFTWARE.
23 |
24 | import asyncio
25 | import logging
26 | import traceback
27 |
28 | import aiohttp
29 | from discord.backoff import ExponentialBackoff
30 |
31 | from . import IllegalAction
32 | from .events import TrackEndEvent, TrackStuckEvent, TrackExceptionEvent, TrackStartEvent
33 |
34 | logger = logging.getLogger("magma")
35 | logging.getLogger('aiohttp').setLevel(logging.DEBUG)
36 |
37 |
38 | class NodeStats:
39 | def __init__(self, msg):
40 | self.msg = msg
41 |
42 | self.players = msg.get("players")
43 | self.playing_players = msg.get("playingPlayers")
44 | self.uptime = msg.get("uptime")
45 |
46 | mem = msg.get("memory")
47 | self.mem_free = mem.get("free")
48 | self.mem_used = mem.get("used")
49 | self.mem_allocated = mem.get("allocated")
50 | self.mem_reservable = mem.get("reserveable")
51 |
52 | cpu = msg.get("cpu")
53 | self.cpu_cores = cpu.get("cores")
54 | self.system_load = cpu.get("systemLoad")
55 | self.lavalink_load = cpu.get("lavalinkLoad")
56 |
57 | frames = msg.get("frameStats")
58 | if frames:
59 | # These are per minute
60 | self.avg_frame_sent = frames.get("sent")
61 | self.avg_frame_nulled = frames.get("nulled")
62 | self.avg_frame_deficit = frames.get("deficit")
63 | else:
64 | self.avg_frame_sent = -1
65 | self.avg_frame_nulled = -1
66 | self.avg_frame_deficit = -1
67 |
68 |
69 | class Node:
70 | def __init__(self, lavalink, name, host, port, headers):
71 | self.name = name
72 | self.lavalink = lavalink
73 | self.links = {}
74 | self.headers = {str(k): str(v) for k, v in headers.items()}
75 | self.stats = None
76 | self.session = aiohttp.ClientSession(headers={"Authorization": self.headers["Authorization"]})
77 | self.ws = None
78 | self.listen_task = None
79 | # self.available = False
80 | self.closing = False
81 |
82 | self.uri = f"ws://{host}:{port}"
83 | self.rest_uri = f"http://{host}:{port}"
84 |
85 | @property
86 | def connected(self):
87 | return self.ws and not self.ws.closed
88 |
89 | async def _connect(self):
90 | backoff = ExponentialBackoff(5, integral=True)
91 | while not self.connected:
92 | try:
93 | logger.info(f'Attempting to establish websocket connection to {self.name}')
94 | self.ws = await self.session.ws_connect(self.uri, headers=self.headers)
95 | except aiohttp.ClientConnectorError:
96 | logger.warning(f'[{self.name}] Invalid response received; this may indicate that '
97 | 'Lavalink is not running, or is running on a port different '
98 | 'to the one you passed to `add_node`.')
99 | except aiohttp.WSServerHandshakeError as ce:
100 | if ce.status in (401, 403):
101 | logger.error(f'Authentication failed when establishing a connection to {self.name}')
102 | return
103 |
104 | logger.warning(f'{self.name} returned a code {ce.status} which was unexpected')
105 |
106 | else:
107 | logger.info(f'Connection established to {self.name}')
108 | self.listen_task = asyncio.create_task(self.listen())
109 | return
110 |
111 | delay = backoff.delay()
112 | logger.error(f"Connection refused, trying again in {delay}s")
113 | await asyncio.sleep(delay)
114 |
115 | async def connect(self):
116 | await self._connect()
117 | await self.on_open()
118 |
119 | async def disconnect(self):
120 | logger.info(f"Closing websocket connection for node: {self.name}")
121 | await self.ws.close()
122 |
123 | async def listen(self):
124 | async for msg in self.ws:
125 | logger.debug(f"Received websocket message from `{self.name}`: {msg.data}")
126 | if msg.type == aiohttp.WSMsgType.TEXT:
127 | await self.on_message(msg.json())
128 | elif msg.type == aiohttp.WSMsgType.ERROR:
129 | exc = self.ws.exception()
130 | logger.error(f'Received an error from `{self.name}`: {exc}')
131 | await self.on_close(reason=exc)
132 | return
133 | elif msg.type in (aiohttp.WSMsgType.CLOSE,
134 | aiohttp.WSMsgType.CLOSING,
135 | aiohttp.WSMsgType.CLOSED):
136 | logger.info(f'Received close frame from `{self.name}`: {msg.data}')
137 | await self.on_close(msg.data, msg.extra)
138 | return
139 | await self.on_close(connect_again=True)
140 |
141 | async def send(self, msg):
142 | if not self.connected:
143 | await self.on_close(connect_again=True)
144 | # raise NodeException("Websocket is not ready, cannot send message")
145 |
146 | logger.debug(f"Sending websocket message: {msg}")
147 | await self.ws.send_json(msg)
148 |
149 | async def get_tracks(self, query, tries=5, retry_on_failure=True):
150 | # Fetch tracks from the Lavalink node using its REST API
151 | params = {"identifier": query}
152 | backoff = ExponentialBackoff(base=1)
153 | for attempt in range(tries):
154 | async with self.session.get(self.rest_uri + "/loadtracks", params=params) as resp:
155 | if resp.status != 200 and retry_on_failure:
156 | retry = backoff.delay()
157 | logger.error(f"Received status code ({resp.status}) while retrieving tracks, retrying in {retry} seconds. Attempt {attempt+1}/{tries}")
158 | continue
159 | elif resp.status != 200 and not retry_on_failure:
160 | logger.error(f"Received status code ({resp.status}) while retrieving tracks, not retrying.")
161 | return {}
162 | res = await resp.json()
163 | return res
164 |
165 | async def on_open(self):
166 | await self.lavalink.load_balancer.on_node_connect(self)
167 |
168 | async def on_close(self, code=None, reason=None, connect_again=False):
169 | self.closing = False
170 |
171 | if not reason:
172 | reason = ""
173 |
174 | if code == 1000:
175 | logger.info(f"Connection to {self.name} closed gracefully with reason: {reason}")
176 | else:
177 | logger.warning(f"Connection to {self.name} closed unexpectedly with code: {code}, reason: {reason}")
178 |
179 | try:
180 | await self.lavalink.load_balancer.on_node_disconnect(self)
181 | except IllegalAction:
182 | traceback.print_exc()
183 |
184 | if connect_again:
185 | logger.info(f"Attempting to reconnect to {self.name}...")
186 | await self.connect()
187 |
188 | async def on_message(self, msg):
189 | # We receive Lavalink responses here
190 | op = msg.get("op")
191 | if op == "playerUpdate":
192 | link = self.lavalink.get_link(msg.get("guildId"))
193 | if link:
194 | await link.player.provide_state(msg.get("state"))
195 | elif op == "stats":
196 | self.stats = NodeStats(msg)
197 | elif op == "event":
198 | await self.handle_event(msg)
199 | else:
200 | logger.info(f"Received unknown op: {op}")
201 |
202 | async def handle_event(self, msg):
203 | # Lavalink sends us track end event types
204 | link = self.lavalink.get_link(msg.get("guildId"))
205 | if not link:
206 | return # the link got destroyed
207 |
208 | player = link.player
209 | event = None
210 | event_type = msg.get("type")
211 |
212 | if event_type == "TrackEndEvent":
213 | event = TrackEndEvent(player, player.current, msg.get("reason"))
214 | elif event_type == "TrackStartEvent":
215 | event = TrackStartEvent(player, player.current)
216 | elif event_type == "TrackExceptionEvent":
217 | event = TrackExceptionEvent(player, player.current, msg.get("error"))
218 | elif event_type == "TrackStuckEvent":
219 | event = TrackStuckEvent(player, player.current, msg.get("thresholdMs"))
220 | elif event_type == "WebSocketClosedEvent":
221 | if msg.get("code") == 4006 and msg.get("byRemote"):
222 | await link.destroy()
223 |
224 | elif event_type:
225 | logger.info(f"Received unknown event: {event_type}")
226 |
227 | if event:
228 | await player.trigger_event(event)
229 |
--------------------------------------------------------------------------------
/core/player.py:
--------------------------------------------------------------------------------
1 | import traceback
2 | from enum import Enum
3 | from time import time
4 |
5 | from .exceptions import IllegalAction
6 | from .events import InternalEventAdapter, TrackPauseEvent, TrackResumeEvent, TrackStartEvent
7 |
8 |
9 | class LoadTypes(Enum):
10 | NO_MATCHES = -2
11 | LOAD_FAILED = -1
12 | UNKNOWN = 0
13 | TRACK_LOADED = 1
14 | PLAYLIST_LOADED = 2
15 | SEARCH_RESULT = 3
16 |
17 |
18 | class BassModes(Enum):
19 | OFF = "off"
20 | LOW = "low"
21 | MEDIUM = "medium"
22 | HIGH = "high"
23 | EXTREME = "extreme"
24 | SICKO = "SICKO"
25 |
26 |
27 | class AudioTrack:
28 | """
29 | The base AudioTrack class that is used by the player to play songs
30 | """
31 | def __init__(self, track):
32 | self.encoded_track = track['track']
33 | self.stream = track['info']['isStream']
34 | self.uri = track['info']['uri']
35 | self.title = track['info']['title']
36 | self.author = track['info']['author']
37 | self.identifier = track['info']['identifier']
38 | self.seekable = track['info']['isSeekable']
39 | self.duration = track['info']['length']
40 | self.user_data = None
41 |
42 |
43 | class AudioTrackPlaylist:
44 | def __init__(self, results):
45 | try:
46 | self.playlist_info = results["playlistInfo"]
47 | self.playlist_name = self.playlist_info.get("name")
48 | self.selected_track = self.playlist_info.get("selectedTrack")
49 | self.load_type = LoadTypes[results["loadType"]]
50 | self.tracks = [AudioTrack(track) for track in results["tracks"]]
51 | except KeyError:
52 | raise IllegalAction(f"Results invalid!, received: {results}")
53 |
54 | @property
55 | def is_playlist(self):
56 | return self.load_type.value == 2 and self.__len__() > 1
57 |
58 | @property
59 | def is_empty(self):
60 | return self.load_type.value < 0 or self.__len__() == 0
61 |
62 | def __iter__(self):
63 | for track in self.tracks:
64 | yield track
65 |
66 | def __len__(self):
67 | return self.tracks.__len__()
68 |
69 | def __getitem__(self, item):
70 | return self.tracks[item]
71 |
72 |
73 | class Equalizer:
74 | def __init__(self, options):
75 | self.__slots__ = tuple(mode.value for mode in options.keys())
76 | for k, v in options.items():
77 | setattr(self, k.value, v)
78 |
79 | @classmethod
80 | def bassboost(cls):
81 | return cls(
82 | {
83 | BassModes.OFF: [(0, 0), (1, 0)],
84 | BassModes.LOW: [(0, 0.25), (1, 0.15)],
85 | BassModes.MEDIUM: [(0, 0.50), (1, 0.25)],
86 | BassModes.HIGH: [(0, 0.75), (1, 0.50)],
87 | BassModes.EXTREME: [(0, 1), (1, 0.75)],
88 | BassModes.SICKO: [(0, 1), (1, 1)],
89 | }
90 | )
91 |
92 |
93 | class Player:
94 | internal_event_adapter = InternalEventAdapter()
95 |
96 | def __init__(self, link):
97 | self.link = link
98 | self.current = None
99 | self.event_adapter = None
100 | self.paused = False
101 | self.volume = 100
102 | self.equalizer = {band: 0 for band in range(15)}
103 | self.bass_mode = BassModes.OFF
104 | self.update_time = -1
105 | self._position = -1
106 |
107 | @property
108 | def is_playing(self):
109 | return self.current is not None
110 |
111 | @property
112 | def position(self):
113 | # We're going to get the position of the current song
114 | # There is a delay between each update so we gotta do some calculations
115 | # btw this is in fucking milliseconds
116 | if not self.paused:
117 | diff = round((time()-self.update_time)*1000)
118 | return min(self._position + diff, self.current.duration)
119 | return min(self._position, self.current.duration)
120 |
121 | def reset(self):
122 | self.current = None
123 | self.update_time = -1
124 | self._position = -1
125 |
126 | async def provide_state(self, state):
127 | self.update_time = time()
128 | if "position" in state:
129 | self._position = state["position"]
130 | return
131 | self.reset()
132 |
133 | async def seek_to(self, position):
134 | """
135 | Sends a request to the Lavalink node to seek to a specific position
136 | :param position: The position in millis
137 | :return:
138 | """
139 | if not self.current:
140 | raise IllegalAction("Not playing anything right now")
141 | if not self.current.seekable:
142 | raise IllegalAction("Cannot seek for this track")
143 |
144 | payload = {
145 | "op": "seek",
146 | "guildId": str(self.link.guild_id),
147 | "position": position
148 | }
149 |
150 | node = await self.link.get_node()
151 | await node.send(payload)
152 |
153 | async def set_paused(self, pause):
154 | """
155 | Sends a request to the Lavalink node to set the paused state
156 | :param pause: A boolean that indicates the pause state
157 | :return:
158 | """
159 |
160 | payload = {
161 | "op": "pause",
162 | "guildId": str(self.link.guild_id),
163 | "pause": pause,
164 | }
165 |
166 | node = await self.link.get_node()
167 | await node.send(payload)
168 |
169 | if pause:
170 | await self.trigger_event(TrackPauseEvent(self))
171 | else:
172 | await self.trigger_event(TrackResumeEvent(self))
173 |
174 | async def set_volume(self, volume):
175 | """
176 | Sends a request to the Lavalink node to set the volume
177 | :param volume: An integer from 0-150
178 | :return:
179 | """
180 | if not 0 <= volume <= 150:
181 | raise IllegalAction("Volume must be between 0-150")
182 |
183 | payload = {
184 | "op": "volume",
185 | "guildId": str(self.link.guild_id),
186 | "volume": volume,
187 | }
188 |
189 | node = await self.link.get_node()
190 | await node.send(payload)
191 | self.volume = volume
192 |
193 | async def set_eq(self, gains_list):
194 | """
195 | Sets gain for multiple bands
196 | :param gains_list: a list of tuples in (band, gain) order.
197 | :return:
198 | """
199 | bands = []
200 | for band, gain in gains_list:
201 |
202 | if not -1 < band < 15:
203 | continue
204 |
205 | gain = max(min(float(gain), 1.0), -0.25)
206 | bands.append({"band": band, "gain": gain})
207 | self.equalizer[band] = gain
208 |
209 | payload = {
210 | "op": "equalizer",
211 | "guildId": str(self.link.guild_id),
212 | "bands": bands
213 | }
214 |
215 | node = await self.link.get_node()
216 | await node.send(payload)
217 |
218 | async def set_gain(self, band, gain):
219 | """
220 | Sets the gain for 1 band
221 | :param band: a band from 0 to 14
222 | :param gain: a value from -0.25 to 1
223 | :return:
224 | """
225 | await self.set_eq((band, gain))
226 |
227 | async def set_bass(self, bass_mode):
228 | """
229 | Sets which bass mode the player is in
230 | :param bass_mode: an BassModes enum value
231 | :return:
232 | """
233 | gains = Equalizer.bassboost().__dict__[bass_mode.value]
234 | self.bass_mode = bass_mode
235 | await self.set_eq(gains)
236 |
237 | async def play(self, track, position=0, no_replace=True):
238 | """
239 | Sends a request to the Lavalink node to play an AudioTrack
240 | :param track: The AudioTrack to play
241 | :param position: Optional; the position to start the song at
242 | :param no_replace: if the current track should NOT be replaced
243 | :return:
244 | """
245 | payload = {
246 | "op": "play",
247 | "guildId": str(self.link.guild_id),
248 | "track": track.encoded_track,
249 | "startTime": position,
250 | "noReplace": no_replace
251 | }
252 | node = await self.link.get_node(True)
253 | await node.send(payload)
254 | self.update_time = time()*1000
255 | self.current = track
256 | # await self.trigger_event(TrackStartEvent(self, track))
257 |
258 | async def stop(self):
259 | """
260 | Sends a request to the Lavalink node to stop the current playing song
261 | :return:
262 | """
263 | payload = {
264 | "op": "stop",
265 | "guildId": str(self.link.guild_id),
266 | }
267 |
268 | node = await self.link.get_node()
269 | await node.send(payload)
270 |
271 | async def destroy(self):
272 | """
273 | Sends a request to the Lavalink node to destroy the player and reset
274 | :return:
275 | """
276 | payload = {
277 | "op": "destroy",
278 | "guildId": str(self.link.guild_id),
279 | }
280 | node = await self.link.get_node()
281 | if node and node.connected:
282 | await node.send(payload)
283 |
284 | if self.event_adapter:
285 | await self.event_adapter.destroy()
286 | self.event_adapter = None
287 |
288 | async def node_changed(self):
289 | if self.current:
290 | await self.play(self.current, self._position)
291 |
292 | if self.paused:
293 | await self.set_paused(True)
294 |
295 | if self.volume != 100:
296 | await self.set_volume(self.volume)
297 |
298 | async def trigger_event(self, event):
299 | await Player.internal_event_adapter.on_event(event)
300 | if self.event_adapter: # If we defined our on adapter
301 | try:
302 | await self.event_adapter.on_event(event)
303 | except:
304 | traceback.print_exc()
305 |
--------------------------------------------------------------------------------