├── .github └── ISSUE_TEMPLATE │ └── bug_report.md ├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── backend.py ├── config.py ├── galaxy ├── __init__.py ├── api │ ├── __init__.py │ ├── consts.py │ ├── errors.py │ ├── jsonrpc.py │ ├── plugin.py │ └── types.py ├── http.py ├── proc_tools.py ├── reader.py ├── task_manager.py ├── tools.py └── unittest │ ├── __init__.py │ └── mock.py ├── manifest.json ├── plugin.py ├── tests ├── __init__.py ├── conftest.py ├── test_achievements.py ├── test_authenticate.py ├── test_chunk_messages.py ├── test_features.py ├── test_friends.py ├── test_game_times.py ├── test_http.py ├── test_install_game.py ├── test_internal.py ├── test_launch_game.py ├── test_launch_platform_client.py ├── test_local_games.py ├── test_owned_games.py ├── test_persistent_cache.py ├── test_shutdown_platform_client.py ├── test_stream_line_reader.py └── test_uninstall_game.py ├── trophy.py ├── version.py └── yaml ├── __init__.py ├── composer.py ├── constructor.py ├── cyaml.py ├── dumper.py ├── emitter.py ├── error.py ├── events.py ├── loader.py ├── nodes.py ├── parser.py ├── reader.py ├── representer.py ├── resolver.py ├── scanner.py ├── serializer.py └── tokens.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help improve the plugin 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | 14 | **When does the bug occur?** 15 | Does it occur when first trying to connect the plugin, or some time after? 16 | 17 | **Screenshots** 18 | If applicable, add screenshots to help explain your problem. 19 | 20 | 21 | **Your setup** 22 | - OS and Version: 23 | - RPCS3 Version: 24 | - Game ID's causing trouble: 25 | - Python Version (if applicable): 26 | 27 | 28 | **Logs** 29 | Upload the following logs and files: 30 | - `plugin-vision-80F9D16B-5D72-4B95-9D46-2A1EF417C1FC.log`: 31 | - `GalaxyClient.log`: 32 | - `config.py`: 33 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | */__pycache__ 3 | 4 | *.lnk 5 | game_times.json -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "devita"] 2 | path = devita 3 | url = https://github.com/mpm11011/devita.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 mpm11011 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GOG Galaxy 2.0 Integration with RPCS3 2 | 3 | ## Supported Features 4 | * Finds all Disc and PSN games known to RPCS3. 5 | * Launches RPCS3 and loads games. 6 | * Tracks game time. 7 | * Imports and tracks trophies/achievements. 8 | 9 | ## Prerequisites 10 | 11 | * [RPCS3 v0.0.7](https://rpcs3.net/) or later 12 | 13 | ## Installation 14 | 15 | ### With git (from command line) 16 | 17 | - Windows: 18 | ``` 19 | cd %localappdata%\GOG.com\Galaxy\plugins\installed 20 | git clone https://github.com/mpm11011/galaxy-integration-rpcs3.git 21 | rename galaxy-integration-rpcs3 rpcs3_80F9D16B-5D72-4B95-9D46-2A1EF417C1FC 22 | cd rpcs3_80F9D16B-5D72-4B95-9D46-2A1EF417C1FC 23 | git submodule update --init 24 | ``` 25 | - macOS: 26 | ``` 27 | cd ~/Library/Application Support/GOG.com/Galaxy/plugins/installed 28 | git clone https://github.com/mpm11011/galaxy-integration-rpcs3.git 29 | mv galaxy-integration-rpcs3 rpcs3_80F9D16B-5D72-4B95-9D46-2A1EF417C1FC 30 | cd rpcs3_80F9D16B-5D72-4B95-9D46-2A1EF417C1FC 31 | git submodule update --init 32 | ``` 33 | 34 | ### Without git 35 | 36 | 1. Download a zip of this repository this directory and unpack it in: 37 | - Windows: `%localappdata%\GOG.com\Galaxy\plugins\installed` 38 | - macOS: `~/Library/Application Support/GOG.com/Galaxy/plugins/installed` 39 | 40 | 2. Download a zip of the [devita repository](https://github.com/mpm11011/devita.git) 41 | and unpack the contents into `galaxy-integration-rpcs3/devita`. 42 | 43 | 3. Rename the `galaxy-integration-rpcs3` folder to `rpcs3_80F9D16B-5D72-4B95-9D46-2A1EF417C1FC`. 44 | 45 | ## Configuration 46 | 47 | 1. Open `config.py` 48 | 49 | * Set the `main_directory` variable to your RPCS3 installation folder. Use single forward slashes only. 50 | 51 | * If some games don't launch from Galaxy, try setting `no_gui` to `False`. Some games don't work yet with no GUI in RPCS3. 52 | 53 | 2. Add your games in RPCS3 (File > Add Games), so that they can be added to the `games.yml` file. Running a game will also add it to this file. 54 | 55 | * This plugin reads from that file, so it is important that your games and their install directories end up here. 56 | 57 | 3. Game-specific launch configuration is handled by RPCS3. 58 | 59 | ## Acknowledgements 60 | 61 | [AHCoder](https://github.com/AHCoder) - Use of the [PS2 Galaxy plugin](https://github.com/AHCoder/galaxy-integration-ps2) as the basis for this plugin. 62 | 63 | [GOG](https://github.com/gogcom) - Use of the [Galaxy Plugin API](https://github.com/gogcom/galaxy-integrations-python-api) provided by GOG. 64 | 65 | [RPCS3 Team](https://github.com/RPCS3) - For help with interfacing with the RPCS3 system. 66 | 67 | [Marshall Ward](https://github.com/marshallward) - Use of the [SFO Python library](https://github.com/marshallward/devita) to open and read PARAM.SFO files. 68 | 69 | Thank you! -------------------------------------------------------------------------------- /backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import yaml 5 | import json 6 | 7 | from config import Config 8 | from devita.sfo import sfo 9 | 10 | class BackendClient: 11 | def __init__(self, config): 12 | self.config = config 13 | self.start_time = 0 14 | self.end_time = 0 15 | 16 | 17 | # Returns an array of Title ID, Title pairs. 18 | def get_games(self): 19 | 20 | # May still be useful if more info on a game is needed. 21 | # url = 'https://rpcs3.net/compatibility?api=v1&g=' 22 | 23 | results = [] 24 | game_paths = [] 25 | with open(self.config.games_yml) as games_file: 26 | 27 | games_yml = yaml.load(games_file, Loader=yaml.SafeLoader) 28 | if games_yml: 29 | for game in games_yml: 30 | game_paths.append(games_yml[game]) 31 | 32 | game_paths.append(self.config.game_directory) 33 | 34 | # This is probably crazy inefficient. 35 | for search_dir in game_paths: 36 | for root, dirs, files in os.walk(search_dir): 37 | for file in files: 38 | if file == 'EBOOT.BIN': 39 | 40 | # Search for EBOOT.BIN to find actual games, 41 | # then PARAM.SFO is one level up from EBOOT.BIN. 42 | bin_path = self.config.joinpath(root, file) 43 | sfo_path = bin_path.replace( 44 | os.path.join('USRDIR', 'EBOOT.BIN'), 45 | 'PARAM.SFO') 46 | 47 | # PARAM.SFO is read as a binary file, 48 | # so all keys must also be in binary. 49 | param_sfo = sfo(sfo_path) 50 | title_id = param_sfo.params[bytes('TITLE_ID', 'utf-8')] 51 | title = param_sfo.params[bytes('TITLE', 'utf-8')] 52 | 53 | # Convert results to strings before return. 54 | results.append([ 55 | str(title_id, 'utf-8'), 56 | str(title, 'utf-8')]) 57 | 58 | return results 59 | 60 | 61 | def get_game_path(self, game_id): 62 | 63 | with open(self.config.games_yml) as games_file: 64 | games_yml = yaml.load(games_file, Loader=yaml.SafeLoader) 65 | if games_yml: 66 | try: 67 | game_path = games_yml[game_id] 68 | game_path = self.config.joinpath(game_path, 'PS3_GAME') 69 | return game_path 70 | 71 | # If the game is not found in games.yml, we will search config.game_path. 72 | except KeyError: 73 | pass 74 | 75 | for folder in os.listdir(self.config.game_directory): 76 | check_path = self.config.joinpath(self.config.game_directory, folder) 77 | 78 | # Check that PARAM.SFO exists before loading it. 79 | sfo_path = self.config.joinpath(check_path, 'PARAM.SFO') 80 | if os.path.exists(sfo_path): 81 | param_sfo = sfo(sfo_path) 82 | 83 | # If PARAM.SFO contains game_id, we found the right path. 84 | if bytes(game_id, 'utf-8') in param_sfo.params[bytes('TITLE_ID', 'utf-8')]: 85 | return check_path 86 | 87 | # If we manage to get here, we should really raise an exception. 88 | raise FileNotFoundError 89 | 90 | 91 | def get_state_changes(self, old_list, new_list): 92 | old_dict = {x.game_id: x.local_game_state for x in old_list} 93 | new_dict = {x.game_id: x.local_game_state for x in new_list} 94 | result = [] 95 | 96 | # removed games 97 | result.extend(LocalGame(id, LocalGameState.None_) for id in old_dict.keys() - new_dict.keys()) 98 | 99 | # added games 100 | result.extend(local_game for local_game in new_list if local_game.game_id in new_dict.keys() - old_dict.keys()) 101 | 102 | # state changed 103 | result.extend(LocalGame(id, new_dict[id]) for id in new_dict.keys() & old_dict.keys() if new_dict[id] != old_dict[id]) 104 | 105 | return result 106 | 107 | 108 | def start_game_time(self): 109 | self.start_time = time.time() 110 | 111 | 112 | def end_game_time(self): 113 | self.end_time = time.time() 114 | 115 | 116 | def get_session_duration(self): 117 | delta = self.end_time - self.start_time 118 | minutes = int(round(delta / 60)) 119 | return minutes -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import yaml 4 | import logging 5 | 6 | class Config: 7 | def __init__(self): 8 | 9 | # Point this to the RPCS3 install directory. 10 | # Use single forward slashes only, like below. 11 | self.main_directory = 'D:/Files/Repositories/rpcs3/bin/' 12 | 13 | # Launch without the main RPCS3 window. 14 | # Turn this off if launching a game through Galaxy does nothing, 15 | # as some games don't work without the main window. 16 | self.no_gui = True 17 | 18 | # You should not have to modify anything beyond this point. 19 | 20 | 21 | 22 | # Just in case... 23 | self.main_directory = os.path.normpath(self.main_directory) 24 | 25 | # Important files that should be in main_directory. 26 | self.rpcs3_exe = self.joinpath(self.main_directory, 'rpcs3.exe') 27 | self.games_yml = self.joinpath(self.main_directory, 'games.yml') 28 | self.config_yml = self.joinpath(self.main_directory, 'config.yml') 29 | 30 | # Make sure these exist! 31 | self.check_files([ 32 | self.rpcs3_exe, 33 | self.games_yml, 34 | self.config_yml]) 35 | 36 | # By default dev_hdd0 is just below main_directory, 37 | # but since you can move it anywhere, we need to find it. 38 | self.dev_hdd0_directory = os.path.normpath(self.find_hdd0()) 39 | 40 | # games.yml holds the games you added, 41 | # but this is the main game directory (for PSN games, etc). 42 | self.game_directory = self.joinpath(self.dev_hdd0_directory, 'game') 43 | 44 | # Path to RPCS3's user directory and related files. 45 | self.user_directory = self.joinpath(self.dev_hdd0_directory, 'home', '00000001') 46 | self.trophy_directory = self.joinpath(self.user_directory, 'trophy') 47 | self.localusername = self.joinpath(self.user_directory, 'localusername') 48 | 49 | # Normalizes and joins paths for OS file handling. 50 | def joinpath(self, path, *paths): 51 | return os.path.normpath(os.path.join(path, *paths)) 52 | 53 | # Find dev_hdd0 from config.yml. 54 | def find_hdd0(self): 55 | hdd0 = None 56 | with open(self.config_yml, 'r') as config_file: 57 | 58 | config = yaml.load(config_file, Loader=yaml.SafeLoader) 59 | hdd0 = config['VFS']['/dev_hdd0/'] 60 | 61 | # I think it's a safe guess to call this main_directory. 62 | if '$(EmulatorDir)' in hdd0: 63 | hdd0 = self.joinpath( 64 | self.main_directory, 65 | hdd0.replace('$(EmulatorDir)', '')) 66 | 67 | return hdd0 68 | 69 | # Raises exception if required files don't exist. 70 | def check_files(self, files): 71 | should_raise = False 72 | 73 | for file in files: 74 | filepath = os.path.normpath(file) 75 | if not os.path.exists(filepath): 76 | logging.error(filepath + ' not found in RPCS3 directory.') 77 | should_raise = True 78 | 79 | if should_raise: 80 | raise FileNotFoundError() 81 | -------------------------------------------------------------------------------- /galaxy/__init__.py: -------------------------------------------------------------------------------- 1 | __path__: str = __import__('pkgutil').extend_path(__path__, __name__) 2 | -------------------------------------------------------------------------------- /galaxy/api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/massimilianodelliubaldini/galaxy-integration-rpcs3/e18319c503dc9a97c315375b9570821ce5636544/galaxy/api/__init__.py -------------------------------------------------------------------------------- /galaxy/api/consts.py: -------------------------------------------------------------------------------- 1 | from enum import Enum, Flag 2 | 3 | 4 | class Platform(Enum): 5 | """Supported gaming platforms""" 6 | Unknown = "unknown" 7 | Gog = "gog" 8 | Steam = "steam" 9 | Psn = "psn" 10 | XBoxOne = "xboxone" 11 | Generic = "generic" 12 | Origin = "origin" 13 | Uplay = "uplay" 14 | Battlenet = "battlenet" 15 | Epic = "epic" 16 | Bethesda = "bethesda" 17 | ParadoxPlaza = "paradox" 18 | HumbleBundle = "humble" 19 | Kartridge = "kartridge" 20 | ItchIo = "itch" 21 | NintendoSwitch = "nswitch" 22 | NintendoWiiU = "nwiiu" 23 | NintendoWii = "nwii" 24 | NintendoGameCube = "ncube" 25 | RiotGames = "riot" 26 | Wargaming = "wargaming" 27 | NintendoGameBoy = "ngameboy" 28 | Atari = "atari" 29 | Amiga = "amiga" 30 | SuperNintendoEntertainmentSystem = "snes" 31 | Beamdog = "beamdog" 32 | Direct2Drive = "d2d" 33 | Discord = "discord" 34 | DotEmu = "dotemu" 35 | GameHouse = "gamehouse" 36 | GreenManGaming = "gmg" 37 | WePlay = "weplay" 38 | ZxSpectrum = "zx" 39 | ColecoVision = "vision" 40 | NintendoEntertainmentSystem = "nes" 41 | SegaMasterSystem = "sms" 42 | Commodore64 = "c64" 43 | PcEngine = "pce" 44 | SegaGenesis = "segag" 45 | NeoGeo = "neo" 46 | Sega32X = "sega32" 47 | SegaCd = "segacd" 48 | _3Do = "3do" 49 | SegaSaturn = "saturn" 50 | PlayStation = "psx" 51 | PlayStation2 = "ps2" 52 | Nintendo64 = "n64" 53 | AtariJaguar = "jaguar" 54 | SegaDreamcast = "dc" 55 | Xbox = "xboxog" 56 | Amazon = "amazon" 57 | GamersGate = "gg" 58 | Newegg = "egg" 59 | BestBuy = "bb" 60 | GameUk = "gameuk" 61 | Fanatical = "fanatical" 62 | PlayAsia = "playasia" 63 | Stadia = "stadia" 64 | Arc = "arc" 65 | ElderScrollsOnline = "eso" 66 | Glyph = "glyph" 67 | AionLegionsOfWar = "aionl" 68 | Aion = "aion" 69 | BladeAndSoul = "blade" 70 | GuildWars = "gw" 71 | GuildWars2 = "gw2" 72 | Lineage2 = "lin2" 73 | FinalFantasy11 = "ffxi" 74 | FinalFantasy14 = "ffxiv" 75 | TotalWar = "totalwar" 76 | WindowsStore = "winstore" 77 | EliteDangerous = "elites" 78 | StarCitizen = "star" 79 | PlayStationPortable = "psp" 80 | PlayStationVita = "psvita" 81 | NintendoDs = "nds" 82 | Nintendo3Ds = "3ds" 83 | PathOfExile = "pathofexile" 84 | Twitch = "twitch" 85 | Minecraft = "minecraft" 86 | GameSessions = "gamesessions" 87 | Nuuvem = "nuuvem" 88 | FXStore = "fxstore" 89 | IndieGala = "indiegala" 90 | Playfire = "playfire" 91 | Oculus = "oculus" 92 | Test = "test" 93 | 94 | 95 | class Feature(Enum): 96 | """Possible features that can be implemented by an integration. 97 | It does not have to support all or any specific features from the list. 98 | """ 99 | Unknown = "Unknown" 100 | ImportInstalledGames = "ImportInstalledGames" 101 | ImportOwnedGames = "ImportOwnedGames" 102 | LaunchGame = "LaunchGame" 103 | InstallGame = "InstallGame" 104 | UninstallGame = "UninstallGame" 105 | ImportAchievements = "ImportAchievements" 106 | ImportGameTime = "ImportGameTime" 107 | Chat = "Chat" 108 | ImportUsers = "ImportUsers" 109 | VerifyGame = "VerifyGame" 110 | ImportFriends = "ImportFriends" 111 | ShutdownPlatformClient = "ShutdownPlatformClient" 112 | LaunchPlatformClient = "LaunchPlatformClient" 113 | 114 | 115 | class LicenseType(Enum): 116 | """Possible game license types, understandable for the GOG Galaxy client.""" 117 | Unknown = "Unknown" 118 | SinglePurchase = "SinglePurchase" 119 | FreeToPlay = "FreeToPlay" 120 | OtherUserLicense = "OtherUserLicense" 121 | 122 | 123 | class LocalGameState(Flag): 124 | """Possible states that a local game can be in. 125 | For example a game which is both installed and currently running should have its state set as a "bitwise or" of Running and Installed flags: 126 | ``local_game_state=`` 127 | """ 128 | None_ = 0 129 | Installed = 1 130 | Running = 2 131 | -------------------------------------------------------------------------------- /galaxy/api/errors.py: -------------------------------------------------------------------------------- 1 | from galaxy.api.jsonrpc import ApplicationError, UnknownError 2 | 3 | assert UnknownError 4 | 5 | class AuthenticationRequired(ApplicationError): 6 | def __init__(self, data=None): 7 | super().__init__(1, "Authentication required", data) 8 | 9 | class BackendNotAvailable(ApplicationError): 10 | def __init__(self, data=None): 11 | super().__init__(2, "Backend not available", data) 12 | 13 | class BackendTimeout(ApplicationError): 14 | def __init__(self, data=None): 15 | super().__init__(3, "Backend timed out", data) 16 | 17 | class BackendError(ApplicationError): 18 | def __init__(self, data=None): 19 | super().__init__(4, "Backend error", data) 20 | 21 | class UnknownBackendResponse(ApplicationError): 22 | def __init__(self, data=None): 23 | super().__init__(4, "Backend responded in uknown way", data) 24 | 25 | class TooManyRequests(ApplicationError): 26 | def __init__(self, data=None): 27 | super().__init__(5, "Too many requests. Try again later", data) 28 | 29 | class InvalidCredentials(ApplicationError): 30 | def __init__(self, data=None): 31 | super().__init__(100, "Invalid credentials", data) 32 | 33 | class NetworkError(ApplicationError): 34 | def __init__(self, data=None): 35 | super().__init__(101, "Network error", data) 36 | 37 | class LoggedInElsewhere(ApplicationError): 38 | def __init__(self, data=None): 39 | super().__init__(102, "Logged in elsewhere", data) 40 | 41 | class ProtocolError(ApplicationError): 42 | def __init__(self, data=None): 43 | super().__init__(103, "Protocol error", data) 44 | 45 | class TemporaryBlocked(ApplicationError): 46 | def __init__(self, data=None): 47 | super().__init__(104, "Temporary blocked", data) 48 | 49 | class Banned(ApplicationError): 50 | def __init__(self, data=None): 51 | super().__init__(105, "Banned", data) 52 | 53 | class AccessDenied(ApplicationError): 54 | def __init__(self, data=None): 55 | super().__init__(106, "Access denied", data) 56 | 57 | class FailedParsingManifest(ApplicationError): 58 | def __init__(self, data=None): 59 | super().__init__(200, "Failed parsing manifest", data) 60 | 61 | class TooManyMessagesSent(ApplicationError): 62 | def __init__(self, data=None): 63 | super().__init__(300, "Too many messages sent", data) 64 | 65 | class IncoherentLastMessage(ApplicationError): 66 | def __init__(self, data=None): 67 | super().__init__(400, "Different last message id on backend", data) 68 | 69 | class MessageNotFound(ApplicationError): 70 | def __init__(self, data=None): 71 | super().__init__(500, "Message not found", data) 72 | 73 | class ImportInProgress(ApplicationError): 74 | def __init__(self, data=None): 75 | super().__init__(600, "Import already in progress", data) 76 | -------------------------------------------------------------------------------- /galaxy/api/jsonrpc.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from collections import namedtuple 3 | from collections.abc import Iterable 4 | import logging 5 | import inspect 6 | import json 7 | 8 | from galaxy.reader import StreamLineReader 9 | from galaxy.task_manager import TaskManager 10 | 11 | class JsonRpcError(Exception): 12 | def __init__(self, code, message, data=None): 13 | self.code = code 14 | self.message = message 15 | self.data = data 16 | super().__init__() 17 | 18 | def __eq__(self, other): 19 | return self.code == other.code and self.message == other.message and self.data == other.data 20 | 21 | def json(self): 22 | obj = { 23 | "code": self.code, 24 | "message": self.message 25 | } 26 | 27 | if self.data is not None: 28 | obj["error"]["data"] = self.data 29 | 30 | return obj 31 | 32 | class ParseError(JsonRpcError): 33 | def __init__(self): 34 | super().__init__(-32700, "Parse error") 35 | 36 | class InvalidRequest(JsonRpcError): 37 | def __init__(self): 38 | super().__init__(-32600, "Invalid Request") 39 | 40 | class MethodNotFound(JsonRpcError): 41 | def __init__(self): 42 | super().__init__(-32601, "Method not found") 43 | 44 | class InvalidParams(JsonRpcError): 45 | def __init__(self): 46 | super().__init__(-32602, "Invalid params") 47 | 48 | class Timeout(JsonRpcError): 49 | def __init__(self): 50 | super().__init__(-32000, "Method timed out") 51 | 52 | class Aborted(JsonRpcError): 53 | def __init__(self): 54 | super().__init__(-32001, "Method aborted") 55 | 56 | class ApplicationError(JsonRpcError): 57 | def __init__(self, code, message, data): 58 | if code >= -32768 and code <= -32000: 59 | raise ValueError("The error code in reserved range") 60 | super().__init__(code, message, data) 61 | 62 | class UnknownError(ApplicationError): 63 | def __init__(self, data=None): 64 | super().__init__(0, "Unknown error", data) 65 | 66 | Request = namedtuple("Request", ["method", "params", "id"], defaults=[{}, None]) 67 | Method = namedtuple("Method", ["callback", "signature", "immediate", "sensitive_params"]) 68 | 69 | 70 | def anonymise_sensitive_params(params, sensitive_params): 71 | anomized_data = "****" 72 | 73 | if isinstance(sensitive_params, bool): 74 | if sensitive_params: 75 | return {k:anomized_data for k,v in params.items()} 76 | 77 | if isinstance(sensitive_params, Iterable): 78 | return {k: anomized_data if k in sensitive_params else v for k, v in params.items()} 79 | 80 | return params 81 | 82 | class Server(): 83 | def __init__(self, reader, writer, encoder=json.JSONEncoder()): 84 | self._active = True 85 | self._reader = StreamLineReader(reader) 86 | self._writer = writer 87 | self._encoder = encoder 88 | self._methods = {} 89 | self._notifications = {} 90 | self._task_manager = TaskManager("jsonrpc server") 91 | 92 | def register_method(self, name, callback, immediate, sensitive_params=False): 93 | """ 94 | Register method 95 | 96 | :param name: 97 | :param callback: 98 | :param internal: if True the callback will be processed immediately (synchronously) 99 | :param sensitive_params: list of parameters that are anonymized before logging; \ 100 | if False - no params are considered sensitive, if True - all params are considered sensitive 101 | """ 102 | self._methods[name] = Method(callback, inspect.signature(callback), immediate, sensitive_params) 103 | 104 | def register_notification(self, name, callback, immediate, sensitive_params=False): 105 | """ 106 | Register notification 107 | 108 | :param name: 109 | :param callback: 110 | :param internal: if True the callback will be processed immediately (synchronously) 111 | :param sensitive_params: list of parameters that are anonymized before logging; \ 112 | if False - no params are considered sensitive, if True - all params are considered sensitive 113 | """ 114 | self._notifications[name] = Method(callback, inspect.signature(callback), immediate, sensitive_params) 115 | 116 | async def run(self): 117 | while self._active: 118 | try: 119 | data = await self._reader.readline() 120 | if not data: 121 | self._eof() 122 | continue 123 | except: 124 | self._eof() 125 | continue 126 | data = data.strip() 127 | logging.debug("Received %d bytes of data", len(data)) 128 | self._handle_input(data) 129 | await asyncio.sleep(0) # To not starve task queue 130 | 131 | def close(self): 132 | logging.info("Closing JSON-RPC server - not more messages will be read") 133 | self._active = False 134 | 135 | async def wait_closed(self): 136 | await self._task_manager.wait() 137 | 138 | def _eof(self): 139 | logging.info("Received EOF") 140 | self.close() 141 | 142 | def _handle_input(self, data): 143 | try: 144 | request = self._parse_request(data) 145 | except JsonRpcError as error: 146 | self._send_error(None, error) 147 | return 148 | 149 | if request.id is not None: 150 | self._handle_request(request) 151 | else: 152 | self._handle_notification(request) 153 | 154 | def _handle_notification(self, request): 155 | method = self._notifications.get(request.method) 156 | if not method: 157 | logging.error("Received unknown notification: %s", request.method) 158 | return 159 | 160 | callback, signature, immediate, sensitive_params = method 161 | self._log_request(request, sensitive_params) 162 | 163 | try: 164 | bound_args = signature.bind(**request.params) 165 | except TypeError: 166 | self._send_error(request.id, InvalidParams()) 167 | 168 | if immediate: 169 | callback(*bound_args.args, **bound_args.kwargs) 170 | else: 171 | try: 172 | self._task_manager.create_task(callback(*bound_args.args, **bound_args.kwargs), request.method) 173 | except Exception: 174 | logging.exception("Unexpected exception raised in notification handler") 175 | 176 | def _handle_request(self, request): 177 | method = self._methods.get(request.method) 178 | if not method: 179 | logging.error("Received unknown request: %s", request.method) 180 | self._send_error(request.id, MethodNotFound()) 181 | return 182 | 183 | callback, signature, immediate, sensitive_params = method 184 | self._log_request(request, sensitive_params) 185 | 186 | try: 187 | bound_args = signature.bind(**request.params) 188 | except TypeError: 189 | self._send_error(request.id, InvalidParams()) 190 | 191 | if immediate: 192 | response = callback(*bound_args.args, **bound_args.kwargs) 193 | self._send_response(request.id, response) 194 | else: 195 | async def handle(): 196 | try: 197 | result = await callback(*bound_args.args, **bound_args.kwargs) 198 | self._send_response(request.id, result) 199 | except NotImplementedError: 200 | self._send_error(request.id, MethodNotFound()) 201 | except JsonRpcError as error: 202 | self._send_error(request.id, error) 203 | except asyncio.CancelledError: 204 | self._send_error(request.id, Aborted()) 205 | except Exception as e: #pylint: disable=broad-except 206 | logging.exception("Unexpected exception raised in plugin handler") 207 | self._send_error(request.id, UnknownError(str(e))) 208 | 209 | self._task_manager.create_task(handle(), request.method) 210 | 211 | @staticmethod 212 | def _parse_request(data): 213 | try: 214 | jsonrpc_request = json.loads(data, encoding="utf-8") 215 | if jsonrpc_request.get("jsonrpc") != "2.0": 216 | raise InvalidRequest() 217 | del jsonrpc_request["jsonrpc"] 218 | return Request(**jsonrpc_request) 219 | except json.JSONDecodeError: 220 | raise ParseError() 221 | except TypeError: 222 | raise InvalidRequest() 223 | 224 | def _send(self, data): 225 | try: 226 | line = self._encoder.encode(data) 227 | logging.debug("Sending data: %s", line) 228 | data = (line + "\n").encode("utf-8") 229 | self._writer.write(data) 230 | self._task_manager.create_task(self._writer.drain(), "drain") 231 | except TypeError as error: 232 | logging.error(str(error)) 233 | 234 | def _send_response(self, request_id, result): 235 | response = { 236 | "jsonrpc": "2.0", 237 | "id": request_id, 238 | "result": result 239 | } 240 | self._send(response) 241 | 242 | def _send_error(self, request_id, error): 243 | response = { 244 | "jsonrpc": "2.0", 245 | "id": request_id, 246 | "error": error.json() 247 | } 248 | 249 | self._send(response) 250 | 251 | @staticmethod 252 | def _log_request(request, sensitive_params): 253 | params = anonymise_sensitive_params(request.params, sensitive_params) 254 | if request.id is not None: 255 | logging.info("Handling request: id=%s, method=%s, params=%s", request.id, request.method, params) 256 | else: 257 | logging.info("Handling notification: method=%s, params=%s", request.method, params) 258 | 259 | class NotificationClient(): 260 | def __init__(self, writer, encoder=json.JSONEncoder()): 261 | self._writer = writer 262 | self._encoder = encoder 263 | self._methods = {} 264 | self._task_manager = TaskManager("notification client") 265 | 266 | def notify(self, method, params, sensitive_params=False): 267 | """ 268 | Send notification 269 | 270 | :param method: 271 | :param params: 272 | :param sensitive_params: list of parameters that are anonymized before logging; \ 273 | if False - no params are considered sensitive, if True - all params are considered sensitive 274 | """ 275 | notification = { 276 | "jsonrpc": "2.0", 277 | "method": method, 278 | "params": params 279 | } 280 | self._log(method, params, sensitive_params) 281 | self._send(notification) 282 | 283 | async def close(self): 284 | await self._task_manager.wait() 285 | 286 | def _send(self, data): 287 | try: 288 | line = self._encoder.encode(data) 289 | data = (line + "\n").encode("utf-8") 290 | logging.debug("Sending %d byte of data", len(data)) 291 | self._writer.write(data) 292 | self._task_manager.create_task(self._writer.drain(), "drain") 293 | except TypeError as error: 294 | logging.error("Failed to parse outgoing message: %s", str(error)) 295 | 296 | @staticmethod 297 | def _log(method, params, sensitive_params): 298 | params = anonymise_sensitive_params(params, sensitive_params) 299 | logging.info("Sending notification: method=%s, params=%s", method, params) 300 | -------------------------------------------------------------------------------- /galaxy/api/types.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Dict, Optional 3 | 4 | from galaxy.api.consts import LicenseType, LocalGameState 5 | 6 | @dataclass 7 | class Authentication(): 8 | """Return this from :meth:`.authenticate` or :meth:`.pass_login_credentials` 9 | to inform the client that authentication has successfully finished. 10 | 11 | :param user_id: id of the authenticated user 12 | :param user_name: username of the authenticated user 13 | """ 14 | user_id: str 15 | user_name: str 16 | 17 | @dataclass 18 | class Cookie(): 19 | """Cookie 20 | 21 | :param name: name of the cookie 22 | :param value: value of the cookie 23 | :param domain: optional domain of the cookie 24 | :param path: optional path of the cookie 25 | """ 26 | name: str 27 | value: str 28 | domain: Optional[str] = None 29 | path: Optional[str] = None 30 | 31 | @dataclass 32 | class NextStep(): 33 | """Return this from :meth:`.authenticate` or :meth:`.pass_login_credentials` to open client built-in browser with given url. 34 | For example: 35 | 36 | .. code-block:: python 37 | :linenos: 38 | 39 | PARAMS = { 40 | "window_title": "Login to platform", 41 | "window_width": 800, 42 | "window_height": 600, 43 | "start_uri": URL, 44 | "end_uri_regex": r"^https://platform_website\.com/.*" 45 | } 46 | 47 | JS = {r"^https://platform_website\.com/.*": [ 48 | r''' 49 | location.reload(); 50 | ''' 51 | ]} 52 | 53 | COOKIES = [Cookie("Cookie1", "ok", ".platform.com"), 54 | Cookie("Cookie2", "ok", ".platform.com") 55 | ] 56 | 57 | async def authenticate(self, stored_credentials=None): 58 | if not stored_credentials: 59 | return NextStep("web_session", PARAMS, cookies=COOKIES, js=JS) 60 | 61 | :param auth_params: configuration options: {"window_title": :class:`str`, "window_width": :class:`str`, "window_height": :class:`int`, "start_uri": :class:`int`, "end_uri_regex": :class:`str`} 62 | :param cookies: browser initial set of cookies 63 | :param js: a map of the url regex patterns into the list of *js* scripts that should be executed on every document at given step of internal browser authentication. 64 | """ 65 | next_step: str 66 | auth_params: Dict[str, str] 67 | cookies: Optional[List[Cookie]] = None 68 | js: Optional[Dict[str, List[str]]] = None 69 | 70 | @dataclass 71 | class LicenseInfo(): 72 | """Information about the license of related product. 73 | 74 | :param license_type: type of license 75 | :param owner: optional owner of the related product, defaults to currently authenticated user 76 | """ 77 | license_type: LicenseType 78 | owner: Optional[str] = None 79 | 80 | @dataclass 81 | class Dlc(): 82 | """Downloadable content object. 83 | 84 | :param dlc_id: id of the dlc 85 | :param dlc_title: title of the dlc 86 | :param license_info: information about the license attached to the dlc 87 | """ 88 | dlc_id: str 89 | dlc_title: str 90 | license_info: LicenseInfo 91 | 92 | @dataclass 93 | class Game(): 94 | """Game object. 95 | 96 | :param game_id: unique identifier of the game, this will be passed as parameter for methods such as launch_game 97 | :param game_title: title of the game 98 | :param dlcs: list of dlcs available for the game 99 | :param license_info: information about the license attached to the game 100 | """ 101 | game_id: str 102 | game_title: str 103 | dlcs: Optional[List[Dlc]] 104 | license_info: LicenseInfo 105 | 106 | @dataclass 107 | class Achievement(): 108 | """Achievement, has to be initialized with either id or name. 109 | 110 | :param unlock_time: unlock time of the achievement 111 | :param achievement_id: optional id of the achievement 112 | :param achievement_name: optional name of the achievement 113 | """ 114 | unlock_time: int 115 | achievement_id: Optional[str] = None 116 | achievement_name: Optional[str] = None 117 | 118 | def __post_init__(self): 119 | assert self.achievement_id or self.achievement_name, \ 120 | "One of achievement_id or achievement_name is required" 121 | 122 | @dataclass 123 | class LocalGame(): 124 | """Game locally present on the authenticated user's computer. 125 | 126 | :param game_id: id of the game 127 | :param local_game_state: state of the game 128 | """ 129 | game_id: str 130 | local_game_state: LocalGameState 131 | 132 | @dataclass 133 | class FriendInfo(): 134 | """Information about a friend of the currently authenticated user. 135 | 136 | :param user_id: id of the user 137 | :param user_name: username of the user 138 | """ 139 | user_id: str 140 | user_name: str 141 | 142 | @dataclass 143 | class GameTime(): 144 | """Game time of a game, defines the total time spent in the game 145 | and the last time the game was played. 146 | 147 | :param game_id: id of the related game 148 | :param time_played: the total time spent in the game in **minutes** 149 | :param last_time_played: last time the game was played (**unix timestamp**) 150 | """ 151 | game_id: str 152 | time_played: Optional[int] 153 | last_played_time: Optional[int] 154 | -------------------------------------------------------------------------------- /galaxy/http.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module standarize http traffic and the error handling for further communication with the GOG Galaxy 2.0. 3 | 4 | It is recommended to use provided convenient methods for HTTP requests, especially when dealing with authorized sessions. 5 | Examplary simple web service could looks like: 6 | 7 | .. code-block:: python 8 | 9 | import logging 10 | from galaxy.http import create_client_session, handle_exception 11 | 12 | class BackendClient: 13 | AUTH_URL = 'my-integration.com/auth' 14 | HEADERS = { 15 | "My-Custom-Header": "true", 16 | } 17 | def __init__(self): 18 | self._session = create_client_session(headers=self.HEADERS) 19 | 20 | async def authenticate(self): 21 | await self._session.request('POST', self.AUTH_URL) 22 | 23 | async def close(self): 24 | # to be called on plugin shutdown 25 | await self._session.close() 26 | 27 | async def _authorized_request(self, method, url, *args, **kwargs): 28 | with handle_exceptions(): 29 | return await self._session.request(method, url, *args, **kwargs) 30 | """ 31 | 32 | import asyncio 33 | import ssl 34 | from contextlib import contextmanager 35 | from http import HTTPStatus 36 | 37 | import aiohttp 38 | import certifi 39 | import logging 40 | 41 | from galaxy.api.errors import ( 42 | AccessDenied, AuthenticationRequired, BackendTimeout, BackendNotAvailable, BackendError, NetworkError, 43 | TooManyRequests, UnknownBackendResponse, UnknownError 44 | ) 45 | 46 | 47 | #: Default limit of the simultaneous connections for ssl connector. 48 | DEFAULT_LIMIT = 20 49 | #: Default timeout in seconds used for client session. 50 | DEFAULT_TIMEOUT = 60 51 | 52 | 53 | class HttpClient: 54 | """ 55 | .. deprecated:: 0.41 56 | Use http module functions instead 57 | """ 58 | def __init__(self, limit=DEFAULT_LIMIT, timeout=aiohttp.ClientTimeout(total=DEFAULT_TIMEOUT), cookie_jar=None): 59 | connector = create_tcp_connector(limit=limit) 60 | self._session = create_client_session(connector=connector, timeout=timeout, cookie_jar=cookie_jar) 61 | 62 | async def close(self): 63 | """Closes connection. Should be called in :meth:`~galaxy.api.plugin.Plugin.shutdown`""" 64 | await self._session.close() 65 | 66 | async def request(self, method, url, *args, **kwargs): 67 | with handle_exception(): 68 | return await self._session.request(method, url, *args, **kwargs) 69 | 70 | 71 | def create_tcp_connector(*args, **kwargs) -> aiohttp.TCPConnector: 72 | """ 73 | Creates TCP connector with resonable defaults. 74 | For details about available parameters refer to 75 | `aiohttp.TCPConnector `_ 76 | """ 77 | ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) 78 | ssl_context.load_verify_locations(certifi.where()) 79 | kwargs.setdefault("ssl", ssl_context) 80 | kwargs.setdefault("limit", DEFAULT_LIMIT) 81 | return aiohttp.TCPConnector(*args, **kwargs) # type: ignore due to https://github.com/python/mypy/issues/4001 82 | 83 | 84 | def create_client_session(*args, **kwargs) -> aiohttp.ClientSession: 85 | """ 86 | Creates client session with resonable defaults. 87 | For details about available parameters refer to 88 | `aiohttp.ClientSession `_ 89 | 90 | Examplary customization: 91 | 92 | .. code-block:: python 93 | 94 | from galaxy.http import create_client_session, create_tcp_connector 95 | 96 | session = create_client_session( 97 | headers={ 98 | "Keep-Alive": "true" 99 | }, 100 | connector=create_tcp_connector(limit=40), 101 | timeout=100) 102 | """ 103 | kwargs.setdefault("connector", create_tcp_connector()) 104 | kwargs.setdefault("timeout", aiohttp.ClientTimeout(total=DEFAULT_TIMEOUT)) 105 | kwargs.setdefault("raise_for_status", True) 106 | return aiohttp.ClientSession(*args, **kwargs) # type: ignore due to https://github.com/python/mypy/issues/4001 107 | 108 | 109 | @contextmanager 110 | def handle_exception(): 111 | """ 112 | Context manager translating network related exceptions 113 | to custom :mod:`~galaxy.api.errors`. 114 | """ 115 | try: 116 | yield 117 | except asyncio.TimeoutError: 118 | raise BackendTimeout() 119 | except aiohttp.ServerDisconnectedError: 120 | raise BackendNotAvailable() 121 | except aiohttp.ClientConnectionError: 122 | raise NetworkError() 123 | except aiohttp.ContentTypeError: 124 | raise UnknownBackendResponse() 125 | except aiohttp.ClientResponseError as error: 126 | if error.status == HTTPStatus.UNAUTHORIZED: 127 | raise AuthenticationRequired() 128 | if error.status == HTTPStatus.FORBIDDEN: 129 | raise AccessDenied() 130 | if error.status == HTTPStatus.SERVICE_UNAVAILABLE: 131 | raise BackendNotAvailable() 132 | if error.status == HTTPStatus.TOO_MANY_REQUESTS: 133 | raise TooManyRequests() 134 | if error.status >= 500: 135 | raise BackendError() 136 | if error.status >= 400: 137 | logging.warning( 138 | "Got status %d while performing %s request for %s", 139 | error.status, error.request_info.method, str(error.request_info.url) 140 | ) 141 | raise UnknownError() 142 | except aiohttp.ClientError: 143 | logging.exception("Caught exception while performing request") 144 | raise UnknownError() 145 | -------------------------------------------------------------------------------- /galaxy/proc_tools.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from dataclasses import dataclass 3 | from typing import Iterable, NewType, Optional, List, cast 4 | 5 | 6 | 7 | ProcessId = NewType("ProcessId", int) 8 | 9 | 10 | @dataclass 11 | class ProcessInfo: 12 | pid: ProcessId 13 | binary_path: Optional[str] 14 | 15 | 16 | if sys.platform == "win32": 17 | from ctypes import byref, sizeof, windll, create_unicode_buffer, FormatError, WinError 18 | from ctypes.wintypes import DWORD 19 | 20 | 21 | def pids() -> Iterable[ProcessId]: 22 | _PROC_ID_T = DWORD 23 | list_size = 4096 24 | 25 | def try_get_pids(list_size: int) -> List[ProcessId]: 26 | result_size = DWORD() 27 | proc_id_list = (_PROC_ID_T * list_size)() 28 | 29 | if not windll.psapi.EnumProcesses(byref(proc_id_list), sizeof(proc_id_list), byref(result_size)): 30 | raise WinError(descr="Failed to get process ID list: %s" % FormatError()) # type: ignore 31 | 32 | return cast(List[ProcessId], proc_id_list[:int(result_size.value / sizeof(_PROC_ID_T()))]) 33 | 34 | while True: 35 | proc_ids = try_get_pids(list_size) 36 | if len(proc_ids) < list_size: 37 | return proc_ids 38 | 39 | list_size *= 2 40 | 41 | 42 | def get_process_info(pid: ProcessId) -> Optional[ProcessInfo]: 43 | _PROC_QUERY_LIMITED_INFORMATION = 0x1000 44 | 45 | process_info = ProcessInfo(pid=pid, binary_path=None) 46 | 47 | h_process = windll.kernel32.OpenProcess(_PROC_QUERY_LIMITED_INFORMATION, False, pid) 48 | if not h_process: 49 | return process_info 50 | 51 | try: 52 | def get_exe_path() -> Optional[str]: 53 | _MAX_PATH = 260 54 | _WIN32_PATH_FORMAT = 0x0000 55 | 56 | exe_path_buffer = create_unicode_buffer(_MAX_PATH) 57 | exe_path_len = DWORD(len(exe_path_buffer)) 58 | 59 | return cast(str, exe_path_buffer[:exe_path_len.value]) if windll.kernel32.QueryFullProcessImageNameW( 60 | h_process, _WIN32_PATH_FORMAT, exe_path_buffer, byref(exe_path_len) 61 | ) else None 62 | 63 | process_info.binary_path = get_exe_path() 64 | finally: 65 | windll.kernel32.CloseHandle(h_process) 66 | return process_info 67 | else: 68 | import psutil 69 | 70 | 71 | def pids() -> Iterable[ProcessId]: 72 | for pid in psutil.pids(): 73 | yield pid 74 | 75 | 76 | def get_process_info(pid: ProcessId) -> Optional[ProcessInfo]: 77 | process_info = ProcessInfo(pid=pid, binary_path=None) 78 | try: 79 | process_info.binary_path = psutil.Process(pid=pid).as_dict(attrs=["exe"])["exe"] 80 | except psutil.NoSuchProcess: 81 | pass 82 | finally: 83 | return process_info 84 | 85 | 86 | def process_iter() -> Iterable[Optional[ProcessInfo]]: 87 | for pid in pids(): 88 | yield get_process_info(pid) 89 | -------------------------------------------------------------------------------- /galaxy/reader.py: -------------------------------------------------------------------------------- 1 | from asyncio import StreamReader 2 | 3 | 4 | class StreamLineReader: 5 | """Handles StreamReader readline without buffer limit""" 6 | def __init__(self, reader: StreamReader): 7 | self._reader = reader 8 | self._buffer = bytes() 9 | self._processed_buffer_it = 0 10 | 11 | async def readline(self): 12 | while True: 13 | # check if there is no unprocessed data in the buffer 14 | if not self._buffer or self._processed_buffer_it != 0: 15 | chunk = await self._reader.read(1024) 16 | if not chunk: 17 | return bytes() # EOF 18 | self._buffer += chunk 19 | 20 | it = self._buffer.find(b"\n", self._processed_buffer_it) 21 | if it < 0: 22 | self._processed_buffer_it = len(self._buffer) 23 | continue 24 | 25 | line = self._buffer[:it] 26 | self._buffer = self._buffer[it+1:] 27 | self._processed_buffer_it = 0 28 | return line 29 | -------------------------------------------------------------------------------- /galaxy/task_manager.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | from collections import OrderedDict 4 | from itertools import count 5 | 6 | class TaskManager: 7 | def __init__(self, name): 8 | self._name = name 9 | self._tasks = OrderedDict() 10 | self._task_counter = count() 11 | 12 | def create_task(self, coro, description, handle_exceptions=True): 13 | """Wrapper around asyncio.create_task - takes care of canceling tasks on shutdown""" 14 | 15 | async def task_wrapper(task_id): 16 | try: 17 | result = await coro 18 | logging.debug("Task manager %s: finished task %d (%s)", self._name, task_id, description) 19 | return result 20 | except asyncio.CancelledError: 21 | if handle_exceptions: 22 | logging.debug("Task manager %s: canceled task %d (%s)", self._name, task_id, description) 23 | else: 24 | raise 25 | except Exception: 26 | if handle_exceptions: 27 | logging.exception("Task manager %s: exception raised in task %d (%s)", self._name, task_id, description) 28 | else: 29 | raise 30 | finally: 31 | del self._tasks[task_id] 32 | 33 | task_id = next(self._task_counter) 34 | logging.debug("Task manager %s: creating task %d (%s)", self._name, task_id, description) 35 | task = asyncio.create_task(task_wrapper(task_id)) 36 | self._tasks[task_id] = task 37 | return task 38 | 39 | def cancel(self): 40 | for task in self._tasks.values(): 41 | task.cancel() 42 | 43 | async def wait(self): 44 | # Tasks can spawn other tasks 45 | while True: 46 | tasks = self._tasks.values() 47 | if not tasks: 48 | return 49 | await asyncio.gather(*tasks, return_exceptions=True) 50 | -------------------------------------------------------------------------------- /galaxy/tools.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import zipfile 4 | from glob import glob 5 | 6 | 7 | def zip_folder(folder): 8 | files = glob(os.path.join(folder, "**"), recursive=True) 9 | files = [file.replace(folder + os.sep, "") for file in files] 10 | files = [file for file in files if file] 11 | 12 | zip_buffer = io.BytesIO() 13 | with zipfile.ZipFile(zip_buffer, mode="w", compression=zipfile.ZIP_DEFLATED) as zipf: 14 | for file in files: 15 | zipf.write(os.path.join(folder, file), arcname=file) 16 | return zip_buffer 17 | 18 | 19 | def zip_folder_to_file(folder, filename): 20 | zip_content = zip_folder(folder).getbuffer() 21 | with open(filename, "wb") as archive: 22 | archive.write(zip_content) 23 | -------------------------------------------------------------------------------- /galaxy/unittest/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/massimilianodelliubaldini/galaxy-integration-rpcs3/e18319c503dc9a97c315375b9570821ce5636544/galaxy/unittest/__init__.py -------------------------------------------------------------------------------- /galaxy/unittest/mock.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from unittest.mock import MagicMock 3 | 4 | 5 | class AsyncMock(MagicMock): 6 | """ 7 | .. deprecated:: 0.45 8 | Use: :class:`MagicMock` with meth:`~.async_return_value`. 9 | """ 10 | async def __call__(self, *args, **kwargs): 11 | return super(AsyncMock, self).__call__(*args, **kwargs) 12 | 13 | 14 | def coroutine_mock(): 15 | """ 16 | .. deprecated:: 0.45 17 | Use: :class:`MagicMock` with meth:`~.async_return_value`. 18 | """ 19 | coro = MagicMock(name="CoroutineResult") 20 | corofunc = MagicMock(name="CoroutineFunction", side_effect=asyncio.coroutine(coro)) 21 | corofunc.coro = coro 22 | return corofunc 23 | 24 | async def skip_loop(iterations=1): 25 | for _ in range(iterations): 26 | await asyncio.sleep(0) 27 | 28 | 29 | async def async_return_value(return_value, loop_iterations_delay=0): 30 | await skip_loop(loop_iterations_delay) 31 | return return_value 32 | -------------------------------------------------------------------------------- /manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Galaxy RPCS3 Plugin", 3 | "platform": "vision", 4 | "guid": "80F9D16B-5D72-4B95-9D46-2A1EF417C1FC", 5 | "version": "0.63", 6 | "description": "GOG Galaxy Integration with RPCS3 Emulator", 7 | "author": "mpm11011", 8 | "email": "markuscicero5@gmail.com", 9 | "url": "https://github.com/mpm11011/galaxy-integration-rpcs3", 10 | "script": "plugin.py" 11 | } -------------------------------------------------------------------------------- /plugin.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import subprocess 4 | import sys 5 | import os 6 | import time 7 | 8 | import trophy 9 | from trophy import Trophy 10 | from config import Config 11 | from backend import BackendClient 12 | from version import get_version 13 | 14 | from galaxy.api.consts import LicenseType, LocalGameState, Platform 15 | from galaxy.api.plugin import Plugin, create_and_run_plugin 16 | from galaxy.api.types import Authentication, Game, GameTime, LicenseInfo, LocalGame, Achievement 17 | 18 | 19 | class RPCS3Plugin(Plugin): 20 | def __init__(self, reader, writer, token): 21 | super().__init__(Platform.ColecoVision, get_version(), reader, writer, token) 22 | try: 23 | self.config = Config() # If we can't create a good config, we can't run the plugin. 24 | except FileNotFoundError: 25 | self.close() 26 | else: 27 | self.backend_client = BackendClient(self.config) 28 | self.games = [] 29 | self.local_games_cache = self.local_games_list() 30 | self.process = None 31 | self.running_game_id = None 32 | 33 | 34 | async def authenticate(self, stored_credentials=None): 35 | return self.do_auth() 36 | 37 | 38 | async def pass_login_credentials(self, step, credentials, cookies): 39 | return self.do_auth() 40 | 41 | 42 | def do_auth(self): 43 | 44 | username = '' 45 | with open(self.config.localusername, 'r') as username_file: 46 | username = username_file.read() 47 | 48 | user_data = {} 49 | user_data['username'] = username 50 | self.store_credentials(user_data) 51 | return Authentication('rpcs3_user', user_data['username']) 52 | 53 | 54 | async def launch_game(self, game_id): 55 | 56 | args = [] 57 | eboot_bin = self.config.joinpath( 58 | self.backend_client.get_game_path(game_id), 59 | 'USRDIR', 60 | 'EBOOT.BIN') 61 | 62 | if self.config.no_gui: 63 | args.append('--no-gui') 64 | 65 | command = [self.config.rpcs3_exe, eboot_bin] + args 66 | self.process = subprocess.Popen(command) 67 | self.backend_client.start_game_time() 68 | self.running_game_id = game_id 69 | 70 | return 71 | 72 | 73 | # Only as placeholders so the feature is recognized 74 | async def install_game(self, game_id): 75 | pass 76 | 77 | async def uninstall_game(self, game_id): 78 | pass 79 | 80 | 81 | async def prepare_game_times_context(self, game_ids): 82 | return self.get_game_times(game_ids) 83 | 84 | 85 | async def prepare_achievements_context(self, game_ids): 86 | return self.get_trophy_achs() 87 | 88 | 89 | async def get_game_time(self, game_id, context): 90 | game_time = context.get(game_id) 91 | return game_time 92 | 93 | 94 | async def get_unlocked_achievements(self, game_id, context): 95 | achs = context.get(game_id) 96 | return achs 97 | 98 | 99 | def get_game_times(self, game_ids): 100 | 101 | # Get the path of the game times file. 102 | base_path = os.path.dirname(os.path.realpath(__file__)) 103 | game_times_path = os.path.join(base_path, 'game_times.json') 104 | game_times = {} 105 | 106 | # If the file does not exist, create it with default values. 107 | if not os.path.exists(game_times_path): 108 | for game in self.games: 109 | 110 | game_id = str(game[0]) 111 | game_times[game_id] = GameTime(game_id, 0, None) 112 | 113 | with open(game_times_path, 'w', encoding='utf-8') as game_times_file: 114 | json.dump(game_times, game_times_file, indent=4) 115 | 116 | # If (when) the file exists, read it and return the game times. 117 | with open(game_times_path, 'r', encoding='utf-8') as game_times_file: 118 | game_times_json = json.load(game_times_file) 119 | 120 | for game_id in game_times_json: 121 | if game_id in game_ids: 122 | time_played = game_times_json.get(game_id).get('time_played') 123 | last_time_played = game_times_json.get(game_id).get('last_time_played') 124 | 125 | game_times[game_id] = GameTime(game_id, time_played, last_time_played) 126 | 127 | return game_times 128 | 129 | def get_trophy_achs(self): 130 | 131 | game_ids = [] 132 | for game in self.games: 133 | game_ids.append(game[0]) 134 | 135 | trophies = None 136 | all_achs = {} 137 | for game_id in game_ids: 138 | game_path = self.backend_client.get_game_path(game_id) 139 | 140 | try: 141 | trophies = Trophy(self.config, game_path) 142 | keys = trophies.tropusr.table6.keys() 143 | 144 | game_achs = [] 145 | for key in keys: 146 | ach = trophies.trop2ach(key) 147 | if ach is not None: 148 | game_achs.append(trophies.trop2ach(key)) 149 | 150 | all_achs[game_id] = game_achs 151 | 152 | # If tropusr doesn't exist, this game has no trophies. 153 | except AttributeError: 154 | all_achs[game_id] = [] 155 | 156 | return all_achs 157 | 158 | 159 | def tick(self): 160 | try: 161 | if self.process.poll() is not None: 162 | 163 | self.backend_client.end_game_time() 164 | self.update_json_game_time( 165 | self.running_game_id, 166 | self.backend_client.get_session_duration(), 167 | int(time.time())) 168 | 169 | # Only update recently played games. Updating all game times every second fills up log way too quickly. 170 | self.create_task(self.update_galaxy_game_times(self.running_game_id), 'Update Galaxy game times') 171 | 172 | self.process = None 173 | self.running_game_id = None 174 | 175 | except AttributeError: 176 | pass 177 | 178 | self.create_task(self.update_local_games(), 'Update local games') 179 | self.create_task(self.update_achievements(), 'Update achievements') 180 | 181 | 182 | async def update_local_games(self): 183 | loop = asyncio.get_running_loop() 184 | 185 | new_list = await loop.run_in_executor(None, self.local_games_list) 186 | notify_list = self.backend_client.get_state_changes(self.local_games_cache, new_list) 187 | self.local_games_cache = new_list 188 | 189 | for local_game_notify in notify_list: 190 | self.update_local_game_status(local_game_notify) 191 | 192 | 193 | async def update_galaxy_game_times(self, game_id): 194 | 195 | # Leave time for Galaxy to fetch games before updating times 196 | await asyncio.sleep(60) 197 | loop = asyncio.get_running_loop() 198 | 199 | game_times = await loop.run_in_executor(None, self.get_game_times, [game_id]) 200 | for game_id in game_times: 201 | self.update_game_time(game_times[game_id]) 202 | 203 | 204 | async def update_achievements(self): 205 | 206 | # Leave time for Galaxy to fetch games before updating times 207 | await asyncio.sleep(60) 208 | loop = asyncio.get_running_loop() 209 | 210 | achs = await loop.run_in_executor(None, self.get_trophy_achs) 211 | # for ach in achs: 212 | # self.unlock_achievement(ach) # TODO - how/when to handle this? 213 | 214 | 215 | def update_json_game_time(self, game_id, duration, last_time_played): 216 | 217 | # Get the path of the game times file. 218 | base_path = os.path.dirname(os.path.realpath(__file__)) 219 | game_times_path = '{}/game_times.json'.format(base_path) 220 | game_times_json = None 221 | 222 | with open(game_times_path, 'r', encoding='utf-8') as game_times_file: 223 | game_times_json = json.load(game_times_file) 224 | 225 | old_time_played = game_times_json.get(game_id).get('time_played') 226 | new_time_played = old_time_played + duration 227 | 228 | game_times_json[game_id]['time_played'] = new_time_played 229 | game_times_json[game_id]['last_time_played'] = last_time_played 230 | 231 | with open(game_times_path, 'w', encoding='utf-8') as game_times_file: 232 | json.dump(game_times_json, game_times_file, indent=4) 233 | 234 | self.update_game_time(GameTime(game_id, new_time_played, last_time_played)) 235 | 236 | 237 | def local_games_list(self): 238 | local_games = [] 239 | 240 | for game in self.games: 241 | local_games.append(LocalGame( 242 | game[0], 243 | LocalGameState.Installed)) 244 | 245 | return local_games 246 | 247 | 248 | async def get_owned_games(self): 249 | self.games = self.backend_client.get_games() 250 | owned_games = [] 251 | 252 | for game in self.games: 253 | owned_games.append(Game( 254 | game[0], 255 | game[1], 256 | None, 257 | LicenseInfo(LicenseType.SinglePurchase, None))) 258 | 259 | return owned_games 260 | 261 | 262 | async def get_local_games(self): 263 | return self.local_games_cache 264 | 265 | 266 | def main(): 267 | create_and_run_plugin(RPCS3Plugin, sys.argv) 268 | 269 | 270 | # run plugin event loop 271 | if __name__ == '__main__': 272 | main() 273 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | def create_message(request): 5 | return json.dumps(request).encode() + b"\n" 6 | 7 | 8 | def get_messages(write_mock): 9 | messages = [] 10 | for call_args in write_mock.call_args_list: 11 | data = call_args[0][0] 12 | for line in data.splitlines(): 13 | message = json.loads(line) 14 | messages.append(message) 15 | return messages 16 | 17 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | from contextlib import ExitStack 2 | import logging 3 | from unittest.mock import patch, MagicMock 4 | 5 | import pytest 6 | 7 | from galaxy.api.plugin import Plugin 8 | from galaxy.api.consts import Platform 9 | from galaxy.unittest.mock import async_return_value 10 | 11 | @pytest.fixture() 12 | def reader(): 13 | stream = MagicMock(name="stream_reader") 14 | stream.read = MagicMock() 15 | yield stream 16 | 17 | @pytest.fixture() 18 | async def writer(): 19 | stream = MagicMock(name="stream_writer") 20 | stream.drain.side_effect = lambda: async_return_value(None) 21 | yield stream 22 | 23 | @pytest.fixture() 24 | def read(reader): 25 | yield reader.read 26 | 27 | @pytest.fixture() 28 | def write(writer): 29 | yield writer.write 30 | 31 | @pytest.fixture() 32 | async def plugin(reader, writer): 33 | """Return plugin instance with all feature methods mocked""" 34 | methods = ( 35 | "handshake_complete", 36 | "authenticate", 37 | "get_owned_games", 38 | "prepare_achievements_context", 39 | "get_unlocked_achievements", 40 | "achievements_import_complete", 41 | "get_local_games", 42 | "launch_game", 43 | "launch_platform_client", 44 | "install_game", 45 | "uninstall_game", 46 | "get_friends", 47 | "get_game_time", 48 | "prepare_game_times_context", 49 | "game_times_import_complete", 50 | "shutdown_platform_client", 51 | "shutdown", 52 | "tick" 53 | ) 54 | 55 | with ExitStack() as stack: 56 | for method in methods: 57 | stack.enter_context(patch.object(Plugin, method)) 58 | 59 | async with Plugin(Platform.Generic, "0.1", reader, writer, "token") as plugin: 60 | plugin.shutdown.return_value = async_return_value(None) 61 | yield plugin 62 | 63 | 64 | @pytest.fixture(autouse=True) 65 | def my_caplog(caplog): 66 | caplog.set_level(logging.DEBUG) 67 | -------------------------------------------------------------------------------- /tests/test_achievements.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import pytest 4 | from pytest import raises 5 | 6 | from galaxy.api.types import Achievement 7 | from galaxy.api.errors import BackendError 8 | from galaxy.unittest.mock import async_return_value 9 | 10 | from tests import create_message, get_messages 11 | 12 | 13 | def test_initialization_no_unlock_time(): 14 | with raises(Exception): 15 | Achievement(achievement_id="lvl30", achievement_name="Got level 30") 16 | 17 | 18 | def test_initialization_no_id_nor_name(): 19 | with raises(AssertionError): 20 | Achievement(unlock_time=1234567890) 21 | 22 | 23 | @pytest.mark.asyncio 24 | async def test_get_unlocked_achievements_success(plugin, read, write): 25 | plugin.prepare_achievements_context.return_value = async_return_value(5) 26 | request = { 27 | "jsonrpc": "2.0", 28 | "id": "3", 29 | "method": "start_achievements_import", 30 | "params": { 31 | "game_ids": ["14"] 32 | } 33 | } 34 | read.side_effect = [async_return_value(create_message(request)), async_return_value(b"", 10)] 35 | plugin.get_unlocked_achievements.return_value = async_return_value([ 36 | Achievement(achievement_id="lvl10", unlock_time=1548421241), 37 | Achievement(achievement_name="Got level 20", unlock_time=1548422395), 38 | Achievement(achievement_id="lvl30", achievement_name="Got level 30", unlock_time=1548495633) 39 | ]) 40 | await plugin.run() 41 | plugin.prepare_achievements_context.assert_called_with(["14"]) 42 | plugin.get_unlocked_achievements.assert_called_with("14", 5) 43 | plugin.achievements_import_complete.asert_called_with() 44 | 45 | assert get_messages(write) == [ 46 | { 47 | "jsonrpc": "2.0", 48 | "id": "3", 49 | "result": None 50 | }, 51 | { 52 | "jsonrpc": "2.0", 53 | "method": "game_achievements_import_success", 54 | "params": { 55 | "game_id": "14", 56 | "unlocked_achievements": [ 57 | { 58 | "achievement_id": "lvl10", 59 | "unlock_time": 1548421241 60 | }, 61 | { 62 | "achievement_name": "Got level 20", 63 | "unlock_time": 1548422395 64 | }, 65 | { 66 | "achievement_id": "lvl30", 67 | "achievement_name": "Got level 30", 68 | "unlock_time": 1548495633 69 | } 70 | ] 71 | } 72 | }, 73 | { 74 | "jsonrpc": "2.0", 75 | "method": "achievements_import_finished", 76 | "params": None 77 | } 78 | ] 79 | 80 | 81 | @pytest.mark.asyncio 82 | @pytest.mark.parametrize("exception,code,message", [ 83 | (BackendError, 4, "Backend error"), 84 | (KeyError, 0, "Unknown error") 85 | ]) 86 | async def test_get_unlocked_achievements_error(exception, code, message, plugin, read, write): 87 | plugin.prepare_achievements_context.return_value = async_return_value(None) 88 | request = { 89 | "jsonrpc": "2.0", 90 | "id": "3", 91 | "method": "start_achievements_import", 92 | "params": { 93 | "game_ids": ["14"] 94 | } 95 | } 96 | 97 | read.side_effect = [async_return_value(create_message(request)), async_return_value(b"", 10)] 98 | plugin.get_unlocked_achievements.side_effect = exception 99 | await plugin.run() 100 | plugin.get_unlocked_achievements.assert_called() 101 | plugin.achievements_import_complete.asert_called_with() 102 | 103 | assert get_messages(write) == [ 104 | { 105 | "jsonrpc": "2.0", 106 | "id": "3", 107 | "result": None 108 | }, 109 | { 110 | "jsonrpc": "2.0", 111 | "method": "game_achievements_import_failure", 112 | "params": { 113 | "game_id": "14", 114 | "error": { 115 | "code": code, 116 | "message": message 117 | } 118 | } 119 | }, 120 | { 121 | "jsonrpc": "2.0", 122 | "method": "achievements_import_finished", 123 | "params": None 124 | } 125 | ] 126 | 127 | @pytest.mark.asyncio 128 | async def test_prepare_get_unlocked_achievements_context_error(plugin, read, write): 129 | plugin.prepare_achievements_context.side_effect = BackendError() 130 | request = { 131 | "jsonrpc": "2.0", 132 | "id": "3", 133 | "method": "start_achievements_import", 134 | "params": { 135 | "game_ids": ["14"] 136 | } 137 | } 138 | read.side_effect = [async_return_value(create_message(request)), async_return_value(b"", 10)] 139 | 140 | await plugin.run() 141 | 142 | assert get_messages(write) == [ 143 | { 144 | "jsonrpc": "2.0", 145 | "id": "3", 146 | "error": { 147 | "code": 4, 148 | "message": "Backend error" 149 | } 150 | } 151 | ] 152 | 153 | 154 | @pytest.mark.asyncio 155 | async def test_import_in_progress(plugin, read, write): 156 | plugin.prepare_achievements_context.return_value = async_return_value(None) 157 | plugin.get_unlocked_achievements.return_value = async_return_value([]) 158 | requests = [ 159 | { 160 | "jsonrpc": "2.0", 161 | "id": "3", 162 | "method": "start_achievements_import", 163 | "params": { 164 | "game_ids": ["14"] 165 | } 166 | }, 167 | { 168 | "jsonrpc": "2.0", 169 | "id": "4", 170 | "method": "start_achievements_import", 171 | "params": { 172 | "game_ids": ["15"] 173 | } 174 | } 175 | ] 176 | read.side_effect = [ 177 | async_return_value(create_message(requests[0])), 178 | async_return_value(create_message(requests[1])), 179 | async_return_value(b"", 10) 180 | ] 181 | 182 | await plugin.run() 183 | 184 | messages = get_messages(write) 185 | assert { 186 | "jsonrpc": "2.0", 187 | "id": "3", 188 | "result": None 189 | } in messages 190 | assert { 191 | "jsonrpc": "2.0", 192 | "id": "4", 193 | "error": { 194 | "code": 600, 195 | "message": "Import already in progress" 196 | } 197 | } in messages 198 | 199 | 200 | @pytest.mark.asyncio 201 | async def test_unlock_achievement(plugin, write): 202 | achievement = Achievement(achievement_id="lvl20", unlock_time=1548422395) 203 | plugin.unlock_achievement("14", achievement) 204 | response = json.loads(write.call_args[0][0]) 205 | 206 | assert response == { 207 | "jsonrpc": "2.0", 208 | "method": "achievement_unlocked", 209 | "params": { 210 | "game_id": "14", 211 | "achievement": { 212 | "achievement_id": "lvl20", 213 | "unlock_time": 1548422395 214 | } 215 | } 216 | } 217 | -------------------------------------------------------------------------------- /tests/test_authenticate.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from galaxy.api.types import Authentication 4 | from galaxy.api.errors import ( 5 | UnknownError, InvalidCredentials, NetworkError, LoggedInElsewhere, ProtocolError, 6 | BackendNotAvailable, BackendTimeout, BackendError, TemporaryBlocked, Banned, AccessDenied 7 | ) 8 | from galaxy.unittest.mock import async_return_value 9 | 10 | from tests import create_message, get_messages 11 | 12 | 13 | @pytest.mark.asyncio 14 | async def test_success(plugin, read, write): 15 | request = { 16 | "jsonrpc": "2.0", 17 | "id": "3", 18 | "method": "init_authentication" 19 | } 20 | read.side_effect = [async_return_value(create_message(request)), async_return_value(b"", 10)] 21 | plugin.authenticate.return_value = async_return_value(Authentication("132", "Zenek")) 22 | await plugin.run() 23 | plugin.authenticate.assert_called_with() 24 | 25 | assert get_messages(write) == [ 26 | { 27 | "jsonrpc": "2.0", 28 | "id": "3", 29 | "result": { 30 | "user_id": "132", 31 | "user_name": "Zenek" 32 | } 33 | } 34 | ] 35 | 36 | 37 | @pytest.mark.asyncio 38 | @pytest.mark.parametrize("error,code,message", [ 39 | pytest.param(UnknownError, 0, "Unknown error", id="unknown_error"), 40 | pytest.param(BackendNotAvailable, 2, "Backend not available", id="backend_not_available"), 41 | pytest.param(BackendTimeout, 3, "Backend timed out", id="backend_timeout"), 42 | pytest.param(BackendError, 4, "Backend error", id="backend_error"), 43 | pytest.param(InvalidCredentials, 100, "Invalid credentials", id="invalid_credentials"), 44 | pytest.param(NetworkError, 101, "Network error", id="network_error"), 45 | pytest.param(LoggedInElsewhere, 102, "Logged in elsewhere", id="logged_elsewhere"), 46 | pytest.param(ProtocolError, 103, "Protocol error", id="protocol_error"), 47 | pytest.param(TemporaryBlocked, 104, "Temporary blocked", id="temporary_blocked"), 48 | pytest.param(Banned, 105, "Banned", id="banned"), 49 | pytest.param(AccessDenied, 106, "Access denied", id="access_denied"), 50 | ]) 51 | async def test_failure(plugin, read, write, error, code, message): 52 | request = { 53 | "jsonrpc": "2.0", 54 | "id": "3", 55 | "method": "init_authentication" 56 | } 57 | 58 | read.side_effect = [async_return_value(create_message(request)), async_return_value(b"", 10)] 59 | plugin.authenticate.side_effect = error() 60 | await plugin.run() 61 | plugin.authenticate.assert_called_with() 62 | 63 | assert get_messages(write) == [ 64 | { 65 | "jsonrpc": "2.0", 66 | "id": "3", 67 | "error": { 68 | "code": code, 69 | "message": message 70 | } 71 | } 72 | ] 73 | 74 | 75 | @pytest.mark.asyncio 76 | async def test_stored_credentials(plugin, read, write): 77 | request = { 78 | "jsonrpc": "2.0", 79 | "id": "3", 80 | "method": "init_authentication", 81 | "params": { 82 | "stored_credentials": { 83 | "token": "ABC" 84 | } 85 | } 86 | } 87 | read.side_effect = [async_return_value(create_message(request)), async_return_value(b"", 10)] 88 | plugin.authenticate.return_value = async_return_value(Authentication("132", "Zenek")) 89 | await plugin.run() 90 | plugin.authenticate.assert_called_with(stored_credentials={"token": "ABC"}) 91 | write.assert_called() 92 | 93 | 94 | @pytest.mark.asyncio 95 | async def test_store_credentials(plugin, write): 96 | credentials = { 97 | "token": "ABC" 98 | } 99 | plugin.store_credentials(credentials) 100 | 101 | assert get_messages(write) == [ 102 | { 103 | "jsonrpc": "2.0", 104 | "method": "store_credentials", 105 | "params": credentials 106 | } 107 | ] 108 | 109 | 110 | @pytest.mark.asyncio 111 | async def test_lost_authentication(plugin, write): 112 | plugin.lost_authentication() 113 | 114 | assert get_messages(write) == [ 115 | { 116 | "jsonrpc": "2.0", 117 | "method": "authentication_lost", 118 | "params": None 119 | } 120 | ] 121 | -------------------------------------------------------------------------------- /tests/test_chunk_messages.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import pytest 4 | 5 | from galaxy.unittest.mock import async_return_value 6 | 7 | 8 | @pytest.mark.asyncio 9 | async def test_chunked_messages(plugin, read): 10 | request = { 11 | "jsonrpc": "2.0", 12 | "method": "install_game", 13 | "params": { 14 | "game_id": "3" 15 | } 16 | } 17 | 18 | message = json.dumps(request).encode() + b"\n" 19 | read.side_effect = [async_return_value(message[:5]), async_return_value(message[5:]), async_return_value(b"")] 20 | await plugin.run() 21 | plugin.install_game.assert_called_with(game_id="3") 22 | 23 | 24 | @pytest.mark.asyncio 25 | async def test_joined_messages(plugin, read): 26 | requests = [ 27 | { 28 | "jsonrpc": "2.0", 29 | "method": "install_game", 30 | "params": { 31 | "game_id": "3" 32 | } 33 | }, 34 | { 35 | "jsonrpc": "2.0", 36 | "method": "launch_game", 37 | "params": { 38 | "game_id": "3" 39 | } 40 | } 41 | ] 42 | data = b"".join([json.dumps(request).encode() + b"\n" for request in requests]) 43 | 44 | read.side_effect = [async_return_value(data), async_return_value(b"")] 45 | await plugin.run() 46 | plugin.install_game.assert_called_with(game_id="3") 47 | plugin.launch_game.assert_called_with(game_id="3") 48 | 49 | 50 | @pytest.mark.asyncio 51 | async def test_not_finished(plugin, read): 52 | request = { 53 | "jsonrpc": "2.0", 54 | "method": "install_game", 55 | "params": { 56 | "game_id": "3" 57 | } 58 | } 59 | 60 | message = json.dumps(request).encode() # no new line 61 | read.side_effect = [async_return_value(message), async_return_value(b"")] 62 | await plugin.run() 63 | plugin.install_game.assert_not_called() 64 | -------------------------------------------------------------------------------- /tests/test_features.py: -------------------------------------------------------------------------------- 1 | from galaxy.api.consts import Feature, Platform 2 | from galaxy.api.plugin import Plugin 3 | 4 | 5 | def test_base_class(): 6 | plugin = Plugin(Platform.Generic, "0.1", None, None, None) 7 | assert set(plugin.features) == { 8 | Feature.ImportInstalledGames, 9 | Feature.ImportOwnedGames, 10 | Feature.LaunchGame, 11 | Feature.InstallGame, 12 | Feature.UninstallGame, 13 | Feature.ImportAchievements, 14 | Feature.ImportGameTime, 15 | Feature.ImportFriends, 16 | Feature.ShutdownPlatformClient, 17 | Feature.LaunchPlatformClient 18 | } 19 | 20 | 21 | def test_no_overloads(): 22 | class PluginImpl(Plugin): # pylint: disable=abstract-method 23 | pass 24 | 25 | plugin = PluginImpl(Platform.Generic, "0.1", None, None, None) 26 | assert plugin.features == [] 27 | 28 | 29 | def test_one_method_feature(): 30 | class PluginImpl(Plugin): # pylint: disable=abstract-method 31 | async def get_owned_games(self): 32 | pass 33 | 34 | plugin = PluginImpl(Platform.Generic, "0.1", None, None, None) 35 | assert plugin.features == [Feature.ImportOwnedGames] 36 | 37 | 38 | def test_multi_features(): 39 | class PluginImpl(Plugin): # pylint: disable=abstract-method 40 | async def get_owned_games(self): 41 | pass 42 | 43 | async def get_unlocked_achievements(self, game_id, context): 44 | pass 45 | 46 | async def get_game_time(self, game_id, context): 47 | pass 48 | 49 | plugin = PluginImpl(Platform.Generic, "0.1", None, None, None) 50 | assert set(plugin.features) == {Feature.ImportAchievements, Feature.ImportOwnedGames, Feature.ImportGameTime} 51 | -------------------------------------------------------------------------------- /tests/test_friends.py: -------------------------------------------------------------------------------- 1 | from galaxy.api.types import FriendInfo 2 | from galaxy.api.errors import UnknownError 3 | from galaxy.unittest.mock import async_return_value 4 | 5 | import pytest 6 | 7 | from tests import create_message, get_messages 8 | 9 | 10 | @pytest.mark.asyncio 11 | async def test_get_friends_success(plugin, read, write): 12 | request = { 13 | "jsonrpc": "2.0", 14 | "id": "3", 15 | "method": "import_friends" 16 | } 17 | 18 | read.side_effect = [async_return_value(create_message(request)), async_return_value(b"", 10)] 19 | plugin.get_friends.return_value = async_return_value([ 20 | FriendInfo("3", "Jan"), 21 | FriendInfo("5", "Ola") 22 | ]) 23 | await plugin.run() 24 | plugin.get_friends.assert_called_with() 25 | 26 | assert get_messages(write) == [ 27 | { 28 | "jsonrpc": "2.0", 29 | "id": "3", 30 | "result": { 31 | "friend_info_list": [ 32 | {"user_id": "3", "user_name": "Jan"}, 33 | {"user_id": "5", "user_name": "Ola"} 34 | ] 35 | } 36 | } 37 | ] 38 | 39 | 40 | @pytest.mark.asyncio 41 | async def test_get_friends_failure(plugin, read, write): 42 | request = { 43 | "jsonrpc": "2.0", 44 | "id": "3", 45 | "method": "import_friends" 46 | } 47 | 48 | read.side_effect = [async_return_value(create_message(request)), async_return_value(b"", 10)] 49 | plugin.get_friends.side_effect = UnknownError() 50 | await plugin.run() 51 | plugin.get_friends.assert_called_with() 52 | 53 | assert get_messages(write) == [ 54 | { 55 | "jsonrpc": "2.0", 56 | "id": "3", 57 | "error": { 58 | "code": 0, 59 | "message": "Unknown error", 60 | } 61 | } 62 | ] 63 | 64 | 65 | @pytest.mark.asyncio 66 | async def test_add_friend(plugin, write): 67 | friend = FriendInfo("7", "Kuba") 68 | 69 | plugin.add_friend(friend) 70 | 71 | assert get_messages(write) == [ 72 | { 73 | "jsonrpc": "2.0", 74 | "method": "friend_added", 75 | "params": { 76 | "friend_info": {"user_id": "7", "user_name": "Kuba"} 77 | } 78 | } 79 | ] 80 | 81 | 82 | @pytest.mark.asyncio 83 | async def test_remove_friend(plugin, write): 84 | plugin.remove_friend("5") 85 | 86 | assert get_messages(write) == [ 87 | { 88 | "jsonrpc": "2.0", 89 | "method": "friend_removed", 90 | "params": { 91 | "user_id": "5" 92 | } 93 | } 94 | ] 95 | -------------------------------------------------------------------------------- /tests/test_game_times.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import call 2 | 3 | import pytest 4 | from galaxy.api.types import GameTime 5 | from galaxy.api.errors import BackendError 6 | from galaxy.unittest.mock import async_return_value 7 | 8 | from tests import create_message, get_messages 9 | 10 | 11 | @pytest.mark.asyncio 12 | async def test_get_game_time_success(plugin, read, write): 13 | plugin.prepare_game_times_context.return_value = async_return_value("abc") 14 | request = { 15 | "jsonrpc": "2.0", 16 | "id": "3", 17 | "method": "start_game_times_import", 18 | "params": { 19 | "game_ids": ["3", "5", "7"] 20 | } 21 | } 22 | read.side_effect = [async_return_value(create_message(request)), async_return_value(b"", 10)] 23 | plugin.get_game_time.side_effect = [ 24 | async_return_value(GameTime("3", 60, 1549550504)), 25 | async_return_value(GameTime("5", 10, None)), 26 | async_return_value(GameTime("7", None, 1549550502)), 27 | ] 28 | await plugin.run() 29 | plugin.get_game_time.assert_has_calls([ 30 | call("3", "abc"), 31 | call("5", "abc"), 32 | call("7", "abc"), 33 | ]) 34 | plugin.game_times_import_complete.assert_called_once_with() 35 | 36 | assert get_messages(write) == [ 37 | { 38 | "jsonrpc": "2.0", 39 | "id": "3", 40 | "result": None 41 | }, 42 | { 43 | "jsonrpc": "2.0", 44 | "method": "game_time_import_success", 45 | "params": { 46 | "game_time": { 47 | "game_id": "3", 48 | "last_played_time": 1549550504, 49 | "time_played": 60 50 | } 51 | } 52 | }, 53 | { 54 | "jsonrpc": "2.0", 55 | "method": "game_time_import_success", 56 | "params": { 57 | "game_time": { 58 | "game_id": "5", 59 | "time_played": 10 60 | } 61 | } 62 | }, 63 | { 64 | "jsonrpc": "2.0", 65 | "method": "game_time_import_success", 66 | "params": { 67 | "game_time": { 68 | "game_id": "7", 69 | "last_played_time": 1549550502 70 | } 71 | } 72 | }, 73 | { 74 | "jsonrpc": "2.0", 75 | "method": "game_times_import_finished", 76 | "params": None 77 | } 78 | ] 79 | 80 | 81 | @pytest.mark.asyncio 82 | @pytest.mark.parametrize("exception,code,message", [ 83 | (BackendError, 4, "Backend error"), 84 | (KeyError, 0, "Unknown error") 85 | ]) 86 | async def test_get_game_time_error(exception, code, message, plugin, read, write): 87 | plugin.prepare_game_times_context.return_value = async_return_value(None) 88 | request = { 89 | "jsonrpc": "2.0", 90 | "id": "3", 91 | "method": "start_game_times_import", 92 | "params": { 93 | "game_ids": ["6"] 94 | } 95 | } 96 | read.side_effect = [async_return_value(create_message(request)), async_return_value(b"", 10)] 97 | plugin.get_game_time.side_effect = exception 98 | await plugin.run() 99 | plugin.get_game_time.assert_called() 100 | plugin.game_times_import_complete.assert_called_once_with() 101 | 102 | assert get_messages(write) == [ 103 | { 104 | "jsonrpc": "2.0", 105 | "id": "3", 106 | "result": None 107 | }, 108 | { 109 | "jsonrpc": "2.0", 110 | "method": "game_time_import_failure", 111 | "params": { 112 | "game_id": "6", 113 | "error": { 114 | "code": code, 115 | "message": message 116 | } 117 | } 118 | }, 119 | { 120 | "jsonrpc": "2.0", 121 | "method": "game_times_import_finished", 122 | "params": None 123 | } 124 | ] 125 | 126 | 127 | @pytest.mark.asyncio 128 | async def test_prepare_get_game_time_context_error(plugin, read, write): 129 | plugin.prepare_game_times_context.side_effect = BackendError() 130 | request = { 131 | "jsonrpc": "2.0", 132 | "id": "3", 133 | "method": "start_game_times_import", 134 | "params": { 135 | "game_ids": ["6"] 136 | } 137 | } 138 | read.side_effect = [async_return_value(create_message(request)), async_return_value(b"", 10)] 139 | await plugin.run() 140 | 141 | assert get_messages(write) == [ 142 | { 143 | "jsonrpc": "2.0", 144 | "id": "3", 145 | "error": { 146 | "code": 4, 147 | "message": "Backend error" 148 | } 149 | } 150 | ] 151 | 152 | 153 | @pytest.mark.asyncio 154 | async def test_import_in_progress(plugin, read, write): 155 | plugin.prepare_game_times_context.return_value = async_return_value(None) 156 | requests = [ 157 | { 158 | "jsonrpc": "2.0", 159 | "id": "3", 160 | "method": "start_game_times_import", 161 | "params": { 162 | "game_ids": ["6"] 163 | } 164 | }, 165 | { 166 | "jsonrpc": "2.0", 167 | "id": "4", 168 | "method": "start_game_times_import", 169 | "params": { 170 | "game_ids": ["7"] 171 | } 172 | } 173 | ] 174 | read.side_effect = [ 175 | async_return_value(create_message(requests[0])), 176 | async_return_value(create_message(requests[1])), 177 | async_return_value(b"", 10) 178 | ] 179 | 180 | await plugin.run() 181 | 182 | messages = get_messages(write) 183 | assert { 184 | "jsonrpc": "2.0", 185 | "id": "3", 186 | "result": None 187 | } in messages 188 | assert { 189 | "jsonrpc": "2.0", 190 | "id": "4", 191 | "error": { 192 | "code": 600, 193 | "message": "Import already in progress" 194 | } 195 | } in messages 196 | 197 | 198 | @pytest.mark.asyncio 199 | async def test_update_game(plugin, write): 200 | game_time = GameTime("3", 60, 1549550504) 201 | plugin.update_game_time(game_time) 202 | 203 | assert get_messages(write) == [ 204 | { 205 | "jsonrpc": "2.0", 206 | "method": "game_time_updated", 207 | "params": { 208 | "game_time": { 209 | "game_id": "3", 210 | "time_played": 60, 211 | "last_played_time": 1549550504 212 | } 213 | } 214 | } 215 | ] 216 | -------------------------------------------------------------------------------- /tests/test_http.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from http import HTTPStatus 3 | 4 | import aiohttp 5 | import pytest 6 | from multidict import CIMultiDict, CIMultiDictProxy 7 | from yarl import URL 8 | 9 | from galaxy.api.errors import ( 10 | AccessDenied, AuthenticationRequired, BackendTimeout, BackendNotAvailable, BackendError, NetworkError, 11 | TooManyRequests, UnknownBackendResponse, UnknownError 12 | ) 13 | from galaxy.http import handle_exception 14 | 15 | request_info = aiohttp.RequestInfo(URL("http://o.pl"), "GET", CIMultiDictProxy(CIMultiDict())) 16 | 17 | @pytest.mark.parametrize( 18 | "aiohttp_exception,expected_exception_type", 19 | [ 20 | (asyncio.TimeoutError(), BackendTimeout), 21 | (aiohttp.ServerDisconnectedError(), BackendNotAvailable), 22 | (aiohttp.ClientConnectionError(), NetworkError), 23 | (aiohttp.ContentTypeError(request_info, ()), UnknownBackendResponse), 24 | (aiohttp.ClientResponseError(request_info, (), status=HTTPStatus.UNAUTHORIZED), AuthenticationRequired), 25 | (aiohttp.ClientResponseError(request_info, (), status=HTTPStatus.FORBIDDEN), AccessDenied), 26 | (aiohttp.ClientResponseError(request_info, (), status=HTTPStatus.SERVICE_UNAVAILABLE), BackendNotAvailable), 27 | (aiohttp.ClientResponseError(request_info, (), status=HTTPStatus.TOO_MANY_REQUESTS), TooManyRequests), 28 | (aiohttp.ClientResponseError(request_info, (), status=HTTPStatus.INTERNAL_SERVER_ERROR), BackendError), 29 | (aiohttp.ClientResponseError(request_info, (), status=HTTPStatus.NOT_IMPLEMENTED), BackendError), 30 | (aiohttp.ClientResponseError(request_info, (), status=HTTPStatus.BAD_REQUEST), UnknownError), 31 | (aiohttp.ClientResponseError(request_info, (), status=HTTPStatus.NOT_FOUND), UnknownError), 32 | (aiohttp.ClientError(), UnknownError) 33 | ] 34 | ) 35 | def test_handle_exception(aiohttp_exception, expected_exception_type): 36 | with pytest.raises(expected_exception_type): 37 | with handle_exception(): 38 | raise aiohttp_exception 39 | 40 | -------------------------------------------------------------------------------- /tests/test_install_game.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from galaxy.unittest.mock import async_return_value 4 | 5 | from tests import create_message 6 | 7 | 8 | @pytest.mark.asyncio 9 | async def test_success(plugin, read): 10 | request = { 11 | "jsonrpc": "2.0", 12 | "method": "install_game", 13 | "params": { 14 | "game_id": "3" 15 | } 16 | } 17 | 18 | read.side_effect = [async_return_value(create_message(request)), async_return_value(b"")] 19 | await plugin.run() 20 | plugin.install_game.assert_called_with(game_id="3") 21 | -------------------------------------------------------------------------------- /tests/test_internal.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from galaxy.api.plugin import Plugin 4 | from galaxy.api.consts import Platform 5 | from galaxy.unittest.mock import async_return_value 6 | 7 | from tests import create_message, get_messages 8 | 9 | 10 | @pytest.mark.asyncio 11 | async def test_get_capabilities(reader, writer, read, write): 12 | class PluginImpl(Plugin): #pylint: disable=abstract-method 13 | async def get_owned_games(self): 14 | pass 15 | 16 | request = { 17 | "jsonrpc": "2.0", 18 | "id": "3", 19 | "method": "get_capabilities" 20 | } 21 | token = "token" 22 | plugin = PluginImpl(Platform.Generic, "0.1", reader, writer, token) 23 | read.side_effect = [async_return_value(create_message(request)), async_return_value(b"")] 24 | await plugin.run() 25 | assert get_messages(write) == [ 26 | { 27 | "jsonrpc": "2.0", 28 | "id": "3", 29 | "result": { 30 | "platform_name": "generic", 31 | "features": [ 32 | "ImportOwnedGames" 33 | ], 34 | "token": token 35 | } 36 | } 37 | ] 38 | 39 | 40 | @pytest.mark.asyncio 41 | async def test_shutdown(plugin, read, write): 42 | request = { 43 | "jsonrpc": "2.0", 44 | "id": "5", 45 | "method": "shutdown" 46 | } 47 | read.side_effect = [async_return_value(create_message(request))] 48 | await plugin.run() 49 | await plugin.wait_closed() 50 | plugin.shutdown.assert_called_with() 51 | assert get_messages(write) == [ 52 | { 53 | "jsonrpc": "2.0", 54 | "id": "5", 55 | "result": None 56 | } 57 | ] 58 | 59 | 60 | @pytest.mark.asyncio 61 | async def test_ping(plugin, read, write): 62 | request = { 63 | "jsonrpc": "2.0", 64 | "id": "7", 65 | "method": "ping" 66 | } 67 | read.side_effect = [async_return_value(create_message(request)), async_return_value(b"")] 68 | await plugin.run() 69 | assert get_messages(write) == [ 70 | { 71 | "jsonrpc": "2.0", 72 | "id": "7", 73 | "result": None 74 | } 75 | ] 76 | 77 | 78 | @pytest.mark.asyncio 79 | async def test_tick_before_handshake(plugin, read): 80 | read.side_effect = [async_return_value(b"")] 81 | await plugin.run() 82 | plugin.tick.assert_not_called() 83 | 84 | 85 | @pytest.mark.asyncio 86 | async def test_tick_after_handshake(plugin, read): 87 | request = { 88 | "jsonrpc": "2.0", 89 | "id": "6", 90 | "method": "initialize_cache", 91 | "params": {"data": {}} 92 | } 93 | read.side_effect = [async_return_value(create_message(request)), async_return_value(b"")] 94 | await plugin.run() 95 | plugin.tick.assert_called_with() 96 | -------------------------------------------------------------------------------- /tests/test_launch_game.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from galaxy.unittest.mock import async_return_value 4 | 5 | from tests import create_message 6 | 7 | 8 | @pytest.mark.asyncio 9 | async def test_success(plugin, read): 10 | request = { 11 | "jsonrpc": "2.0", 12 | "method": "launch_game", 13 | "params": { 14 | "game_id": "3" 15 | } 16 | } 17 | 18 | read.side_effect = [async_return_value(create_message(request)), async_return_value(b"")] 19 | await plugin.run() 20 | plugin.launch_game.assert_called_with(game_id="3") 21 | -------------------------------------------------------------------------------- /tests/test_launch_platform_client.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from galaxy.unittest.mock import async_return_value 4 | 5 | from tests import create_message 6 | 7 | @pytest.mark.asyncio 8 | async def test_success(plugin, read): 9 | request = { 10 | "jsonrpc": "2.0", 11 | "method": "launch_platform_client" 12 | } 13 | 14 | read.side_effect = [async_return_value(create_message(request)), async_return_value(b"")] 15 | plugin.launch_platform_client.return_value = async_return_value(None) 16 | await plugin.run() 17 | plugin.launch_platform_client.assert_called_with() 18 | -------------------------------------------------------------------------------- /tests/test_local_games.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from galaxy.api.types import LocalGame 4 | from galaxy.api.consts import LocalGameState 5 | from galaxy.api.errors import UnknownError, FailedParsingManifest 6 | from galaxy.unittest.mock import async_return_value 7 | 8 | from tests import create_message, get_messages 9 | 10 | 11 | @pytest.mark.asyncio 12 | async def test_success(plugin, read, write): 13 | request = { 14 | "jsonrpc": "2.0", 15 | "id": "3", 16 | "method": "import_local_games" 17 | } 18 | read.side_effect = [async_return_value(create_message(request)), async_return_value(b"", 10)] 19 | 20 | plugin.get_local_games.return_value = async_return_value([ 21 | LocalGame("1", LocalGameState.Running), 22 | LocalGame("2", LocalGameState.Installed), 23 | LocalGame("3", LocalGameState.Installed | LocalGameState.Running) 24 | ]) 25 | await plugin.run() 26 | plugin.get_local_games.assert_called_with() 27 | 28 | assert get_messages(write) == [ 29 | { 30 | "jsonrpc": "2.0", 31 | "id": "3", 32 | "result": { 33 | "local_games" : [ 34 | { 35 | "game_id": "1", 36 | "local_game_state": LocalGameState.Running.value 37 | }, 38 | { 39 | "game_id": "2", 40 | "local_game_state": LocalGameState.Installed.value 41 | }, 42 | { 43 | "game_id": "3", 44 | "local_game_state": (LocalGameState.Installed | LocalGameState.Running).value 45 | } 46 | ] 47 | } 48 | } 49 | ] 50 | 51 | 52 | @pytest.mark.asyncio 53 | @pytest.mark.parametrize( 54 | "error,code,message", 55 | [ 56 | pytest.param(UnknownError, 0, "Unknown error", id="unknown_error"), 57 | pytest.param(FailedParsingManifest, 200, "Failed parsing manifest", id="failed_parsing") 58 | ], 59 | ) 60 | async def test_failure(plugin, read, write, error, code, message): 61 | request = { 62 | "jsonrpc": "2.0", 63 | "id": "3", 64 | "method": "import_local_games" 65 | } 66 | read.side_effect = [async_return_value(create_message(request)), async_return_value(b"", 10)] 67 | plugin.get_local_games.side_effect = error() 68 | await plugin.run() 69 | plugin.get_local_games.assert_called_with() 70 | 71 | assert get_messages(write) == [ 72 | { 73 | "jsonrpc": "2.0", 74 | "id": "3", 75 | "error": { 76 | "code": code, 77 | "message": message 78 | } 79 | } 80 | ] 81 | 82 | @pytest.mark.asyncio 83 | async def test_local_game_state_update(plugin, write): 84 | game = LocalGame("1", LocalGameState.Running) 85 | plugin.update_local_game_status(game) 86 | 87 | assert get_messages(write) == [ 88 | { 89 | "jsonrpc": "2.0", 90 | "method": "local_game_status_changed", 91 | "params": { 92 | "local_game": { 93 | "game_id": "1", 94 | "local_game_state": LocalGameState.Running.value 95 | } 96 | } 97 | } 98 | ] 99 | -------------------------------------------------------------------------------- /tests/test_owned_games.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from galaxy.api.types import Game, Dlc, LicenseInfo 4 | from galaxy.api.consts import LicenseType 5 | from galaxy.api.errors import UnknownError 6 | from galaxy.unittest.mock import async_return_value 7 | 8 | from tests import create_message, get_messages 9 | 10 | 11 | @pytest.mark.asyncio 12 | async def test_success(plugin, read, write): 13 | request = { 14 | "jsonrpc": "2.0", 15 | "id": "3", 16 | "method": "import_owned_games" 17 | } 18 | read.side_effect = [async_return_value(create_message(request)), async_return_value(b"", 10)] 19 | 20 | plugin.get_owned_games.return_value = async_return_value([ 21 | Game("3", "Doom", None, LicenseInfo(LicenseType.SinglePurchase, None)), 22 | Game( 23 | "5", 24 | "Witcher 3", 25 | [ 26 | Dlc("7", "Hearts of Stone", LicenseInfo(LicenseType.SinglePurchase, None)), 27 | Dlc("8", "Temerian Armor Set", LicenseInfo(LicenseType.FreeToPlay, None)), 28 | ], 29 | LicenseInfo(LicenseType.SinglePurchase, None)) 30 | ]) 31 | await plugin.run() 32 | plugin.get_owned_games.assert_called_with() 33 | assert get_messages(write) == [ 34 | { 35 | "jsonrpc": "2.0", 36 | "id": "3", 37 | "result": { 38 | "owned_games": [ 39 | { 40 | "game_id": "3", 41 | "game_title": "Doom", 42 | "license_info": { 43 | "license_type": "SinglePurchase" 44 | } 45 | }, 46 | { 47 | "game_id": "5", 48 | "game_title": "Witcher 3", 49 | "dlcs": [ 50 | { 51 | "dlc_id": "7", 52 | "dlc_title": "Hearts of Stone", 53 | "license_info": { 54 | "license_type": "SinglePurchase" 55 | } 56 | }, 57 | { 58 | "dlc_id": "8", 59 | "dlc_title": "Temerian Armor Set", 60 | "license_info": { 61 | "license_type": "FreeToPlay" 62 | } 63 | } 64 | ], 65 | "license_info": { 66 | "license_type": "SinglePurchase" 67 | } 68 | } 69 | ] 70 | } 71 | } 72 | ] 73 | 74 | 75 | @pytest.mark.asyncio 76 | async def test_failure(plugin, read, write): 77 | request = { 78 | "jsonrpc": "2.0", 79 | "id": "3", 80 | "method": "import_owned_games" 81 | } 82 | 83 | read.side_effect = [async_return_value(create_message(request)), async_return_value(b"", 10)] 84 | plugin.get_owned_games.side_effect = UnknownError() 85 | await plugin.run() 86 | plugin.get_owned_games.assert_called_with() 87 | assert get_messages(write) == [ 88 | { 89 | "jsonrpc": "2.0", 90 | "id": "3", 91 | "error": { 92 | "code": 0, 93 | "message": "Unknown error" 94 | } 95 | } 96 | ] 97 | 98 | 99 | @pytest.mark.asyncio 100 | async def test_add_game(plugin, write): 101 | game = Game("3", "Doom", None, LicenseInfo(LicenseType.SinglePurchase, None)) 102 | plugin.add_game(game) 103 | assert get_messages(write) == [ 104 | { 105 | "jsonrpc": "2.0", 106 | "method": "owned_game_added", 107 | "params": { 108 | "owned_game": { 109 | "game_id": "3", 110 | "game_title": "Doom", 111 | "license_info": { 112 | "license_type": "SinglePurchase" 113 | } 114 | } 115 | } 116 | } 117 | ] 118 | 119 | 120 | @pytest.mark.asyncio 121 | async def test_remove_game(plugin, write): 122 | plugin.remove_game("5") 123 | assert get_messages(write) == [ 124 | { 125 | "jsonrpc": "2.0", 126 | "method": "owned_game_removed", 127 | "params": { 128 | "game_id": "5" 129 | } 130 | } 131 | ] 132 | 133 | 134 | @pytest.mark.asyncio 135 | async def test_update_game(plugin, write): 136 | game = Game("3", "Doom", None, LicenseInfo(LicenseType.SinglePurchase, None)) 137 | plugin.update_game(game) 138 | assert get_messages(write) == [ 139 | { 140 | "jsonrpc": "2.0", 141 | "method": "owned_game_updated", 142 | "params": { 143 | "owned_game": { 144 | "game_id": "3", 145 | "game_title": "Doom", 146 | "license_info": { 147 | "license_type": "SinglePurchase" 148 | } 149 | } 150 | } 151 | } 152 | ] 153 | -------------------------------------------------------------------------------- /tests/test_persistent_cache.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from galaxy.unittest.mock import async_return_value 4 | 5 | from tests import create_message, get_messages 6 | 7 | 8 | def assert_rpc_response(write, response_id, result=None): 9 | assert get_messages(write) == [ 10 | { 11 | "jsonrpc": "2.0", 12 | "id": str(response_id), 13 | "result": result 14 | } 15 | ] 16 | 17 | 18 | def assert_rpc_request(write, method, params=None): 19 | assert get_messages(write) == [ 20 | { 21 | "jsonrpc": "2.0", 22 | "method": method, 23 | "params": {"data": params} 24 | } 25 | ] 26 | 27 | 28 | @pytest.fixture 29 | def cache_data(): 30 | return { 31 | "persistent key": "persistent value", 32 | "persistent object": {"answer to everything": 42} 33 | } 34 | 35 | 36 | @pytest.mark.asyncio 37 | async def test_initialize_cache(plugin, read, write, cache_data): 38 | request_id = 3 39 | request = { 40 | "jsonrpc": "2.0", 41 | "id": str(request_id), 42 | "method": "initialize_cache", 43 | "params": {"data": cache_data} 44 | } 45 | read.side_effect = [async_return_value(create_message(request)), async_return_value(b"")] 46 | 47 | assert {} == plugin.persistent_cache 48 | await plugin.run() 49 | plugin.handshake_complete.assert_called_once_with() 50 | assert cache_data == plugin.persistent_cache 51 | assert_rpc_response(write, response_id=request_id) 52 | 53 | 54 | @pytest.mark.asyncio 55 | async def test_set_cache(plugin, write, cache_data): 56 | assert {} == plugin.persistent_cache 57 | 58 | plugin.persistent_cache.update(cache_data) 59 | plugin.push_cache() 60 | 61 | assert_rpc_request(write, "push_cache", cache_data) 62 | assert cache_data == plugin.persistent_cache 63 | 64 | 65 | @pytest.mark.asyncio 66 | async def test_clear_cache(plugin, write, cache_data): 67 | plugin._persistent_cache = cache_data 68 | 69 | plugin.persistent_cache.clear() 70 | plugin.push_cache() 71 | 72 | assert_rpc_request(write, "push_cache", {}) 73 | assert {} == plugin.persistent_cache 74 | -------------------------------------------------------------------------------- /tests/test_shutdown_platform_client.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from galaxy.unittest.mock import async_return_value 4 | 5 | from tests import create_message 6 | 7 | @pytest.mark.asyncio 8 | async def test_success(plugin, read): 9 | request = { 10 | "jsonrpc": "2.0", 11 | "method": "shutdown_platform_client" 12 | } 13 | 14 | read.side_effect = [async_return_value(create_message(request)), async_return_value(b"")] 15 | plugin.shutdown_platform_client.return_value = async_return_value(None) 16 | await plugin.run() 17 | plugin.shutdown_platform_client.assert_called_with() 18 | -------------------------------------------------------------------------------- /tests/test_stream_line_reader.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from galaxy.reader import StreamLineReader 4 | from galaxy.unittest.mock import async_return_value 5 | 6 | 7 | @pytest.fixture() 8 | def stream_line_reader(reader): 9 | return StreamLineReader(reader) 10 | 11 | 12 | @pytest.mark.asyncio 13 | async def test_message(stream_line_reader, read): 14 | read.return_value = async_return_value(b"a\n") 15 | assert await stream_line_reader.readline() == b"a" 16 | read.assert_called_once() 17 | 18 | 19 | @pytest.mark.asyncio 20 | async def test_separate_messages(stream_line_reader, read): 21 | read.side_effect = [async_return_value(b"a\n"), async_return_value(b"b\n")] 22 | assert await stream_line_reader.readline() == b"a" 23 | assert await stream_line_reader.readline() == b"b" 24 | assert read.call_count == 2 25 | 26 | 27 | @pytest.mark.asyncio 28 | async def test_connected_messages(stream_line_reader, read): 29 | read.return_value = async_return_value(b"a\nb\n") 30 | assert await stream_line_reader.readline() == b"a" 31 | assert await stream_line_reader.readline() == b"b" 32 | read.assert_called_once() 33 | 34 | 35 | @pytest.mark.asyncio 36 | async def test_cut_message(stream_line_reader, read): 37 | read.side_effect = [async_return_value(b"a"), async_return_value(b"b\n")] 38 | assert await stream_line_reader.readline() == b"ab" 39 | assert read.call_count == 2 40 | 41 | 42 | @pytest.mark.asyncio 43 | async def test_half_message(stream_line_reader, read): 44 | read.side_effect = [async_return_value(b"a"), async_return_value(b"")] 45 | assert await stream_line_reader.readline() == b"" 46 | assert read.call_count == 2 47 | -------------------------------------------------------------------------------- /tests/test_uninstall_game.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from galaxy.unittest.mock import async_return_value 4 | 5 | from tests import create_message 6 | 7 | @pytest.mark.asyncio 8 | async def test_success(plugin, read): 9 | request = { 10 | "jsonrpc": "2.0", 11 | "method": "uninstall_game", 12 | "params": { 13 | "game_id": "3" 14 | } 15 | } 16 | read.side_effect = [async_return_value(create_message(request)), async_return_value(b"")] 17 | plugin.get_owned_games.return_value = None 18 | await plugin.run() 19 | plugin.uninstall_game.assert_called_with(game_id="3") 20 | -------------------------------------------------------------------------------- /trophy.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import xml.etree.ElementTree as ET 5 | 6 | from config import Config 7 | from galaxy.api.types import Achievement 8 | 9 | 10 | class TropConf: 11 | def __init__(self, trophies_path): 12 | self.path = os.path.join(trophies_path, 'TROPCONF.SFM') 13 | self.tree = ET.parse(self.path) 14 | self.root = self.tree.getroot() 15 | 16 | 17 | # This metaclass allows us to call 'len(TropUsrHeader)' and get '48'. 18 | # Useful for substringing TROPUSR.DAT. 19 | class TropMetaclass(type): 20 | def __len__(self): 21 | total = 0 22 | for v in vars(self)['__annotations__'].values(): 23 | total += len(v) 24 | return total 25 | 26 | 27 | class TropUsrHeader(object, metaclass=TropMetaclass): 28 | magic : bytes(4) 29 | unk1 : bytes(4) 30 | table_count : bytes(4) 31 | unk2 : bytes(4) 32 | reserved : bytes(32) 33 | 34 | # bytestring is only a portion of the file; 35 | # subscriptors are relative to the start of the portion, not the file. 36 | def __init__(self, bytestring): 37 | self.magic = bytestring[0 : 4] 38 | self.unk1 = bytestring[4 : 8] 39 | self.table_count = bytestring[8 : 12] 40 | self.unk2 = bytestring[12 : 16] 41 | self.reserved = bytestring[16 : 48] 42 | 43 | 44 | class TropUsrTableHeader(object, metaclass=TropMetaclass): 45 | table_type : bytes(4) 46 | entries_size : bytes(4) 47 | unk1 : bytes(4) 48 | entries_count : bytes(4) 49 | offset : bytes(8) 50 | reserved : bytes(8) 51 | 52 | def __init__(self, bytestring): 53 | self.table_type = bytestring[0 : 4] 54 | self.entries_size = bytestring[4 : 8] 55 | self.unk1 = bytestring[8 : 12] 56 | self.entries_count = bytestring[12 : 16] 57 | self.offset = bytestring[16 : 24] 58 | self.reserved = bytestring[24 : 32] 59 | 60 | 61 | class TropUsrEntry4(object, metaclass=TropMetaclass): 62 | # Entry Header 63 | entry_type : bytes(4) 64 | entry_size : bytes(4) 65 | entry_id : bytes(4) 66 | entry_unk1 : bytes(4) 67 | 68 | # Entry Contents 69 | trophy_id : bytes(4) 70 | trophy_grade : bytes(4) 71 | unk5 : bytes(4) 72 | unk6 : bytes(68) 73 | 74 | def __init__(self, bytestring): 75 | # Entry Header 76 | self.entry_type = bytestring[0 : 4] 77 | self.entry_size = bytestring[4 : 8] 78 | self.entry_id = bytestring[8 : 12] 79 | self.entry_unk1 = bytestring[12 : 16] 80 | 81 | # Entry Contents 82 | self.trophy_id = bytestring[16 : 20] 83 | self.trophy_grade = bytestring[20 : 24] 84 | self.unk5 = bytestring[24 : 28] 85 | self.unk6 = bytestring[28 : 96] 86 | 87 | 88 | class TropUsrEntry6(object, metaclass=TropMetaclass): 89 | # Entry Header 90 | entry_type : bytes(4) 91 | entry_size : bytes(4) 92 | entry_id : bytes(4) 93 | entry_unk1 : bytes(4) 94 | 95 | # Entry Contents 96 | trophy_id : bytes(4) 97 | trophy_state : bytes(4) 98 | unk4 : bytes(4) 99 | unk5 : bytes(4) 100 | timestamp1 : bytes(8) 101 | timestamp2 : bytes(8) 102 | unk6 : bytes(64) 103 | 104 | def __init__(self, bytestring): 105 | # Entry Header 106 | self.entry_type = bytestring[0 : 4] 107 | self.entry_size = bytestring[4 : 8] 108 | self.entry_id = bytestring[8 : 12] 109 | self.entry_unk1 = bytestring[12 : 16] 110 | 111 | # Entry Contents 112 | self.trophy_id = bytestring[16 : 20] 113 | self.trophy_state = bytestring[20 : 24] 114 | self.unk4 = bytestring[24 : 28] 115 | self.unk5 = bytestring[28 : 32] 116 | self.timestamp1 = bytestring[32 : 40] 117 | self.timestamp2 = bytestring[40 : 48] 118 | self.unk6 = bytestring[48 : 112] 119 | 120 | 121 | class TropUsr: 122 | def __init__(self, trophies_path): 123 | 124 | self.file = None 125 | self.header = None 126 | self.table_headers = {} 127 | self.table4 = {} 128 | self.table6 = {} 129 | 130 | self.path = os.path.join(trophies_path, 'TROPUSR.DAT') 131 | if os.path.exists(self.path): 132 | with open(self.path, 'rb') as file: 133 | self.file = file.read() 134 | 135 | self.load() 136 | 137 | 138 | def load(self) -> bool: 139 | 140 | # TODO - Generate if none exists. Is this necessary? 141 | 142 | if not self.file: 143 | raise FileNotFoundError(self.path + ' file not found.') 144 | return False 145 | 146 | if not self.load_header(): 147 | raise ValueError('Magic bytes not found in file header.') 148 | return False 149 | 150 | if not self.load_table_headers(): 151 | raise ValueError('Unknown table type found in table header.') 152 | return False 153 | 154 | if not self.load_tables(): 155 | raise ValueError('Error loading table entries.') 156 | return False 157 | 158 | return True 159 | 160 | 161 | def load_header(self) -> bool: 162 | 163 | self.header = TropUsrHeader(self.file[0 : len(TropUsrHeader)]) 164 | 165 | if b'\x81\x8F\x54\xAD' not in self.header.magic: 166 | return False 167 | 168 | return True 169 | 170 | 171 | def load_table_headers(self) -> bool: 172 | 173 | table_header = None 174 | table_count = int.from_bytes(self.header.table_count, byteorder='big') 175 | 176 | self.table_headers = {} 177 | for i in range(table_count): 178 | 179 | table_header = TropUsrTableHeader(self.file[ 180 | len(TropUsrHeader) + (len(TropUsrTableHeader) * i) : 181 | len(TropUsrHeader) + (len(TropUsrTableHeader) * (i + 1)) 182 | ]) 183 | 184 | if b'\x00\x00\x00\x04' not in table_header.table_type and b'\x00\x00\x00\x06' not in table_header.table_type: 185 | return False 186 | 187 | self.table_headers[i] = table_header 188 | 189 | return True 190 | 191 | 192 | def load_tables(self) -> bool: 193 | 194 | table_header = None 195 | table_count = int.from_bytes(self.header.table_count, byteorder='big') 196 | 197 | for i in range(table_count): 198 | 199 | table_header = self.table_headers[i] 200 | offset = int.from_bytes(table_header.offset, byteorder='big') 201 | entries_count = int.from_bytes(table_header.entries_count, byteorder='big') 202 | 203 | if b'\x00\x00\x00\x04' in table_header.table_type: 204 | self.table4 = {} 205 | for i in range(entries_count): 206 | 207 | entry = TropUsrEntry4(self.file[ 208 | offset + (len(TropUsrEntry4) * i) : 209 | offset + (len(TropUsrEntry4) * (i + 1)) 210 | ]) 211 | 212 | if b'\x00\x00\x00\x04' not in entry.entry_type: 213 | return False 214 | 215 | if b'\x00\x00\x00\x50' not in entry.entry_size: 216 | return False 217 | 218 | trophy_id = int.from_bytes(entry.trophy_id, byteorder='big') 219 | self.table4[trophy_id] = entry 220 | 221 | if b'\x00\x00\x00\x06' in table_header.table_type: 222 | self.table6 = {} 223 | for i in range(entries_count): 224 | 225 | entry = TropUsrEntry6(self.file[ 226 | offset + (len(TropUsrEntry6) * i) : 227 | offset + (len(TropUsrEntry6) * (i + 1)) 228 | ]) 229 | 230 | if b'\x00\x00\x00\x06' not in entry.entry_type: 231 | return False 232 | 233 | if b'\x00\x00\x00\x60' not in entry.entry_size: 234 | return False 235 | 236 | trophy_id = int.from_bytes(entry.trophy_id, byteorder='big') 237 | self.table6[trophy_id] = entry 238 | 239 | return True 240 | 241 | 242 | class Trophy: 243 | config : Config 244 | tropconf : TropConf 245 | tropusr : TropUsr 246 | 247 | def __init__(self, config : Config, game_path : str): 248 | 249 | try: 250 | self.config = config 251 | tropdir_path = self.config.joinpath(game_path, 'TROPDIR') 252 | game_trophies_id = os.listdir(tropdir_path)[0] 253 | 254 | game_trophies_path = self.config.joinpath( 255 | self.config.trophy_directory, 256 | game_trophies_id) 257 | 258 | self.tropconf = TropConf(game_trophies_path) 259 | self.tropusr = TropUsr(game_trophies_path) 260 | 261 | # File not found, etc. 262 | except: 263 | pass 264 | 265 | 266 | def trop2ach(self, trophy_id : int) -> Achievement: 267 | 268 | trophy_state = self.tropusr.table6[trophy_id].trophy_state 269 | unlocked = int.from_bytes(trophy_state, byteorder='big') 270 | 271 | if bool(unlocked): 272 | unlock_timestamp = self.tropusr.table6[trophy_id].timestamp2 273 | unlock_time = int.from_bytes(unlock_timestamp, byteorder='big') 274 | 275 | pad_tid = format(trophy_id, '03d') 276 | name = self.tropconf.root.find('.//trophy[@id="' + pad_tid + '"]/name') 277 | 278 | ach = Achievement(unlock_time, None, name.text) 279 | return ach 280 | else: 281 | return None -------------------------------------------------------------------------------- /version.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | import os 4 | 5 | def get_version(): 6 | base_path = '' 7 | version = '' 8 | 9 | base_path = os.path.dirname(os.path.realpath(__file__)) 10 | manifest_path = os.path.join(base_path, 'manifest.json') 11 | 12 | with open(manifest_path) as manifest: 13 | man = json.load(manifest) 14 | version = man['version'] 15 | 16 | return version 17 | -------------------------------------------------------------------------------- /yaml/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .error import * 3 | 4 | from .tokens import * 5 | from .events import * 6 | from .nodes import * 7 | 8 | from .loader import * 9 | from .dumper import * 10 | 11 | __version__ = '5.1.2' 12 | try: 13 | from .cyaml import * 14 | __with_libyaml__ = True 15 | except ImportError: 16 | __with_libyaml__ = False 17 | 18 | import io 19 | 20 | #------------------------------------------------------------------------------ 21 | # Warnings control 22 | #------------------------------------------------------------------------------ 23 | 24 | # 'Global' warnings state: 25 | _warnings_enabled = { 26 | 'YAMLLoadWarning': True, 27 | } 28 | 29 | # Get or set global warnings' state 30 | def warnings(settings=None): 31 | if settings is None: 32 | return _warnings_enabled 33 | 34 | if type(settings) is dict: 35 | for key in settings: 36 | if key in _warnings_enabled: 37 | _warnings_enabled[key] = settings[key] 38 | 39 | # Warn when load() is called without Loader=... 40 | class YAMLLoadWarning(RuntimeWarning): 41 | pass 42 | 43 | def load_warning(method): 44 | if _warnings_enabled['YAMLLoadWarning'] is False: 45 | return 46 | 47 | import warnings 48 | 49 | message = ( 50 | "calling yaml.%s() without Loader=... is deprecated, as the " 51 | "default Loader is unsafe. Please read " 52 | "https://msg.pyyaml.org/load for full details." 53 | ) % method 54 | 55 | warnings.warn(message, YAMLLoadWarning, stacklevel=3) 56 | 57 | #------------------------------------------------------------------------------ 58 | def scan(stream, Loader=Loader): 59 | """ 60 | Scan a YAML stream and produce scanning tokens. 61 | """ 62 | loader = Loader(stream) 63 | try: 64 | while loader.check_token(): 65 | yield loader.get_token() 66 | finally: 67 | loader.dispose() 68 | 69 | def parse(stream, Loader=Loader): 70 | """ 71 | Parse a YAML stream and produce parsing events. 72 | """ 73 | loader = Loader(stream) 74 | try: 75 | while loader.check_event(): 76 | yield loader.get_event() 77 | finally: 78 | loader.dispose() 79 | 80 | def compose(stream, Loader=Loader): 81 | """ 82 | Parse the first YAML document in a stream 83 | and produce the corresponding representation tree. 84 | """ 85 | loader = Loader(stream) 86 | try: 87 | return loader.get_single_node() 88 | finally: 89 | loader.dispose() 90 | 91 | def compose_all(stream, Loader=Loader): 92 | """ 93 | Parse all YAML documents in a stream 94 | and produce corresponding representation trees. 95 | """ 96 | loader = Loader(stream) 97 | try: 98 | while loader.check_node(): 99 | yield loader.get_node() 100 | finally: 101 | loader.dispose() 102 | 103 | def load(stream, Loader=None): 104 | """ 105 | Parse the first YAML document in a stream 106 | and produce the corresponding Python object. 107 | """ 108 | if Loader is None: 109 | load_warning('load') 110 | Loader = FullLoader 111 | 112 | loader = Loader(stream) 113 | try: 114 | return loader.get_single_data() 115 | finally: 116 | loader.dispose() 117 | 118 | def load_all(stream, Loader=None): 119 | """ 120 | Parse all YAML documents in a stream 121 | and produce corresponding Python objects. 122 | """ 123 | if Loader is None: 124 | load_warning('load_all') 125 | Loader = FullLoader 126 | 127 | loader = Loader(stream) 128 | try: 129 | while loader.check_data(): 130 | yield loader.get_data() 131 | finally: 132 | loader.dispose() 133 | 134 | def full_load(stream): 135 | """ 136 | Parse the first YAML document in a stream 137 | and produce the corresponding Python object. 138 | 139 | Resolve all tags except those known to be 140 | unsafe on untrusted input. 141 | """ 142 | return load(stream, FullLoader) 143 | 144 | def full_load_all(stream): 145 | """ 146 | Parse all YAML documents in a stream 147 | and produce corresponding Python objects. 148 | 149 | Resolve all tags except those known to be 150 | unsafe on untrusted input. 151 | """ 152 | return load_all(stream, FullLoader) 153 | 154 | def safe_load(stream): 155 | """ 156 | Parse the first YAML document in a stream 157 | and produce the corresponding Python object. 158 | 159 | Resolve only basic YAML tags. This is known 160 | to be safe for untrusted input. 161 | """ 162 | return load(stream, SafeLoader) 163 | 164 | def safe_load_all(stream): 165 | """ 166 | Parse all YAML documents in a stream 167 | and produce corresponding Python objects. 168 | 169 | Resolve only basic YAML tags. This is known 170 | to be safe for untrusted input. 171 | """ 172 | return load_all(stream, SafeLoader) 173 | 174 | def unsafe_load(stream): 175 | """ 176 | Parse the first YAML document in a stream 177 | and produce the corresponding Python object. 178 | 179 | Resolve all tags, even those known to be 180 | unsafe on untrusted input. 181 | """ 182 | return load(stream, UnsafeLoader) 183 | 184 | def unsafe_load_all(stream): 185 | """ 186 | Parse all YAML documents in a stream 187 | and produce corresponding Python objects. 188 | 189 | Resolve all tags, even those known to be 190 | unsafe on untrusted input. 191 | """ 192 | return load_all(stream, UnsafeLoader) 193 | 194 | def emit(events, stream=None, Dumper=Dumper, 195 | canonical=None, indent=None, width=None, 196 | allow_unicode=None, line_break=None): 197 | """ 198 | Emit YAML parsing events into a stream. 199 | If stream is None, return the produced string instead. 200 | """ 201 | getvalue = None 202 | if stream is None: 203 | stream = io.StringIO() 204 | getvalue = stream.getvalue 205 | dumper = Dumper(stream, canonical=canonical, indent=indent, width=width, 206 | allow_unicode=allow_unicode, line_break=line_break) 207 | try: 208 | for event in events: 209 | dumper.emit(event) 210 | finally: 211 | dumper.dispose() 212 | if getvalue: 213 | return getvalue() 214 | 215 | def serialize_all(nodes, stream=None, Dumper=Dumper, 216 | canonical=None, indent=None, width=None, 217 | allow_unicode=None, line_break=None, 218 | encoding=None, explicit_start=None, explicit_end=None, 219 | version=None, tags=None): 220 | """ 221 | Serialize a sequence of representation trees into a YAML stream. 222 | If stream is None, return the produced string instead. 223 | """ 224 | getvalue = None 225 | if stream is None: 226 | if encoding is None: 227 | stream = io.StringIO() 228 | else: 229 | stream = io.BytesIO() 230 | getvalue = stream.getvalue 231 | dumper = Dumper(stream, canonical=canonical, indent=indent, width=width, 232 | allow_unicode=allow_unicode, line_break=line_break, 233 | encoding=encoding, version=version, tags=tags, 234 | explicit_start=explicit_start, explicit_end=explicit_end) 235 | try: 236 | dumper.open() 237 | for node in nodes: 238 | dumper.serialize(node) 239 | dumper.close() 240 | finally: 241 | dumper.dispose() 242 | if getvalue: 243 | return getvalue() 244 | 245 | def serialize(node, stream=None, Dumper=Dumper, **kwds): 246 | """ 247 | Serialize a representation tree into a YAML stream. 248 | If stream is None, return the produced string instead. 249 | """ 250 | return serialize_all([node], stream, Dumper=Dumper, **kwds) 251 | 252 | def dump_all(documents, stream=None, Dumper=Dumper, 253 | default_style=None, default_flow_style=False, 254 | canonical=None, indent=None, width=None, 255 | allow_unicode=None, line_break=None, 256 | encoding=None, explicit_start=None, explicit_end=None, 257 | version=None, tags=None, sort_keys=True): 258 | """ 259 | Serialize a sequence of Python objects into a YAML stream. 260 | If stream is None, return the produced string instead. 261 | """ 262 | getvalue = None 263 | if stream is None: 264 | if encoding is None: 265 | stream = io.StringIO() 266 | else: 267 | stream = io.BytesIO() 268 | getvalue = stream.getvalue 269 | dumper = Dumper(stream, default_style=default_style, 270 | default_flow_style=default_flow_style, 271 | canonical=canonical, indent=indent, width=width, 272 | allow_unicode=allow_unicode, line_break=line_break, 273 | encoding=encoding, version=version, tags=tags, 274 | explicit_start=explicit_start, explicit_end=explicit_end, sort_keys=sort_keys) 275 | try: 276 | dumper.open() 277 | for data in documents: 278 | dumper.represent(data) 279 | dumper.close() 280 | finally: 281 | dumper.dispose() 282 | if getvalue: 283 | return getvalue() 284 | 285 | def dump(data, stream=None, Dumper=Dumper, **kwds): 286 | """ 287 | Serialize a Python object into a YAML stream. 288 | If stream is None, return the produced string instead. 289 | """ 290 | return dump_all([data], stream, Dumper=Dumper, **kwds) 291 | 292 | def safe_dump_all(documents, stream=None, **kwds): 293 | """ 294 | Serialize a sequence of Python objects into a YAML stream. 295 | Produce only basic YAML tags. 296 | If stream is None, return the produced string instead. 297 | """ 298 | return dump_all(documents, stream, Dumper=SafeDumper, **kwds) 299 | 300 | def safe_dump(data, stream=None, **kwds): 301 | """ 302 | Serialize a Python object into a YAML stream. 303 | Produce only basic YAML tags. 304 | If stream is None, return the produced string instead. 305 | """ 306 | return dump_all([data], stream, Dumper=SafeDumper, **kwds) 307 | 308 | def add_implicit_resolver(tag, regexp, first=None, 309 | Loader=Loader, Dumper=Dumper): 310 | """ 311 | Add an implicit scalar detector. 312 | If an implicit scalar value matches the given regexp, 313 | the corresponding tag is assigned to the scalar. 314 | first is a sequence of possible initial characters or None. 315 | """ 316 | Loader.add_implicit_resolver(tag, regexp, first) 317 | Dumper.add_implicit_resolver(tag, regexp, first) 318 | 319 | def add_path_resolver(tag, path, kind=None, Loader=Loader, Dumper=Dumper): 320 | """ 321 | Add a path based resolver for the given tag. 322 | A path is a list of keys that forms a path 323 | to a node in the representation tree. 324 | Keys can be string values, integers, or None. 325 | """ 326 | Loader.add_path_resolver(tag, path, kind) 327 | Dumper.add_path_resolver(tag, path, kind) 328 | 329 | def add_constructor(tag, constructor, Loader=Loader): 330 | """ 331 | Add a constructor for the given tag. 332 | Constructor is a function that accepts a Loader instance 333 | and a node object and produces the corresponding Python object. 334 | """ 335 | Loader.add_constructor(tag, constructor) 336 | 337 | def add_multi_constructor(tag_prefix, multi_constructor, Loader=Loader): 338 | """ 339 | Add a multi-constructor for the given tag prefix. 340 | Multi-constructor is called for a node if its tag starts with tag_prefix. 341 | Multi-constructor accepts a Loader instance, a tag suffix, 342 | and a node object and produces the corresponding Python object. 343 | """ 344 | Loader.add_multi_constructor(tag_prefix, multi_constructor) 345 | 346 | def add_representer(data_type, representer, Dumper=Dumper): 347 | """ 348 | Add a representer for the given type. 349 | Representer is a function accepting a Dumper instance 350 | and an instance of the given data type 351 | and producing the corresponding representation node. 352 | """ 353 | Dumper.add_representer(data_type, representer) 354 | 355 | def add_multi_representer(data_type, multi_representer, Dumper=Dumper): 356 | """ 357 | Add a representer for the given type. 358 | Multi-representer is a function accepting a Dumper instance 359 | and an instance of the given data type or subtype 360 | and producing the corresponding representation node. 361 | """ 362 | Dumper.add_multi_representer(data_type, multi_representer) 363 | 364 | class YAMLObjectMetaclass(type): 365 | """ 366 | The metaclass for YAMLObject. 367 | """ 368 | def __init__(cls, name, bases, kwds): 369 | super(YAMLObjectMetaclass, cls).__init__(name, bases, kwds) 370 | if 'yaml_tag' in kwds and kwds['yaml_tag'] is not None: 371 | cls.yaml_loader.add_constructor(cls.yaml_tag, cls.from_yaml) 372 | cls.yaml_dumper.add_representer(cls, cls.to_yaml) 373 | 374 | class YAMLObject(metaclass=YAMLObjectMetaclass): 375 | """ 376 | An object that can dump itself to a YAML stream 377 | and load itself from a YAML stream. 378 | """ 379 | 380 | __slots__ = () # no direct instantiation, so allow immutable subclasses 381 | 382 | yaml_loader = Loader 383 | yaml_dumper = Dumper 384 | 385 | yaml_tag = None 386 | yaml_flow_style = None 387 | 388 | @classmethod 389 | def from_yaml(cls, loader, node): 390 | """ 391 | Convert a representation node to a Python object. 392 | """ 393 | return loader.construct_yaml_object(node, cls) 394 | 395 | @classmethod 396 | def to_yaml(cls, dumper, data): 397 | """ 398 | Convert a Python object to a representation node. 399 | """ 400 | return dumper.represent_yaml_object(cls.yaml_tag, data, cls, 401 | flow_style=cls.yaml_flow_style) 402 | 403 | -------------------------------------------------------------------------------- /yaml/composer.py: -------------------------------------------------------------------------------- 1 | 2 | __all__ = ['Composer', 'ComposerError'] 3 | 4 | from .error import MarkedYAMLError 5 | from .events import * 6 | from .nodes import * 7 | 8 | class ComposerError(MarkedYAMLError): 9 | pass 10 | 11 | class Composer: 12 | 13 | def __init__(self): 14 | self.anchors = {} 15 | 16 | def check_node(self): 17 | # Drop the STREAM-START event. 18 | if self.check_event(StreamStartEvent): 19 | self.get_event() 20 | 21 | # If there are more documents available? 22 | return not self.check_event(StreamEndEvent) 23 | 24 | def get_node(self): 25 | # Get the root node of the next document. 26 | if not self.check_event(StreamEndEvent): 27 | return self.compose_document() 28 | 29 | def get_single_node(self): 30 | # Drop the STREAM-START event. 31 | self.get_event() 32 | 33 | # Compose a document if the stream is not empty. 34 | document = None 35 | if not self.check_event(StreamEndEvent): 36 | document = self.compose_document() 37 | 38 | # Ensure that the stream contains no more documents. 39 | if not self.check_event(StreamEndEvent): 40 | event = self.get_event() 41 | raise ComposerError("expected a single document in the stream", 42 | document.start_mark, "but found another document", 43 | event.start_mark) 44 | 45 | # Drop the STREAM-END event. 46 | self.get_event() 47 | 48 | return document 49 | 50 | def compose_document(self): 51 | # Drop the DOCUMENT-START event. 52 | self.get_event() 53 | 54 | # Compose the root node. 55 | node = self.compose_node(None, None) 56 | 57 | # Drop the DOCUMENT-END event. 58 | self.get_event() 59 | 60 | self.anchors = {} 61 | return node 62 | 63 | def compose_node(self, parent, index): 64 | if self.check_event(AliasEvent): 65 | event = self.get_event() 66 | anchor = event.anchor 67 | if anchor not in self.anchors: 68 | raise ComposerError(None, None, "found undefined alias %r" 69 | % anchor, event.start_mark) 70 | return self.anchors[anchor] 71 | event = self.peek_event() 72 | anchor = event.anchor 73 | if anchor is not None: 74 | if anchor in self.anchors: 75 | raise ComposerError("found duplicate anchor %r; first occurrence" 76 | % anchor, self.anchors[anchor].start_mark, 77 | "second occurrence", event.start_mark) 78 | self.descend_resolver(parent, index) 79 | if self.check_event(ScalarEvent): 80 | node = self.compose_scalar_node(anchor) 81 | elif self.check_event(SequenceStartEvent): 82 | node = self.compose_sequence_node(anchor) 83 | elif self.check_event(MappingStartEvent): 84 | node = self.compose_mapping_node(anchor) 85 | self.ascend_resolver() 86 | return node 87 | 88 | def compose_scalar_node(self, anchor): 89 | event = self.get_event() 90 | tag = event.tag 91 | if tag is None or tag == '!': 92 | tag = self.resolve(ScalarNode, event.value, event.implicit) 93 | node = ScalarNode(tag, event.value, 94 | event.start_mark, event.end_mark, style=event.style) 95 | if anchor is not None: 96 | self.anchors[anchor] = node 97 | return node 98 | 99 | def compose_sequence_node(self, anchor): 100 | start_event = self.get_event() 101 | tag = start_event.tag 102 | if tag is None or tag == '!': 103 | tag = self.resolve(SequenceNode, None, start_event.implicit) 104 | node = SequenceNode(tag, [], 105 | start_event.start_mark, None, 106 | flow_style=start_event.flow_style) 107 | if anchor is not None: 108 | self.anchors[anchor] = node 109 | index = 0 110 | while not self.check_event(SequenceEndEvent): 111 | node.value.append(self.compose_node(node, index)) 112 | index += 1 113 | end_event = self.get_event() 114 | node.end_mark = end_event.end_mark 115 | return node 116 | 117 | def compose_mapping_node(self, anchor): 118 | start_event = self.get_event() 119 | tag = start_event.tag 120 | if tag is None or tag == '!': 121 | tag = self.resolve(MappingNode, None, start_event.implicit) 122 | node = MappingNode(tag, [], 123 | start_event.start_mark, None, 124 | flow_style=start_event.flow_style) 125 | if anchor is not None: 126 | self.anchors[anchor] = node 127 | while not self.check_event(MappingEndEvent): 128 | #key_event = self.peek_event() 129 | item_key = self.compose_node(node, None) 130 | #if item_key in node.value: 131 | # raise ComposerError("while composing a mapping", start_event.start_mark, 132 | # "found duplicate key", key_event.start_mark) 133 | item_value = self.compose_node(node, item_key) 134 | #node.value[item_key] = item_value 135 | node.value.append((item_key, item_value)) 136 | end_event = self.get_event() 137 | node.end_mark = end_event.end_mark 138 | return node 139 | 140 | -------------------------------------------------------------------------------- /yaml/cyaml.py: -------------------------------------------------------------------------------- 1 | 2 | __all__ = [ 3 | 'CBaseLoader', 'CSafeLoader', 'CFullLoader', 'CUnsafeLoader', 'CLoader', 4 | 'CBaseDumper', 'CSafeDumper', 'CDumper' 5 | ] 6 | 7 | from _yaml import CParser, CEmitter 8 | 9 | from .constructor import * 10 | 11 | from .serializer import * 12 | from .representer import * 13 | 14 | from .resolver import * 15 | 16 | class CBaseLoader(CParser, BaseConstructor, BaseResolver): 17 | 18 | def __init__(self, stream): 19 | CParser.__init__(self, stream) 20 | BaseConstructor.__init__(self) 21 | BaseResolver.__init__(self) 22 | 23 | class CSafeLoader(CParser, SafeConstructor, Resolver): 24 | 25 | def __init__(self, stream): 26 | CParser.__init__(self, stream) 27 | SafeConstructor.__init__(self) 28 | Resolver.__init__(self) 29 | 30 | class CFullLoader(CParser, FullConstructor, Resolver): 31 | 32 | def __init__(self, stream): 33 | CParser.__init__(self, stream) 34 | FullConstructor.__init__(self) 35 | Resolver.__init__(self) 36 | 37 | class CUnsafeLoader(CParser, UnsafeConstructor, Resolver): 38 | 39 | def __init__(self, stream): 40 | CParser.__init__(self, stream) 41 | UnsafeConstructor.__init__(self) 42 | Resolver.__init__(self) 43 | 44 | class CLoader(CParser, Constructor, Resolver): 45 | 46 | def __init__(self, stream): 47 | CParser.__init__(self, stream) 48 | Constructor.__init__(self) 49 | Resolver.__init__(self) 50 | 51 | class CBaseDumper(CEmitter, BaseRepresenter, BaseResolver): 52 | 53 | def __init__(self, stream, 54 | default_style=None, default_flow_style=False, 55 | canonical=None, indent=None, width=None, 56 | allow_unicode=None, line_break=None, 57 | encoding=None, explicit_start=None, explicit_end=None, 58 | version=None, tags=None, sort_keys=True): 59 | CEmitter.__init__(self, stream, canonical=canonical, 60 | indent=indent, width=width, encoding=encoding, 61 | allow_unicode=allow_unicode, line_break=line_break, 62 | explicit_start=explicit_start, explicit_end=explicit_end, 63 | version=version, tags=tags) 64 | Representer.__init__(self, default_style=default_style, 65 | default_flow_style=default_flow_style, sort_keys=sort_keys) 66 | Resolver.__init__(self) 67 | 68 | class CSafeDumper(CEmitter, SafeRepresenter, Resolver): 69 | 70 | def __init__(self, stream, 71 | default_style=None, default_flow_style=False, 72 | canonical=None, indent=None, width=None, 73 | allow_unicode=None, line_break=None, 74 | encoding=None, explicit_start=None, explicit_end=None, 75 | version=None, tags=None, sort_keys=True): 76 | CEmitter.__init__(self, stream, canonical=canonical, 77 | indent=indent, width=width, encoding=encoding, 78 | allow_unicode=allow_unicode, line_break=line_break, 79 | explicit_start=explicit_start, explicit_end=explicit_end, 80 | version=version, tags=tags) 81 | SafeRepresenter.__init__(self, default_style=default_style, 82 | default_flow_style=default_flow_style, sort_keys=sort_keys) 83 | Resolver.__init__(self) 84 | 85 | class CDumper(CEmitter, Serializer, Representer, Resolver): 86 | 87 | def __init__(self, stream, 88 | default_style=None, default_flow_style=False, 89 | canonical=None, indent=None, width=None, 90 | allow_unicode=None, line_break=None, 91 | encoding=None, explicit_start=None, explicit_end=None, 92 | version=None, tags=None, sort_keys=True): 93 | CEmitter.__init__(self, stream, canonical=canonical, 94 | indent=indent, width=width, encoding=encoding, 95 | allow_unicode=allow_unicode, line_break=line_break, 96 | explicit_start=explicit_start, explicit_end=explicit_end, 97 | version=version, tags=tags) 98 | Representer.__init__(self, default_style=default_style, 99 | default_flow_style=default_flow_style, sort_keys=sort_keys) 100 | Resolver.__init__(self) 101 | 102 | -------------------------------------------------------------------------------- /yaml/dumper.py: -------------------------------------------------------------------------------- 1 | 2 | __all__ = ['BaseDumper', 'SafeDumper', 'Dumper'] 3 | 4 | from .emitter import * 5 | from .serializer import * 6 | from .representer import * 7 | from .resolver import * 8 | 9 | class BaseDumper(Emitter, Serializer, BaseRepresenter, BaseResolver): 10 | 11 | def __init__(self, stream, 12 | default_style=None, default_flow_style=False, 13 | canonical=None, indent=None, width=None, 14 | allow_unicode=None, line_break=None, 15 | encoding=None, explicit_start=None, explicit_end=None, 16 | version=None, tags=None, sort_keys=True): 17 | Emitter.__init__(self, stream, canonical=canonical, 18 | indent=indent, width=width, 19 | allow_unicode=allow_unicode, line_break=line_break) 20 | Serializer.__init__(self, encoding=encoding, 21 | explicit_start=explicit_start, explicit_end=explicit_end, 22 | version=version, tags=tags) 23 | Representer.__init__(self, default_style=default_style, 24 | default_flow_style=default_flow_style, sort_keys=sort_keys) 25 | Resolver.__init__(self) 26 | 27 | class SafeDumper(Emitter, Serializer, SafeRepresenter, Resolver): 28 | 29 | def __init__(self, stream, 30 | default_style=None, default_flow_style=False, 31 | canonical=None, indent=None, width=None, 32 | allow_unicode=None, line_break=None, 33 | encoding=None, explicit_start=None, explicit_end=None, 34 | version=None, tags=None, sort_keys=True): 35 | Emitter.__init__(self, stream, canonical=canonical, 36 | indent=indent, width=width, 37 | allow_unicode=allow_unicode, line_break=line_break) 38 | Serializer.__init__(self, encoding=encoding, 39 | explicit_start=explicit_start, explicit_end=explicit_end, 40 | version=version, tags=tags) 41 | SafeRepresenter.__init__(self, default_style=default_style, 42 | default_flow_style=default_flow_style, sort_keys=sort_keys) 43 | Resolver.__init__(self) 44 | 45 | class Dumper(Emitter, Serializer, Representer, Resolver): 46 | 47 | def __init__(self, stream, 48 | default_style=None, default_flow_style=False, 49 | canonical=None, indent=None, width=None, 50 | allow_unicode=None, line_break=None, 51 | encoding=None, explicit_start=None, explicit_end=None, 52 | version=None, tags=None, sort_keys=True): 53 | Emitter.__init__(self, stream, canonical=canonical, 54 | indent=indent, width=width, 55 | allow_unicode=allow_unicode, line_break=line_break) 56 | Serializer.__init__(self, encoding=encoding, 57 | explicit_start=explicit_start, explicit_end=explicit_end, 58 | version=version, tags=tags) 59 | Representer.__init__(self, default_style=default_style, 60 | default_flow_style=default_flow_style, sort_keys=sort_keys) 61 | Resolver.__init__(self) 62 | 63 | -------------------------------------------------------------------------------- /yaml/error.py: -------------------------------------------------------------------------------- 1 | 2 | __all__ = ['Mark', 'YAMLError', 'MarkedYAMLError'] 3 | 4 | class Mark: 5 | 6 | def __init__(self, name, index, line, column, buffer, pointer): 7 | self.name = name 8 | self.index = index 9 | self.line = line 10 | self.column = column 11 | self.buffer = buffer 12 | self.pointer = pointer 13 | 14 | def get_snippet(self, indent=4, max_length=75): 15 | if self.buffer is None: 16 | return None 17 | head = '' 18 | start = self.pointer 19 | while start > 0 and self.buffer[start-1] not in '\0\r\n\x85\u2028\u2029': 20 | start -= 1 21 | if self.pointer-start > max_length/2-1: 22 | head = ' ... ' 23 | start += 5 24 | break 25 | tail = '' 26 | end = self.pointer 27 | while end < len(self.buffer) and self.buffer[end] not in '\0\r\n\x85\u2028\u2029': 28 | end += 1 29 | if end-self.pointer > max_length/2-1: 30 | tail = ' ... ' 31 | end -= 5 32 | break 33 | snippet = self.buffer[start:end] 34 | return ' '*indent + head + snippet + tail + '\n' \ 35 | + ' '*(indent+self.pointer-start+len(head)) + '^' 36 | 37 | def __str__(self): 38 | snippet = self.get_snippet() 39 | where = " in \"%s\", line %d, column %d" \ 40 | % (self.name, self.line+1, self.column+1) 41 | if snippet is not None: 42 | where += ":\n"+snippet 43 | return where 44 | 45 | class YAMLError(Exception): 46 | pass 47 | 48 | class MarkedYAMLError(YAMLError): 49 | 50 | def __init__(self, context=None, context_mark=None, 51 | problem=None, problem_mark=None, note=None): 52 | self.context = context 53 | self.context_mark = context_mark 54 | self.problem = problem 55 | self.problem_mark = problem_mark 56 | self.note = note 57 | 58 | def __str__(self): 59 | lines = [] 60 | if self.context is not None: 61 | lines.append(self.context) 62 | if self.context_mark is not None \ 63 | and (self.problem is None or self.problem_mark is None 64 | or self.context_mark.name != self.problem_mark.name 65 | or self.context_mark.line != self.problem_mark.line 66 | or self.context_mark.column != self.problem_mark.column): 67 | lines.append(str(self.context_mark)) 68 | if self.problem is not None: 69 | lines.append(self.problem) 70 | if self.problem_mark is not None: 71 | lines.append(str(self.problem_mark)) 72 | if self.note is not None: 73 | lines.append(self.note) 74 | return '\n'.join(lines) 75 | 76 | -------------------------------------------------------------------------------- /yaml/events.py: -------------------------------------------------------------------------------- 1 | 2 | # Abstract classes. 3 | 4 | class Event(object): 5 | def __init__(self, start_mark=None, end_mark=None): 6 | self.start_mark = start_mark 7 | self.end_mark = end_mark 8 | def __repr__(self): 9 | attributes = [key for key in ['anchor', 'tag', 'implicit', 'value'] 10 | if hasattr(self, key)] 11 | arguments = ', '.join(['%s=%r' % (key, getattr(self, key)) 12 | for key in attributes]) 13 | return '%s(%s)' % (self.__class__.__name__, arguments) 14 | 15 | class NodeEvent(Event): 16 | def __init__(self, anchor, start_mark=None, end_mark=None): 17 | self.anchor = anchor 18 | self.start_mark = start_mark 19 | self.end_mark = end_mark 20 | 21 | class CollectionStartEvent(NodeEvent): 22 | def __init__(self, anchor, tag, implicit, start_mark=None, end_mark=None, 23 | flow_style=None): 24 | self.anchor = anchor 25 | self.tag = tag 26 | self.implicit = implicit 27 | self.start_mark = start_mark 28 | self.end_mark = end_mark 29 | self.flow_style = flow_style 30 | 31 | class CollectionEndEvent(Event): 32 | pass 33 | 34 | # Implementations. 35 | 36 | class StreamStartEvent(Event): 37 | def __init__(self, start_mark=None, end_mark=None, encoding=None): 38 | self.start_mark = start_mark 39 | self.end_mark = end_mark 40 | self.encoding = encoding 41 | 42 | class StreamEndEvent(Event): 43 | pass 44 | 45 | class DocumentStartEvent(Event): 46 | def __init__(self, start_mark=None, end_mark=None, 47 | explicit=None, version=None, tags=None): 48 | self.start_mark = start_mark 49 | self.end_mark = end_mark 50 | self.explicit = explicit 51 | self.version = version 52 | self.tags = tags 53 | 54 | class DocumentEndEvent(Event): 55 | def __init__(self, start_mark=None, end_mark=None, 56 | explicit=None): 57 | self.start_mark = start_mark 58 | self.end_mark = end_mark 59 | self.explicit = explicit 60 | 61 | class AliasEvent(NodeEvent): 62 | pass 63 | 64 | class ScalarEvent(NodeEvent): 65 | def __init__(self, anchor, tag, implicit, value, 66 | start_mark=None, end_mark=None, style=None): 67 | self.anchor = anchor 68 | self.tag = tag 69 | self.implicit = implicit 70 | self.value = value 71 | self.start_mark = start_mark 72 | self.end_mark = end_mark 73 | self.style = style 74 | 75 | class SequenceStartEvent(CollectionStartEvent): 76 | pass 77 | 78 | class SequenceEndEvent(CollectionEndEvent): 79 | pass 80 | 81 | class MappingStartEvent(CollectionStartEvent): 82 | pass 83 | 84 | class MappingEndEvent(CollectionEndEvent): 85 | pass 86 | 87 | -------------------------------------------------------------------------------- /yaml/loader.py: -------------------------------------------------------------------------------- 1 | 2 | __all__ = ['BaseLoader', 'FullLoader', 'SafeLoader', 'Loader', 'UnsafeLoader'] 3 | 4 | from .reader import * 5 | from .scanner import * 6 | from .parser import * 7 | from .composer import * 8 | from .constructor import * 9 | from .resolver import * 10 | 11 | class BaseLoader(Reader, Scanner, Parser, Composer, BaseConstructor, BaseResolver): 12 | 13 | def __init__(self, stream): 14 | Reader.__init__(self, stream) 15 | Scanner.__init__(self) 16 | Parser.__init__(self) 17 | Composer.__init__(self) 18 | BaseConstructor.__init__(self) 19 | BaseResolver.__init__(self) 20 | 21 | class FullLoader(Reader, Scanner, Parser, Composer, FullConstructor, Resolver): 22 | 23 | def __init__(self, stream): 24 | Reader.__init__(self, stream) 25 | Scanner.__init__(self) 26 | Parser.__init__(self) 27 | Composer.__init__(self) 28 | FullConstructor.__init__(self) 29 | Resolver.__init__(self) 30 | 31 | class SafeLoader(Reader, Scanner, Parser, Composer, SafeConstructor, Resolver): 32 | 33 | def __init__(self, stream): 34 | Reader.__init__(self, stream) 35 | Scanner.__init__(self) 36 | Parser.__init__(self) 37 | Composer.__init__(self) 38 | SafeConstructor.__init__(self) 39 | Resolver.__init__(self) 40 | 41 | class Loader(Reader, Scanner, Parser, Composer, Constructor, Resolver): 42 | 43 | def __init__(self, stream): 44 | Reader.__init__(self, stream) 45 | Scanner.__init__(self) 46 | Parser.__init__(self) 47 | Composer.__init__(self) 48 | Constructor.__init__(self) 49 | Resolver.__init__(self) 50 | 51 | # UnsafeLoader is the same as Loader (which is and was always unsafe on 52 | # untrusted input). Use of either Loader or UnsafeLoader should be rare, since 53 | # FullLoad should be able to load almost all YAML safely. Loader is left intact 54 | # to ensure backwards compatability. 55 | class UnsafeLoader(Reader, Scanner, Parser, Composer, Constructor, Resolver): 56 | 57 | def __init__(self, stream): 58 | Reader.__init__(self, stream) 59 | Scanner.__init__(self) 60 | Parser.__init__(self) 61 | Composer.__init__(self) 62 | Constructor.__init__(self) 63 | Resolver.__init__(self) 64 | -------------------------------------------------------------------------------- /yaml/nodes.py: -------------------------------------------------------------------------------- 1 | 2 | class Node(object): 3 | def __init__(self, tag, value, start_mark, end_mark): 4 | self.tag = tag 5 | self.value = value 6 | self.start_mark = start_mark 7 | self.end_mark = end_mark 8 | def __repr__(self): 9 | value = self.value 10 | #if isinstance(value, list): 11 | # if len(value) == 0: 12 | # value = '' 13 | # elif len(value) == 1: 14 | # value = '<1 item>' 15 | # else: 16 | # value = '<%d items>' % len(value) 17 | #else: 18 | # if len(value) > 75: 19 | # value = repr(value[:70]+u' ... ') 20 | # else: 21 | # value = repr(value) 22 | value = repr(value) 23 | return '%s(tag=%r, value=%s)' % (self.__class__.__name__, self.tag, value) 24 | 25 | class ScalarNode(Node): 26 | id = 'scalar' 27 | def __init__(self, tag, value, 28 | start_mark=None, end_mark=None, style=None): 29 | self.tag = tag 30 | self.value = value 31 | self.start_mark = start_mark 32 | self.end_mark = end_mark 33 | self.style = style 34 | 35 | class CollectionNode(Node): 36 | def __init__(self, tag, value, 37 | start_mark=None, end_mark=None, flow_style=None): 38 | self.tag = tag 39 | self.value = value 40 | self.start_mark = start_mark 41 | self.end_mark = end_mark 42 | self.flow_style = flow_style 43 | 44 | class SequenceNode(CollectionNode): 45 | id = 'sequence' 46 | 47 | class MappingNode(CollectionNode): 48 | id = 'mapping' 49 | 50 | -------------------------------------------------------------------------------- /yaml/parser.py: -------------------------------------------------------------------------------- 1 | 2 | # The following YAML grammar is LL(1) and is parsed by a recursive descent 3 | # parser. 4 | # 5 | # stream ::= STREAM-START implicit_document? explicit_document* STREAM-END 6 | # implicit_document ::= block_node DOCUMENT-END* 7 | # explicit_document ::= DIRECTIVE* DOCUMENT-START block_node? DOCUMENT-END* 8 | # block_node_or_indentless_sequence ::= 9 | # ALIAS 10 | # | properties (block_content | indentless_block_sequence)? 11 | # | block_content 12 | # | indentless_block_sequence 13 | # block_node ::= ALIAS 14 | # | properties block_content? 15 | # | block_content 16 | # flow_node ::= ALIAS 17 | # | properties flow_content? 18 | # | flow_content 19 | # properties ::= TAG ANCHOR? | ANCHOR TAG? 20 | # block_content ::= block_collection | flow_collection | SCALAR 21 | # flow_content ::= flow_collection | SCALAR 22 | # block_collection ::= block_sequence | block_mapping 23 | # flow_collection ::= flow_sequence | flow_mapping 24 | # block_sequence ::= BLOCK-SEQUENCE-START (BLOCK-ENTRY block_node?)* BLOCK-END 25 | # indentless_sequence ::= (BLOCK-ENTRY block_node?)+ 26 | # block_mapping ::= BLOCK-MAPPING_START 27 | # ((KEY block_node_or_indentless_sequence?)? 28 | # (VALUE block_node_or_indentless_sequence?)?)* 29 | # BLOCK-END 30 | # flow_sequence ::= FLOW-SEQUENCE-START 31 | # (flow_sequence_entry FLOW-ENTRY)* 32 | # flow_sequence_entry? 33 | # FLOW-SEQUENCE-END 34 | # flow_sequence_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)? 35 | # flow_mapping ::= FLOW-MAPPING-START 36 | # (flow_mapping_entry FLOW-ENTRY)* 37 | # flow_mapping_entry? 38 | # FLOW-MAPPING-END 39 | # flow_mapping_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)? 40 | # 41 | # FIRST sets: 42 | # 43 | # stream: { STREAM-START } 44 | # explicit_document: { DIRECTIVE DOCUMENT-START } 45 | # implicit_document: FIRST(block_node) 46 | # block_node: { ALIAS TAG ANCHOR SCALAR BLOCK-SEQUENCE-START BLOCK-MAPPING-START FLOW-SEQUENCE-START FLOW-MAPPING-START } 47 | # flow_node: { ALIAS ANCHOR TAG SCALAR FLOW-SEQUENCE-START FLOW-MAPPING-START } 48 | # block_content: { BLOCK-SEQUENCE-START BLOCK-MAPPING-START FLOW-SEQUENCE-START FLOW-MAPPING-START SCALAR } 49 | # flow_content: { FLOW-SEQUENCE-START FLOW-MAPPING-START SCALAR } 50 | # block_collection: { BLOCK-SEQUENCE-START BLOCK-MAPPING-START } 51 | # flow_collection: { FLOW-SEQUENCE-START FLOW-MAPPING-START } 52 | # block_sequence: { BLOCK-SEQUENCE-START } 53 | # block_mapping: { BLOCK-MAPPING-START } 54 | # block_node_or_indentless_sequence: { ALIAS ANCHOR TAG SCALAR BLOCK-SEQUENCE-START BLOCK-MAPPING-START FLOW-SEQUENCE-START FLOW-MAPPING-START BLOCK-ENTRY } 55 | # indentless_sequence: { ENTRY } 56 | # flow_collection: { FLOW-SEQUENCE-START FLOW-MAPPING-START } 57 | # flow_sequence: { FLOW-SEQUENCE-START } 58 | # flow_mapping: { FLOW-MAPPING-START } 59 | # flow_sequence_entry: { ALIAS ANCHOR TAG SCALAR FLOW-SEQUENCE-START FLOW-MAPPING-START KEY } 60 | # flow_mapping_entry: { ALIAS ANCHOR TAG SCALAR FLOW-SEQUENCE-START FLOW-MAPPING-START KEY } 61 | 62 | __all__ = ['Parser', 'ParserError'] 63 | 64 | from .error import MarkedYAMLError 65 | from .tokens import * 66 | from .events import * 67 | from .scanner import * 68 | 69 | class ParserError(MarkedYAMLError): 70 | pass 71 | 72 | class Parser: 73 | # Since writing a recursive-descendant parser is a straightforward task, we 74 | # do not give many comments here. 75 | 76 | DEFAULT_TAGS = { 77 | '!': '!', 78 | '!!': 'tag:yaml.org,2002:', 79 | } 80 | 81 | def __init__(self): 82 | self.current_event = None 83 | self.yaml_version = None 84 | self.tag_handles = {} 85 | self.states = [] 86 | self.marks = [] 87 | self.state = self.parse_stream_start 88 | 89 | def dispose(self): 90 | # Reset the state attributes (to clear self-references) 91 | self.states = [] 92 | self.state = None 93 | 94 | def check_event(self, *choices): 95 | # Check the type of the next event. 96 | if self.current_event is None: 97 | if self.state: 98 | self.current_event = self.state() 99 | if self.current_event is not None: 100 | if not choices: 101 | return True 102 | for choice in choices: 103 | if isinstance(self.current_event, choice): 104 | return True 105 | return False 106 | 107 | def peek_event(self): 108 | # Get the next event. 109 | if self.current_event is None: 110 | if self.state: 111 | self.current_event = self.state() 112 | return self.current_event 113 | 114 | def get_event(self): 115 | # Get the next event and proceed further. 116 | if self.current_event is None: 117 | if self.state: 118 | self.current_event = self.state() 119 | value = self.current_event 120 | self.current_event = None 121 | return value 122 | 123 | # stream ::= STREAM-START implicit_document? explicit_document* STREAM-END 124 | # implicit_document ::= block_node DOCUMENT-END* 125 | # explicit_document ::= DIRECTIVE* DOCUMENT-START block_node? DOCUMENT-END* 126 | 127 | def parse_stream_start(self): 128 | 129 | # Parse the stream start. 130 | token = self.get_token() 131 | event = StreamStartEvent(token.start_mark, token.end_mark, 132 | encoding=token.encoding) 133 | 134 | # Prepare the next state. 135 | self.state = self.parse_implicit_document_start 136 | 137 | return event 138 | 139 | def parse_implicit_document_start(self): 140 | 141 | # Parse an implicit document. 142 | if not self.check_token(DirectiveToken, DocumentStartToken, 143 | StreamEndToken): 144 | self.tag_handles = self.DEFAULT_TAGS 145 | token = self.peek_token() 146 | start_mark = end_mark = token.start_mark 147 | event = DocumentStartEvent(start_mark, end_mark, 148 | explicit=False) 149 | 150 | # Prepare the next state. 151 | self.states.append(self.parse_document_end) 152 | self.state = self.parse_block_node 153 | 154 | return event 155 | 156 | else: 157 | return self.parse_document_start() 158 | 159 | def parse_document_start(self): 160 | 161 | # Parse any extra document end indicators. 162 | while self.check_token(DocumentEndToken): 163 | self.get_token() 164 | 165 | # Parse an explicit document. 166 | if not self.check_token(StreamEndToken): 167 | token = self.peek_token() 168 | start_mark = token.start_mark 169 | version, tags = self.process_directives() 170 | if not self.check_token(DocumentStartToken): 171 | raise ParserError(None, None, 172 | "expected '', but found %r" 173 | % self.peek_token().id, 174 | self.peek_token().start_mark) 175 | token = self.get_token() 176 | end_mark = token.end_mark 177 | event = DocumentStartEvent(start_mark, end_mark, 178 | explicit=True, version=version, tags=tags) 179 | self.states.append(self.parse_document_end) 180 | self.state = self.parse_document_content 181 | else: 182 | # Parse the end of the stream. 183 | token = self.get_token() 184 | event = StreamEndEvent(token.start_mark, token.end_mark) 185 | assert not self.states 186 | assert not self.marks 187 | self.state = None 188 | return event 189 | 190 | def parse_document_end(self): 191 | 192 | # Parse the document end. 193 | token = self.peek_token() 194 | start_mark = end_mark = token.start_mark 195 | explicit = False 196 | if self.check_token(DocumentEndToken): 197 | token = self.get_token() 198 | end_mark = token.end_mark 199 | explicit = True 200 | event = DocumentEndEvent(start_mark, end_mark, 201 | explicit=explicit) 202 | 203 | # Prepare the next state. 204 | self.state = self.parse_document_start 205 | 206 | return event 207 | 208 | def parse_document_content(self): 209 | if self.check_token(DirectiveToken, 210 | DocumentStartToken, DocumentEndToken, StreamEndToken): 211 | event = self.process_empty_scalar(self.peek_token().start_mark) 212 | self.state = self.states.pop() 213 | return event 214 | else: 215 | return self.parse_block_node() 216 | 217 | def process_directives(self): 218 | self.yaml_version = None 219 | self.tag_handles = {} 220 | while self.check_token(DirectiveToken): 221 | token = self.get_token() 222 | if token.name == 'YAML': 223 | if self.yaml_version is not None: 224 | raise ParserError(None, None, 225 | "found duplicate YAML directive", token.start_mark) 226 | major, minor = token.value 227 | if major != 1: 228 | raise ParserError(None, None, 229 | "found incompatible YAML document (version 1.* is required)", 230 | token.start_mark) 231 | self.yaml_version = token.value 232 | elif token.name == 'TAG': 233 | handle, prefix = token.value 234 | if handle in self.tag_handles: 235 | raise ParserError(None, None, 236 | "duplicate tag handle %r" % handle, 237 | token.start_mark) 238 | self.tag_handles[handle] = prefix 239 | if self.tag_handles: 240 | value = self.yaml_version, self.tag_handles.copy() 241 | else: 242 | value = self.yaml_version, None 243 | for key in self.DEFAULT_TAGS: 244 | if key not in self.tag_handles: 245 | self.tag_handles[key] = self.DEFAULT_TAGS[key] 246 | return value 247 | 248 | # block_node_or_indentless_sequence ::= ALIAS 249 | # | properties (block_content | indentless_block_sequence)? 250 | # | block_content 251 | # | indentless_block_sequence 252 | # block_node ::= ALIAS 253 | # | properties block_content? 254 | # | block_content 255 | # flow_node ::= ALIAS 256 | # | properties flow_content? 257 | # | flow_content 258 | # properties ::= TAG ANCHOR? | ANCHOR TAG? 259 | # block_content ::= block_collection | flow_collection | SCALAR 260 | # flow_content ::= flow_collection | SCALAR 261 | # block_collection ::= block_sequence | block_mapping 262 | # flow_collection ::= flow_sequence | flow_mapping 263 | 264 | def parse_block_node(self): 265 | return self.parse_node(block=True) 266 | 267 | def parse_flow_node(self): 268 | return self.parse_node() 269 | 270 | def parse_block_node_or_indentless_sequence(self): 271 | return self.parse_node(block=True, indentless_sequence=True) 272 | 273 | def parse_node(self, block=False, indentless_sequence=False): 274 | if self.check_token(AliasToken): 275 | token = self.get_token() 276 | event = AliasEvent(token.value, token.start_mark, token.end_mark) 277 | self.state = self.states.pop() 278 | else: 279 | anchor = None 280 | tag = None 281 | start_mark = end_mark = tag_mark = None 282 | if self.check_token(AnchorToken): 283 | token = self.get_token() 284 | start_mark = token.start_mark 285 | end_mark = token.end_mark 286 | anchor = token.value 287 | if self.check_token(TagToken): 288 | token = self.get_token() 289 | tag_mark = token.start_mark 290 | end_mark = token.end_mark 291 | tag = token.value 292 | elif self.check_token(TagToken): 293 | token = self.get_token() 294 | start_mark = tag_mark = token.start_mark 295 | end_mark = token.end_mark 296 | tag = token.value 297 | if self.check_token(AnchorToken): 298 | token = self.get_token() 299 | end_mark = token.end_mark 300 | anchor = token.value 301 | if tag is not None: 302 | handle, suffix = tag 303 | if handle is not None: 304 | if handle not in self.tag_handles: 305 | raise ParserError("while parsing a node", start_mark, 306 | "found undefined tag handle %r" % handle, 307 | tag_mark) 308 | tag = self.tag_handles[handle]+suffix 309 | else: 310 | tag = suffix 311 | #if tag == '!': 312 | # raise ParserError("while parsing a node", start_mark, 313 | # "found non-specific tag '!'", tag_mark, 314 | # "Please check 'http://pyyaml.org/wiki/YAMLNonSpecificTag' and share your opinion.") 315 | if start_mark is None: 316 | start_mark = end_mark = self.peek_token().start_mark 317 | event = None 318 | implicit = (tag is None or tag == '!') 319 | if indentless_sequence and self.check_token(BlockEntryToken): 320 | end_mark = self.peek_token().end_mark 321 | event = SequenceStartEvent(anchor, tag, implicit, 322 | start_mark, end_mark) 323 | self.state = self.parse_indentless_sequence_entry 324 | else: 325 | if self.check_token(ScalarToken): 326 | token = self.get_token() 327 | end_mark = token.end_mark 328 | if (token.plain and tag is None) or tag == '!': 329 | implicit = (True, False) 330 | elif tag is None: 331 | implicit = (False, True) 332 | else: 333 | implicit = (False, False) 334 | event = ScalarEvent(anchor, tag, implicit, token.value, 335 | start_mark, end_mark, style=token.style) 336 | self.state = self.states.pop() 337 | elif self.check_token(FlowSequenceStartToken): 338 | end_mark = self.peek_token().end_mark 339 | event = SequenceStartEvent(anchor, tag, implicit, 340 | start_mark, end_mark, flow_style=True) 341 | self.state = self.parse_flow_sequence_first_entry 342 | elif self.check_token(FlowMappingStartToken): 343 | end_mark = self.peek_token().end_mark 344 | event = MappingStartEvent(anchor, tag, implicit, 345 | start_mark, end_mark, flow_style=True) 346 | self.state = self.parse_flow_mapping_first_key 347 | elif block and self.check_token(BlockSequenceStartToken): 348 | end_mark = self.peek_token().start_mark 349 | event = SequenceStartEvent(anchor, tag, implicit, 350 | start_mark, end_mark, flow_style=False) 351 | self.state = self.parse_block_sequence_first_entry 352 | elif block and self.check_token(BlockMappingStartToken): 353 | end_mark = self.peek_token().start_mark 354 | event = MappingStartEvent(anchor, tag, implicit, 355 | start_mark, end_mark, flow_style=False) 356 | self.state = self.parse_block_mapping_first_key 357 | elif anchor is not None or tag is not None: 358 | # Empty scalars are allowed even if a tag or an anchor is 359 | # specified. 360 | event = ScalarEvent(anchor, tag, (implicit, False), '', 361 | start_mark, end_mark) 362 | self.state = self.states.pop() 363 | else: 364 | if block: 365 | node = 'block' 366 | else: 367 | node = 'flow' 368 | token = self.peek_token() 369 | raise ParserError("while parsing a %s node" % node, start_mark, 370 | "expected the node content, but found %r" % token.id, 371 | token.start_mark) 372 | return event 373 | 374 | # block_sequence ::= BLOCK-SEQUENCE-START (BLOCK-ENTRY block_node?)* BLOCK-END 375 | 376 | def parse_block_sequence_first_entry(self): 377 | token = self.get_token() 378 | self.marks.append(token.start_mark) 379 | return self.parse_block_sequence_entry() 380 | 381 | def parse_block_sequence_entry(self): 382 | if self.check_token(BlockEntryToken): 383 | token = self.get_token() 384 | if not self.check_token(BlockEntryToken, BlockEndToken): 385 | self.states.append(self.parse_block_sequence_entry) 386 | return self.parse_block_node() 387 | else: 388 | self.state = self.parse_block_sequence_entry 389 | return self.process_empty_scalar(token.end_mark) 390 | if not self.check_token(BlockEndToken): 391 | token = self.peek_token() 392 | raise ParserError("while parsing a block collection", self.marks[-1], 393 | "expected , but found %r" % token.id, token.start_mark) 394 | token = self.get_token() 395 | event = SequenceEndEvent(token.start_mark, token.end_mark) 396 | self.state = self.states.pop() 397 | self.marks.pop() 398 | return event 399 | 400 | # indentless_sequence ::= (BLOCK-ENTRY block_node?)+ 401 | 402 | def parse_indentless_sequence_entry(self): 403 | if self.check_token(BlockEntryToken): 404 | token = self.get_token() 405 | if not self.check_token(BlockEntryToken, 406 | KeyToken, ValueToken, BlockEndToken): 407 | self.states.append(self.parse_indentless_sequence_entry) 408 | return self.parse_block_node() 409 | else: 410 | self.state = self.parse_indentless_sequence_entry 411 | return self.process_empty_scalar(token.end_mark) 412 | token = self.peek_token() 413 | event = SequenceEndEvent(token.start_mark, token.start_mark) 414 | self.state = self.states.pop() 415 | return event 416 | 417 | # block_mapping ::= BLOCK-MAPPING_START 418 | # ((KEY block_node_or_indentless_sequence?)? 419 | # (VALUE block_node_or_indentless_sequence?)?)* 420 | # BLOCK-END 421 | 422 | def parse_block_mapping_first_key(self): 423 | token = self.get_token() 424 | self.marks.append(token.start_mark) 425 | return self.parse_block_mapping_key() 426 | 427 | def parse_block_mapping_key(self): 428 | if self.check_token(KeyToken): 429 | token = self.get_token() 430 | if not self.check_token(KeyToken, ValueToken, BlockEndToken): 431 | self.states.append(self.parse_block_mapping_value) 432 | return self.parse_block_node_or_indentless_sequence() 433 | else: 434 | self.state = self.parse_block_mapping_value 435 | return self.process_empty_scalar(token.end_mark) 436 | if not self.check_token(BlockEndToken): 437 | token = self.peek_token() 438 | raise ParserError("while parsing a block mapping", self.marks[-1], 439 | "expected , but found %r" % token.id, token.start_mark) 440 | token = self.get_token() 441 | event = MappingEndEvent(token.start_mark, token.end_mark) 442 | self.state = self.states.pop() 443 | self.marks.pop() 444 | return event 445 | 446 | def parse_block_mapping_value(self): 447 | if self.check_token(ValueToken): 448 | token = self.get_token() 449 | if not self.check_token(KeyToken, ValueToken, BlockEndToken): 450 | self.states.append(self.parse_block_mapping_key) 451 | return self.parse_block_node_or_indentless_sequence() 452 | else: 453 | self.state = self.parse_block_mapping_key 454 | return self.process_empty_scalar(token.end_mark) 455 | else: 456 | self.state = self.parse_block_mapping_key 457 | token = self.peek_token() 458 | return self.process_empty_scalar(token.start_mark) 459 | 460 | # flow_sequence ::= FLOW-SEQUENCE-START 461 | # (flow_sequence_entry FLOW-ENTRY)* 462 | # flow_sequence_entry? 463 | # FLOW-SEQUENCE-END 464 | # flow_sequence_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)? 465 | # 466 | # Note that while production rules for both flow_sequence_entry and 467 | # flow_mapping_entry are equal, their interpretations are different. 468 | # For `flow_sequence_entry`, the part `KEY flow_node? (VALUE flow_node?)?` 469 | # generate an inline mapping (set syntax). 470 | 471 | def parse_flow_sequence_first_entry(self): 472 | token = self.get_token() 473 | self.marks.append(token.start_mark) 474 | return self.parse_flow_sequence_entry(first=True) 475 | 476 | def parse_flow_sequence_entry(self, first=False): 477 | if not self.check_token(FlowSequenceEndToken): 478 | if not first: 479 | if self.check_token(FlowEntryToken): 480 | self.get_token() 481 | else: 482 | token = self.peek_token() 483 | raise ParserError("while parsing a flow sequence", self.marks[-1], 484 | "expected ',' or ']', but got %r" % token.id, token.start_mark) 485 | 486 | if self.check_token(KeyToken): 487 | token = self.peek_token() 488 | event = MappingStartEvent(None, None, True, 489 | token.start_mark, token.end_mark, 490 | flow_style=True) 491 | self.state = self.parse_flow_sequence_entry_mapping_key 492 | return event 493 | elif not self.check_token(FlowSequenceEndToken): 494 | self.states.append(self.parse_flow_sequence_entry) 495 | return self.parse_flow_node() 496 | token = self.get_token() 497 | event = SequenceEndEvent(token.start_mark, token.end_mark) 498 | self.state = self.states.pop() 499 | self.marks.pop() 500 | return event 501 | 502 | def parse_flow_sequence_entry_mapping_key(self): 503 | token = self.get_token() 504 | if not self.check_token(ValueToken, 505 | FlowEntryToken, FlowSequenceEndToken): 506 | self.states.append(self.parse_flow_sequence_entry_mapping_value) 507 | return self.parse_flow_node() 508 | else: 509 | self.state = self.parse_flow_sequence_entry_mapping_value 510 | return self.process_empty_scalar(token.end_mark) 511 | 512 | def parse_flow_sequence_entry_mapping_value(self): 513 | if self.check_token(ValueToken): 514 | token = self.get_token() 515 | if not self.check_token(FlowEntryToken, FlowSequenceEndToken): 516 | self.states.append(self.parse_flow_sequence_entry_mapping_end) 517 | return self.parse_flow_node() 518 | else: 519 | self.state = self.parse_flow_sequence_entry_mapping_end 520 | return self.process_empty_scalar(token.end_mark) 521 | else: 522 | self.state = self.parse_flow_sequence_entry_mapping_end 523 | token = self.peek_token() 524 | return self.process_empty_scalar(token.start_mark) 525 | 526 | def parse_flow_sequence_entry_mapping_end(self): 527 | self.state = self.parse_flow_sequence_entry 528 | token = self.peek_token() 529 | return MappingEndEvent(token.start_mark, token.start_mark) 530 | 531 | # flow_mapping ::= FLOW-MAPPING-START 532 | # (flow_mapping_entry FLOW-ENTRY)* 533 | # flow_mapping_entry? 534 | # FLOW-MAPPING-END 535 | # flow_mapping_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)? 536 | 537 | def parse_flow_mapping_first_key(self): 538 | token = self.get_token() 539 | self.marks.append(token.start_mark) 540 | return self.parse_flow_mapping_key(first=True) 541 | 542 | def parse_flow_mapping_key(self, first=False): 543 | if not self.check_token(FlowMappingEndToken): 544 | if not first: 545 | if self.check_token(FlowEntryToken): 546 | self.get_token() 547 | else: 548 | token = self.peek_token() 549 | raise ParserError("while parsing a flow mapping", self.marks[-1], 550 | "expected ',' or '}', but got %r" % token.id, token.start_mark) 551 | if self.check_token(KeyToken): 552 | token = self.get_token() 553 | if not self.check_token(ValueToken, 554 | FlowEntryToken, FlowMappingEndToken): 555 | self.states.append(self.parse_flow_mapping_value) 556 | return self.parse_flow_node() 557 | else: 558 | self.state = self.parse_flow_mapping_value 559 | return self.process_empty_scalar(token.end_mark) 560 | elif not self.check_token(FlowMappingEndToken): 561 | self.states.append(self.parse_flow_mapping_empty_value) 562 | return self.parse_flow_node() 563 | token = self.get_token() 564 | event = MappingEndEvent(token.start_mark, token.end_mark) 565 | self.state = self.states.pop() 566 | self.marks.pop() 567 | return event 568 | 569 | def parse_flow_mapping_value(self): 570 | if self.check_token(ValueToken): 571 | token = self.get_token() 572 | if not self.check_token(FlowEntryToken, FlowMappingEndToken): 573 | self.states.append(self.parse_flow_mapping_key) 574 | return self.parse_flow_node() 575 | else: 576 | self.state = self.parse_flow_mapping_key 577 | return self.process_empty_scalar(token.end_mark) 578 | else: 579 | self.state = self.parse_flow_mapping_key 580 | token = self.peek_token() 581 | return self.process_empty_scalar(token.start_mark) 582 | 583 | def parse_flow_mapping_empty_value(self): 584 | self.state = self.parse_flow_mapping_key 585 | return self.process_empty_scalar(self.peek_token().start_mark) 586 | 587 | def process_empty_scalar(self, mark): 588 | return ScalarEvent(None, None, (True, False), '', mark, mark) 589 | 590 | -------------------------------------------------------------------------------- /yaml/reader.py: -------------------------------------------------------------------------------- 1 | # This module contains abstractions for the input stream. You don't have to 2 | # looks further, there are no pretty code. 3 | # 4 | # We define two classes here. 5 | # 6 | # Mark(source, line, column) 7 | # It's just a record and its only use is producing nice error messages. 8 | # Parser does not use it for any other purposes. 9 | # 10 | # Reader(source, data) 11 | # Reader determines the encoding of `data` and converts it to unicode. 12 | # Reader provides the following methods and attributes: 13 | # reader.peek(length=1) - return the next `length` characters 14 | # reader.forward(length=1) - move the current position to `length` characters. 15 | # reader.index - the number of the current character. 16 | # reader.line, stream.column - the line and the column of the current character. 17 | 18 | __all__ = ['Reader', 'ReaderError'] 19 | 20 | from .error import YAMLError, Mark 21 | 22 | import codecs, re 23 | 24 | class ReaderError(YAMLError): 25 | 26 | def __init__(self, name, position, character, encoding, reason): 27 | self.name = name 28 | self.character = character 29 | self.position = position 30 | self.encoding = encoding 31 | self.reason = reason 32 | 33 | def __str__(self): 34 | if isinstance(self.character, bytes): 35 | return "'%s' codec can't decode byte #x%02x: %s\n" \ 36 | " in \"%s\", position %d" \ 37 | % (self.encoding, ord(self.character), self.reason, 38 | self.name, self.position) 39 | else: 40 | return "unacceptable character #x%04x: %s\n" \ 41 | " in \"%s\", position %d" \ 42 | % (self.character, self.reason, 43 | self.name, self.position) 44 | 45 | class Reader(object): 46 | # Reader: 47 | # - determines the data encoding and converts it to a unicode string, 48 | # - checks if characters are in allowed range, 49 | # - adds '\0' to the end. 50 | 51 | # Reader accepts 52 | # - a `bytes` object, 53 | # - a `str` object, 54 | # - a file-like object with its `read` method returning `str`, 55 | # - a file-like object with its `read` method returning `unicode`. 56 | 57 | # Yeah, it's ugly and slow. 58 | 59 | def __init__(self, stream): 60 | self.name = None 61 | self.stream = None 62 | self.stream_pointer = 0 63 | self.eof = True 64 | self.buffer = '' 65 | self.pointer = 0 66 | self.raw_buffer = None 67 | self.raw_decode = None 68 | self.encoding = None 69 | self.index = 0 70 | self.line = 0 71 | self.column = 0 72 | if isinstance(stream, str): 73 | self.name = "" 74 | self.check_printable(stream) 75 | self.buffer = stream+'\0' 76 | elif isinstance(stream, bytes): 77 | self.name = "" 78 | self.raw_buffer = stream 79 | self.determine_encoding() 80 | else: 81 | self.stream = stream 82 | self.name = getattr(stream, 'name', "") 83 | self.eof = False 84 | self.raw_buffer = None 85 | self.determine_encoding() 86 | 87 | def peek(self, index=0): 88 | try: 89 | return self.buffer[self.pointer+index] 90 | except IndexError: 91 | self.update(index+1) 92 | return self.buffer[self.pointer+index] 93 | 94 | def prefix(self, length=1): 95 | if self.pointer+length >= len(self.buffer): 96 | self.update(length) 97 | return self.buffer[self.pointer:self.pointer+length] 98 | 99 | def forward(self, length=1): 100 | if self.pointer+length+1 >= len(self.buffer): 101 | self.update(length+1) 102 | while length: 103 | ch = self.buffer[self.pointer] 104 | self.pointer += 1 105 | self.index += 1 106 | if ch in '\n\x85\u2028\u2029' \ 107 | or (ch == '\r' and self.buffer[self.pointer] != '\n'): 108 | self.line += 1 109 | self.column = 0 110 | elif ch != '\uFEFF': 111 | self.column += 1 112 | length -= 1 113 | 114 | def get_mark(self): 115 | if self.stream is None: 116 | return Mark(self.name, self.index, self.line, self.column, 117 | self.buffer, self.pointer) 118 | else: 119 | return Mark(self.name, self.index, self.line, self.column, 120 | None, None) 121 | 122 | def determine_encoding(self): 123 | while not self.eof and (self.raw_buffer is None or len(self.raw_buffer) < 2): 124 | self.update_raw() 125 | if isinstance(self.raw_buffer, bytes): 126 | if self.raw_buffer.startswith(codecs.BOM_UTF16_LE): 127 | self.raw_decode = codecs.utf_16_le_decode 128 | self.encoding = 'utf-16-le' 129 | elif self.raw_buffer.startswith(codecs.BOM_UTF16_BE): 130 | self.raw_decode = codecs.utf_16_be_decode 131 | self.encoding = 'utf-16-be' 132 | else: 133 | self.raw_decode = codecs.utf_8_decode 134 | self.encoding = 'utf-8' 135 | self.update(1) 136 | 137 | NON_PRINTABLE = re.compile('[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]') 138 | def check_printable(self, data): 139 | match = self.NON_PRINTABLE.search(data) 140 | if match: 141 | character = match.group() 142 | position = self.index+(len(self.buffer)-self.pointer)+match.start() 143 | raise ReaderError(self.name, position, ord(character), 144 | 'unicode', "special characters are not allowed") 145 | 146 | def update(self, length): 147 | if self.raw_buffer is None: 148 | return 149 | self.buffer = self.buffer[self.pointer:] 150 | self.pointer = 0 151 | while len(self.buffer) < length: 152 | if not self.eof: 153 | self.update_raw() 154 | if self.raw_decode is not None: 155 | try: 156 | data, converted = self.raw_decode(self.raw_buffer, 157 | 'strict', self.eof) 158 | except UnicodeDecodeError as exc: 159 | character = self.raw_buffer[exc.start] 160 | if self.stream is not None: 161 | position = self.stream_pointer-len(self.raw_buffer)+exc.start 162 | else: 163 | position = exc.start 164 | raise ReaderError(self.name, position, character, 165 | exc.encoding, exc.reason) 166 | else: 167 | data = self.raw_buffer 168 | converted = len(data) 169 | self.check_printable(data) 170 | self.buffer += data 171 | self.raw_buffer = self.raw_buffer[converted:] 172 | if self.eof: 173 | self.buffer += '\0' 174 | self.raw_buffer = None 175 | break 176 | 177 | def update_raw(self, size=4096): 178 | data = self.stream.read(size) 179 | if self.raw_buffer is None: 180 | self.raw_buffer = data 181 | else: 182 | self.raw_buffer += data 183 | self.stream_pointer += len(data) 184 | if not data: 185 | self.eof = True 186 | -------------------------------------------------------------------------------- /yaml/representer.py: -------------------------------------------------------------------------------- 1 | 2 | __all__ = ['BaseRepresenter', 'SafeRepresenter', 'Representer', 3 | 'RepresenterError'] 4 | 5 | from .error import * 6 | from .nodes import * 7 | 8 | import datetime, sys, copyreg, types, base64, collections 9 | 10 | class RepresenterError(YAMLError): 11 | pass 12 | 13 | class BaseRepresenter: 14 | 15 | yaml_representers = {} 16 | yaml_multi_representers = {} 17 | 18 | def __init__(self, default_style=None, default_flow_style=False, sort_keys=True): 19 | self.default_style = default_style 20 | self.sort_keys = sort_keys 21 | self.default_flow_style = default_flow_style 22 | self.represented_objects = {} 23 | self.object_keeper = [] 24 | self.alias_key = None 25 | 26 | def represent(self, data): 27 | node = self.represent_data(data) 28 | self.serialize(node) 29 | self.represented_objects = {} 30 | self.object_keeper = [] 31 | self.alias_key = None 32 | 33 | def represent_data(self, data): 34 | if self.ignore_aliases(data): 35 | self.alias_key = None 36 | else: 37 | self.alias_key = id(data) 38 | if self.alias_key is not None: 39 | if self.alias_key in self.represented_objects: 40 | node = self.represented_objects[self.alias_key] 41 | #if node is None: 42 | # raise RepresenterError("recursive objects are not allowed: %r" % data) 43 | return node 44 | #self.represented_objects[alias_key] = None 45 | self.object_keeper.append(data) 46 | data_types = type(data).__mro__ 47 | if data_types[0] in self.yaml_representers: 48 | node = self.yaml_representers[data_types[0]](self, data) 49 | else: 50 | for data_type in data_types: 51 | if data_type in self.yaml_multi_representers: 52 | node = self.yaml_multi_representers[data_type](self, data) 53 | break 54 | else: 55 | if None in self.yaml_multi_representers: 56 | node = self.yaml_multi_representers[None](self, data) 57 | elif None in self.yaml_representers: 58 | node = self.yaml_representers[None](self, data) 59 | else: 60 | node = ScalarNode(None, str(data)) 61 | #if alias_key is not None: 62 | # self.represented_objects[alias_key] = node 63 | return node 64 | 65 | @classmethod 66 | def add_representer(cls, data_type, representer): 67 | if not 'yaml_representers' in cls.__dict__: 68 | cls.yaml_representers = cls.yaml_representers.copy() 69 | cls.yaml_representers[data_type] = representer 70 | 71 | @classmethod 72 | def add_multi_representer(cls, data_type, representer): 73 | if not 'yaml_multi_representers' in cls.__dict__: 74 | cls.yaml_multi_representers = cls.yaml_multi_representers.copy() 75 | cls.yaml_multi_representers[data_type] = representer 76 | 77 | def represent_scalar(self, tag, value, style=None): 78 | if style is None: 79 | style = self.default_style 80 | node = ScalarNode(tag, value, style=style) 81 | if self.alias_key is not None: 82 | self.represented_objects[self.alias_key] = node 83 | return node 84 | 85 | def represent_sequence(self, tag, sequence, flow_style=None): 86 | value = [] 87 | node = SequenceNode(tag, value, flow_style=flow_style) 88 | if self.alias_key is not None: 89 | self.represented_objects[self.alias_key] = node 90 | best_style = True 91 | for item in sequence: 92 | node_item = self.represent_data(item) 93 | if not (isinstance(node_item, ScalarNode) and not node_item.style): 94 | best_style = False 95 | value.append(node_item) 96 | if flow_style is None: 97 | if self.default_flow_style is not None: 98 | node.flow_style = self.default_flow_style 99 | else: 100 | node.flow_style = best_style 101 | return node 102 | 103 | def represent_mapping(self, tag, mapping, flow_style=None): 104 | value = [] 105 | node = MappingNode(tag, value, flow_style=flow_style) 106 | if self.alias_key is not None: 107 | self.represented_objects[self.alias_key] = node 108 | best_style = True 109 | if hasattr(mapping, 'items'): 110 | mapping = list(mapping.items()) 111 | if self.sort_keys: 112 | try: 113 | mapping = sorted(mapping) 114 | except TypeError: 115 | pass 116 | for item_key, item_value in mapping: 117 | node_key = self.represent_data(item_key) 118 | node_value = self.represent_data(item_value) 119 | if not (isinstance(node_key, ScalarNode) and not node_key.style): 120 | best_style = False 121 | if not (isinstance(node_value, ScalarNode) and not node_value.style): 122 | best_style = False 123 | value.append((node_key, node_value)) 124 | if flow_style is None: 125 | if self.default_flow_style is not None: 126 | node.flow_style = self.default_flow_style 127 | else: 128 | node.flow_style = best_style 129 | return node 130 | 131 | def ignore_aliases(self, data): 132 | return False 133 | 134 | class SafeRepresenter(BaseRepresenter): 135 | 136 | def ignore_aliases(self, data): 137 | if data is None: 138 | return True 139 | if isinstance(data, tuple) and data == (): 140 | return True 141 | if isinstance(data, (str, bytes, bool, int, float)): 142 | return True 143 | 144 | def represent_none(self, data): 145 | return self.represent_scalar('tag:yaml.org,2002:null', 'null') 146 | 147 | def represent_str(self, data): 148 | return self.represent_scalar('tag:yaml.org,2002:str', data) 149 | 150 | def represent_binary(self, data): 151 | if hasattr(base64, 'encodebytes'): 152 | data = base64.encodebytes(data).decode('ascii') 153 | else: 154 | data = base64.encodestring(data).decode('ascii') 155 | return self.represent_scalar('tag:yaml.org,2002:binary', data, style='|') 156 | 157 | def represent_bool(self, data): 158 | if data: 159 | value = 'true' 160 | else: 161 | value = 'false' 162 | return self.represent_scalar('tag:yaml.org,2002:bool', value) 163 | 164 | def represent_int(self, data): 165 | return self.represent_scalar('tag:yaml.org,2002:int', str(data)) 166 | 167 | inf_value = 1e300 168 | while repr(inf_value) != repr(inf_value*inf_value): 169 | inf_value *= inf_value 170 | 171 | def represent_float(self, data): 172 | if data != data or (data == 0.0 and data == 1.0): 173 | value = '.nan' 174 | elif data == self.inf_value: 175 | value = '.inf' 176 | elif data == -self.inf_value: 177 | value = '-.inf' 178 | else: 179 | value = repr(data).lower() 180 | # Note that in some cases `repr(data)` represents a float number 181 | # without the decimal parts. For instance: 182 | # >>> repr(1e17) 183 | # '1e17' 184 | # Unfortunately, this is not a valid float representation according 185 | # to the definition of the `!!float` tag. We fix this by adding 186 | # '.0' before the 'e' symbol. 187 | if '.' not in value and 'e' in value: 188 | value = value.replace('e', '.0e', 1) 189 | return self.represent_scalar('tag:yaml.org,2002:float', value) 190 | 191 | def represent_list(self, data): 192 | #pairs = (len(data) > 0 and isinstance(data, list)) 193 | #if pairs: 194 | # for item in data: 195 | # if not isinstance(item, tuple) or len(item) != 2: 196 | # pairs = False 197 | # break 198 | #if not pairs: 199 | return self.represent_sequence('tag:yaml.org,2002:seq', data) 200 | #value = [] 201 | #for item_key, item_value in data: 202 | # value.append(self.represent_mapping(u'tag:yaml.org,2002:map', 203 | # [(item_key, item_value)])) 204 | #return SequenceNode(u'tag:yaml.org,2002:pairs', value) 205 | 206 | def represent_dict(self, data): 207 | return self.represent_mapping('tag:yaml.org,2002:map', data) 208 | 209 | def represent_set(self, data): 210 | value = {} 211 | for key in data: 212 | value[key] = None 213 | return self.represent_mapping('tag:yaml.org,2002:set', value) 214 | 215 | def represent_date(self, data): 216 | value = data.isoformat() 217 | return self.represent_scalar('tag:yaml.org,2002:timestamp', value) 218 | 219 | def represent_datetime(self, data): 220 | value = data.isoformat(' ') 221 | return self.represent_scalar('tag:yaml.org,2002:timestamp', value) 222 | 223 | def represent_yaml_object(self, tag, data, cls, flow_style=None): 224 | if hasattr(data, '__getstate__'): 225 | state = data.__getstate__() 226 | else: 227 | state = data.__dict__.copy() 228 | return self.represent_mapping(tag, state, flow_style=flow_style) 229 | 230 | def represent_undefined(self, data): 231 | raise RepresenterError("cannot represent an object", data) 232 | 233 | SafeRepresenter.add_representer(type(None), 234 | SafeRepresenter.represent_none) 235 | 236 | SafeRepresenter.add_representer(str, 237 | SafeRepresenter.represent_str) 238 | 239 | SafeRepresenter.add_representer(bytes, 240 | SafeRepresenter.represent_binary) 241 | 242 | SafeRepresenter.add_representer(bool, 243 | SafeRepresenter.represent_bool) 244 | 245 | SafeRepresenter.add_representer(int, 246 | SafeRepresenter.represent_int) 247 | 248 | SafeRepresenter.add_representer(float, 249 | SafeRepresenter.represent_float) 250 | 251 | SafeRepresenter.add_representer(list, 252 | SafeRepresenter.represent_list) 253 | 254 | SafeRepresenter.add_representer(tuple, 255 | SafeRepresenter.represent_list) 256 | 257 | SafeRepresenter.add_representer(dict, 258 | SafeRepresenter.represent_dict) 259 | 260 | SafeRepresenter.add_representer(set, 261 | SafeRepresenter.represent_set) 262 | 263 | SafeRepresenter.add_representer(datetime.date, 264 | SafeRepresenter.represent_date) 265 | 266 | SafeRepresenter.add_representer(datetime.datetime, 267 | SafeRepresenter.represent_datetime) 268 | 269 | SafeRepresenter.add_representer(None, 270 | SafeRepresenter.represent_undefined) 271 | 272 | class Representer(SafeRepresenter): 273 | 274 | def represent_complex(self, data): 275 | if data.imag == 0.0: 276 | data = '%r' % data.real 277 | elif data.real == 0.0: 278 | data = '%rj' % data.imag 279 | elif data.imag > 0: 280 | data = '%r+%rj' % (data.real, data.imag) 281 | else: 282 | data = '%r%rj' % (data.real, data.imag) 283 | return self.represent_scalar('tag:yaml.org,2002:python/complex', data) 284 | 285 | def represent_tuple(self, data): 286 | return self.represent_sequence('tag:yaml.org,2002:python/tuple', data) 287 | 288 | def represent_name(self, data): 289 | name = '%s.%s' % (data.__module__, data.__name__) 290 | return self.represent_scalar('tag:yaml.org,2002:python/name:'+name, '') 291 | 292 | def represent_module(self, data): 293 | return self.represent_scalar( 294 | 'tag:yaml.org,2002:python/module:'+data.__name__, '') 295 | 296 | def represent_object(self, data): 297 | # We use __reduce__ API to save the data. data.__reduce__ returns 298 | # a tuple of length 2-5: 299 | # (function, args, state, listitems, dictitems) 300 | 301 | # For reconstructing, we calls function(*args), then set its state, 302 | # listitems, and dictitems if they are not None. 303 | 304 | # A special case is when function.__name__ == '__newobj__'. In this 305 | # case we create the object with args[0].__new__(*args). 306 | 307 | # Another special case is when __reduce__ returns a string - we don't 308 | # support it. 309 | 310 | # We produce a !!python/object, !!python/object/new or 311 | # !!python/object/apply node. 312 | 313 | cls = type(data) 314 | if cls in copyreg.dispatch_table: 315 | reduce = copyreg.dispatch_table[cls](data) 316 | elif hasattr(data, '__reduce_ex__'): 317 | reduce = data.__reduce_ex__(2) 318 | elif hasattr(data, '__reduce__'): 319 | reduce = data.__reduce__() 320 | else: 321 | raise RepresenterError("cannot represent an object", data) 322 | reduce = (list(reduce)+[None]*5)[:5] 323 | function, args, state, listitems, dictitems = reduce 324 | args = list(args) 325 | if state is None: 326 | state = {} 327 | if listitems is not None: 328 | listitems = list(listitems) 329 | if dictitems is not None: 330 | dictitems = dict(dictitems) 331 | if function.__name__ == '__newobj__': 332 | function = args[0] 333 | args = args[1:] 334 | tag = 'tag:yaml.org,2002:python/object/new:' 335 | newobj = True 336 | else: 337 | tag = 'tag:yaml.org,2002:python/object/apply:' 338 | newobj = False 339 | function_name = '%s.%s' % (function.__module__, function.__name__) 340 | if not args and not listitems and not dictitems \ 341 | and isinstance(state, dict) and newobj: 342 | return self.represent_mapping( 343 | 'tag:yaml.org,2002:python/object:'+function_name, state) 344 | if not listitems and not dictitems \ 345 | and isinstance(state, dict) and not state: 346 | return self.represent_sequence(tag+function_name, args) 347 | value = {} 348 | if args: 349 | value['args'] = args 350 | if state or not isinstance(state, dict): 351 | value['state'] = state 352 | if listitems: 353 | value['listitems'] = listitems 354 | if dictitems: 355 | value['dictitems'] = dictitems 356 | return self.represent_mapping(tag+function_name, value) 357 | 358 | def represent_ordered_dict(self, data): 359 | # Provide uniform representation across different Python versions. 360 | data_type = type(data) 361 | tag = 'tag:yaml.org,2002:python/object/apply:%s.%s' \ 362 | % (data_type.__module__, data_type.__name__) 363 | items = [[key, value] for key, value in data.items()] 364 | return self.represent_sequence(tag, [items]) 365 | 366 | Representer.add_representer(complex, 367 | Representer.represent_complex) 368 | 369 | Representer.add_representer(tuple, 370 | Representer.represent_tuple) 371 | 372 | Representer.add_representer(type, 373 | Representer.represent_name) 374 | 375 | Representer.add_representer(collections.OrderedDict, 376 | Representer.represent_ordered_dict) 377 | 378 | Representer.add_representer(types.FunctionType, 379 | Representer.represent_name) 380 | 381 | Representer.add_representer(types.BuiltinFunctionType, 382 | Representer.represent_name) 383 | 384 | Representer.add_representer(types.ModuleType, 385 | Representer.represent_module) 386 | 387 | Representer.add_multi_representer(object, 388 | Representer.represent_object) 389 | 390 | -------------------------------------------------------------------------------- /yaml/resolver.py: -------------------------------------------------------------------------------- 1 | 2 | __all__ = ['BaseResolver', 'Resolver'] 3 | 4 | from .error import * 5 | from .nodes import * 6 | 7 | import re 8 | 9 | class ResolverError(YAMLError): 10 | pass 11 | 12 | class BaseResolver: 13 | 14 | DEFAULT_SCALAR_TAG = 'tag:yaml.org,2002:str' 15 | DEFAULT_SEQUENCE_TAG = 'tag:yaml.org,2002:seq' 16 | DEFAULT_MAPPING_TAG = 'tag:yaml.org,2002:map' 17 | 18 | yaml_implicit_resolvers = {} 19 | yaml_path_resolvers = {} 20 | 21 | def __init__(self): 22 | self.resolver_exact_paths = [] 23 | self.resolver_prefix_paths = [] 24 | 25 | @classmethod 26 | def add_implicit_resolver(cls, tag, regexp, first): 27 | if not 'yaml_implicit_resolvers' in cls.__dict__: 28 | implicit_resolvers = {} 29 | for key in cls.yaml_implicit_resolvers: 30 | implicit_resolvers[key] = cls.yaml_implicit_resolvers[key][:] 31 | cls.yaml_implicit_resolvers = implicit_resolvers 32 | if first is None: 33 | first = [None] 34 | for ch in first: 35 | cls.yaml_implicit_resolvers.setdefault(ch, []).append((tag, regexp)) 36 | 37 | @classmethod 38 | def add_path_resolver(cls, tag, path, kind=None): 39 | # Note: `add_path_resolver` is experimental. The API could be changed. 40 | # `new_path` is a pattern that is matched against the path from the 41 | # root to the node that is being considered. `node_path` elements are 42 | # tuples `(node_check, index_check)`. `node_check` is a node class: 43 | # `ScalarNode`, `SequenceNode`, `MappingNode` or `None`. `None` 44 | # matches any kind of a node. `index_check` could be `None`, a boolean 45 | # value, a string value, or a number. `None` and `False` match against 46 | # any _value_ of sequence and mapping nodes. `True` matches against 47 | # any _key_ of a mapping node. A string `index_check` matches against 48 | # a mapping value that corresponds to a scalar key which content is 49 | # equal to the `index_check` value. An integer `index_check` matches 50 | # against a sequence value with the index equal to `index_check`. 51 | if not 'yaml_path_resolvers' in cls.__dict__: 52 | cls.yaml_path_resolvers = cls.yaml_path_resolvers.copy() 53 | new_path = [] 54 | for element in path: 55 | if isinstance(element, (list, tuple)): 56 | if len(element) == 2: 57 | node_check, index_check = element 58 | elif len(element) == 1: 59 | node_check = element[0] 60 | index_check = True 61 | else: 62 | raise ResolverError("Invalid path element: %s" % element) 63 | else: 64 | node_check = None 65 | index_check = element 66 | if node_check is str: 67 | node_check = ScalarNode 68 | elif node_check is list: 69 | node_check = SequenceNode 70 | elif node_check is dict: 71 | node_check = MappingNode 72 | elif node_check not in [ScalarNode, SequenceNode, MappingNode] \ 73 | and not isinstance(node_check, str) \ 74 | and node_check is not None: 75 | raise ResolverError("Invalid node checker: %s" % node_check) 76 | if not isinstance(index_check, (str, int)) \ 77 | and index_check is not None: 78 | raise ResolverError("Invalid index checker: %s" % index_check) 79 | new_path.append((node_check, index_check)) 80 | if kind is str: 81 | kind = ScalarNode 82 | elif kind is list: 83 | kind = SequenceNode 84 | elif kind is dict: 85 | kind = MappingNode 86 | elif kind not in [ScalarNode, SequenceNode, MappingNode] \ 87 | and kind is not None: 88 | raise ResolverError("Invalid node kind: %s" % kind) 89 | cls.yaml_path_resolvers[tuple(new_path), kind] = tag 90 | 91 | def descend_resolver(self, current_node, current_index): 92 | if not self.yaml_path_resolvers: 93 | return 94 | exact_paths = {} 95 | prefix_paths = [] 96 | if current_node: 97 | depth = len(self.resolver_prefix_paths) 98 | for path, kind in self.resolver_prefix_paths[-1]: 99 | if self.check_resolver_prefix(depth, path, kind, 100 | current_node, current_index): 101 | if len(path) > depth: 102 | prefix_paths.append((path, kind)) 103 | else: 104 | exact_paths[kind] = self.yaml_path_resolvers[path, kind] 105 | else: 106 | for path, kind in self.yaml_path_resolvers: 107 | if not path: 108 | exact_paths[kind] = self.yaml_path_resolvers[path, kind] 109 | else: 110 | prefix_paths.append((path, kind)) 111 | self.resolver_exact_paths.append(exact_paths) 112 | self.resolver_prefix_paths.append(prefix_paths) 113 | 114 | def ascend_resolver(self): 115 | if not self.yaml_path_resolvers: 116 | return 117 | self.resolver_exact_paths.pop() 118 | self.resolver_prefix_paths.pop() 119 | 120 | def check_resolver_prefix(self, depth, path, kind, 121 | current_node, current_index): 122 | node_check, index_check = path[depth-1] 123 | if isinstance(node_check, str): 124 | if current_node.tag != node_check: 125 | return 126 | elif node_check is not None: 127 | if not isinstance(current_node, node_check): 128 | return 129 | if index_check is True and current_index is not None: 130 | return 131 | if (index_check is False or index_check is None) \ 132 | and current_index is None: 133 | return 134 | if isinstance(index_check, str): 135 | if not (isinstance(current_index, ScalarNode) 136 | and index_check == current_index.value): 137 | return 138 | elif isinstance(index_check, int) and not isinstance(index_check, bool): 139 | if index_check != current_index: 140 | return 141 | return True 142 | 143 | def resolve(self, kind, value, implicit): 144 | if kind is ScalarNode and implicit[0]: 145 | if value == '': 146 | resolvers = self.yaml_implicit_resolvers.get('', []) 147 | else: 148 | resolvers = self.yaml_implicit_resolvers.get(value[0], []) 149 | resolvers += self.yaml_implicit_resolvers.get(None, []) 150 | for tag, regexp in resolvers: 151 | if regexp.match(value): 152 | return tag 153 | implicit = implicit[1] 154 | if self.yaml_path_resolvers: 155 | exact_paths = self.resolver_exact_paths[-1] 156 | if kind in exact_paths: 157 | return exact_paths[kind] 158 | if None in exact_paths: 159 | return exact_paths[None] 160 | if kind is ScalarNode: 161 | return self.DEFAULT_SCALAR_TAG 162 | elif kind is SequenceNode: 163 | return self.DEFAULT_SEQUENCE_TAG 164 | elif kind is MappingNode: 165 | return self.DEFAULT_MAPPING_TAG 166 | 167 | class Resolver(BaseResolver): 168 | pass 169 | 170 | Resolver.add_implicit_resolver( 171 | 'tag:yaml.org,2002:bool', 172 | re.compile(r'''^(?:yes|Yes|YES|no|No|NO 173 | |true|True|TRUE|false|False|FALSE 174 | |on|On|ON|off|Off|OFF)$''', re.X), 175 | list('yYnNtTfFoO')) 176 | 177 | Resolver.add_implicit_resolver( 178 | 'tag:yaml.org,2002:float', 179 | re.compile(r'''^(?:[-+]?(?:[0-9][0-9_]*)\.[0-9_]*(?:[eE][-+][0-9]+)? 180 | |\.[0-9_]+(?:[eE][-+][0-9]+)? 181 | |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\.[0-9_]* 182 | |[-+]?\.(?:inf|Inf|INF) 183 | |\.(?:nan|NaN|NAN))$''', re.X), 184 | list('-+0123456789.')) 185 | 186 | Resolver.add_implicit_resolver( 187 | 'tag:yaml.org,2002:int', 188 | re.compile(r'''^(?:[-+]?0b[0-1_]+ 189 | |[-+]?0[0-7_]+ 190 | |[-+]?(?:0|[1-9][0-9_]*) 191 | |[-+]?0x[0-9a-fA-F_]+ 192 | |[-+]?[1-9][0-9_]*(?::[0-5]?[0-9])+)$''', re.X), 193 | list('-+0123456789')) 194 | 195 | Resolver.add_implicit_resolver( 196 | 'tag:yaml.org,2002:merge', 197 | re.compile(r'^(?:<<)$'), 198 | ['<']) 199 | 200 | Resolver.add_implicit_resolver( 201 | 'tag:yaml.org,2002:null', 202 | re.compile(r'''^(?: ~ 203 | |null|Null|NULL 204 | | )$''', re.X), 205 | ['~', 'n', 'N', '']) 206 | 207 | Resolver.add_implicit_resolver( 208 | 'tag:yaml.org,2002:timestamp', 209 | re.compile(r'''^(?:[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9] 210 | |[0-9][0-9][0-9][0-9] -[0-9][0-9]? -[0-9][0-9]? 211 | (?:[Tt]|[ \t]+)[0-9][0-9]? 212 | :[0-9][0-9] :[0-9][0-9] (?:\.[0-9]*)? 213 | (?:[ \t]*(?:Z|[-+][0-9][0-9]?(?::[0-9][0-9])?))?)$''', re.X), 214 | list('0123456789')) 215 | 216 | Resolver.add_implicit_resolver( 217 | 'tag:yaml.org,2002:value', 218 | re.compile(r'^(?:=)$'), 219 | ['=']) 220 | 221 | # The following resolver is only for documentation purposes. It cannot work 222 | # because plain scalars cannot start with '!', '&', or '*'. 223 | Resolver.add_implicit_resolver( 224 | 'tag:yaml.org,2002:yaml', 225 | re.compile(r'^(?:!|&|\*)$'), 226 | list('!&*')) 227 | 228 | -------------------------------------------------------------------------------- /yaml/serializer.py: -------------------------------------------------------------------------------- 1 | 2 | __all__ = ['Serializer', 'SerializerError'] 3 | 4 | from .error import YAMLError 5 | from .events import * 6 | from .nodes import * 7 | 8 | class SerializerError(YAMLError): 9 | pass 10 | 11 | class Serializer: 12 | 13 | ANCHOR_TEMPLATE = 'id%03d' 14 | 15 | def __init__(self, encoding=None, 16 | explicit_start=None, explicit_end=None, version=None, tags=None): 17 | self.use_encoding = encoding 18 | self.use_explicit_start = explicit_start 19 | self.use_explicit_end = explicit_end 20 | self.use_version = version 21 | self.use_tags = tags 22 | self.serialized_nodes = {} 23 | self.anchors = {} 24 | self.last_anchor_id = 0 25 | self.closed = None 26 | 27 | def open(self): 28 | if self.closed is None: 29 | self.emit(StreamStartEvent(encoding=self.use_encoding)) 30 | self.closed = False 31 | elif self.closed: 32 | raise SerializerError("serializer is closed") 33 | else: 34 | raise SerializerError("serializer is already opened") 35 | 36 | def close(self): 37 | if self.closed is None: 38 | raise SerializerError("serializer is not opened") 39 | elif not self.closed: 40 | self.emit(StreamEndEvent()) 41 | self.closed = True 42 | 43 | #def __del__(self): 44 | # self.close() 45 | 46 | def serialize(self, node): 47 | if self.closed is None: 48 | raise SerializerError("serializer is not opened") 49 | elif self.closed: 50 | raise SerializerError("serializer is closed") 51 | self.emit(DocumentStartEvent(explicit=self.use_explicit_start, 52 | version=self.use_version, tags=self.use_tags)) 53 | self.anchor_node(node) 54 | self.serialize_node(node, None, None) 55 | self.emit(DocumentEndEvent(explicit=self.use_explicit_end)) 56 | self.serialized_nodes = {} 57 | self.anchors = {} 58 | self.last_anchor_id = 0 59 | 60 | def anchor_node(self, node): 61 | if node in self.anchors: 62 | if self.anchors[node] is None: 63 | self.anchors[node] = self.generate_anchor(node) 64 | else: 65 | self.anchors[node] = None 66 | if isinstance(node, SequenceNode): 67 | for item in node.value: 68 | self.anchor_node(item) 69 | elif isinstance(node, MappingNode): 70 | for key, value in node.value: 71 | self.anchor_node(key) 72 | self.anchor_node(value) 73 | 74 | def generate_anchor(self, node): 75 | self.last_anchor_id += 1 76 | return self.ANCHOR_TEMPLATE % self.last_anchor_id 77 | 78 | def serialize_node(self, node, parent, index): 79 | alias = self.anchors[node] 80 | if node in self.serialized_nodes: 81 | self.emit(AliasEvent(alias)) 82 | else: 83 | self.serialized_nodes[node] = True 84 | self.descend_resolver(parent, index) 85 | if isinstance(node, ScalarNode): 86 | detected_tag = self.resolve(ScalarNode, node.value, (True, False)) 87 | default_tag = self.resolve(ScalarNode, node.value, (False, True)) 88 | implicit = (node.tag == detected_tag), (node.tag == default_tag) 89 | self.emit(ScalarEvent(alias, node.tag, implicit, node.value, 90 | style=node.style)) 91 | elif isinstance(node, SequenceNode): 92 | implicit = (node.tag 93 | == self.resolve(SequenceNode, node.value, True)) 94 | self.emit(SequenceStartEvent(alias, node.tag, implicit, 95 | flow_style=node.flow_style)) 96 | index = 0 97 | for item in node.value: 98 | self.serialize_node(item, node, index) 99 | index += 1 100 | self.emit(SequenceEndEvent()) 101 | elif isinstance(node, MappingNode): 102 | implicit = (node.tag 103 | == self.resolve(MappingNode, node.value, True)) 104 | self.emit(MappingStartEvent(alias, node.tag, implicit, 105 | flow_style=node.flow_style)) 106 | for key, value in node.value: 107 | self.serialize_node(key, node, None) 108 | self.serialize_node(value, node, key) 109 | self.emit(MappingEndEvent()) 110 | self.ascend_resolver() 111 | 112 | -------------------------------------------------------------------------------- /yaml/tokens.py: -------------------------------------------------------------------------------- 1 | 2 | class Token(object): 3 | def __init__(self, start_mark, end_mark): 4 | self.start_mark = start_mark 5 | self.end_mark = end_mark 6 | def __repr__(self): 7 | attributes = [key for key in self.__dict__ 8 | if not key.endswith('_mark')] 9 | attributes.sort() 10 | arguments = ', '.join(['%s=%r' % (key, getattr(self, key)) 11 | for key in attributes]) 12 | return '%s(%s)' % (self.__class__.__name__, arguments) 13 | 14 | #class BOMToken(Token): 15 | # id = '' 16 | 17 | class DirectiveToken(Token): 18 | id = '' 19 | def __init__(self, name, value, start_mark, end_mark): 20 | self.name = name 21 | self.value = value 22 | self.start_mark = start_mark 23 | self.end_mark = end_mark 24 | 25 | class DocumentStartToken(Token): 26 | id = '' 27 | 28 | class DocumentEndToken(Token): 29 | id = '' 30 | 31 | class StreamStartToken(Token): 32 | id = '' 33 | def __init__(self, start_mark=None, end_mark=None, 34 | encoding=None): 35 | self.start_mark = start_mark 36 | self.end_mark = end_mark 37 | self.encoding = encoding 38 | 39 | class StreamEndToken(Token): 40 | id = '' 41 | 42 | class BlockSequenceStartToken(Token): 43 | id = '' 44 | 45 | class BlockMappingStartToken(Token): 46 | id = '' 47 | 48 | class BlockEndToken(Token): 49 | id = '' 50 | 51 | class FlowSequenceStartToken(Token): 52 | id = '[' 53 | 54 | class FlowMappingStartToken(Token): 55 | id = '{' 56 | 57 | class FlowSequenceEndToken(Token): 58 | id = ']' 59 | 60 | class FlowMappingEndToken(Token): 61 | id = '}' 62 | 63 | class KeyToken(Token): 64 | id = '?' 65 | 66 | class ValueToken(Token): 67 | id = ':' 68 | 69 | class BlockEntryToken(Token): 70 | id = '-' 71 | 72 | class FlowEntryToken(Token): 73 | id = ',' 74 | 75 | class AliasToken(Token): 76 | id = '' 77 | def __init__(self, value, start_mark, end_mark): 78 | self.value = value 79 | self.start_mark = start_mark 80 | self.end_mark = end_mark 81 | 82 | class AnchorToken(Token): 83 | id = '' 84 | def __init__(self, value, start_mark, end_mark): 85 | self.value = value 86 | self.start_mark = start_mark 87 | self.end_mark = end_mark 88 | 89 | class TagToken(Token): 90 | id = '' 91 | def __init__(self, value, start_mark, end_mark): 92 | self.value = value 93 | self.start_mark = start_mark 94 | self.end_mark = end_mark 95 | 96 | class ScalarToken(Token): 97 | id = '' 98 | def __init__(self, value, plain, start_mark, end_mark, style=None): 99 | self.value = value 100 | self.plain = plain 101 | self.start_mark = start_mark 102 | self.end_mark = end_mark 103 | self.style = style 104 | 105 | --------------------------------------------------------------------------------