├── .github └── workflows │ ├── python-package.yml │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── pyproject.toml ├── rnsh ├── __init__.py ├── args.py ├── docopt.py ├── exception.py ├── helpers.py ├── initiator.py ├── listener.py ├── loop.py ├── process.py ├── protocol.py ├── retry.py ├── rnsh.py ├── rnslogging.py ├── session.py └── testlogging.py ├── tests ├── __init__.py ├── helpers.py ├── reticulum_test_config ├── test_args.py ├── test_exception.py ├── test_process.py ├── test_protocol.py ├── test_retry.py └── test_rnsh.py └── tty_test.py /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | workflow_dispatch: 5 | push: 6 | branches: 7 | - 'main' 8 | paths-ignore: 9 | - 'README.md' 10 | - '.github/**' 11 | pull_request: 12 | types: 13 | - opened 14 | branches: 15 | - 'main' 16 | paths-ignore: 17 | - 'README.md' 18 | 19 | jobs: 20 | 21 | test: 22 | runs-on: [self-hosted, linux] 23 | 24 | steps: 25 | - uses: actions/checkout@v1 26 | with: 27 | fetch-depth: 1 28 | 29 | - name: Set up Python 3.9 30 | uses: actions/setup-python@v4 31 | with: 32 | python-version: 3.9 33 | 34 | - name: Install Poetry Action 35 | uses: snok/install-poetry@v1.3.3 36 | 37 | 38 | - name: Cache Poetry virtualenv 39 | uses: actions/cache@v1 40 | id: cache 41 | with: 42 | path: ~/.virtualenvs 43 | key: poetry-${{ hashFiles('**/poetry.lock') }} 44 | restore-keys: | 45 | poetry-${{ hashFiles('**/poetry.lock') }} 46 | 47 | 48 | - name: Install Dependencies 49 | run: poetry install 50 | if: steps.cache.outputs.cache-hit != 'true' 51 | 52 | # - name: Code Quality 53 | # run: poetry run black . --check 54 | 55 | - name: Test with pytest 56 | run: poetry run pytest -m "not skip_ci" tests 57 | 58 | # - name: Vulnerability check 59 | # run: poetry run safety check 60 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Publish 10 | 11 | on: 12 | workflow_dispatch: 13 | push: 14 | # Sequence of patterns matched against refs/tags 15 | tags: 16 | - 'release/v*' # Push events to matching v*, i.e. v1.0, v20.15.10 17 | 18 | permissions: 19 | contents: read 20 | 21 | jobs: 22 | deploy: 23 | 24 | runs-on: [self-hosted, linux] 25 | 26 | steps: 27 | - uses: actions/checkout@v2 28 | with: 29 | fetch-depth: 1 30 | 31 | - name: Set up Python 3.9 32 | uses: actions/setup-python@v4 33 | with: 34 | python-version: 3.9 35 | 36 | - name: Install Poetry Action 37 | uses: snok/install-poetry@v1.3.3 38 | 39 | 40 | - name: Cache Poetry virtualenv 41 | uses: actions/cache@v3 42 | id: cache 43 | with: 44 | path: ~/.virtualenvs 45 | key: poetry-${{ hashFiles('**/poetry.lock') }} 46 | restore-keys: | 47 | poetry-${{ hashFiles('**/poetry.lock') }} 48 | 49 | 50 | - name: Install Dependencies 51 | run: poetry install 52 | if: steps.cache.outputs.cache-hit != 'true' 53 | 54 | 55 | - name: Test with pytest 56 | run: poetry run pytest -m "not skip_ci" tests 57 | 58 | 59 | - name: Build package 60 | run: poetry build 61 | 62 | - name: Set Versions 63 | uses: actions/github-script@v6.4.0 64 | id: set_version 65 | with: 66 | script: | 67 | const tag = context.ref.substring(18) 68 | const no_v = tag.replace('v', '') 69 | const dash_index = no_v.lastIndexOf('-') 70 | const no_dash = (dash_index > -1) ? no_v.substring(0, dash_index) : no_v 71 | core.setOutput('tag', tag) 72 | core.setOutput('no-v', no_v) 73 | core.setOutput('no-dash', no_dash) 74 | 75 | 76 | # - name: Upload a Build Artifact 77 | # uses: actions/upload-artifact@v3.1.2 78 | # with: 79 | # # Artifact name 80 | # name: "pip package" 81 | # # A file, directory or wildcard pattern that describes what to upload 82 | # path: "dist/*" 83 | # # The desired behavior if no files are found using the provided path. 84 | # if-no-files-found: error 85 | 86 | - name: Create Release 87 | id: create_release 88 | uses: actions/create-release@v1 89 | env: 90 | GITHUB_TOKEN: ${{ secrets.GH_API_TOKEN }} 91 | with: 92 | tag_name: ${{ github.ref }} 93 | release_name: ${{ steps.set_version.outputs.tag }} 94 | draft: true 95 | prerelease: false 96 | 97 | - name: Upload Release Artefact 98 | uses: actions/upload-release-asset@v1 99 | env: 100 | GITHUB_TOKEN: ${{ secrets.GH_API_TOKEN }} 101 | with: 102 | upload_url: ${{ steps.create_release.outputs.upload_url }} 103 | asset_path: ./dist/rnsh-${{ steps.set_version.outputs.no-v }}-py3-none-any.whl 104 | asset_name: rnsh-${{ steps.set_version.outputs.no-v }}-py3-none-any.whl 105 | asset_content_type: application/zip 106 | 107 | 108 | - name: Publish to PyPI 109 | run: poetry publish --username __token__ --password ${{ secrets.PYPI_API_TOKEN }} 110 | 111 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /venv/ 2 | testconfig/ 3 | /.idea/ 4 | /poetry.lock 5 | /rnsh.egg-info/ 6 | /build/ 7 | /dist/ 8 | .pytest_cache/ 9 | *__pycache__ 10 | /RNS 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2023 Aaron Heise 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # `r n s h`  Shell over Reticulum 2 | [![CI](https://github.com/acehoss/rnsh/actions/workflows/python-package.yml/badge.svg)](https://github.com/acehoss/rnsh/actions/workflows/python-package.yml)  3 | [![Release](https://github.com/acehoss/rnsh/actions/workflows/python-publish.yml/badge.svg)](https://github.com/acehoss/rnsh/actions/workflows/python-publish.yml)  4 | [![PyPI version](https://badge.fury.io/py/rnsh.svg)](https://badge.fury.io/py/rnsh)   5 | ![PyPI - Downloads](https://img.shields.io/pypi/dw/rnsh?color=informational&label=Installs&logo=pypi) 6 | 7 | `rnsh` is a utility written in Python that facilitates shell 8 | sessions over [Reticulum](https://reticulum.network) networks. 9 | It is based on the `rnx` utility that ships with Reticulum and 10 | aims to provide a similar experience to SSH. 11 | 12 | ## Contents 13 | 14 | - [Alpha Disclaimer](#reminder--alpha-software) 15 | - [Recent Changes](#recent-changes) 16 | - [Quickstart](#quickstart) 17 | - [Options](#options) 18 | - [How it works](#how-it-works) 19 | - [Roadmap](#roadmap) 20 | - [Active TODO](#todo) 21 | 22 | ### Reminder: Beta Software 23 | The interface is starting to firm up, but some bug fixes at this 24 | point still may introduce breaking changes, especially in the 25 | protocol layers of the software. 26 | 27 | ## Recent Changes 28 | ### v0.1.4 29 | - Fix invalid escape sequence handling for terminal escape sequences 30 | 31 | ### v0.1.3 32 | - Fix an issue where disconnecting a session using ~. would result in further connections to 33 | the same initiator would appear to hang. 34 | - Setting `-q` will suppress the pre-connect spinners 35 | 36 | ### v0.1.2 37 | - Adaptive compression (RNS update) provides significant performance improvements ([PR](https://github.com/acehoss/rnsh/pull/24)) 38 | - Allowed identities file - put allowed identities in a file instead of on the command 39 | line for easier service invocations. ([see PR for details](https://github.com/acehoss/rnsh/pull/25)) 40 | - Escape Sequences, Session Termination & Line-Interactive Mode ([see PR for details](https://github.com/acehoss/rnsh/pull/26)) 41 | 42 | ### v0.1.1 43 | - Fix issue with intermittent data corruption 44 | 45 | ### v0.1.0 46 | - First beta! Includes peformance improvements. 47 | 48 | ### v0.0.13, v0.0.14 49 | - Bug fixes 50 | 51 | ### v0.0.12 52 | - Remove service name from RNS destination aspects. Service name 53 | now selects a suffix for the identity file and should only be 54 | supplied on the listener. The initiator only needs the destination 55 | hash of the listener to connect. 56 | - Show a spinner during link establishment on tty sessions 57 | - Attempt to catch and beautify exceptions on initiator 58 | 59 | ### v0.0.11 60 | - Event loop bursting improves throughput and CPU utilization on 61 | both listener and initiator. 62 | - Packet retries use RNS resend feature to prevent duplicate 63 | packets. 64 | 65 | ### v0.0.10 66 | - Rate limit window change events to prevent saturation of transport 67 | - Tweaked some loop timers to improve CPU utilization 68 | 69 | ### v0.0.9 70 | - Switch to a new packet-based protocol 71 | - Bug fixes and dependency updates 72 | 73 | ## Quickstart 74 | 75 | Tested (thus far) on Python 3.11 macOS 13.1 ARM64. Should 76 | run on Python 3.6+ on Linux or Unix. WSL probably works. 77 | Cygwin might work, too. 78 | 79 | - Activate a virtualenv 80 | - `pip3 install rnsh` 81 | - Or from a `whl` release, `pip3 install /path/to/rnsh-0.0.1-py3-none-any.whl` 82 | - Configure Reticulum interfaces, check with `rnstatus` 83 | - Ready to run `rnsh`. The options are shown below. 84 | 85 | ### Example: Shell server 86 | #### Setup 87 | Before running the listener or initiator, you'll need to get the 88 | listener destination hash and the initiator identity hash. 89 | ```shell 90 | # On listener 91 | rnsh -l -p 92 | 93 | # On initiator 94 | rnsh -p 95 | ``` 96 | Note: service name no longer is supplied on initiator. The destination 97 | hash encapsulates this information. 98 | 99 | #### Listener 100 | - Listening for default service name ("default"). 101 | - Using user's default Reticulum config dir (~/.reticulum). 102 | - Using default identity ($RNSCONFIGDIR/storage/identities/rnsh). 103 | - Allowing remote identity `6d47805065fa470852cf1b1ef417a1ac` to connect. 104 | - Launching `/bin/zsh` on authorized connect. 105 | ```shell 106 | rnsh -l -a 6d47805065fa470852cf1b1ef417a1ac -- /bin/zsh 107 | ``` 108 | #### Initiator 109 | - Connecting to default service name ("default"). 110 | - Using user's default Reticulum config dir (~/.reticulum). 111 | - Using default identity ($RNSCONFIGDIR/storage/identities/rnsh). 112 | - Connecting to destination `a5f72aefc2cb3cdba648f73f77c4e887` 113 | ```shell 114 | rnsh a5f72aefc2cb3cdba648f73f77c4e887 115 | ``` 116 | 117 | ## Options 118 | ``` 119 | Usage: 120 | rnsh -l [-c ] [-i | -s ] [-v... | -q...] -p 121 | rnsh -l [-c ] [-i | -s ] [-v... | -q...] 122 | [-b ] (-n | -a [-a ] ...) [-A | -C] 123 | [[--] [ ...]] 124 | rnsh [-c ] [-i ] [-v... | -q...] -p 125 | rnsh [-c ] [-i ] [-v... | -q...] [-N] [-m] [-w ] 126 | [[--] [ ...]] 127 | rnsh -h 128 | rnsh --version 129 | 130 | Options: 131 | -c DIR --config DIR Alternate Reticulum config directory to use 132 | -i FILE --identity FILE Specific identity file to use 133 | -s NAME --service NAME Service name for identity file if not default 134 | -p --print-identity Print identity information and exit 135 | -l --listen Listen (server) mode. If supplied, ...will 136 | be used as the command line when the initiator does not 137 | provide one or when remote command is disabled. If 138 | is not supplied, the default shell of the 139 | user rnsh is running under will be used. 140 | -b --announce PERIOD Announce on startup and every PERIOD seconds 141 | Specify 0 for PERIOD to announce on startup only. 142 | -a HASH --allowed HASH Specify identities allowed to connect 143 | -n --no-auth Disable authentication 144 | -N --no-id Disable identify on connect 145 | -A --remote-command-as-args Concatenate remote command to argument list of /shell 146 | -C --no-remote-command Disable executing command line from remote 147 | -m --mirror Client returns with code of remote process 148 | -w TIME --timeout TIME Specify client connect and request timeout in seconds 149 | -q --quiet Increase quietness (move level up), multiple increases effect 150 | DEFAULT LOGGING LEVEL 151 | CRITICAL (silent) 152 | Initiator -> ERROR 153 | WARNING 154 | Listener -> INFO 155 | DEBUG (insane) 156 | -v --verbose Increase verbosity (move level down), multiple increases effect 157 | --version Show version 158 | -h --help Show this help 159 | ``` 160 | 161 | ## How it works 162 | ### Listeners 163 | Listener instances are the servers. Each listener is configured 164 | with an RNS identity, and a service name. Together, RNS makes 165 | these into a destination hash that can be used to connect to 166 | your listener. 167 | 168 | Each listener must use a unique identity. The `-s` parameter 169 | can be used to specify a service name, which creates a unique 170 | identity file. 171 | 172 | Listeners can be configured with a command line to run on 173 | connect. Initiators can supply a command line as well. There 174 | are several different options for the way the command line 175 | is handled: 176 | 177 | - `-C` no initiator command line is allowed; the connection will 178 | be terminated if one is supplied. 179 | - `-A` initiator-supplied command line is appended to listener- 180 | configured command line 181 | - With neither of these options, the listener will use the first 182 | valid command line from this list: 183 | 1. Initiator-supplied command line 184 | 2. Listener command line argument 185 | 3. Default shell of user listener is running under 186 | 187 | 188 | If the `-n` option is not set on the listener, the initiator 189 | is required to identify before starting a command. The program 190 | will be started with the initiator's identity hash string is set 191 | in the environment variable `RNS_REMOTE_IDENTITY`. 192 | 193 | Listeners are set up using the `-l` flag. 194 | 195 | ### Initiators 196 | Initiators are the clients. Each initiator has an identity 197 | hash which is used as an authentication mechanism on Reticulum. 198 | You'll need this value to configure the listener to allow 199 | your connection. It is possible to run the server without 200 | authentication, but hopefully it's obvious that this is an 201 | advanced use case. 202 | 203 | To get the identity hash, use the `-p` flag. 204 | 205 | With the initiator identity set up in the listener command 206 | line, and with the listener identity copied (you'll need to 207 | do `-p` on the listener side, too), you can run the 208 | initiator. 209 | 210 | I recommend staying pretty vanilla to start with and 211 | trying `/bin/zsh` or whatever your favorite shell is these 212 | days. The shell should start in login mode. Ideally it 213 | works just like an `ssh` shell session. 214 | 215 | ## Protocol 216 | The protocol is build on top of the Reticulum `Packet` API. 217 | Application software sends and receives `Message` objects, 218 | which are encapsulated by `Packet` objects. Messages are 219 | (currently) sent one per packet, and only one packet is 220 | sent at a time (per link). The next packet is not sent until 221 | the receiver proves the outstanding packet. 222 | 223 | A future update will work to allow a sliding window of 224 | outstanding packets to improve channel utilization. 225 | 226 | ### Session Establishment 227 | 1. Initiator establishes link. Listener session enters state 228 | `LSSTATE_WAIT_IDENT`, or `LSSTATE_WAIT_VERS` if running 229 | with `--no-auth` option. 230 | 231 | 2. Initiator identifies on link if not using `--no-id`. 232 | - If using `--allowed-hash`, listener validates identity 233 | against configuration and if no match, sends a 234 | protocol error message and tears down link after prune 235 | timer. 236 | 3. Initiator transmits a `VersionInformationMessage`, which 237 | is evaluated by the server for compatibility. If 238 | incompatible, a protocol error is sent. 239 | 4. Listener responds with a `VersionInfoMessage` and enters 240 | state `LSSTATE_WAIT_CMD` 241 | 5. Initiator evaluates the listener's version information 242 | for compatibility and if incompatible, tears down link. 243 | 6. Initiator sends an `ExecuteCommandMessage` (which could 244 | be an empty command) and enters the session event loop. 245 | 7. Listener evaluates the command message against the 246 | configured options such as `-A` or `-C` and responds 247 | with a protocol error if not allowed. 248 | 8. Listener starts the program. If success, the listener 249 | enters the session event loop. If failure, responds 250 | with a `CommandExitedMessage`. 251 | 252 | ### Session Event Loop 253 | ##### Listener state `LSSTATE_RUNNING` 254 | Process messages received from initiator. 255 | - `WindowSizeMessage`: set window size on child tty if appropriate 256 | - `StreamDataMessage`: binary data stream for child process; 257 | streams ids 0, 1, 2 = stdin, stdout, stderr 258 | - `NoopMessage`: no operation - listener replies with `NoopMessage` 259 | - When link is torn down, child process is terminated if running and 260 | session destroyed 261 | - If command terminates, a `CommandExitedMessage` is sent and session 262 | is pruned after an idle timeout. 263 | ##### Initiator state `ISSTATE_RUNNING` 264 | Process messages received from listener. 265 | - `ErrorMessage`: print error, terminate link, and exit 266 | - `StreamDataMessage`: binary stream information; 267 | streams ids 0, 1, 2 = stdin, stdout, stderr 268 | - `CommandExitedMessage`: remote command exited 269 | - If link is torn down unexpectedly, print message and exit 270 | 271 | 272 | ## Roadmap 273 | 1. Plan a better roadmap 274 | 2. ? 275 | 3. Keep my day job 276 | 277 | ## TODO 278 | - [X] ~~Initial version~~ 279 | - [X] ~~Pip package with command-line utility support~~ 280 | - [X] ~~Publish to PyPI~~ 281 | - [X] ~~Improve signal handling~~ 282 | - [X] ~~Make it scriptable (currently requires a tty)~~ 283 | - [X] ~~Protocol improvements (throughput!)~~ 284 | - [X] ~~Documentation improvements~~ 285 | - [X] ~~Fix issues with running `rnsh` in a binary pipeline, i.e. 286 | piping the output of `tar` over `rsh`.~~ 287 | - [X] ~~Test on several platforms~~ 288 | - [X] ~~Fix issues that come up with testing~~ 289 | - [X] ~~v0.1.0 beta~~ 290 | - [X] ~~Test and fix more issues~~ 291 | - [ ] More betas 292 | - [ ] Enhancement Ideas 293 | - [x] `authorized_keys` mode similar to SSH to allow one listener 294 | process to serve multiple users 295 | - [ ] Git over `rnsh` (git remote helper) 296 | - [ ] Sliding window acknowledgements for improved throughput 297 | - [ ] v1.0 someday probably maybe 298 | 299 | ## Miscellaneous 300 | 301 | By piping into/out of `rnsh`, it is possible to transfer 302 | files using the same method discussed in 303 | [this article](https://cromwell-intl.com/open-source/tar-and-ssh.html). 304 | It's not terribly fast currently, due to the round-trip rule 305 | enforced by the protocol. Sliding window acknowledgements will 306 | speed this up significantly. 307 | 308 | Running tests: `poetry run pytest tests` -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "rnsh" 3 | version = "0.1.5" 4 | description = "Shell over Reticulum" 5 | authors = ["acehoss "] 6 | license = "MIT" 7 | readme = "README.md" 8 | 9 | [tool.poetry.dependencies] 10 | python = "^3.7" 11 | rns = ">=0.9.0" 12 | 13 | [tool.poetry.scripts] 14 | rnsh = 'rnsh.rnsh:rnsh_cli' 15 | 16 | [tool.poetry.group.dev.dependencies] 17 | pytest = "^7.2.1" 18 | setuptools = "^67.2.0" 19 | pytest-asyncio = "^0.20.3" 20 | safety = "^2.3.5" 21 | tomli = "^2.0.1" 22 | 23 | [tool.pytest.ini_options] 24 | markers = [ 25 | "skip_ci: marks tests that should not be run in CI builds (deselect with '-m \"not skip_ci\"')" 26 | ] 27 | 28 | [build-system] 29 | requires = ["poetry-core"] 30 | build-backend = "poetry.core.masonry.api" 31 | -------------------------------------------------------------------------------- /rnsh/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2023 Aaron Heise 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import os 24 | module_abs_filename = os.path.abspath(__file__) 25 | module_dir = os.path.dirname(module_abs_filename) 26 | 27 | def _get_version(): 28 | def pkg_res_version(): 29 | import pkg_resources 30 | return pkg_resources.get_distribution("rnsh").version 31 | 32 | def tomli_version(): 33 | import tomli 34 | return tomli.load(open(os.path.join(os.path.dirname(module_dir), "pyproject.toml"), "rb"))["tool"]["poetry"]["version"] 35 | 36 | try: 37 | if (os.path.isfile(os.path.join(os.path.dirname(module_dir), "pyproject.toml"))): 38 | try: 39 | return tomli_version() 40 | except: 41 | return "0.0.0" 42 | else: 43 | try: 44 | return pkg_res_version() 45 | except: 46 | return "0.0.0" 47 | 48 | except: 49 | return "0.0.0" 50 | 51 | __version__ = _get_version() -------------------------------------------------------------------------------- /rnsh/args.py: -------------------------------------------------------------------------------- 1 | from typing import TypeVar 2 | import RNS 3 | import rnsh 4 | import sys 5 | from rnsh import docopt 6 | 7 | _T = TypeVar("_T") 8 | 9 | 10 | def _split_array_at(arr: [_T], at: _T) -> ([_T], [_T]): 11 | try: 12 | idx = arr.index(at) 13 | return arr[:idx], arr[idx + 1:] 14 | except ValueError: 15 | return arr, [] 16 | 17 | 18 | usage = \ 19 | ''' 20 | Usage: 21 | rnsh -l [-c ] [-i | -s ] [-v... | -q...] -p 22 | rnsh -l [-c ] [-i | -s ] [-v... | -q...] 23 | [-b ] [-n] [-a ] ([-a ] ...) [-A | -C] 24 | [[--] [ ...]] 25 | rnsh [-c ] [-i ] [-v... | -q...] -p 26 | rnsh [-c ] [-i ] [-v... | -q...] [-N] [-m] [-w ] 27 | [[--] [ ...]] 28 | rnsh -h 29 | rnsh --version 30 | 31 | Options: 32 | -c DIR --config DIR Alternate Reticulum config directory to use 33 | -i FILE --identity FILE Specific identity file to use 34 | -s NAME --service NAME Service name for identity file if not default 35 | -p --print-identity Print identity information and exit 36 | -l --listen Listen (server) mode. If supplied, ...will 37 | be used as the command line when the initiator does not 38 | provide one or when remote command is disabled. If 39 | is not supplied, the default shell of the 40 | user rnsh is running under will be used. 41 | -b --announce PERIOD Announce on startup and every PERIOD seconds 42 | Specify 0 for PERIOD to announce on startup only. 43 | -a HASH --allowed HASH Specify identities allowed to connect. Allowed identities 44 | can also be specified in ~/.rnsh/allowed_identities or 45 | ~/.config/rnsh/allowed_identities, one hash per line. 46 | -n --no-auth Disable authentication 47 | -N --no-id Disable identify on connect 48 | -A --remote-command-as-args Concatenate remote command to argument list of /shell 49 | -C --no-remote-command Disable executing command line from remote 50 | -m --mirror Client returns with code of remote process 51 | -w TIME --timeout TIME Specify client connect and request timeout in seconds 52 | -q --quiet Increase quietness (move level up), multiple increases effect 53 | DEFAULT LOGGING LEVEL 54 | CRITICAL (silent) 55 | Initiator -> ERROR 56 | WARNING 57 | Listener -> INFO 58 | DEBUG (insane) 59 | -v --verbose Increase verbosity (move level down), multiple increases effect 60 | --version Show version 61 | -h --help Show this help 62 | ''' 63 | 64 | DEFAULT_SERVICE_NAME = "default" 65 | 66 | class Args: 67 | def __init__(self, argv: [str]): 68 | global usage 69 | try: 70 | self.argv = argv 71 | self.program_args = [] 72 | self.docopts_argv, self.program_args = _split_array_at(self.argv, "--") 73 | # need to add first arg after -- back onto argv for docopts, but only for listener 74 | if next(filter(lambda a: a == "-l" or a == "--listen", self.docopts_argv), None) is not None \ 75 | and len(self.program_args) > 0: 76 | self.docopts_argv.append(self.program_args[0]) 77 | self.program_args = self.program_args[1:] 78 | 79 | args = docopt.docopt(usage, argv=self.docopts_argv[1:], version=f"rnsh {rnsh.__version__}") 80 | # json.dump(args, sys.stdout) 81 | 82 | self.listen = args.get("--listen", None) or False 83 | self.service_name = args.get("--service", None) 84 | if self.listen and (self.service_name is None or len(self.service_name) > 0): 85 | self.service_name = DEFAULT_SERVICE_NAME 86 | self.identity = args.get("--identity", None) 87 | self.config = args.get("--config", None) 88 | self.print_identity = args.get("--print-identity", None) or False 89 | self.verbose = args.get("--verbose", None) or 0 90 | self.quiet = args.get("--quiet", None) or 0 91 | announce = args.get("--announce", None) 92 | self.announce = None 93 | try: 94 | if announce: 95 | self.announce = int(announce) 96 | except ValueError: 97 | print("Invalid value for --announce") 98 | sys.exit(1) 99 | self.no_auth = args.get("--no-auth", None) or False 100 | self.allowed = args.get("--allowed", None) or [] 101 | self.remote_cmd_as_args = args.get("--remote-command-as-args", None) or False 102 | self.no_remote_cmd = args.get("--no-remote-command", None) or False 103 | self.program = args.get("", None) 104 | if len(self.program_args) == 0: 105 | self.program_args = args.get("", None) or [] 106 | self.no_id = args.get("--no-id", None) or False 107 | self.mirror = args.get("--mirror", None) or False 108 | timeout = args.get("--timeout", None) 109 | self.timeout = None 110 | try: 111 | if timeout: 112 | self.timeout = int(timeout) 113 | except ValueError: 114 | print("Invalid value for --timeout") 115 | sys.exit(1) 116 | self.destination = args.get("", None) 117 | self.help = args.get("--help", None) or False 118 | self.command_line = [self.program] if self.program else [] 119 | self.command_line.extend(self.program_args) 120 | except docopt.DocoptExit: 121 | print() 122 | print(usage) 123 | sys.exit(1) 124 | except Exception as e: 125 | print(f"Error parsing arguments: {e}") 126 | print() 127 | print(usage) 128 | sys.exit(1) 129 | 130 | if self.help: 131 | sys.exit(0) 132 | -------------------------------------------------------------------------------- /rnsh/docopt.py: -------------------------------------------------------------------------------- 1 | """Pythonic command-line interface parser that will make you smile. 2 | 3 | * http://docopt.org 4 | * Repository and issue-tracker: https://github.com/docopt/docopt 5 | * Licensed under terms of MIT license (see LICENSE-MIT) 6 | * Copyright (c) 2013 Vladimir Keleshev, vladimir@keleshev.com 7 | 8 | """ 9 | import sys 10 | import re 11 | 12 | __all__ = ['docopt'] 13 | __version__ = '0.6.2' 14 | 15 | class DocoptLanguageError(Exception): 16 | 17 | """Error in construction of usage-message by developer.""" 18 | 19 | 20 | class DocoptExit(SystemExit): 21 | 22 | usage = '' 23 | 24 | def __init__(self, message=''): 25 | SystemExit.__init__(self, (message + '\n' + self.usage).strip()) 26 | 27 | 28 | class Pattern(object): 29 | 30 | def __eq__(self, other): 31 | return repr(self) == repr(other) 32 | 33 | def __hash__(self): 34 | return hash(repr(self)) 35 | 36 | def fix(self): 37 | self.fix_identities() 38 | self.fix_repeating_arguments() 39 | return self 40 | 41 | def fix_identities(self, uniq=None): 42 | if not hasattr(self, 'children'): 43 | return self 44 | uniq = list(set(self.flat())) if uniq is None else uniq 45 | for i, c in enumerate(self.children): 46 | if not hasattr(c, 'children'): 47 | assert c in uniq 48 | self.children[i] = uniq[uniq.index(c)] 49 | else: 50 | c.fix_identities(uniq) 51 | 52 | def fix_repeating_arguments(self): 53 | either = [list(c.children) for c in self.either.children] 54 | for case in either: 55 | for e in [c for c in case if case.count(c) > 1]: 56 | if type(e) is Argument or type(e) is Option and e.argcount: 57 | if e.value is None: 58 | e.value = [] 59 | elif type(e.value) is not list: 60 | e.value = e.value.split() 61 | if type(e) is Command or type(e) is Option and e.argcount == 0: 62 | e.value = 0 63 | return self 64 | 65 | @property 66 | def either(self): 67 | ret = [] 68 | groups = [[self]] 69 | while groups: 70 | children = groups.pop(0) 71 | types = [type(c) for c in children] 72 | if Either in types: 73 | either = [c for c in children if type(c) is Either][0] 74 | children.pop(children.index(either)) 75 | for c in either.children: 76 | groups.append([c] + children) 77 | elif Required in types: 78 | required = [c for c in children if type(c) is Required][0] 79 | children.pop(children.index(required)) 80 | groups.append(list(required.children) + children) 81 | elif Optional in types: 82 | optional = [c for c in children if type(c) is Optional][0] 83 | children.pop(children.index(optional)) 84 | groups.append(list(optional.children) + children) 85 | elif AnyOptions in types: 86 | optional = [c for c in children if type(c) is AnyOptions][0] 87 | children.pop(children.index(optional)) 88 | groups.append(list(optional.children) + children) 89 | elif OneOrMore in types: 90 | oneormore = [c for c in children if type(c) is OneOrMore][0] 91 | children.pop(children.index(oneormore)) 92 | groups.append(list(oneormore.children) * 2 + children) 93 | else: 94 | ret.append(children) 95 | return Either(*[Required(*e) for e in ret]) 96 | 97 | 98 | class ChildPattern(Pattern): 99 | 100 | def __init__(self, name, value=None): 101 | self.name = name 102 | self.value = value 103 | 104 | def __repr__(self): 105 | return '%s(%r, %r)' % (self.__class__.__name__, self.name, self.value) 106 | 107 | def flat(self, *types): 108 | return [self] if not types or type(self) in types else [] 109 | 110 | def match(self, left, collected=None): 111 | collected = [] if collected is None else collected 112 | pos, match = self.single_match(left) 113 | if match is None: 114 | return False, left, collected 115 | left_ = left[:pos] + left[pos + 1:] 116 | same_name = [a for a in collected if a.name == self.name] 117 | if type(self.value) in (int, list): 118 | if type(self.value) is int: 119 | increment = 1 120 | else: 121 | increment = ([match.value] if type(match.value) is str 122 | else match.value) 123 | if not same_name: 124 | match.value = increment 125 | return True, left_, collected + [match] 126 | same_name[0].value += increment 127 | return True, left_, collected 128 | return True, left_, collected + [match] 129 | 130 | 131 | class ParentPattern(Pattern): 132 | 133 | def __init__(self, *children): 134 | self.children = list(children) 135 | 136 | def __repr__(self): 137 | return '%s(%s)' % (self.__class__.__name__, 138 | ', '.join(repr(a) for a in self.children)) 139 | 140 | def flat(self, *types): 141 | if type(self) in types: 142 | return [self] 143 | return sum([c.flat(*types) for c in self.children], []) 144 | 145 | 146 | class Argument(ChildPattern): 147 | 148 | def single_match(self, left): 149 | for n, p in enumerate(left): 150 | if type(p) is Argument: 151 | return n, Argument(self.name, p.value) 152 | return None, None 153 | 154 | @classmethod 155 | def parse(class_, source): 156 | name = re.findall('(<\S*?>)', source)[0] 157 | value = re.findall('\[default: (.*)\]', source, flags=re.I) 158 | return class_(name, value[0] if value else None) 159 | 160 | 161 | class Command(Argument): 162 | 163 | def __init__(self, name, value=False): 164 | self.name = name 165 | self.value = value 166 | 167 | def single_match(self, left): 168 | for n, p in enumerate(left): 169 | if type(p) is Argument: 170 | if p.value == self.name: 171 | return n, Command(self.name, True) 172 | else: 173 | break 174 | return None, None 175 | 176 | 177 | class Option(ChildPattern): 178 | 179 | def __init__(self, short=None, long=None, argcount=0, value=False): 180 | assert argcount in (0, 1) 181 | self.short, self.long = short, long 182 | self.argcount, self.value = argcount, value 183 | self.value = None if value is False and argcount else value 184 | 185 | @classmethod 186 | def parse(class_, option_description): 187 | short, long, argcount, value = None, None, 0, False 188 | options, _, description = option_description.strip().partition(' ') 189 | options = options.replace(',', ' ').replace('=', ' ') 190 | for s in options.split(): 191 | if s.startswith('--'): 192 | long = s 193 | elif s.startswith('-'): 194 | short = s 195 | else: 196 | argcount = 1 197 | if argcount: 198 | matched = re.findall('\[default: (.*)\]', description, flags=re.I) 199 | value = matched[0] if matched else None 200 | return class_(short, long, argcount, value) 201 | 202 | def single_match(self, left): 203 | for n, p in enumerate(left): 204 | if self.name == p.name: 205 | return n, p 206 | return None, None 207 | 208 | @property 209 | def name(self): 210 | return self.long or self.short 211 | 212 | def __repr__(self): 213 | return 'Option(%r, %r, %r, %r)' % (self.short, self.long, 214 | self.argcount, self.value) 215 | 216 | 217 | class Required(ParentPattern): 218 | 219 | def match(self, left, collected=None): 220 | collected = [] if collected is None else collected 221 | l = left 222 | c = collected 223 | for p in self.children: 224 | matched, l, c = p.match(l, c) 225 | if not matched: 226 | return False, left, collected 227 | return True, l, c 228 | 229 | 230 | class Optional(ParentPattern): 231 | 232 | def match(self, left, collected=None): 233 | collected = [] if collected is None else collected 234 | for p in self.children: 235 | m, left, collected = p.match(left, collected) 236 | return True, left, collected 237 | 238 | 239 | class AnyOptions(Optional): 240 | 241 | """Marker/placeholder for [options] shortcut.""" 242 | 243 | 244 | class OneOrMore(ParentPattern): 245 | 246 | def match(self, left, collected=None): 247 | assert len(self.children) == 1 248 | collected = [] if collected is None else collected 249 | l = left 250 | c = collected 251 | l_ = None 252 | matched = True 253 | times = 0 254 | while matched: 255 | # could it be that something didn't match but changed l or c? 256 | matched, l, c = self.children[0].match(l, c) 257 | times += 1 if matched else 0 258 | if l_ == l: 259 | break 260 | l_ = l 261 | if times >= 1: 262 | return True, l, c 263 | return False, left, collected 264 | 265 | 266 | class Either(ParentPattern): 267 | 268 | def match(self, left, collected=None): 269 | collected = [] if collected is None else collected 270 | outcomes = [] 271 | for p in self.children: 272 | matched, _, _ = outcome = p.match(left, collected) 273 | if matched: 274 | outcomes.append(outcome) 275 | if outcomes: 276 | return min(outcomes, key=lambda outcome: len(outcome[1])) 277 | return False, left, collected 278 | 279 | 280 | class TokenStream(list): 281 | 282 | def __init__(self, source, error): 283 | self += source.split() if hasattr(source, 'split') else source 284 | self.error = error 285 | 286 | def move(self): 287 | return self.pop(0) if len(self) else None 288 | 289 | def current(self): 290 | return self[0] if len(self) else None 291 | 292 | 293 | def parse_long(tokens, options): 294 | long, eq, value = tokens.move().partition('=') 295 | assert long.startswith('--') 296 | value = None if eq == value == '' else value 297 | similar = [o for o in options if o.long == long] 298 | if tokens.error is DocoptExit and similar == []: # if no exact match 299 | similar = [o for o in options if o.long and o.long.startswith(long)] 300 | if len(similar) > 1: # might be simply specified ambiguously 2+ times? 301 | raise tokens.error('%s is not a unique prefix: %s?' % 302 | (long, ', '.join(o.long for o in similar))) 303 | elif len(similar) < 1: 304 | argcount = 1 if eq == '=' else 0 305 | o = Option(None, long, argcount) 306 | options.append(o) 307 | if tokens.error is DocoptExit: 308 | o = Option(None, long, argcount, value if argcount else True) 309 | else: 310 | o = Option(similar[0].short, similar[0].long, 311 | similar[0].argcount, similar[0].value) 312 | if o.argcount == 0: 313 | if value is not None: 314 | raise tokens.error('%s must not have an argument' % o.long) 315 | else: 316 | if value is None: 317 | if tokens.current() is None: 318 | raise tokens.error('%s requires argument' % o.long) 319 | value = tokens.move() 320 | if tokens.error is DocoptExit: 321 | o.value = value if value is not None else True 322 | return [o] 323 | 324 | 325 | def parse_shorts(tokens, options): 326 | token = tokens.move() 327 | assert token.startswith('-') and not token.startswith('--') 328 | left = token.lstrip('-') 329 | parsed = [] 330 | while left != '': 331 | short, left = '-' + left[0], left[1:] 332 | similar = [o for o in options if o.short == short] 333 | if len(similar) > 1: 334 | raise tokens.error('%s is specified ambiguously %d times' % 335 | (short, len(similar))) 336 | elif len(similar) < 1: 337 | o = Option(short, None, 0) 338 | options.append(o) 339 | if tokens.error is DocoptExit: 340 | o = Option(short, None, 0, True) 341 | else: # why copying is necessary here? 342 | o = Option(short, similar[0].long, 343 | similar[0].argcount, similar[0].value) 344 | value = None 345 | if o.argcount != 0: 346 | if left == '': 347 | if tokens.current() is None: 348 | raise tokens.error('%s requires argument' % short) 349 | value = tokens.move() 350 | else: 351 | value = left 352 | left = '' 353 | if tokens.error is DocoptExit: 354 | o.value = value if value is not None else True 355 | parsed.append(o) 356 | return parsed 357 | 358 | 359 | def parse_pattern(source, options): 360 | tokens = TokenStream(re.sub(r'([\[\]\(\)\|]|\.\.\.)', r' \1 ', source), 361 | DocoptLanguageError) 362 | result = parse_expr(tokens, options) 363 | if tokens.current() is not None: 364 | raise tokens.error('unexpected ending: %r' % ' '.join(tokens)) 365 | return Required(*result) 366 | 367 | 368 | def parse_expr(tokens, options): 369 | seq = parse_seq(tokens, options) 370 | if tokens.current() != '|': 371 | return seq 372 | result = [Required(*seq)] if len(seq) > 1 else seq 373 | while tokens.current() == '|': 374 | tokens.move() 375 | seq = parse_seq(tokens, options) 376 | result += [Required(*seq)] if len(seq) > 1 else seq 377 | return [Either(*result)] if len(result) > 1 else result 378 | 379 | 380 | def parse_seq(tokens, options): 381 | result = [] 382 | while tokens.current() not in [None, ']', ')', '|']: 383 | atom = parse_atom(tokens, options) 384 | if tokens.current() == '...': 385 | atom = [OneOrMore(*atom)] 386 | tokens.move() 387 | result += atom 388 | return result 389 | 390 | 391 | def parse_atom(tokens, options): 392 | token = tokens.current() 393 | result = [] 394 | if token in '([': 395 | tokens.move() 396 | matching, pattern = {'(': [')', Required], '[': [']', Optional]}[token] 397 | result = pattern(*parse_expr(tokens, options)) 398 | if tokens.move() != matching: 399 | raise tokens.error("unmatched '%s'" % token) 400 | return [result] 401 | elif token == 'options': 402 | tokens.move() 403 | return [AnyOptions()] 404 | elif token.startswith('--') and token != '--': 405 | return parse_long(tokens, options) 406 | elif token.startswith('-') and token not in ('-', '--'): 407 | return parse_shorts(tokens, options) 408 | elif token.startswith('<') and token.endswith('>') or token.isupper(): 409 | return [Argument(tokens.move())] 410 | else: 411 | return [Command(tokens.move())] 412 | 413 | 414 | def parse_argv(tokens, options, options_first=False): 415 | parsed = [] 416 | while tokens.current() is not None: 417 | if tokens.current() == '--': 418 | return parsed + [Argument(None, v) for v in tokens] 419 | elif tokens.current().startswith('--'): 420 | parsed += parse_long(tokens, options) 421 | elif tokens.current().startswith('-') and tokens.current() != '-': 422 | parsed += parse_shorts(tokens, options) 423 | elif options_first: 424 | return parsed + [Argument(None, v) for v in tokens] 425 | else: 426 | parsed.append(Argument(None, tokens.move())) 427 | return parsed 428 | 429 | 430 | def parse_defaults(doc): 431 | # in python < 2.7 you can't pass flags=re.MULTILINE 432 | split = re.split('\n *(<\S+?>|-\S+?)', doc)[1:] 433 | split = [s1 + s2 for s1, s2 in zip(split[::2], split[1::2])] 434 | options = [Option.parse(s) for s in split if s.startswith('-')] 435 | #arguments = [Argument.parse(s) for s in split if s.startswith('<')] 436 | #return options, arguments 437 | return options 438 | 439 | 440 | def printable_usage(doc): 441 | # in python < 2.7 you can't pass flags=re.IGNORECASE 442 | usage_split = re.split(r'([Uu][Ss][Aa][Gg][Ee]:)', doc) 443 | if len(usage_split) < 3: 444 | raise DocoptLanguageError('"usage:" (case-insensitive) not found.') 445 | if len(usage_split) > 3: 446 | raise DocoptLanguageError('More than one "usage:" (case-insensitive).') 447 | return re.split(r'\n\s*\n', ''.join(usage_split[1:]))[0].strip() 448 | 449 | 450 | def formal_usage(printable_usage): 451 | pu = printable_usage.split()[1:] # split and drop "usage:" 452 | return '( ' + ' '.join(') | (' if s == pu[0] else s for s in pu[1:]) + ' )' 453 | 454 | 455 | def extras(help, version, options, doc): 456 | if help and any((o.name in ('-h', '--help')) and o.value for o in options): 457 | print(doc.strip("\n")) 458 | sys.exit() 459 | if version and any(o.name == '--version' and o.value for o in options): 460 | print(version) 461 | sys.exit() 462 | 463 | 464 | class Dict(dict): 465 | def __repr__(self): 466 | return '{%s}' % ',\n '.join('%r: %r' % i for i in sorted(self.items())) 467 | 468 | 469 | def docopt(doc, argv=None, help=True, version=None, options_first=False): 470 | if argv is None: 471 | argv = sys.argv[1:] 472 | DocoptExit.usage = printable_usage(doc) 473 | options = parse_defaults(doc) 474 | pattern = parse_pattern(formal_usage(DocoptExit.usage), options) 475 | argv = parse_argv(TokenStream(argv, DocoptExit), list(options), 476 | options_first) 477 | pattern_options = set(pattern.flat(Option)) 478 | for ao in pattern.flat(AnyOptions): 479 | doc_options = parse_defaults(doc) 480 | ao.children = list(set(doc_options) - pattern_options) 481 | extras(help, version, argv, doc) 482 | matched, left, collected = pattern.fix().match(argv) 483 | if matched and left == []: # better error message if left? 484 | return Dict((a.name, a.value) for a in (pattern.flat() + collected)) 485 | raise DocoptExit() 486 | -------------------------------------------------------------------------------- /rnsh/exception.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | from contextlib import AbstractContextManager 3 | import logging 4 | import sys 5 | 6 | 7 | class permit(AbstractContextManager): 8 | """Context manager to allow specified exceptions 9 | 10 | The specified exceptions will be allowed to bubble up. Other 11 | exceptions are suppressed. 12 | 13 | After a non-matching exception is suppressed, execution proceeds 14 | with the next statement following the with statement. 15 | 16 | with allow(KeyboardInterrupt): 17 | time.sleep(300) 18 | # Execution still resumes here if no KeyboardInterrupt 19 | """ 20 | 21 | def __init__(self, *exceptions): 22 | self._exceptions = exceptions 23 | 24 | def __enter__(self): 25 | pass 26 | 27 | def __exit__(self, exctype, excinst, exctb): 28 | return exctype is not None and not issubclass(exctype, self._exceptions) 29 | 30 | 31 | -------------------------------------------------------------------------------- /rnsh/helpers.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import time 3 | 4 | 5 | def bitwise_or_if(value: int, condition: bool, orval: int): 6 | if not condition: 7 | return value 8 | return value | orval 9 | 10 | 11 | def check_and(value: int, andval: int) -> bool: 12 | return (value & andval) > 0 13 | 14 | class SleepRate: 15 | def __init__(self, target_period: float): 16 | self.target_period = target_period 17 | self.last_wake = time.time() 18 | 19 | def next_sleep_time(self) -> float: 20 | old_last_wake = self.last_wake 21 | self.last_wake = time.time() 22 | next_wake = max(old_last_wake + 0.01, self.last_wake) 23 | sleep_for = next_wake - self.last_wake 24 | return sleep_for if sleep_for > 0 else 0 25 | 26 | async def sleep_async(self): 27 | await asyncio.sleep(self.next_sleep_time()) 28 | 29 | def sleep_block(self): 30 | time.sleep(self.next_sleep_time()) 31 | 32 | 33 | -------------------------------------------------------------------------------- /rnsh/initiator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # MIT License 4 | # 5 | # Copyright (c) 2016-2022 Mark Qvist / unsigned.io 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | from __future__ import annotations 26 | 27 | import asyncio 28 | import base64 29 | import enum 30 | import functools 31 | import logging as __logging 32 | import os 33 | import queue 34 | import shlex 35 | import signal 36 | import sys 37 | import termios 38 | import threading 39 | import time 40 | import tty 41 | from typing import Callable, TypeVar 42 | import RNS 43 | import rnsh.exception as exception 44 | import rnsh.process as process 45 | import rnsh.retry as retry 46 | import rnsh.rnslogging as rnslogging 47 | import rnsh.session as session 48 | import re 49 | import contextlib 50 | import rnsh.args 51 | import pwd 52 | import bz2 53 | import rnsh.protocol as protocol 54 | import rnsh.helpers as helpers 55 | import rnsh.rnsh 56 | 57 | module_logger = __logging.getLogger(__name__) 58 | 59 | 60 | def _get_logger(name: str): 61 | global module_logger 62 | return module_logger.getChild(name) 63 | 64 | 65 | _identity = None 66 | _reticulum = None 67 | _cmd: [str] | None = None 68 | DATA_AVAIL_MSG = "data available" 69 | _finished: asyncio.Event = None 70 | _retry_timer: retry.RetryThread | None = None 71 | _destination: RNS.Destination | None = None 72 | _loop: asyncio.AbstractEventLoop | None = None 73 | 74 | 75 | async def _check_finished(timeout: float = 0): 76 | return _finished is not None and await process.event_wait(_finished, timeout=timeout) 77 | 78 | 79 | def _sigint_handler(sig, loop): 80 | global _finished 81 | log = _get_logger("_sigint_handler") 82 | log.debug(signal.Signals(sig).name) 83 | if _finished is not None: 84 | _finished.set() 85 | else: 86 | raise KeyboardInterrupt() 87 | 88 | 89 | async def _spin_tty(until=None, msg=None, timeout=None): 90 | i = 0 91 | syms = "⢄⢂⢁⡁⡈⡐⡠" 92 | if timeout != None: 93 | timeout = time.time()+timeout 94 | 95 | print(msg+" ", end=" ") 96 | while (timeout == None or time.time() timeout: 105 | return False 106 | else: 107 | return True 108 | 109 | 110 | async def _spin_pipe(until: callable = None, msg=None, timeout: float | None = None) -> bool: 111 | if timeout is not None: 112 | timeout += time.time() 113 | 114 | while (timeout is None or time.time() < timeout) and not until(): 115 | if await _check_finished(0.1): 116 | raise asyncio.CancelledError() 117 | if timeout is not None and time.time() > timeout: 118 | return False 119 | else: 120 | return True 121 | 122 | 123 | async def _spin(until: callable = None, msg=None, timeout: float | None = None, quiet: bool = False) -> bool: 124 | if not quiet and os.isatty(1): 125 | return await _spin_tty(until, msg, timeout) 126 | else: 127 | return await _spin_pipe(until, msg, timeout) 128 | 129 | 130 | _link: RNS.Link | None = None 131 | _remote_exec_grace = 2.0 132 | _pq = queue.Queue() 133 | 134 | 135 | class InitiatorState(enum.IntEnum): 136 | IS_INITIAL = 0 137 | IS_LINKED = 1 138 | IS_WAIT_VERS = 2 139 | IS_RUNNING = 3 140 | IS_TERMINATE = 4 141 | IS_TEARDOWN = 5 142 | 143 | 144 | def _client_link_closed(link): 145 | log = _get_logger("_client_link_closed") 146 | if _finished: 147 | _finished.set() 148 | 149 | 150 | def _client_message_handler(message: RNS.MessageBase): 151 | log = _get_logger("_client_message_handler") 152 | _pq.put(message) 153 | 154 | 155 | class RemoteExecutionError(Exception): 156 | def __init__(self, msg): 157 | self.msg = msg 158 | 159 | 160 | async def _initiate_link(configdir, identitypath=None, verbosity=0, quietness=0, noid=False, destination=None, 161 | timeout=RNS.Transport.PATH_REQUEST_TIMEOUT): 162 | global _identity, _reticulum, _link, _destination, _remote_exec_grace 163 | log = _get_logger("_initiate_link") 164 | 165 | dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH // 8) * 2 166 | if len(destination) != dest_len: 167 | raise RemoteExecutionError( 168 | "Allowed destination length is invalid, must be {hex} hexadecimal characters ({byte} bytes).".format( 169 | hex=dest_len, byte=dest_len // 2)) 170 | try: 171 | destination_hash = bytes.fromhex(destination) 172 | except Exception as e: 173 | raise RemoteExecutionError("Invalid destination entered. Check your input.") 174 | 175 | if _reticulum is None: 176 | targetloglevel = RNS.LOG_ERROR + verbosity - quietness 177 | _reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel) 178 | rnslogging.RnsHandler.set_log_level_with_rns_level(targetloglevel) 179 | 180 | if _identity is None: 181 | _identity = rnsh.rnsh.prepare_identity(identitypath) 182 | 183 | if not RNS.Transport.has_path(destination_hash): 184 | RNS.Transport.request_path(destination_hash) 185 | log.info(f"Requesting path...") 186 | if not await _spin(until=lambda: RNS.Transport.has_path(destination_hash), msg="Requesting path...", 187 | timeout=timeout, quiet=quietness > 0): 188 | raise RemoteExecutionError("Path not found") 189 | 190 | if _destination is None: 191 | listener_identity = RNS.Identity.recall(destination_hash) 192 | _destination = RNS.Destination( 193 | listener_identity, 194 | RNS.Destination.OUT, 195 | RNS.Destination.SINGLE, 196 | rnsh.rnsh.APP_NAME 197 | ) 198 | 199 | if _link is None or _link.status == RNS.Link.PENDING: 200 | log.debug("No link") 201 | _link = RNS.Link(_destination) 202 | _link.did_identify = False 203 | 204 | _link.set_link_closed_callback(_client_link_closed) 205 | 206 | log.info(f"Establishing link...") 207 | if not await _spin(until=lambda: _link.status == RNS.Link.ACTIVE, msg="Establishing link...", 208 | timeout=timeout, quiet=quietness > 0): 209 | raise RemoteExecutionError("Could not establish link with " + RNS.prettyhexrep(destination_hash)) 210 | 211 | log.debug("Have link") 212 | if not noid and not _link.did_identify: 213 | _link.identify(_identity) 214 | _link.did_identify = True 215 | 216 | 217 | async def _handle_error(errmsg: RNS.MessageBase): 218 | if isinstance(errmsg, protocol.ErrorMessage): 219 | with contextlib.suppress(Exception): 220 | if _link and _link.status == RNS.Link.ACTIVE: 221 | _link.teardown() 222 | await asyncio.sleep(0.1) 223 | raise RemoteExecutionError(f"Remote error: {errmsg.msg}") 224 | 225 | 226 | async def initiate(configdir: str, identitypath: str, verbosity: int, quietness: int, noid: bool, destination: str, 227 | timeout: float, command: [str] | None = None): 228 | global _finished, _link 229 | log = _get_logger("_initiate") 230 | with process.TTYRestorer(sys.stdin.fileno()) as ttyRestorer: 231 | loop = asyncio.get_running_loop() 232 | state = InitiatorState.IS_INITIAL 233 | data_buffer = bytearray(sys.stdin.buffer.read()) if not os.isatty(sys.stdin.fileno()) else bytearray() 234 | line_buffer = bytearray() 235 | 236 | await _initiate_link( 237 | configdir=configdir, 238 | identitypath=identitypath, 239 | verbosity=verbosity, 240 | quietness=quietness, 241 | noid=noid, 242 | destination=destination, 243 | timeout=timeout 244 | ) 245 | 246 | if not _link or _link.status not in [RNS.Link.ACTIVE, RNS.Link.PENDING]: 247 | return 255 248 | 249 | state = InitiatorState.IS_LINKED 250 | outlet = session.RNSOutlet(_link) 251 | channel = _link.get_channel() 252 | protocol.register_message_types(channel) 253 | channel.add_message_handler(_client_message_handler) 254 | 255 | # Next step after linking and identifying: send version 256 | # if not await _spin(lambda: messenger.is_outlet_ready(outlet), timeout=5, quiet=quietness > 0): 257 | # print("Error bringing up link") 258 | # return 253 259 | 260 | channel.send(protocol.VersionInfoMessage()) 261 | try: 262 | vm = _pq.get(timeout=max(outlet.rtt * 20, 5)) 263 | await _handle_error(vm) 264 | if not isinstance(vm, protocol.VersionInfoMessage): 265 | raise Exception("Invalid message received") 266 | log.debug(f"Server version info: sw {vm.sw_version} prot {vm.protocol_version}") 267 | state = InitiatorState.IS_RUNNING 268 | except queue.Empty: 269 | print("Protocol error") 270 | return 254 271 | 272 | winch = False 273 | def sigwinch_handler(): 274 | nonlocal winch 275 | # log.debug("WindowChanged") 276 | winch = True 277 | 278 | esc = False 279 | pre_esc = True 280 | line_mode = False 281 | line_flush = False 282 | blind_write_count = 0 283 | flush_chars = ["\x01", "\x03", "\x04", "\x05", "\x0c", "\x11", "\x13", "\x15", "\x19", "\t", "\x1A", "\x1B"] 284 | def handle_escape(b): 285 | nonlocal line_mode 286 | if b == "?": 287 | os.write(1, "\n\r\n\rSupported rnsh escape sequences:".encode("utf-8")) 288 | os.write(1, "\n\r ~~ Send the escape character by typing it twice".encode("utf-8")) 289 | os.write(1, "\n\r ~. Terminate session and exit immediately".encode("utf-8")) 290 | os.write(1, "\n\r ~L Toggle line-interactive mode".encode("utf-8")) 291 | os.write(1, "\n\r ~? Display this quick reference\n\r".encode("utf-8")) 292 | os.write(1, "\n\r(Escape sequences are only recognized immediately after newline)\n\r".encode("utf-8")) 293 | return None 294 | elif b == ".": 295 | _link.teardown() 296 | return None 297 | elif b == "L": 298 | line_mode = not line_mode 299 | if line_mode: 300 | os.write(1, "\n\rLine-interactive mode enabled\n\r".encode("utf-8")) 301 | else: 302 | os.write(1, "\n\rLine-interactive mode disabled\n\r".encode("utf-8")) 303 | return None 304 | 305 | return b 306 | 307 | stdin_eof = False 308 | def stdin(): 309 | nonlocal stdin_eof, pre_esc, esc, line_mode 310 | nonlocal line_flush, blind_write_count 311 | try: 312 | in_data = process.tty_read(sys.stdin.fileno()) 313 | if in_data is not None: 314 | data = bytearray() 315 | for b in bytes(in_data): 316 | c = chr(b) 317 | if c == "\r": 318 | pre_esc = True 319 | line_flush = True 320 | data.append(b) 321 | elif line_mode and c in flush_chars: 322 | pre_esc = False 323 | line_flush = True 324 | data.append(b) 325 | elif line_mode and (c == "\b" or c == "\x7f"): 326 | pre_esc = False 327 | if len(line_buffer)>0: 328 | line_buffer.pop(-1) 329 | blind_write_count -= 1 330 | os.write(1, "\b \b".encode("utf-8")) 331 | elif pre_esc == True and c == "~": 332 | pre_esc = False 333 | esc = True 334 | elif esc == True: 335 | ret = handle_escape(c) 336 | if ret != None: 337 | if ret != "~": 338 | data.append(ord("~")) 339 | data.append(ord(ret)) 340 | esc = False 341 | else: 342 | pre_esc = False 343 | data.append(b) 344 | 345 | if not line_mode: 346 | data_buffer.extend(data) 347 | else: 348 | line_buffer.extend(data) 349 | if line_flush: 350 | data_buffer.extend(line_buffer) 351 | line_buffer.clear() 352 | os.write(1, ("\b \b"*blind_write_count).encode("utf-8")) 353 | line_flush = False 354 | blind_write_count = 0 355 | else: 356 | os.write(1, data) 357 | blind_write_count += len(data) 358 | 359 | except EOFError: 360 | if os.isatty(0): 361 | data_buffer.extend(process.CTRL_D) 362 | stdin_eof = True 363 | process.tty_unset_reader_callbacks(sys.stdin.fileno()) 364 | 365 | process.tty_add_reader_callback(sys.stdin.fileno(), stdin) 366 | 367 | tcattr = None 368 | rows, cols, hpix, vpix = (None, None, None, None) 369 | try: 370 | tcattr = termios.tcgetattr(0) 371 | rows, cols, hpix, vpix = process.tty_get_winsize(0) 372 | except: 373 | try: 374 | tcattr = termios.tcgetattr(1) 375 | rows, cols, hpix, vpix = process.tty_get_winsize(1) 376 | except: 377 | try: 378 | tcattr = termios.tcgetattr(2) 379 | rows, cols, hpix, vpix = process.tty_get_winsize(2) 380 | except: 381 | pass 382 | 383 | await _spin(lambda: channel.is_ready_to_send(), "Waiting for channel...", 1, quietness > 0) 384 | channel.send(protocol.ExecuteCommandMesssage(cmdline=command, 385 | pipe_stdin=not os.isatty(0), 386 | pipe_stdout=not os.isatty(1), 387 | pipe_stderr=not os.isatty(2), 388 | tcflags=tcattr, 389 | term=os.environ.get("TERM", None), 390 | rows=rows, 391 | cols=cols, 392 | hpix=hpix, 393 | vpix=vpix)) 394 | 395 | loop.add_signal_handler(signal.SIGWINCH, sigwinch_handler) 396 | _finished = asyncio.Event() 397 | loop.add_signal_handler(signal.SIGINT, functools.partial(_sigint_handler, signal.SIGINT, loop)) 398 | loop.add_signal_handler(signal.SIGTERM, functools.partial(_sigint_handler, signal.SIGTERM, loop)) 399 | mdu = _link.MDU - 16 400 | sent_eof = False 401 | last_winch = time.time() 402 | sleeper = helpers.SleepRate(0.01) 403 | processed = False 404 | while not await _check_finished() and state in [InitiatorState.IS_RUNNING]: 405 | try: 406 | try: 407 | message = _pq.get(timeout=sleeper.next_sleep_time() if not processed else 0.0005) 408 | await _handle_error(message) 409 | processed = True 410 | if isinstance(message, protocol.StreamDataMessage): 411 | if message.stream_id == protocol.StreamDataMessage.STREAM_ID_STDOUT: 412 | if message.data and len(message.data) > 0: 413 | ttyRestorer.raw() 414 | log.debug(f"stdout: {message.data}") 415 | os.write(1, message.data) 416 | sys.stdout.flush() 417 | if message.eof: 418 | os.close(1) 419 | if message.stream_id == protocol.StreamDataMessage.STREAM_ID_STDERR: 420 | if message.data and len(message.data) > 0: 421 | ttyRestorer.raw() 422 | log.debug(f"stdout: {message.data}") 423 | os.write(2, message.data) 424 | sys.stderr.flush() 425 | if message.eof: 426 | os.close(2) 427 | elif isinstance(message, protocol.CommandExitedMessage): 428 | log.debug(f"received return code {message.return_code}, exiting") 429 | return message.return_code 430 | elif isinstance(message, protocol.ErrorMessage): 431 | log.error(message.data) 432 | if message.fatal: 433 | _link.teardown() 434 | return 200 435 | 436 | except queue.Empty: 437 | processed = False 438 | 439 | if channel.is_ready_to_send(): 440 | def compress_adaptive(buf: bytes): 441 | comp_tries = RNS.RawChannelWriter.COMPRESSION_TRIES 442 | comp_try = 1 443 | comp_success = False 444 | 445 | chunk_len = len(buf) 446 | if chunk_len > RNS.RawChannelWriter.MAX_CHUNK_LEN: 447 | chunk_len = RNS.RawChannelWriter.MAX_CHUNK_LEN 448 | chunk_segment = None 449 | 450 | chunk_segment = None 451 | max_data_len = channel.mdu - protocol.StreamDataMessage.OVERHEAD 452 | while chunk_len > 32 and comp_try < comp_tries: 453 | chunk_segment_length = int(chunk_len/comp_try) 454 | compressed_chunk = bz2.compress(buf[:chunk_segment_length]) 455 | compressed_length = len(compressed_chunk) 456 | if compressed_length < max_data_len and compressed_length < chunk_segment_length: 457 | comp_success = True 458 | break 459 | else: 460 | comp_try += 1 461 | 462 | if comp_success: 463 | diff = max_data_len - len(compressed_chunk) 464 | chunk = compressed_chunk 465 | processed_length = chunk_segment_length 466 | else: 467 | chunk = bytes(buf[:max_data_len]) 468 | processed_length = len(chunk) 469 | 470 | return comp_success, processed_length, chunk 471 | 472 | comp_success, processed_length, chunk = compress_adaptive(data_buffer) 473 | stdin = chunk 474 | data_buffer = data_buffer[processed_length:] 475 | eof = not sent_eof and stdin_eof and len(stdin) == 0 476 | if len(stdin) > 0 or eof: 477 | channel.send(protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDIN, stdin, eof, comp_success)) 478 | sent_eof = eof 479 | processed = True 480 | 481 | # send window change, but rate limited 482 | if winch and time.time() - last_winch > _link.rtt * 25: 483 | last_winch = time.time() 484 | winch = False 485 | with contextlib.suppress(Exception): 486 | r, c, h, v = process.tty_get_winsize(0) 487 | channel.send(protocol.WindowSizeMessage(r, c, h, v)) 488 | processed = True 489 | except RemoteExecutionError as e: 490 | print(e.msg) 491 | return 255 492 | except Exception as ex: 493 | print(f"Client exception: {ex}") 494 | if _link and _link.status != RNS.Link.CLOSED: 495 | _link.teardown() 496 | return 127 497 | 498 | # await process.event_wait_any([_new_data, _finished], timeout=min(max(rtt * 50, 5), 120)) 499 | # await sleeper.sleep_async() 500 | log.debug("after main loop") 501 | return 0 -------------------------------------------------------------------------------- /rnsh/listener.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # MIT License 4 | # 5 | # Copyright (c) 2016-2022 Mark Qvist / unsigned.io 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | from __future__ import annotations 26 | 27 | import asyncio 28 | import base64 29 | import enum 30 | import functools 31 | import logging as __logging 32 | import os 33 | import queue 34 | import shlex 35 | import signal 36 | import sys 37 | import termios 38 | import threading 39 | import time 40 | import tty 41 | from typing import Callable, TypeVar 42 | import RNS 43 | import rnsh.exception as exception 44 | import rnsh.process as process 45 | import rnsh.retry as retry 46 | import rnsh.rnslogging as rnslogging 47 | import rnsh.session as session 48 | import re 49 | import contextlib 50 | import rnsh.args 51 | import pwd 52 | import rnsh.protocol as protocol 53 | import rnsh.helpers as helpers 54 | import rnsh.rnsh 55 | 56 | module_logger = __logging.getLogger(__name__) 57 | 58 | 59 | def _get_logger(name: str): 60 | global module_logger 61 | return module_logger.getChild(name) 62 | 63 | 64 | _identity = None 65 | _reticulum = None 66 | _allow_all = False 67 | _allowed_file = None 68 | _allowed_identity_hashes = [] 69 | _allowed_file_identity_hashes = [] 70 | _cmd: [str] | None = None 71 | DATA_AVAIL_MSG = "data available" 72 | _finished: asyncio.Event = None 73 | _retry_timer: retry.RetryThread | None = None 74 | _destination: RNS.Destination | None = None 75 | _loop: asyncio.AbstractEventLoop | None = None 76 | _no_remote_command = True 77 | _remote_cmd_as_args = False 78 | 79 | 80 | async def _check_finished(timeout: float = 0): 81 | return await process.event_wait(_finished, timeout=timeout) 82 | 83 | 84 | def _sigint_handler(sig, loop): 85 | global _finished 86 | log = _get_logger("_sigint_handler") 87 | log.debug(signal.Signals(sig).name) 88 | if _finished is not None: 89 | _finished.set() 90 | else: 91 | raise KeyboardInterrupt() 92 | 93 | def _reload_allowed_file(): 94 | global _allowed_file, _allowed_file_identity_hashes 95 | log = _get_logger("_listen") 96 | if _allowed_file != None: 97 | try: 98 | with open(_allowed_file, "r") as file: 99 | dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH // 8) * 2 100 | added = 0 101 | line = 0 102 | _allowed_file_identity_hashes = [] 103 | for allow in file.read().replace("\r", "").split("\n"): 104 | line += 1 105 | if len(allow) == dest_len: 106 | try: 107 | destination_hash = bytes.fromhex(allow) 108 | _allowed_file_identity_hashes.append(destination_hash) 109 | added += 1 110 | except Exception: 111 | log.debug(f"Discarded invalid Identity hash in {_allowed_file} at line {line}") 112 | 113 | ms = "y" if added == 1 else "ies" 114 | log.debug(f"Loaded {added} allowed identit{ms} from "+str(_allowed_file)) 115 | except Exception as e: 116 | log.error(f"Error while reloading allowed indetities file: {e}") 117 | 118 | 119 | async def listen(configdir, command, identitypath=None, service_name=None, verbosity=0, quietness=0, allowed=None, 120 | allowed_file=None, disable_auth=None, announce_period=900, no_remote_command=True, remote_cmd_as_args=False, 121 | loop: asyncio.AbstractEventLoop = None): 122 | global _identity, _allow_all, _allowed_identity_hashes, _allowed_file, _allowed_file_identity_hashes 123 | global _reticulum, _cmd, _destination, _no_remote_command, _remote_cmd_as_args, _finished 124 | log = _get_logger("_listen") 125 | if not loop: 126 | loop = asyncio.get_running_loop() 127 | if service_name is None or len(service_name) == 0: 128 | service_name = "default" 129 | 130 | log.info(f"Using service name {service_name}") 131 | 132 | 133 | targetloglevel = RNS.LOG_INFO + verbosity - quietness 134 | _reticulum = RNS.Reticulum(configdir=configdir, loglevel=targetloglevel) 135 | rnslogging.RnsHandler.set_log_level_with_rns_level(targetloglevel) 136 | _identity = rnsh.rnsh.prepare_identity(identitypath, service_name) 137 | _destination = RNS.Destination(_identity, RNS.Destination.IN, RNS.Destination.SINGLE, rnsh.rnsh.APP_NAME) 138 | 139 | _cmd = command 140 | if _cmd is None or len(_cmd) == 0: 141 | shell = None 142 | try: 143 | shell = pwd.getpwuid(os.getuid()).pw_shell 144 | except Exception as e: 145 | log.error(f"Error looking up shell: {e}") 146 | log.info(f"Using {shell} for default command.") 147 | _cmd = [shell] if shell else None 148 | else: 149 | log.info(f"Using command {shlex.join(_cmd)}") 150 | 151 | _no_remote_command = no_remote_command 152 | session.ListenerSession.allow_remote_command = not no_remote_command 153 | _remote_cmd_as_args = remote_cmd_as_args 154 | if (_cmd is None or len(_cmd) == 0 or _cmd[0] is None or len(_cmd[0]) == 0) \ 155 | and (_no_remote_command or _remote_cmd_as_args): 156 | raise Exception(f"Unable to look up shell for {os.getlogin}, cannot proceed with -A or -C and no .") 157 | 158 | session.ListenerSession.default_command = _cmd 159 | session.ListenerSession.remote_cmd_as_args = _remote_cmd_as_args 160 | 161 | if disable_auth: 162 | _allow_all = True 163 | session.ListenerSession.allow_all = True 164 | else: 165 | if allowed_file is not None: 166 | _allowed_file = allowed_file 167 | _reload_allowed_file() 168 | 169 | if allowed is not None: 170 | for a in allowed: 171 | try: 172 | dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH // 8) * 2 173 | if len(a) != dest_len: 174 | raise ValueError( 175 | "Allowed destination length is invalid, must be {hex} hexadecimal " + 176 | "characters ({byte} bytes).".format( 177 | hex=dest_len, byte=dest_len // 2)) 178 | try: 179 | destination_hash = bytes.fromhex(a) 180 | _allowed_identity_hashes.append(destination_hash) 181 | session.ListenerSession.allowed_identity_hashes.append(destination_hash) 182 | except Exception: 183 | raise ValueError("Invalid destination entered. Check your input.") 184 | except Exception as e: 185 | log.error(str(e)) 186 | exit(1) 187 | 188 | if (len(_allowed_identity_hashes) < 1 and len(_allowed_file_identity_hashes) < 1) and not disable_auth: 189 | log.warning("Warning: No allowed identities configured, rnsh will not accept any connections!") 190 | 191 | def link_established(lnk: RNS.Link): 192 | _reload_allowed_file() 193 | session.ListenerSession.allowed_file_identity_hashes = _allowed_file_identity_hashes 194 | session.ListenerSession(session.RNSOutlet.get_outlet(lnk), lnk.get_channel(), loop) 195 | _destination.set_link_established_callback(link_established) 196 | 197 | _finished = asyncio.Event() 198 | signal.signal(signal.SIGINT, _sigint_handler) 199 | 200 | log.info("rnsh listening for commands on " + RNS.prettyhexrep(_destination.hash)) 201 | 202 | if announce_period is not None: 203 | _destination.announce() 204 | 205 | last_announce = time.time() 206 | sleeper = helpers.SleepRate(0.01) 207 | 208 | try: 209 | while not await _check_finished(): 210 | if announce_period and 0 < announce_period < time.time() - last_announce: 211 | last_announce = time.time() 212 | _destination.announce() 213 | if len(session.ListenerSession.sessions) > 0: 214 | # no sleep if there's work to do 215 | if not await session.ListenerSession.pump_all(): 216 | await sleeper.sleep_async() 217 | else: 218 | await asyncio.sleep(0.25) 219 | finally: 220 | log.warning("Shutting down") 221 | await session.ListenerSession.terminate_all("Shutting down") 222 | await asyncio.sleep(1) 223 | links_still_active = list(filter(lambda l: l.status != RNS.Link.CLOSED, _destination.links)) 224 | for link in links_still_active: 225 | if link.status not in [RNS.Link.CLOSED]: 226 | link.teardown() 227 | await asyncio.sleep(0.01) -------------------------------------------------------------------------------- /rnsh/loop.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import functools 3 | from typing import Callable 4 | 5 | 6 | def sig_handler_sys_to_loop(handler: Callable[[int, any], None]) -> Callable[[int, asyncio.AbstractEventLoop], None]: 7 | def wrapped(cb: Callable[[int, any], None], signal: int, loop: asyncio.AbstractEventLoop): 8 | cb(signal, None) 9 | return functools.partial(wrapped, handler) 10 | 11 | 12 | def loop_set_signal(sig, handler: Callable[[int, asyncio.AbstractEventLoop], None], loop: asyncio.AbstractEventLoop = None): 13 | if loop is None: 14 | loop = asyncio.get_running_loop() 15 | loop.remove_signal_handler(sig) 16 | loop.add_signal_handler(sig, functools.partial(handler, sig, loop)) -------------------------------------------------------------------------------- /rnsh/process.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2023 Aaron Heise 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | from __future__ import annotations 24 | import asyncio 25 | import contextlib 26 | import copy 27 | import errno 28 | import fcntl 29 | import functools 30 | import logging as __logging 31 | import os 32 | import pty 33 | import select 34 | import signal 35 | import struct 36 | import sys 37 | import termios 38 | import threading 39 | import tty 40 | import types 41 | import typing 42 | 43 | import rnsh.exception as exception 44 | 45 | module_logger = __logging.getLogger(__name__) 46 | 47 | CTRL_C = "\x03".encode("utf-8") 48 | CTRL_D = "\x04".encode("utf-8") 49 | 50 | def tty_add_reader_callback(fd: int, callback: callable, loop: asyncio.AbstractEventLoop = None): 51 | """ 52 | Add an async reader callback for a tty file descriptor. 53 | 54 | Example usage: 55 | 56 | def reader(): 57 | data = tty_read(fd) 58 | # do something with data 59 | 60 | tty_add_reader_callback(self._child_fd, reader, self._loop) 61 | 62 | :param fd: file descriptor 63 | :param callback: callback function 64 | :param loop: asyncio event loop to which the reader should be added. If None, use the currently-running loop. 65 | """ 66 | if loop is None: 67 | loop = asyncio.get_running_loop() 68 | loop.add_reader(fd, callback) 69 | 70 | 71 | def tty_read(fd: int) -> bytes: 72 | """ 73 | Read available bytes from a tty file descriptor. When used in a callback added to a file descriptor using 74 | tty_add_reader_callback(...), this function creates a solution for non-blocking reads from ttys. 75 | :param fd: tty file descriptor 76 | :return: bytes read 77 | """ 78 | if fd_is_closed(fd): 79 | raise EOFError 80 | 81 | try: 82 | run = True 83 | result = bytearray() 84 | while not fd_is_closed(fd): 85 | ready, _, _ = select.select([fd], [], [], 0) 86 | if len(ready) == 0: 87 | break 88 | for f in ready: 89 | try: 90 | data = os.read(f, 4096) 91 | except OSError as e: 92 | if e.errno != errno.EIO and e.errno != errno.EWOULDBLOCK: 93 | raise 94 | else: 95 | if not data: # EOF 96 | if data is not None and len(data) > 0: 97 | result.extend(data) 98 | return result 99 | elif len(result) > 0: 100 | return result 101 | else: 102 | raise EOFError 103 | if data is not None and len(data) > 0: 104 | result.extend(data) 105 | return result 106 | except EOFError: 107 | raise 108 | except Exception as ex: 109 | module_logger.error("tty_read error: {ex}") 110 | 111 | 112 | def tty_read_poll(fd: int) -> bytes: 113 | """ 114 | Read available bytes from a tty file descriptor. When used in a callback added to a file descriptor using 115 | tty_add_reader_callback(...), this function creates a solution for non-blocking reads from ttys. 116 | :param fd: tty file descriptor 117 | :return: bytes read 118 | """ 119 | if fd_is_closed(fd): 120 | raise EOFError 121 | 122 | result = bytearray() 123 | try: 124 | flags = fcntl.fcntl(fd, fcntl.F_GETFL) 125 | fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) 126 | try: 127 | data = os.read(fd, 4096) 128 | result.extend(data or []) 129 | except OSError as e: 130 | if e.errno != errno.EIO and e.errno != errno.EWOULDBLOCK: 131 | raise 132 | elif e.errno == errno.EIO: 133 | raise EOFError 134 | except EOFError: 135 | raise 136 | except Exception as ex: 137 | module_logger.error(f"tty_read error: {ex}") 138 | return result 139 | 140 | 141 | def fd_is_closed(fd: int) -> bool: 142 | """ 143 | Check if file descriptor is closed 144 | :param fd: file descriptor 145 | :return: True if file descriptor is closed 146 | """ 147 | try: 148 | fcntl.fcntl(fd, fcntl.F_GETFL) < 0 149 | except OSError as ose: 150 | return ose.errno == errno.EBADF 151 | 152 | 153 | def tty_unset_reader_callbacks(fd: int, loop: asyncio.AbstractEventLoop = None): 154 | """ 155 | Remove async reader callbacks for file descriptor. 156 | :param fd: file descriptor 157 | :param loop: asyncio event loop from which to remove callbacks 158 | """ 159 | with exception.permit(SystemExit): 160 | if loop is None: 161 | loop = asyncio.get_running_loop() 162 | loop.remove_reader(fd) 163 | 164 | 165 | def tty_get_winsize(fd: int) -> [int, int, int, int]: 166 | """ 167 | Ge the window size of a tty. 168 | :param fd: file descriptor of tty 169 | :return: (rows, cols, h_pixels, v_pixels) 170 | """ 171 | packed = fcntl.ioctl(fd, termios.TIOCGWINSZ, struct.pack('HHHH', 0, 0, 0, 0)) 172 | rows, cols, h_pixels, v_pixels = struct.unpack('HHHH', packed) 173 | return rows, cols, h_pixels, v_pixels 174 | 175 | 176 | def tty_set_winsize(fd: int, rows: int, cols: int, h_pixels: int, v_pixels: int): 177 | """ 178 | Set the window size on a tty. 179 | :param fd: file descriptor of tty 180 | :param rows: number of visible rows 181 | :param cols: number of visible columns 182 | :param h_pixels: number of visible horizontal pixels 183 | :param v_pixels: number of visible vertical pixels 184 | """ 185 | if fd < 0: 186 | return 187 | packed = struct.pack('HHHH', rows, cols, h_pixels, v_pixels) 188 | fcntl.ioctl(fd, termios.TIOCSWINSZ, packed) 189 | 190 | 191 | def process_exists(pid) -> bool: 192 | """ 193 | Check For the existence of a unix pid. 194 | :param pid: process id to check 195 | :return: True if process exists 196 | """ 197 | try: 198 | os.kill(pid, 0) 199 | except OSError: 200 | return False 201 | else: 202 | return True 203 | 204 | 205 | class TTYRestorer(contextlib.AbstractContextManager): 206 | # Indexes of flags within the attrs array 207 | ATTR_IDX_IFLAG = 0 208 | ATTR_IDX_OFLAG = 1 209 | ATTR_IDX_CFLAG = 2 210 | ATTR_IDX_LFLAG = 4 211 | ATTR_IDX_CC = 5 212 | 213 | def __init__(self, fd: int, suppress_logs=False): 214 | """ 215 | Saves termios attributes for a tty for later restoration. 216 | 217 | The attributes are an array of values with the following meanings. 218 | 219 | tcflag_t c_iflag; /* input modes */ 220 | tcflag_t c_oflag; /* output modes */ 221 | tcflag_t c_cflag; /* control modes */ 222 | tcflag_t c_lflag; /* local modes */ 223 | cc_t c_cc[NCCS]; /* special characters */ 224 | 225 | :param fd: file descriptor of tty 226 | """ 227 | self._log = module_logger.getChild(self.__class__.__name__) 228 | self._fd = fd 229 | self._tattr = None 230 | self._suppress_logs = suppress_logs 231 | self._tattr = self.current_attr() 232 | if not self._tattr and not self._suppress_logs: 233 | self._log.debug(f"Could not get attrs for fd {fd}") 234 | 235 | def raw(self): 236 | """ 237 | Set raw mode on tty 238 | """ 239 | if self._fd is None: 240 | return 241 | with contextlib.suppress(termios.error): 242 | tty.setraw(self._fd, termios.TCSANOW) 243 | 244 | def original_attr(self) -> [any]: 245 | return copy.deepcopy(self._tattr) 246 | 247 | def current_attr(self) -> [any]: 248 | """ 249 | Get the current termios attributes for the wrapped fd. 250 | :return: attribute array 251 | """ 252 | if self._fd is None: 253 | return None 254 | 255 | with contextlib.suppress(termios.error): 256 | return copy.deepcopy(termios.tcgetattr(self._fd)) 257 | return None 258 | 259 | def set_attr(self, attr: [any], when: int = termios.TCSADRAIN): 260 | """ 261 | Set termios attributes 262 | :param attr: attribute list to set 263 | :param when: when attributes should be applied (termios.TCSANOW, termios.TCSADRAIN, termios.TCSAFLUSH) 264 | """ 265 | if not attr or self._fd is None: 266 | return 267 | 268 | with contextlib.suppress(termios.error): 269 | termios.tcsetattr(self._fd, when, attr) 270 | 271 | def isatty(self): 272 | return os.isatty(self._fd) if self._fd is not None else None 273 | 274 | def restore(self): 275 | """ 276 | Restore termios settings to state captured in constructor. 277 | """ 278 | self.set_attr(self._tattr, termios.TCSADRAIN) 279 | 280 | def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException, 281 | __traceback: types.TracebackType) -> bool: 282 | self.restore() 283 | return False #__exc_type is not None and issubclass(__exc_type, termios.error) 284 | 285 | 286 | def _task_from_event(evt: asyncio.Event, loop: asyncio.AbstractEventLoop = None): 287 | if not loop: 288 | loop = asyncio.get_running_loop() 289 | 290 | #TODO: this is hacky 291 | async def wait(): 292 | while not evt.is_set(): 293 | await asyncio.sleep(0.1) 294 | return True 295 | 296 | return loop.create_task(wait()) 297 | 298 | 299 | class AggregateException(Exception): 300 | def __init__(self, inner_exceptions: [Exception]): 301 | super().__init__() 302 | self.inner_exceptions = inner_exceptions 303 | 304 | def __str__(self): 305 | return "Multiple exceptions encountered: \n\n" + "\n\n".join(map(lambda e: str(e), self.inner_exceptions)) 306 | 307 | 308 | async def event_wait_any(evts: [asyncio.Event], timeout: float = None) -> (any, any): 309 | tasks = list(map(lambda evt: (evt, _task_from_event(evt)), evts)) 310 | try: 311 | finished, unfinished = await asyncio.wait(map(lambda t: t[1], tasks), 312 | timeout=timeout, 313 | return_when=asyncio.FIRST_COMPLETED) 314 | 315 | if len(unfinished) > 0: 316 | for task in unfinished: 317 | task.cancel() 318 | await asyncio.wait(unfinished) 319 | 320 | exceptions = [] 321 | 322 | for f in finished: 323 | ex = f.exception() 324 | if ex and not isinstance(ex, asyncio.CancelledError) and not isinstance(ex, TimeoutError): 325 | exceptions.append(ex) 326 | 327 | if len(exceptions) > 0: 328 | raise AggregateException(exceptions) 329 | 330 | return next(map(lambda t: next(map(lambda tt: tt[0], tasks)), finished), None) 331 | finally: 332 | unfinished = [] 333 | for task in map(lambda t: t[1], tasks): 334 | if task.done(): 335 | if not task.cancelled(): 336 | task.exception() 337 | else: 338 | task.cancel() 339 | unfinished.append(task) 340 | if len(unfinished) > 0: 341 | await asyncio.wait(unfinished) 342 | 343 | 344 | async def event_wait(evt: asyncio.Event, timeout: float) -> bool: 345 | """ 346 | Wait for event to be set, or timeout to expire. 347 | :param evt: asyncio.Event to wait on 348 | :param timeout: maximum number of seconds to wait. 349 | :return: True if event was set, False if timeout expired 350 | """ 351 | await event_wait_any([evt], timeout=timeout) 352 | return evt.is_set() 353 | 354 | 355 | def _launch_child(cmd_line: list[str], env: dict[str, str], stdin_is_pipe: bool, stdout_is_pipe: bool, 356 | stderr_is_pipe: bool) -> tuple[int, int, int, int]: 357 | # Set up PTY and/or pipes 358 | child_fd = parent_fd = None 359 | if not (stdin_is_pipe and stdout_is_pipe and stderr_is_pipe): 360 | parent_fd, child_fd = pty.openpty() 361 | child_stdin, parent_stdin = (os.pipe() if stdin_is_pipe else (child_fd, parent_fd)) 362 | parent_stdout, child_stdout = (os.pipe() if stdout_is_pipe else (parent_fd, child_fd)) 363 | parent_stderr, child_stderr = (os.pipe() if stderr_is_pipe else (parent_fd, child_fd)) 364 | 365 | # Fork 366 | pid = os.fork() 367 | 368 | if pid == 0: 369 | try: 370 | # We are in the child process, so close all open sockets and pipes except for the PTY and/or pipes 371 | max_fd = os.sysconf("SC_OPEN_MAX") 372 | for fd in range(3, max_fd): 373 | if fd not in (child_stdin, child_stdout, child_stderr): 374 | try: 375 | os.close(fd) 376 | except OSError: 377 | pass 378 | 379 | # Set up PTY and/or pipes 380 | os.dup2(child_stdin, 0) 381 | os.dup2(child_stdout, 1) 382 | os.dup2(child_stderr, 2) 383 | # Make PTY controlling if necessary 384 | if child_fd is not None: 385 | os.setsid() 386 | try: 387 | tmp_fd = os.open(os.ttyname(0 if not stdin_is_pipe else 1 if not stdout_is_pipe else 2), os.O_RDWR) 388 | os.close(tmp_fd) 389 | except: 390 | pass 391 | # fcntl.ioctl(0 if not stdin_is_pipe else 1 if not stdout_is_pipe else 2), os.O_RDWR, termios.TIOCSCTTY, 0) 392 | 393 | # Execute the command 394 | os.execvpe(cmd_line[0], cmd_line, env) 395 | except Exception as err: 396 | exc_type, exc_obj, exc_tb = sys.exc_info() 397 | fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1] 398 | print(f"Unable to start {cmd_line[0]}: {err} ({fname}:{exc_tb.tb_lineno})") 399 | sys.stdout.flush() 400 | # don't let any other modules get in our way, do an immediate silent exit. 401 | os._exit(255) 402 | 403 | else: 404 | # We are in the parent process, so close the child-side of the PTY and/or pipes 405 | if child_fd is not None: 406 | os.close(child_fd) 407 | if child_stdin != child_fd: 408 | os.close(child_stdin) 409 | if child_stdout != child_fd: 410 | os.close(child_stdout) 411 | if child_stderr != child_fd: 412 | os.close(child_stderr) 413 | # # Close the write end of the pipe if a pipe is used for standard input 414 | # if not stdin_is_pipe: 415 | # os.close(parent_stdin) 416 | # Return the child PID and the file descriptors for the PTY and/or pipes 417 | return pid, parent_stdin, parent_stdout, parent_stderr 418 | 419 | 420 | class CallbackSubprocess: 421 | # time between checks of child process 422 | PROCESS_POLL_TIME: float = 0.1 423 | PROCESS_PIPE_TIME: int = 60 424 | 425 | def __init__(self, argv: [str], env: dict, loop: asyncio.AbstractEventLoop, stdout_callback: callable, 426 | stderr_callback: callable, terminated_callback: callable, stdin_is_pipe: bool, stdout_is_pipe: bool, 427 | stderr_is_pipe: bool): 428 | """ 429 | Fork a child process and generate callbacks with output from the process. 430 | :param argv: the command line, tokenized. The first element must be the absolute path to an executable file. 431 | :param env: environment variables to override 432 | :param loop: the asyncio event loop to use 433 | :param stdout_callback: callback for data, e.g. def callback(data:bytes) -> None 434 | :param terminated_callback: callback for termination/return code, e.g. def callback(return_code:int) -> None 435 | """ 436 | assert loop is not None, "loop should not be None" 437 | assert stdout_callback is not None, "stdout_callback should not be None" 438 | assert terminated_callback is not None, "terminated_callback should not be None" 439 | 440 | self._log = module_logger.getChild(self.__class__.__name__) 441 | # self._log.debug(f"__init__({argv},{term},...") 442 | self._command: [str] = argv 443 | self._env = env or {} 444 | self._loop = loop 445 | self._stdout_cb = stdout_callback 446 | self._stderr_cb = stderr_callback 447 | self._terminated_cb = terminated_callback 448 | self._pid: int = None 449 | self._child_stdin: int = None 450 | self._child_stdout: int = None 451 | self._child_stderr: int = None 452 | self._return_code: int = None 453 | self._stdout_eof: bool = False 454 | self._stderr_eof: bool = False 455 | self._stdin_is_pipe = stdin_is_pipe 456 | self._stdout_is_pipe = stdout_is_pipe 457 | self._stderr_is_pipe = stderr_is_pipe 458 | 459 | def _ensure_pipes_closed(self): 460 | stdin = self._child_stdin 461 | stdout = self._child_stdout 462 | stderr = self._child_stderr 463 | fds = set(filter(lambda x: x is not None, list({stdin, stdout, stderr}))) 464 | self._log.debug(f"Queuing close of pipes for ended process (fds: {fds})") 465 | 466 | def ensure_pipes_closed_inner(): 467 | self._log.debug(f"Ensuring pipes are closed (fds: {fds})") 468 | for fd in fds: 469 | self._log.debug(f"Closing fd {fd}") 470 | with contextlib.suppress(OSError): 471 | tty_unset_reader_callbacks(fd) 472 | with contextlib.suppress(OSError): 473 | os.close(fd) 474 | 475 | self._child_stdin = None 476 | self._child_stdout = None 477 | self._child_stderr = None 478 | 479 | self._loop.call_later(CallbackSubprocess.PROCESS_PIPE_TIME, ensure_pipes_closed_inner) 480 | 481 | def terminate(self, kill_delay: float = 1.0): 482 | """ 483 | Terminate child process if running 484 | :param kill_delay: if after kill_delay seconds the child process has not exited, escalate to SIGHUP and SIGKILL 485 | """ 486 | 487 | self._log.debug("terminate()") 488 | if not self.running: 489 | return 490 | 491 | with exception.permit(SystemExit): 492 | os.kill(self._pid, signal.SIGTERM) 493 | 494 | def kill(): 495 | if process_exists(self._pid): 496 | self._log.debug("kill()") 497 | with exception.permit(SystemExit): 498 | os.kill(self._pid, signal.SIGHUP) 499 | os.kill(self._pid, signal.SIGKILL) 500 | 501 | self._loop.call_later(kill_delay, kill) 502 | 503 | def wait(): 504 | self._log.debug("wait()") 505 | with contextlib.suppress(OSError): 506 | os.waitpid(self._pid, 0) 507 | self._ensure_pipes_closed() 508 | self._log.debug("wait() finish") 509 | 510 | threading.Thread(target=wait, daemon=True).start() 511 | 512 | def close_stdin(self): 513 | with contextlib.suppress(Exception): 514 | os.close(self._child_stdin) 515 | 516 | @property 517 | def started(self) -> bool: 518 | """ 519 | :return: True if child process has been started 520 | """ 521 | return self._pid is not None 522 | 523 | @property 524 | def running(self) -> bool: 525 | """ 526 | :return: True if child process is still running 527 | """ 528 | return self._pid is not None and process_exists(self._pid) 529 | 530 | def write(self, data: bytes): 531 | """ 532 | Write bytes to the stdin of the child process. 533 | :param data: bytes to write 534 | """ 535 | os.write(self._child_stdin, data) 536 | 537 | def set_winsize(self, r: int, c: int, h: int, v: int): 538 | """ 539 | Set the window size on the tty of the child process. 540 | :param r: rows visible 541 | :param c: columns visible 542 | :param h: horizontal pixels visible 543 | :param v: vertical pixels visible 544 | :return: 545 | """ 546 | self._log.debug(f"set_winsize({r},{c},{h},{v}") 547 | tty_set_winsize(self._child_stdout, r, c, h, v) 548 | 549 | def copy_winsize(self, fromfd: int): 550 | """ 551 | Copy window size from one tty to another. 552 | :param fromfd: source tty file descriptor 553 | """ 554 | r, c, h, v = tty_get_winsize(fromfd) 555 | self.set_winsize(r, c, h, v) 556 | 557 | def tcsetattr(self, when: int, attr: list[any]): # actual type is list[int | list[int | bytes]] 558 | """ 559 | Set tty attributes. 560 | :param when: when to apply change: termios.TCSANOW or termios.TCSADRAIN or termios.TCSAFLUSH 561 | :param attr: attributes to set 562 | """ 563 | termios.tcsetattr(self._child_stdin, when, attr) 564 | 565 | def tcgetattr(self) -> list[any]: # actual type is list[int | list[int | bytes]] 566 | """ 567 | Get tty attributes. 568 | :return: tty attributes value 569 | """ 570 | return termios.tcgetattr(self._child_stdout) 571 | 572 | def ttysetraw(self): 573 | tty.setraw(self._child_stdout, termios.TCSADRAIN) 574 | 575 | def start(self): 576 | """ 577 | Start the child process. 578 | """ 579 | self._log.debug("start()") 580 | 581 | # # Using the parent environment seems to do some weird stuff, at least on macOS 582 | # parentenv = os.environ.copy() 583 | # env = {"HOME": parentenv["HOME"], 584 | # "PATH": parentenv["PATH"], 585 | # "TERM": self._term if self._term is not None else parentenv.get("TERM", "xterm"), 586 | # "LANG": parentenv.get("LANG"), 587 | # "SHELL": self._command[0]} 588 | 589 | env = os.environ.copy() 590 | for key in self._env: 591 | env[key] = self._env[key] 592 | 593 | program = self._command[0] 594 | assert isinstance(program, str) 595 | 596 | # match = re.search("^/bin/(.*sh)$", program) 597 | # if match: 598 | # self._command[0] = "-" + match.group(1) 599 | # env["SHELL"] = program 600 | # self._log.debug(f"set login shell {self._command}") 601 | 602 | self._pid, \ 603 | self._child_stdin, \ 604 | self._child_stdout, \ 605 | self._child_stderr = _launch_child(self._command, env, self._stdin_is_pipe, self._stdout_is_pipe, 606 | self._stderr_is_pipe) 607 | self._log.debug("Started pid %d, fds: %d, %d, %d", self.pid, self._child_stdin, self._child_stdout, self._child_stderr) 608 | 609 | def poll(): 610 | # self.log.debug("poll") 611 | try: 612 | pid, self._return_code = os.waitpid(self._pid, os.WNOHANG) 613 | if self._return_code is not None: 614 | self._return_code = self._return_code & 0xff 615 | if self._return_code is not None and not process_exists(self._pid): 616 | self._log.debug(f"polled return code {self._return_code}") 617 | self._terminated_cb(self._return_code) 618 | if self.running: 619 | self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll) 620 | else: 621 | self._ensure_pipes_closed() 622 | except Exception as e: 623 | if not hasattr(e, "errno") or e.errno != errno.ECHILD: 624 | self._log.debug(f"Error in process poll: {e}") 625 | 626 | self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll) 627 | 628 | def stdout(): 629 | try: 630 | with exception.permit(SystemExit): 631 | data = tty_read_poll(self._child_stdout) 632 | if data is not None and len(data) > 0: 633 | self._stdout_cb(data) 634 | except EOFError: 635 | self._stdout_eof = True 636 | tty_unset_reader_callbacks(self._child_stdout) 637 | self._stdout_cb(bytearray()) 638 | 639 | def stderr(): 640 | try: 641 | with exception.permit(SystemExit): 642 | data = tty_read_poll(self._child_stderr) 643 | if data is not None and len(data) > 0: 644 | self._stderr_cb(data) 645 | except EOFError: 646 | self._stderr_eof = True 647 | tty_unset_reader_callbacks(self._child_stderr) 648 | self._stdout_cb(bytearray()) 649 | 650 | tty_add_reader_callback(self._child_stdout, stdout, self._loop) 651 | if self._child_stderr != self._child_stdout: 652 | tty_add_reader_callback(self._child_stderr, stderr, self._loop) 653 | 654 | @property 655 | def stdout_eof(self): 656 | return self._stdout_eof or not self.running 657 | 658 | @property 659 | def stderr_eof(self): 660 | return self._stderr_eof or not self.running 661 | 662 | 663 | @property 664 | def return_code(self) -> int: 665 | return self._return_code 666 | 667 | @property 668 | def pid(self) -> int: 669 | return self._pid 670 | 671 | 672 | async def main(): 673 | """ 674 | A test driver for the CallbackProcess class. 675 | python ./process.py /bin/zsh --login 676 | """ 677 | 678 | log = module_logger.getChild("main") 679 | if len(sys.argv) <= 1: 680 | print(f"Usage: {sys.argv} [child_arg ...]") 681 | exit(1) 682 | 683 | loop = asyncio.get_event_loop() 684 | # asyncio.set_event_loop(loop) 685 | retcode = loop.create_future() 686 | 687 | def stdout(data: bytes): 688 | # log.debug("stdout") 689 | os.write(sys.stdout.fileno(), data) 690 | # sys.stdout.flush() 691 | 692 | def terminated(rc: int): 693 | # log.debug(f"terminated {rc}") 694 | retcode.set_result(rc) 695 | 696 | process = CallbackSubprocess(argv=sys.argv[1:], 697 | env={"TERM": os.environ.get("TERM", "xterm")}, 698 | loop=loop, 699 | stdout_callback=stdout, 700 | terminated_callback=terminated) 701 | 702 | def sigint_handler(sig, frame): 703 | # log.debug("KeyboardInterrupt") 704 | if process is None or process.started and not process.running: 705 | raise KeyboardInterrupt 706 | elif process.running: 707 | process.write("\x03".encode("utf-8")) 708 | 709 | def sigwinch_handler(sig, frame): 710 | # log.debug("WindowChanged") 711 | process.copy_winsize(sys.stdin.fileno()) 712 | 713 | signal.signal(signal.SIGINT, sigint_handler) 714 | signal.signal(signal.SIGWINCH, sigwinch_handler) 715 | 716 | def stdin(): 717 | try: 718 | data = tty_read(sys.stdin.fileno()) 719 | # log.debug(f"stdin {data}") 720 | if data is not None: 721 | process.write(data) 722 | # sys.stdout.buffer.write(data) 723 | except EOFError: 724 | tty_unset_reader_callbacks(sys.stdin.fileno()) 725 | process.write(CTRL_D) 726 | 727 | tty_add_reader_callback(sys.stdin.fileno(), stdin) 728 | process.start() 729 | # call_soon called it too soon, not sure why. 730 | loop.call_later(0.001, functools.partial(process.copy_winsize, sys.stdin.fileno())) 731 | 732 | val = await retcode 733 | log.debug(f"got retcode {val}") 734 | return val 735 | 736 | 737 | if __name__ == "__main__": 738 | tr = TTYRestorer(sys.stdin.fileno()) 739 | try: 740 | tr.raw() 741 | asyncio.run(main()) 742 | finally: 743 | tty_unset_reader_callbacks(sys.stdin.fileno()) 744 | tr.restore() 745 | -------------------------------------------------------------------------------- /rnsh/protocol.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import enum 4 | import queue 5 | import threading 6 | import time 7 | import typing 8 | import uuid 9 | from types import TracebackType 10 | from typing import Type, Callable, TypeVar, Tuple 11 | import RNS 12 | from RNS.vendor import umsgpack 13 | from RNS.Buffer import StreamDataMessage as RNSStreamDataMessage 14 | import rnsh.retry 15 | import abc 16 | import contextlib 17 | import struct 18 | import logging as __logging 19 | from abc import ABC, abstractmethod 20 | 21 | module_logger = __logging.getLogger(__name__) 22 | 23 | MSG_MAGIC = 0xac 24 | PROTOCOL_VERSION = 1 25 | 26 | 27 | def _make_MSGTYPE(val: int): 28 | return ((MSG_MAGIC << 8) & 0xff00) | (val & 0x00ff) 29 | 30 | 31 | class NoopMessage(RNS.MessageBase): 32 | MSGTYPE = _make_MSGTYPE(0) 33 | 34 | def pack(self) -> bytes: 35 | return bytes() 36 | 37 | def unpack(self, raw): 38 | pass 39 | 40 | 41 | class WindowSizeMessage(RNS.MessageBase): 42 | MSGTYPE = _make_MSGTYPE(2) 43 | 44 | def __init__(self, rows: int = None, cols: int = None, hpix: int = None, vpix: int = None): 45 | super().__init__() 46 | self.rows = rows 47 | self.cols = cols 48 | self.hpix = hpix 49 | self.vpix = vpix 50 | 51 | def pack(self) -> bytes: 52 | return umsgpack.packb((self.rows, self.cols, self.hpix, self.vpix)) 53 | 54 | def unpack(self, raw): 55 | self.rows, self.cols, self.hpix, self.vpix = umsgpack.unpackb(raw) 56 | 57 | 58 | class ExecuteCommandMesssage(RNS.MessageBase): 59 | MSGTYPE = _make_MSGTYPE(3) 60 | 61 | def __init__(self, cmdline: [str] = None, pipe_stdin: bool = False, pipe_stdout: bool = False, 62 | pipe_stderr: bool = False, tcflags: [any] = None, term: str | None = None, rows: int = None, 63 | cols: int = None, hpix: int = None, vpix: int = None): 64 | super().__init__() 65 | self.cmdline = cmdline 66 | self.pipe_stdin = pipe_stdin 67 | self.pipe_stdout = pipe_stdout 68 | self.pipe_stderr = pipe_stderr 69 | self.tcflags = tcflags 70 | self.term = term 71 | self.rows = rows 72 | self.cols = cols 73 | self.hpix = hpix 74 | self.vpix = vpix 75 | 76 | def pack(self) -> bytes: 77 | return umsgpack.packb((self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr, 78 | self.tcflags, self.term, self.rows, self.cols, self.hpix, self.vpix)) 79 | 80 | def unpack(self, raw): 81 | self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr, self.tcflags, self.term, self.rows, \ 82 | self.cols, self.hpix, self.vpix = umsgpack.unpackb(raw) 83 | 84 | 85 | # Create a version of RNS.Buffer.StreamDataMessage that we control 86 | class StreamDataMessage(RNSStreamDataMessage): 87 | MSGTYPE = _make_MSGTYPE(4) 88 | STREAM_ID_STDIN = 0 89 | STREAM_ID_STDOUT = 1 90 | STREAM_ID_STDERR = 2 91 | 92 | 93 | class VersionInfoMessage(RNS.MessageBase): 94 | MSGTYPE = _make_MSGTYPE(5) 95 | 96 | def __init__(self, sw_version: str = None): 97 | super().__init__() 98 | self.sw_version = sw_version or rnsh.__version__ 99 | self.protocol_version = PROTOCOL_VERSION 100 | 101 | def pack(self) -> bytes: 102 | return umsgpack.packb((self.sw_version, self.protocol_version)) 103 | 104 | def unpack(self, raw): 105 | self.sw_version, self.protocol_version = umsgpack.unpackb(raw) 106 | 107 | 108 | class ErrorMessage(RNS.MessageBase): 109 | MSGTYPE = _make_MSGTYPE(6) 110 | 111 | def __init__(self, msg: str = None, fatal: bool = False, data: dict = None): 112 | super().__init__() 113 | self.msg = msg 114 | self.fatal = fatal 115 | self.data = data 116 | 117 | def pack(self) -> bytes: 118 | return umsgpack.packb((self.msg, self.fatal, self.data)) 119 | 120 | def unpack(self, raw: bytes): 121 | self.msg, self.fatal, self.data = umsgpack.unpackb(raw) 122 | 123 | 124 | class CommandExitedMessage(RNS.MessageBase): 125 | MSGTYPE = _make_MSGTYPE(7) 126 | 127 | def __init__(self, return_code: int = None): 128 | super().__init__() 129 | self.return_code = return_code 130 | 131 | def pack(self) -> bytes: 132 | return umsgpack.packb(self.return_code) 133 | 134 | def unpack(self, raw: bytes): 135 | self.return_code = umsgpack.unpackb(raw) 136 | 137 | 138 | message_types = [NoopMessage, VersionInfoMessage, WindowSizeMessage, ExecuteCommandMesssage, StreamDataMessage, 139 | CommandExitedMessage, ErrorMessage] 140 | 141 | 142 | def register_message_types(channel: RNS.Channel.Channel): 143 | for message_type in message_types: 144 | channel.register_message_type(message_type) -------------------------------------------------------------------------------- /rnsh/retry.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2023 Aaron Heise 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import asyncio 24 | import logging 25 | import threading 26 | import time 27 | import rnsh.exception as exception 28 | import logging as __logging 29 | from typing import Callable 30 | from contextlib import AbstractContextManager 31 | import types 32 | import typing 33 | 34 | module_logger = __logging.getLogger(__name__) 35 | 36 | 37 | class RetryStatus: 38 | def __init__(self, tag: any, try_limit: int, wait_delay: float, retry_callback: Callable[[any, int], any], 39 | timeout_callback: Callable[[any, int], None], tries: int = 1): 40 | self._log = module_logger.getChild(self.__class__.__name__) 41 | self._log.setLevel(logging.INFO) 42 | self.tag = tag 43 | self.try_limit = try_limit 44 | self.tries = tries 45 | self.wait_delay = wait_delay 46 | self.retry_callback = retry_callback 47 | self.timeout_callback = timeout_callback 48 | self.try_time = time.time() 49 | self.completed = False 50 | 51 | @property 52 | def ready(self): 53 | ready = time.time() > self.try_time + self.wait_delay 54 | self._log.debug(f"ready check {self.tag} try_time {self.try_time} wait_delay {self.wait_delay} " + 55 | f"next_try {self.try_time + self.wait_delay} now {time.time()} " + 56 | f"exceeded {time.time() - self.try_time - self.wait_delay} ready {ready}") 57 | return ready 58 | 59 | @property 60 | def timed_out(self): 61 | return self.ready and self.tries >= self.try_limit 62 | 63 | def timeout(self): 64 | self.completed = True 65 | self.timeout_callback(self.tag, self.tries) 66 | 67 | def retry(self) -> any: 68 | self.tries = self.tries + 1 69 | self.try_time = time.time() 70 | return self.retry_callback(self.tag, self.tries) 71 | 72 | 73 | class RetryThread(AbstractContextManager): 74 | def __init__(self, loop_period: float = 0.25, name: str = "retry thread"): 75 | self._log = module_logger.getChild(self.__class__.__name__) 76 | self._loop_period = loop_period 77 | self._statuses: list[RetryStatus] = [] 78 | self._tag_counter = 0 79 | self._lock = threading.RLock() 80 | self._run = True 81 | self._finished: asyncio.Future = None 82 | self._thread = threading.Thread(name=name, target=self._thread_run, daemon=True) 83 | self._thread.start() 84 | 85 | def is_alive(self): 86 | return self._thread.is_alive() 87 | 88 | def close(self, loop: asyncio.AbstractEventLoop = None) -> asyncio.Future: 89 | self._log.debug("stopping timer thread") 90 | if loop is None: 91 | self._run = False 92 | self._thread.join() 93 | return None 94 | else: 95 | self._finished = loop.create_future() 96 | return self._finished 97 | 98 | def wait(self, timeout: float = None): 99 | if timeout: 100 | timeout = timeout + time.time() 101 | 102 | while timeout is None or time.time() < timeout: 103 | with self._lock: 104 | task_count = len(self._statuses) 105 | if task_count == 0: 106 | return 107 | time.sleep(0.1) 108 | 109 | 110 | def _thread_run(self): 111 | while self._run and self._finished is None: 112 | time.sleep(self._loop_period) 113 | ready: list[RetryStatus] = [] 114 | prune: list[RetryStatus] = [] 115 | with self._lock: 116 | ready.extend(list(filter(lambda s: s.ready, self._statuses))) 117 | for retry in ready: 118 | try: 119 | if not retry.completed: 120 | if retry.timed_out: 121 | self._log.debug(f"timed out {retry.tag} after {retry.try_limit} tries") 122 | retry.timeout() 123 | prune.append(retry) 124 | elif retry.ready: 125 | self._log.debug(f"retrying {retry.tag}, try {retry.tries + 1}/{retry.try_limit}") 126 | should_continue = retry.retry() 127 | if not should_continue: 128 | self.complete(retry.tag) 129 | except Exception as e: 130 | self._log.error(f"error processing retry id {retry.tag}: {e}") 131 | prune.append(retry) 132 | 133 | with self._lock: 134 | for retry in prune: 135 | self._log.debug(f"pruned retry {retry.tag}, retry count {retry.tries}/{retry.try_limit}") 136 | with exception.permit(SystemExit): 137 | self._statuses.remove(retry) 138 | if self._finished is not None: 139 | self._finished.set_result(None) 140 | 141 | def _get_next_tag(self): 142 | self._tag_counter += 1 143 | return self._tag_counter 144 | 145 | def has_tag(self, tag: any) -> bool: 146 | with self._lock: 147 | return next(filter(lambda s: s.tag == tag, self._statuses), None) is not None 148 | 149 | def begin(self, try_limit: int, wait_delay: float, try_callback: Callable[[any, int], any], 150 | timeout_callback: Callable[[any, int], None]) -> any: 151 | self._log.debug(f"running first try") 152 | tag = try_callback(None, 1) 153 | self._log.debug(f"first try got id {tag}") 154 | if not tag: 155 | self._log.debug(f"callback returned None/False/0, considering complete.") 156 | return None 157 | with self._lock: 158 | if tag is None: 159 | tag = self._get_next_tag() 160 | self.complete(tag) 161 | self._statuses.append(RetryStatus(tag=tag, 162 | tries=1, 163 | try_limit=try_limit, 164 | wait_delay=wait_delay, 165 | retry_callback=try_callback, 166 | timeout_callback=timeout_callback)) 167 | self._log.debug(f"added retry timer for {tag}") 168 | return tag 169 | 170 | def complete(self, tag: any): 171 | assert tag is not None 172 | with self._lock: 173 | status = next(filter(lambda l: l.tag == tag, self._statuses), None) 174 | if status is not None: 175 | status.completed = True 176 | self._statuses.remove(status) 177 | self._log.debug(f"completed {tag}") 178 | return 179 | 180 | self._log.debug(f"status not found to complete {tag}") 181 | 182 | def complete_all(self): 183 | with self._lock: 184 | for status in self._statuses: 185 | status.completed = True 186 | self._log.debug(f"completed {status.tag}") 187 | self._statuses.clear() 188 | 189 | def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException, 190 | __traceback: types.TracebackType) -> bool: 191 | self.close() 192 | return False 193 | -------------------------------------------------------------------------------- /rnsh/rnsh.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # MIT License 4 | # 5 | # Copyright (c) 2016-2022 Mark Qvist / unsigned.io 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | from __future__ import annotations 26 | 27 | import asyncio 28 | import base64 29 | import enum 30 | import functools 31 | import logging as __logging 32 | import os 33 | import queue 34 | import shlex 35 | import signal 36 | import sys 37 | import termios 38 | import threading 39 | import time 40 | import tty 41 | from typing import Callable, TypeVar 42 | import RNS 43 | import rnsh.exception as exception 44 | import rnsh.process as process 45 | import rnsh.retry as retry 46 | import rnsh.rnslogging as rnslogging 47 | import rnsh.session as session 48 | import re 49 | import contextlib 50 | import rnsh.args 51 | import pwd 52 | import rnsh.protocol as protocol 53 | import rnsh.helpers as helpers 54 | import rnsh.loop 55 | import rnsh.listener as listener 56 | import rnsh.initiator as initiator 57 | 58 | module_logger = __logging.getLogger(__name__) 59 | 60 | 61 | def _get_logger(name: str): 62 | global module_logger 63 | return module_logger.getChild(name) 64 | 65 | 66 | APP_NAME = "rnsh" 67 | loop: asyncio.AbstractEventLoop | None = None 68 | 69 | 70 | def _sanitize_service_name(service_name:str) -> str: 71 | return re.sub(r'\W+', '', service_name) 72 | 73 | 74 | def prepare_identity(identity_path, service_name: str = None) -> tuple[RNS.Identity]: 75 | log = _get_logger("_prepare_identity") 76 | service_name = _sanitize_service_name(service_name or "") 77 | if identity_path is None: 78 | identity_path = RNS.Reticulum.identitypath + "/" + APP_NAME + \ 79 | (f".{service_name}" if service_name and len(service_name) > 0 else "") 80 | 81 | identity = None 82 | if os.path.isfile(identity_path): 83 | identity = RNS.Identity.from_file(identity_path) 84 | 85 | if identity is None: 86 | log.info("No valid saved identity found, creating new...") 87 | identity = RNS.Identity() 88 | identity.to_file(identity_path) 89 | return identity 90 | 91 | 92 | def print_identity(configdir, identitypath, service_name, include_destination: bool): 93 | reticulum = RNS.Reticulum(configdir=configdir, loglevel=RNS.LOG_INFO) 94 | if service_name and len(service_name) > 0: 95 | print(f"Using service name \"{service_name}\"") 96 | identity = prepare_identity(identitypath, service_name) 97 | destination = RNS.Destination(identity, RNS.Destination.IN, RNS.Destination.SINGLE, APP_NAME) 98 | print("Identity : " + str(identity)) 99 | if include_destination: 100 | print("Listening on : " + RNS.prettyhexrep(destination.hash)) 101 | exit(0) 102 | 103 | 104 | verbose_set = False 105 | 106 | 107 | async def _rnsh_cli_main(): 108 | global verbose_set 109 | log = _get_logger("main") 110 | _loop = asyncio.get_running_loop() 111 | rnslogging.set_main_loop(_loop) 112 | args = rnsh.args.Args(sys.argv) 113 | verbose_set = args.verbose > 0 114 | 115 | if args.print_identity: 116 | print_identity(args.config, args.identity, args.service_name, args.listen) 117 | return 0 118 | 119 | if args.listen: 120 | allowed_file = None 121 | dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH//8)*2 122 | if os.path.isfile(os.path.expanduser("~/.config/rnsh/allowed_identities")): 123 | allowed_file = os.path.expanduser("~/.config/rnsh/allowed_identities") 124 | elif os.path.isfile(os.path.expanduser("~/.rnsh/allowed_identities")): 125 | allowed_file = os.path.expanduser("~/.rnsh/allowed_identities") 126 | 127 | await listener.listen(configdir=args.config, 128 | command=args.command_line, 129 | identitypath=args.identity, 130 | service_name=args.service_name, 131 | verbosity=args.verbose, 132 | quietness=args.quiet, 133 | allowed=args.allowed, 134 | allowed_file=allowed_file, 135 | disable_auth=args.no_auth, 136 | announce_period=args.announce, 137 | no_remote_command=args.no_remote_cmd, 138 | remote_cmd_as_args=args.remote_cmd_as_args) 139 | return 0 140 | 141 | if args.destination is not None: 142 | return_code = await initiator.initiate(configdir=args.config, 143 | identitypath=args.identity, 144 | verbosity=args.verbose, 145 | quietness=args.quiet, 146 | noid=args.no_id, 147 | destination=args.destination, 148 | timeout=args.timeout, 149 | command=args.command_line 150 | ) 151 | return return_code if args.mirror else 0 152 | else: 153 | print("") 154 | print(rnsh.args.usage) 155 | print("") 156 | return 1 157 | 158 | 159 | def rnsh_cli(): 160 | global verbose_set 161 | return_code = 1 162 | exc = None 163 | try: 164 | return_code = asyncio.run(_rnsh_cli_main()) 165 | except SystemExit: 166 | pass 167 | except KeyboardInterrupt: 168 | pass 169 | except Exception as ex: 170 | print(f"Unhandled exception: {ex}") 171 | exc = ex 172 | process.tty_unset_reader_callbacks(0) 173 | if verbose_set and exc: 174 | raise exc 175 | sys.exit(return_code if return_code is not None else 255) 176 | 177 | 178 | if __name__ == "__main__": 179 | rnsh_cli() 180 | -------------------------------------------------------------------------------- /rnsh/rnslogging.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2023 Aaron Heise 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import asyncio 24 | import logging 25 | import sys 26 | import termios 27 | import rnsh.process as process 28 | from logging import Handler, getLevelName 29 | # from types import GenericAlias 30 | from typing import Any 31 | 32 | import RNS 33 | 34 | import rnsh.exception as exception 35 | 36 | 37 | class RnsHandler(Handler): 38 | """ 39 | A handler class which writes logging records, appropriately formatted, 40 | to the RNS logger. 41 | """ 42 | 43 | def __init__(self): 44 | """ 45 | Initialize the handler. 46 | """ 47 | Handler.__init__(self) 48 | 49 | @staticmethod 50 | def get_rns_loglevel(loglevel: int) -> int: 51 | if loglevel == logging.CRITICAL: 52 | return RNS.LOG_CRITICAL 53 | if loglevel == logging.ERROR: 54 | return RNS.LOG_ERROR 55 | if loglevel == logging.WARNING: 56 | return RNS.LOG_WARNING 57 | if loglevel == logging.INFO: 58 | return RNS.LOG_INFO 59 | if loglevel == logging.DEBUG: 60 | return RNS.LOG_DEBUG 61 | return RNS.LOG_DEBUG 62 | 63 | def get_logging_loglevel(rnsloglevel: int) -> int: 64 | if rnsloglevel == RNS.LOG_CRITICAL: 65 | return logging.CRITICAL 66 | if rnsloglevel == RNS.LOG_ERROR: 67 | return logging.ERROR 68 | if rnsloglevel == RNS.LOG_WARNING: 69 | return logging.WARNING 70 | if rnsloglevel == RNS.LOG_NOTICE: 71 | return logging.INFO 72 | if rnsloglevel == RNS.LOG_INFO: 73 | return logging.INFO 74 | if rnsloglevel >= RNS.LOG_VERBOSE: 75 | return RNS.LOG_DEBUG 76 | return RNS.LOG_DEBUG 77 | 78 | @classmethod 79 | def set_log_level_with_rns_level(cls, rns_log_level: int): 80 | logging.getLogger().setLevel(RnsHandler.get_logging_loglevel(rns_log_level)) 81 | RNS.loglevel = rns_log_level 82 | 83 | def set_log_level_with_logging_level(cls, logging_log_level: int): 84 | logging.getLogger().setLevel(logging_log_level) 85 | RNS.loglevel = cls.get_rns_loglevel(logging_log_level) 86 | 87 | def emit(self, record): 88 | """ 89 | Emit a record. 90 | """ 91 | try: 92 | msg = self.format(record) 93 | 94 | RNS.log(msg, RnsHandler.get_rns_loglevel(record.levelno)) 95 | except RecursionError: # See issue 36272 96 | raise 97 | except Exception: 98 | self.handleError(record) 99 | 100 | def __repr__(self): 101 | level = getLevelName(self.level) 102 | return '<%s (%s)>' % (self.__class__.__name__, level) 103 | 104 | # __class_getitem__ = classmethod(GenericAlias) 105 | 106 | 107 | log_format = '%(name)-30s %(message)s [%(threadName)s]' 108 | 109 | logging.basicConfig( 110 | level=logging.DEBUG, # RNS.log will filter it, but some formatting will still be processed before it gets there 111 | # format='%(asctime)s.%(msecs)03d %(levelname)-6s %(threadName)-15s %(name)-15s %(message)s', 112 | format=log_format, 113 | datefmt='%Y-%m-%d %H:%M:%S', 114 | handlers=[RnsHandler()]) 115 | 116 | _loop: asyncio.AbstractEventLoop = None 117 | 118 | 119 | def set_main_loop(loop: asyncio.AbstractEventLoop): 120 | global _loop 121 | _loop = loop 122 | 123 | 124 | # hack for temporarily overriding term settings to make debug print right 125 | _rns_log_orig = RNS.log 126 | 127 | 128 | def _rns_log(msg, level=3, _override_destination=False): 129 | if RNS.loglevel < level: 130 | return 131 | 132 | if not RNS.compact_log_fmt: 133 | msg = (" " * (7 - len(RNS.loglevelname(level)))) + msg 134 | 135 | def _rns_log_inner(): 136 | nonlocal msg, level, _override_destination 137 | try: 138 | with process.TTYRestorer(sys.stdin.fileno(), suppress_logs=True) as tr: 139 | attr = tr.current_attr() 140 | if attr: 141 | attr[process.TTYRestorer.ATTR_IDX_OFLAG] = attr[process.TTYRestorer.ATTR_IDX_OFLAG] | \ 142 | termios.ONLRET | termios.ONLCR | termios.OPOST 143 | tr.set_attr(attr) 144 | _rns_log_orig(msg, level, _override_destination) 145 | except: 146 | _rns_log_orig(msg, level, _override_destination) 147 | 148 | # TODO: figure out if forcing this to the main thread actually helps. 149 | try: 150 | if _loop and _loop.is_running(): 151 | _loop.call_soon_threadsafe(_rns_log_inner) 152 | else: 153 | _rns_log_inner() 154 | except: 155 | _rns_log_inner() 156 | 157 | 158 | RNS.log = _rns_log 159 | -------------------------------------------------------------------------------- /rnsh/session.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import contextlib 3 | import functools 4 | import threading 5 | import rnsh.exception as exception 6 | import asyncio 7 | import rnsh.process as process 8 | import rnsh.helpers as helpers 9 | import rnsh.protocol as protocol 10 | import enum 11 | from typing import TypeVar, Generic, Callable, List 12 | from abc import abstractmethod, ABC 13 | from multiprocessing import Manager 14 | import os 15 | import bz2 16 | import RNS 17 | 18 | import logging as __logging 19 | 20 | module_logger = __logging.getLogger(__name__) 21 | 22 | _TLink = TypeVar("_TLink") 23 | 24 | class SEType(enum.IntEnum): 25 | SE_LINK_CLOSED = 0 26 | 27 | 28 | class SessionException(Exception): 29 | def __init__(self, setype: SEType, msg: str, *args): 30 | super().__init__(msg, args) 31 | self.type = setype 32 | 33 | 34 | class LSState(enum.IntEnum): 35 | LSSTATE_WAIT_IDENT = 1 36 | LSSTATE_WAIT_VERS = 2 37 | LSSTATE_WAIT_CMD = 3 38 | LSSTATE_RUNNING = 4 39 | LSSTATE_ERROR = 5 40 | LSSTATE_TEARDOWN = 6 41 | 42 | 43 | _TIdentity = TypeVar("_TIdentity") 44 | 45 | 46 | class LSOutletBase(ABC): 47 | @abstractmethod 48 | def set_initiator_identified_callback(self, cb: Callable[[LSOutletBase, _TIdentity], None]): 49 | raise NotImplemented() 50 | 51 | @abstractmethod 52 | def set_link_closed_callback(self, cb: Callable[[LSOutletBase], None]): 53 | raise NotImplemented() 54 | 55 | @abstractmethod 56 | def unset_link_closed_callback(self): 57 | raise NotImplemented() 58 | 59 | @property 60 | @abstractmethod 61 | def rtt(self): 62 | raise NotImplemented() 63 | 64 | @abstractmethod 65 | def teardown(self): 66 | raise NotImplemented() 67 | 68 | 69 | class ListenerSession: 70 | sessions: List[ListenerSession] = [] 71 | allowed_identity_hashes: [any] = [] 72 | allowed_file_identity_hashes: [any] = [] 73 | allow_all: bool = False 74 | allow_remote_command: bool = False 75 | default_command: [str] = [] 76 | remote_cmd_as_args = False 77 | 78 | def __init__(self, outlet: LSOutletBase, channel: RNS.Channel.Channel, loop: asyncio.AbstractEventLoop): 79 | self._log = module_logger.getChild(self.__class__.__name__) 80 | self._log.info(f"Session started for {outlet}") 81 | self.outlet = outlet 82 | self.channel = channel 83 | self.outlet.set_initiator_identified_callback(self._initiator_identified) 84 | self.outlet.set_link_closed_callback(self._link_closed) 85 | self.loop = loop 86 | self.state: LSState = None 87 | self.remote_identity = None 88 | self.term: str | None = None 89 | self.stdin_is_pipe: bool = False 90 | self.stdout_is_pipe: bool = False 91 | self.stderr_is_pipe: bool = False 92 | self.tcflags: [any] = None 93 | self.cmdline: [str] = None 94 | self.rows: int = 0 95 | self.cols: int = 0 96 | self.hpix: int = 0 97 | self.vpix: int = 0 98 | self.stdout_buf = bytearray() 99 | self.stdout_eof_sent = False 100 | self.stderr_buf = bytearray() 101 | self.stderr_eof_sent = False 102 | self.return_code: int | None = None 103 | self.return_code_sent = False 104 | self.process: process.CallbackSubprocess | None = None 105 | if self.allow_all: 106 | self._set_state(LSState.LSSTATE_WAIT_VERS) 107 | else: 108 | self._set_state(LSState.LSSTATE_WAIT_IDENT) 109 | self.sessions.append(self) 110 | protocol.register_message_types(self.channel) 111 | self.channel.add_message_handler(self._handle_message) 112 | 113 | def _terminated(self, return_code: int): 114 | self.return_code = return_code 115 | 116 | def _set_state(self, state: LSState, timeout_factor: float = 10.0): 117 | timeout = max(self.outlet.rtt * timeout_factor, max(self.outlet.rtt * 2, 10)) if timeout_factor is not None else None 118 | self._log.debug(f"Set state: {state.name}, timeout {timeout}") 119 | orig_state = self.state 120 | self.state = state 121 | if timeout_factor is not None: 122 | self._call(functools.partial(self._check_protocol_timeout, lambda: self.state == orig_state, state.name), timeout) 123 | 124 | def _call(self, func: callable, delay: float = 0): 125 | def call_inner(): 126 | # self._log.debug("call_inner") 127 | if delay == 0: 128 | func() 129 | else: 130 | self.loop.call_later(delay, func) 131 | self.loop.call_soon_threadsafe(call_inner) 132 | 133 | def send(self, message: RNS.MessageBase): 134 | self.channel.send(message) 135 | 136 | def _protocol_error(self, name: str): 137 | self.terminate(f"Protocol error ({name})") 138 | 139 | def _protocol_timeout_error(self, name: str): 140 | self.terminate(f"Protocol timeout error: {name}") 141 | 142 | def terminate(self, error: str = None): 143 | with contextlib.suppress(Exception): 144 | self._log.debug("Terminating session" + (f": {error}" if error else "")) 145 | if error and self.state != LSState.LSSTATE_TEARDOWN: 146 | with contextlib.suppress(Exception): 147 | self.send(protocol.ErrorMessage(error, True)) 148 | self.state = LSState.LSSTATE_ERROR 149 | self._terminate_process() 150 | self._call(self._prune, max(self.outlet.rtt * 3, process.CallbackSubprocess.PROCESS_PIPE_TIME+5)) 151 | 152 | def _prune(self): 153 | self.state = LSState.LSSTATE_TEARDOWN 154 | self._log.debug("Pruning session") 155 | with contextlib.suppress(ValueError): 156 | self.sessions.remove(self) 157 | with contextlib.suppress(Exception): 158 | self.outlet.teardown() 159 | 160 | def _check_protocol_timeout(self, fail_condition: Callable[[], bool], name: str): 161 | timeout = True 162 | try: 163 | timeout = self.state != LSState.LSSTATE_TEARDOWN and fail_condition() 164 | except Exception as ex: 165 | self._log.exception("Error in protocol timeout", ex) 166 | if timeout: 167 | self._protocol_timeout_error(name) 168 | 169 | def _link_closed(self, outlet: LSOutletBase): 170 | outlet.unset_link_closed_callback() 171 | 172 | if outlet != self.outlet: 173 | self._log.debug("Link closed received from incorrect outlet") 174 | return 175 | 176 | self._log.debug(f"link_closed {outlet}") 177 | self.terminate() 178 | 179 | def _initiator_identified(self, outlet, identity): 180 | if outlet != self.outlet: 181 | self._log.debug("Identity received from incorrect outlet") 182 | return 183 | 184 | self._log.info(f"initiator_identified {identity} on link {outlet}") 185 | if self.state not in [LSState.LSSTATE_WAIT_IDENT, LSState.LSSTATE_WAIT_VERS]: 186 | self._protocol_error(LSState.LSSTATE_WAIT_IDENT.name) 187 | 188 | if not self.allow_all and identity.hash not in self.allowed_identity_hashes and identity.hash not in self.allowed_file_identity_hashes: 189 | self.terminate("Identity is not allowed.") 190 | 191 | self.remote_identity = identity 192 | self._set_state(LSState.LSSTATE_WAIT_VERS) 193 | 194 | @classmethod 195 | async def pump_all(cls) -> True: 196 | processed_any = False 197 | for session in cls.sessions: 198 | processed = session.pump() 199 | processed_any = processed_any or processed 200 | await asyncio.sleep(0) 201 | 202 | 203 | @classmethod 204 | async def terminate_all(cls, reason: str): 205 | for session in cls.sessions: 206 | session.terminate(reason) 207 | await asyncio.sleep(0) 208 | 209 | def pump(self) -> bool: 210 | def compress_adaptive(buf: bytes): 211 | comp_tries = RNS.RawChannelWriter.COMPRESSION_TRIES 212 | comp_try = 1 213 | comp_success = False 214 | 215 | chunk_len = len(buf) 216 | if chunk_len > RNS.RawChannelWriter.MAX_CHUNK_LEN: 217 | chunk_len = RNS.RawChannelWriter.MAX_CHUNK_LEN 218 | chunk_segment = None 219 | 220 | chunk_segment = None 221 | max_data_len = self.channel.mdu - protocol.StreamDataMessage.OVERHEAD 222 | while chunk_len > 32 and comp_try < comp_tries: 223 | chunk_segment_length = int(chunk_len/comp_try) 224 | compressed_chunk = bz2.compress(buf[:chunk_segment_length]) 225 | compressed_length = len(compressed_chunk) 226 | if compressed_length < max_data_len and compressed_length < chunk_segment_length: 227 | comp_success = True 228 | break 229 | else: 230 | comp_try += 1 231 | 232 | if comp_success: 233 | diff = max_data_len - len(compressed_chunk) 234 | chunk = compressed_chunk 235 | processed_length = chunk_segment_length 236 | else: 237 | chunk = bytes(buf[:max_data_len]) 238 | processed_length = len(chunk) 239 | 240 | return comp_success, processed_length, chunk 241 | 242 | try: 243 | if self.state != LSState.LSSTATE_RUNNING: 244 | return False 245 | elif not self.channel.is_ready_to_send(): 246 | return False 247 | elif len(self.stderr_buf) > 0: 248 | comp_success, processed_length, data = compress_adaptive(self.stderr_buf) 249 | self.stderr_buf = self.stderr_buf[processed_length:] 250 | send_eof = self.process.stderr_eof and len(data) == 0 and not self.stderr_eof_sent 251 | self.stderr_eof_sent = self.stderr_eof_sent or send_eof 252 | msg = protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDERR, 253 | data, send_eof, comp_success) 254 | self.send(msg) 255 | if send_eof: 256 | self.stderr_eof_sent = True 257 | return True 258 | elif len(self.stdout_buf) > 0: 259 | comp_success, processed_length, data = compress_adaptive(self.stdout_buf) 260 | self.stdout_buf = self.stdout_buf[processed_length:] 261 | send_eof = self.process.stdout_eof and len(data) == 0 and not self.stdout_eof_sent 262 | self.stdout_eof_sent = self.stdout_eof_sent or send_eof 263 | msg = protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDOUT, 264 | data, send_eof, comp_success) 265 | self.send(msg) 266 | if send_eof: 267 | self.stdout_eof_sent = True 268 | return True 269 | elif self.return_code is not None and not self.return_code_sent: 270 | msg = protocol.CommandExitedMessage(self.return_code) 271 | self.send(msg) 272 | self.return_code_sent = True 273 | self._call(functools.partial(self._check_protocol_timeout, 274 | lambda: self.state == LSState.LSSTATE_RUNNING, "CommandExitedMessage"), 275 | max(self.outlet.rtt * 5, 10)) 276 | return False 277 | except Exception as ex: 278 | self._log.exception("Error during pump", ex) 279 | return False 280 | 281 | def _terminate_process(self): 282 | with contextlib.suppress(Exception): 283 | if self.process and self.process.running: 284 | self.process.terminate() 285 | 286 | def _start_cmd(self, cmdline: [str], pipe_stdin: bool, pipe_stdout: bool, pipe_stderr: bool, tcflags: [any], 287 | term: str | None, rows: int, cols: int, hpix: int, vpix: int): 288 | 289 | self.cmdline = self.default_command 290 | if not self.allow_remote_command and cmdline and len(cmdline) > 0: 291 | self.terminate("Remote command line not allowed by listener") 292 | return 293 | 294 | if self.remote_cmd_as_args and cmdline and len(cmdline) > 0: 295 | self.cmdline.extend(cmdline) 296 | elif cmdline and len(cmdline) > 0: 297 | self.cmdline = cmdline 298 | 299 | 300 | self.stdin_is_pipe = pipe_stdin 301 | self.stdout_is_pipe = pipe_stdout 302 | self.stderr_is_pipe = pipe_stderr 303 | self.tcflags = tcflags 304 | self.term = term 305 | 306 | def stdout(data: bytes): 307 | self.stdout_buf.extend(data) 308 | 309 | def stderr(data: bytes): 310 | self.stderr_buf.extend(data) 311 | 312 | try: 313 | self.process = process.CallbackSubprocess(argv=self.cmdline, 314 | env={"TERM": self.term or os.environ.get("TERM", None), 315 | "RNS_REMOTE_IDENTITY": (RNS.prettyhexrep(self.remote_identity.hash) 316 | if self.remote_identity and self.remote_identity.hash else "")}, 317 | loop=self.loop, 318 | stdout_callback=stdout, 319 | stderr_callback=stderr, 320 | terminated_callback=self._terminated, 321 | stdin_is_pipe=self.stdin_is_pipe, 322 | stdout_is_pipe=self.stdout_is_pipe, 323 | stderr_is_pipe=self.stderr_is_pipe) 324 | self.process.start() 325 | self._set_window_size(rows, cols, hpix, vpix) 326 | except Exception as ex: 327 | self._log.exception(f"Unable to start process for link {self.outlet}", ex) 328 | self.terminate("Unable to start process") 329 | 330 | def _set_window_size(self, rows: int, cols: int, hpix: int, vpix: int): 331 | self.rows = rows 332 | self.cols = cols 333 | self.hpix = hpix 334 | self.vpix = vpix 335 | with contextlib.suppress(Exception): 336 | self.process.set_winsize(rows, cols, hpix, vpix) 337 | 338 | def _received_stdin(self, data: bytes, eof: bool): 339 | if data and len(data) > 0: 340 | self.process.write(data) 341 | if eof: 342 | self.process.close_stdin() 343 | 344 | def _handle_message(self, message: RNS.MessageBase): 345 | if self.state == LSState.LSSTATE_WAIT_IDENT: 346 | self._protocol_error("Identification required") 347 | return 348 | if self.state == LSState.LSSTATE_WAIT_VERS: 349 | if not isinstance(message, protocol.VersionInfoMessage): 350 | self._protocol_error(self.state.name) 351 | return 352 | self._log.info(f"version {message.sw_version}, protocol {message.protocol_version} on link {self.outlet}") 353 | if message.protocol_version != protocol.PROTOCOL_VERSION: 354 | self.terminate("Incompatible protocol") 355 | return 356 | self.send(protocol.VersionInfoMessage()) 357 | self._set_state(LSState.LSSTATE_WAIT_CMD) 358 | return 359 | elif self.state == LSState.LSSTATE_WAIT_CMD: 360 | if not isinstance(message, protocol.ExecuteCommandMesssage): 361 | return self._protocol_error(self.state.name) 362 | self._log.info(f"Execute command message on link {self.outlet}: {message.cmdline}") 363 | self._set_state(LSState.LSSTATE_RUNNING) 364 | self._start_cmd(message.cmdline, message.pipe_stdin, message.pipe_stdout, message.pipe_stderr, 365 | message.tcflags, message.term, message.rows, message.cols, message.hpix, message.vpix) 366 | return 367 | elif self.state == LSState.LSSTATE_RUNNING: 368 | if isinstance(message, protocol.WindowSizeMessage): 369 | self._set_window_size(message.rows, message.cols, message.hpix, message.vpix) 370 | elif isinstance(message, protocol.StreamDataMessage): 371 | if message.stream_id != protocol.StreamDataMessage.STREAM_ID_STDIN: 372 | self._log.error(f"Received stream data for invalid stream {message.stream_id} on link {self.outlet}") 373 | return self._protocol_error(self.state.name) 374 | self._received_stdin(message.data, message.eof) 375 | return 376 | elif isinstance(message, protocol.NoopMessage): 377 | # echo noop only on listener--used for keepalive/connectivity check 378 | self.send(message) 379 | return 380 | elif self.state in [LSState.LSSTATE_ERROR, LSState.LSSTATE_TEARDOWN]: 381 | self._log.error(f"Received packet, but in state {self.state.name}") 382 | return 383 | else: 384 | self._protocol_error("unexpected message") 385 | return 386 | 387 | 388 | class RNSOutlet(LSOutletBase): 389 | 390 | def set_initiator_identified_callback(self, cb: Callable[[LSOutletBase, _TIdentity], None]): 391 | def inner_cb(link, identity: _TIdentity): 392 | cb(self, identity) 393 | 394 | self.link.set_remote_identified_callback(inner_cb) 395 | 396 | def set_link_closed_callback(self, cb: Callable[[LSOutletBase], None]): 397 | def inner_cb(link): 398 | cb(self) 399 | 400 | self.link.set_link_closed_callback(inner_cb) 401 | 402 | def unset_link_closed_callback(self): 403 | self.link.set_link_closed_callback(None) 404 | 405 | def teardown(self): 406 | self.link.teardown() 407 | 408 | @property 409 | def rtt(self) -> float: 410 | return self.link.rtt 411 | 412 | def __str__(self): 413 | return f"Outlet RNS Link {self.link}" 414 | 415 | def __init__(self, link: RNS.Link): 416 | self.link = link 417 | link.lsoutlet = self 418 | 419 | @staticmethod 420 | def get_outlet(link: RNS.Link): 421 | if hasattr(link, "lsoutlet"): 422 | return link.lsoutlet 423 | 424 | return RNSOutlet(link) -------------------------------------------------------------------------------- /rnsh/testlogging.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2023 Aaron Heise 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import logging as __logging 24 | import os 25 | 26 | log_format = '%(levelname)-6s %(name)-40s %(message)s [%(threadName)s]' \ 27 | if os.environ.get('UNDER_SYSTEMD') == "1" \ 28 | else '\r%(asctime)s.%(msecs)03d %(levelname)-6s %(name)-40s %(message)s [%(threadName)s]' 29 | 30 | __logging.basicConfig( 31 | level=__logging.INFO, 32 | # format='%(asctime)s.%(msecs)03d %(levelname)-6s %(threadName)-15s %(name)-15s %(message)s', 33 | format=log_format, 34 | datefmt='%Y-%m-%d %H:%M:%S', 35 | handlers=[__logging.StreamHandler()]) 36 | 37 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acehoss/rnsh/428aa39b59295bc5fcf7482e48156dbf4cab07f4/tests/__init__.py -------------------------------------------------------------------------------- /tests/helpers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | import types 4 | import typing 5 | import tempfile 6 | 7 | import pytest 8 | 9 | import rnsh.rnsh 10 | import asyncio 11 | import rnsh.process 12 | import contextlib 13 | import threading 14 | import os 15 | import pathlib 16 | import tests 17 | import shutil 18 | import random 19 | 20 | module_logger = logging.getLogger(__name__) 21 | 22 | module_abs_filename = os.path.abspath(tests.__file__) 23 | module_dir = os.path.dirname(module_abs_filename) 24 | 25 | 26 | class SubprocessReader(contextlib.AbstractContextManager): 27 | def __init__(self, argv: [str], env: dict = None, name: str = None, stdin_is_pipe: bool = False, 28 | stdout_is_pipe: bool = False, stderr_is_pipe: bool = False): 29 | self._log = module_logger.getChild(self.__class__.__name__ + ("" if name is None else f"({name})")) 30 | self.name = name or "subproc" 31 | self.process: rnsh.process.CallbackSubprocess 32 | self.loop = asyncio.get_running_loop() 33 | self.env = env or os.environ.copy() 34 | self.argv = argv 35 | self._lock = threading.RLock() 36 | self._stdout = bytearray() 37 | self._stderr = bytearray() 38 | self.return_code: int = None 39 | self.process = rnsh.process.CallbackSubprocess(argv=self.argv, 40 | env=self.env, 41 | loop=self.loop, 42 | stdout_callback=self._stdout_cb, 43 | terminated_callback=self._terminated_cb, 44 | stderr_callback=self._stderr_cb, 45 | stdin_is_pipe=stdin_is_pipe, 46 | stdout_is_pipe=stdout_is_pipe, 47 | stderr_is_pipe=stderr_is_pipe) 48 | 49 | def _stdout_cb(self, data): 50 | self._log.debug(f"_stdout_cb({data})") 51 | with self._lock: 52 | self._stdout.extend(data) 53 | 54 | def read(self): 55 | self._log.debug(f"read()") 56 | with self._lock: 57 | data = self._stdout.copy() 58 | self._stdout.clear() 59 | self._log.debug(f"read() returns {data}") 60 | return data 61 | 62 | def _stderr_cb(self, data): 63 | self._log.debug(f"_stderr_cb({data})") 64 | with self._lock: 65 | self._stderr.extend(data) 66 | 67 | def read_err(self): 68 | self._log.debug(f"read_err()") 69 | with self._lock: 70 | data = self._stderr.copy() 71 | self._stderr.clear() 72 | self._log.debug(f"read_err() returns {data}") 73 | return data 74 | 75 | def _terminated_cb(self, rc): 76 | self._log.debug(f"_terminated_cb({rc})") 77 | self.return_code = rc 78 | 79 | def start(self): 80 | self._log.debug(f"start()") 81 | self.process.start() 82 | 83 | def cleanup(self): 84 | self._log.debug(f"cleanup()") 85 | if self.process and self.process.running: 86 | self.process.terminate(kill_delay=0.1) 87 | time.sleep(0.5) 88 | 89 | def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException, 90 | __traceback: types.TracebackType) -> bool: 91 | self._log.debug(f"__exit__({__exc_type}, {__exc_value}, {__traceback})") 92 | self.cleanup() 93 | return False 94 | 95 | 96 | def replace_text_in_file(filename: str, text: str, replacement: str): 97 | # Read in the file 98 | with open(filename, 'r') as file: 99 | filedata = file.read() 100 | 101 | # Replace the target string 102 | filedata = filedata.replace(text, replacement) 103 | 104 | # Write the file out again 105 | with open(filename, 'w') as file: 106 | file.write(filedata) 107 | 108 | 109 | class tempdir(object): 110 | """Sets the cwd within the context 111 | 112 | Args: 113 | path (Path): The path to the cwd 114 | """ 115 | def __init__(self, cd: bool = False): 116 | self.cd = cd 117 | self.tempdir = tempfile.TemporaryDirectory() 118 | self.path = self.tempdir.name 119 | self.origin = pathlib.Path().absolute() 120 | self.configfile = os.path.join(self.path, "config") 121 | 122 | def setup_files(self): 123 | shutil.copy(os.path.join(module_dir, "reticulum_test_config"), self.configfile) 124 | port1 = random.randint(30000, 65000) 125 | port2 = port1 + 1 126 | replace_text_in_file(self.configfile, "22222", str(port1)) 127 | replace_text_in_file(self.configfile, "22223", str(port2)) 128 | 129 | 130 | def __enter__(self): 131 | self.setup_files() 132 | if self.cd: 133 | os.chdir(self.path) 134 | 135 | return self.path 136 | 137 | def __exit__(self, exc, value, tb): 138 | if self.cd: 139 | os.chdir(self.origin) 140 | self.tempdir.__exit__(exc, value, tb) 141 | 142 | 143 | def test_config_and_cleanup(): 144 | td = None 145 | with tests.helpers.tempdir() as td: 146 | assert os.path.isfile(os.path.join(td, "config")) 147 | with open(os.path.join(td, "config"), 'r') as file: 148 | filedata = file.read() 149 | assert filedata.index("acehoss test config") > 0 150 | with pytest.raises(ValueError): 151 | filedata.index("22222") 152 | assert not os.path.exists(os.path.join(td, "config")) 153 | 154 | 155 | def wait_for_condition(condition: callable, timeout: float): 156 | tm = time.time() + timeout 157 | while tm > time.time() and not condition(): 158 | time.sleep(0.01) 159 | 160 | 161 | async def wait_for_condition_async(condition: callable, timeout: float): 162 | tm = time.time() + timeout 163 | while tm > time.time() and not condition(): 164 | await asyncio.sleep(0.01) -------------------------------------------------------------------------------- /tests/reticulum_test_config: -------------------------------------------------------------------------------- 1 | # acehoss test config 2 | [reticulum] 3 | enable_transport = False 4 | share_instance = Yes 5 | shared_instance_port = 22222 6 | instance_control_port = 22223 7 | panic_on_interface_error = No 8 | 9 | [logging] 10 | loglevel = 7 11 | 12 | [interfaces] 13 | [[Default Interface]] 14 | type = AutoInterface 15 | enabled = Yes 16 | -------------------------------------------------------------------------------- /tests/test_args.py: -------------------------------------------------------------------------------- 1 | import rnsh.args 2 | import shlex 3 | from rnsh import docopt 4 | 5 | def test_program_args(): 6 | docopt_threw = False 7 | try: 8 | args = rnsh.args.Args(shlex.split("rnsh -l -n one two three")) 9 | assert args.listen 10 | assert args.program == "one" 11 | assert args.program_args == ["two", "three"] 12 | assert args.command_line == ["one", "two", "three"] 13 | except docopt.DocoptExit: 14 | docopt_threw = True 15 | assert not docopt_threw 16 | 17 | 18 | def test_program_args_dash(): 19 | docopt_threw = False 20 | try: 21 | args = rnsh.args.Args(shlex.split("rnsh -l -n -- one -l -C")) 22 | assert args.listen 23 | assert args.program == "one" 24 | assert args.program_args == ["-l", "-C"] 25 | assert args.command_line == ["one", "-l", "-C"] 26 | except docopt.DocoptExit: 27 | docopt_threw = True 28 | assert not docopt_threw 29 | 30 | def test_program_initiate_no_args(): 31 | docopt_threw = False 32 | try: 33 | args = rnsh.args.Args(shlex.split("rnsh one")) 34 | assert not args.listen 35 | assert args.destination == "one" 36 | assert not args.no_id 37 | assert args.command_line == [] 38 | except docopt.DocoptExit: 39 | docopt_threw = True 40 | assert not docopt_threw 41 | 42 | 43 | def test_program_initiate_no_auth(): 44 | docopt_threw = False 45 | try: 46 | args = rnsh.args.Args(shlex.split("rnsh -N one")) 47 | assert not args.listen 48 | assert args.destination == "one" 49 | assert args.no_id 50 | assert args.command_line == [] 51 | except docopt.DocoptExit: 52 | docopt_threw = True 53 | assert not docopt_threw 54 | 55 | 56 | def test_program_initiate_dash_args(): 57 | docopt_threw = False 58 | try: 59 | args = rnsh.args.Args(shlex.split("rnsh --config ~/Projects/rnsh/testconfig -vvvvvvv a5f72aefc2cb3cdba648f73f77c4e887 -- -l")) 60 | assert not args.listen 61 | assert args.config == "~/Projects/rnsh/testconfig" 62 | assert args.verbose == 7 63 | assert args.destination == "a5f72aefc2cb3cdba648f73f77c4e887" 64 | assert args.command_line == ["-l"] 65 | except docopt.DocoptExit: 66 | docopt_threw = True 67 | assert not docopt_threw 68 | 69 | 70 | def test_program_listen_dash_args(): 71 | docopt_threw = False 72 | try: 73 | args = rnsh.args.Args(shlex.split("rnsh -l --config ~/Projects/rnsh/testconfig -n -C -- /bin/pwd")) 74 | assert args.listen 75 | assert args.config == "~/Projects/rnsh/testconfig" 76 | assert args.destination is None 77 | assert args.no_auth 78 | assert args.no_remote_cmd 79 | assert args.command_line == ["/bin/pwd"] 80 | except docopt.DocoptExit: 81 | docopt_threw = True 82 | assert not docopt_threw 83 | 84 | 85 | def test_program_listen_config_print(): 86 | docopt_threw = False 87 | try: 88 | args = rnsh.args.Args(shlex.split("rnsh -l --config testconfig -p")) 89 | assert args.listen 90 | assert args.config == "testconfig" 91 | assert args.print_identity 92 | assert args.command_line == [] 93 | except docopt.DocoptExit: 94 | docopt_threw = True 95 | assert not docopt_threw 96 | 97 | 98 | def test_split_at(): 99 | a, b = rnsh.args._split_array_at(["one", "two", "three"], "two") 100 | assert a == ["one"] 101 | assert b == ["three"] 102 | 103 | def test_split_at_not_found(): 104 | a, b = rnsh.args._split_array_at(["one", "two", "three"], "four") 105 | assert a == ["one", "two", "three"] 106 | assert b == [] -------------------------------------------------------------------------------- /tests/test_exception.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import rnsh.exception as exception 3 | 4 | def test_permit(): 5 | with pytest.raises(SystemExit): 6 | with exception.permit(SystemExit): 7 | raise Exception("Should not bubble") 8 | with exception.permit(SystemExit): 9 | raise SystemExit() -------------------------------------------------------------------------------- /tests/test_process.py: -------------------------------------------------------------------------------- 1 | import tests.helpers 2 | import time 3 | import pytest 4 | import rnsh.process 5 | import asyncio 6 | import logging 7 | import multiprocessing.pool 8 | logging.getLogger().setLevel(logging.DEBUG) 9 | 10 | 11 | @pytest.mark.skip_ci 12 | @pytest.mark.asyncio 13 | async def test_echo(): 14 | """ 15 | Echoing some text through cat. 16 | """ 17 | with tests.helpers.SubprocessReader(argv=["/bin/cat"]) as state: 18 | state.start() 19 | assert state.process is not None 20 | assert state.process.running 21 | message = "test\n" 22 | state.process.write(message.encode("utf-8")) 23 | await asyncio.sleep(0.1) 24 | data = state.read() 25 | state.process.write(rnsh.process.CTRL_D) 26 | await asyncio.sleep(0.1) 27 | assert len(data) > 0 28 | decoded = data.decode("utf-8") 29 | assert decoded == message.replace("\n", "\r\n") * 2 30 | assert not state.process.running 31 | 32 | 33 | @pytest.mark.skip_ci 34 | @pytest.mark.asyncio 35 | async def test_echo_live(): 36 | """ 37 | Check for immediate echo 38 | """ 39 | with tests.helpers.SubprocessReader(argv=["/bin/cat"]) as state: 40 | state.start() 41 | assert state.process is not None 42 | assert state.process.running 43 | message = "t" 44 | state.process.write(message.encode("utf-8")) 45 | await asyncio.sleep(0.1) 46 | data = state.read() 47 | state.process.write(rnsh.process.CTRL_C) 48 | await asyncio.sleep(0.1) 49 | assert len(data) > 0 50 | decoded = data.decode("utf-8") 51 | assert decoded == message 52 | assert not state.process.running 53 | 54 | 55 | @pytest.mark.skip_ci 56 | @pytest.mark.asyncio 57 | async def test_echo_live_pipe_in(): 58 | """ 59 | Check for immediate echo 60 | """ 61 | with tests.helpers.SubprocessReader(argv=["/bin/cat"], stdin_is_pipe=True) as state: 62 | state.start() 63 | assert state.process is not None 64 | assert state.process.running 65 | message = "t" 66 | state.process.write(message.encode("utf-8")) 67 | await asyncio.sleep(0.1) 68 | data = state.read() 69 | state.process.close_stdin() 70 | await asyncio.sleep(0.1) 71 | assert len(data) > 0 72 | decoded = data.decode("utf-8") 73 | assert decoded == message 74 | assert not state.process.running 75 | 76 | 77 | @pytest.mark.skip_ci 78 | @pytest.mark.asyncio 79 | async def test_echo_live_pipe_out(): 80 | """ 81 | Check for immediate echo 82 | """ 83 | with tests.helpers.SubprocessReader(argv=["/bin/cat"], stdout_is_pipe=True) as state: 84 | state.start() 85 | assert state.process is not None 86 | assert state.process.running 87 | message = "t" 88 | state.process.write(message.encode("utf-8")) 89 | state.process.write(rnsh.process.CTRL_D) 90 | await asyncio.sleep(0.1) 91 | data = state.read() 92 | assert len(data) > 0 93 | decoded = data.decode("utf-8") 94 | assert decoded == message 95 | data = state.read_err() 96 | assert len(data) > 0 97 | state.process.close_stdin() 98 | await asyncio.sleep(0.1) 99 | assert not state.process.running 100 | 101 | 102 | @pytest.mark.skip_ci 103 | @pytest.mark.asyncio 104 | async def test_echo_live_pipe_err(): 105 | """ 106 | Check for immediate echo 107 | """ 108 | with tests.helpers.SubprocessReader(argv=["/bin/cat"], stderr_is_pipe=True) as state: 109 | state.start() 110 | assert state.process is not None 111 | assert state.process.running 112 | message = "t" 113 | state.process.write(message.encode("utf-8")) 114 | await asyncio.sleep(0.1) 115 | data = state.read() 116 | state.process.write(rnsh.process.CTRL_C) 117 | await asyncio.sleep(0.1) 118 | assert len(data) > 0 119 | decoded = data.decode("utf-8") 120 | assert decoded == message 121 | assert not state.process.running 122 | 123 | 124 | @pytest.mark.skip_ci 125 | @pytest.mark.asyncio 126 | async def test_echo_live_pipe_out_err(): 127 | """ 128 | Check for immediate echo 129 | """ 130 | with tests.helpers.SubprocessReader(argv=["/bin/cat"], stdout_is_pipe=True, stderr_is_pipe=True) as state: 131 | state.start() 132 | assert state.process is not None 133 | assert state.process.running 134 | message = "t" 135 | state.process.write(message.encode("utf-8")) 136 | state.process.write(rnsh.process.CTRL_D) 137 | await asyncio.sleep(0.1) 138 | data = state.read() 139 | assert len(data) > 0 140 | decoded = data.decode("utf-8") 141 | assert decoded == message 142 | data = state.read_err() 143 | assert len(data) == 0 144 | state.process.close_stdin() 145 | await asyncio.sleep(0.1) 146 | assert not state.process.running 147 | 148 | 149 | 150 | @pytest.mark.skip_ci 151 | @pytest.mark.asyncio 152 | async def test_echo_live_pipe_all(): 153 | """ 154 | Check for immediate echo 155 | """ 156 | with tests.helpers.SubprocessReader(argv=["/bin/cat"], stdout_is_pipe=True, stderr_is_pipe=True, 157 | stdin_is_pipe=True) as state: 158 | state.start() 159 | assert state.process is not None 160 | assert state.process.running 161 | message = "t" 162 | state.process.write(message.encode("utf-8")) 163 | await asyncio.sleep(0.1) 164 | data = state.read() 165 | state.process.close_stdin() 166 | await asyncio.sleep(0.1) 167 | assert len(data) > 0 168 | decoded = data.decode("utf-8") 169 | assert decoded == message 170 | assert not state.process.running 171 | 172 | 173 | @pytest.mark.skip_ci 174 | @pytest.mark.asyncio 175 | async def test_double_echo_live(): 176 | """ 177 | Check for immediate echo 178 | """ 179 | with tests.helpers.SubprocessReader(name="state", argv=["/bin/cat"]) as state: 180 | with tests.helpers.SubprocessReader(name="state2", argv=["/bin/cat"]) as state2: 181 | state.start() 182 | state2.start() 183 | assert state.process is not None 184 | assert state.process.running 185 | assert state2.process is not None 186 | assert state2.process.running 187 | message = "t" 188 | state.process.write(message.encode("utf-8")) 189 | state2.process.write(message.encode("utf-8")) 190 | await asyncio.sleep(0.1) 191 | data = state.read() 192 | data2 = state2.read() 193 | state.process.write(rnsh.process.CTRL_C) 194 | state2.process.write(rnsh.process.CTRL_C) 195 | await asyncio.sleep(0.1) 196 | assert len(data) > 0 197 | assert len(data2) > 0 198 | decoded = data.decode("utf-8") 199 | decoded2 = data.decode("utf-8") 200 | assert decoded == message 201 | assert decoded2 == message 202 | assert not state.process.running 203 | assert not state2.process.running 204 | 205 | 206 | @pytest.mark.asyncio 207 | async def test_event_wait_any(): 208 | delay = 0.5 209 | with multiprocessing.pool.ThreadPool() as pool: 210 | loop = asyncio.get_running_loop() 211 | evt1 = asyncio.Event() 212 | evt2 = asyncio.Event() 213 | 214 | def assert_between(min, max, val): 215 | assert min <= val <= max 216 | 217 | # test 1: both timeout 218 | ts = time.time() 219 | finished = await rnsh.process.event_wait_any([evt1, evt2], timeout=delay*2) 220 | assert_between(delay*2, delay*2.1, time.time() - ts) 221 | assert finished is None 222 | assert not evt1.is_set() 223 | assert not evt2.is_set() 224 | 225 | #test 2: evt1 set, evt2 not set 226 | hits = 0 227 | 228 | def test2_bg(): 229 | nonlocal hits 230 | hits += 1 231 | time.sleep(delay) 232 | evt1.set() 233 | 234 | ts = time.time() 235 | pool.apply_async(test2_bg) 236 | finished = await rnsh.process.event_wait_any([evt1, evt2], timeout=delay * 2) 237 | assert_between(delay * 0.5, delay * 1.5, time.time() - ts) 238 | assert hits == 1 239 | assert evt1.is_set() 240 | assert not evt2.is_set() 241 | assert finished == evt1 242 | -------------------------------------------------------------------------------- /tests/test_protocol.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | 5 | from RNS.Channel import TPacket, MessageState, ChannelOutletBase, Channel 6 | from typing import Callable 7 | 8 | logging.getLogger().setLevel(logging.DEBUG) 9 | 10 | import rnsh.protocol 11 | import contextlib 12 | import typing 13 | import types 14 | import time 15 | import uuid 16 | from RNS.Channel import MessageBase 17 | 18 | 19 | module_logger = logging.getLogger(__name__) 20 | 21 | 22 | def test_send_receive_streamdata(): 23 | message = rnsh.protocol.StreamDataMessage(stream_id=rnsh.protocol.StreamDataMessage.STREAM_ID_STDIN, 24 | data=b'Test', eof=True) 25 | rx_message = message.__class__() 26 | rx_message.unpack(message.pack()) 27 | 28 | assert isinstance(rx_message, message.__class__) 29 | assert rx_message.stream_id == message.stream_id 30 | assert rx_message.data == message.data 31 | assert rx_message.eof == message.eof 32 | 33 | 34 | def test_send_receive_noop(): 35 | message = rnsh.protocol.NoopMessage() 36 | 37 | rx_message = message.__class__() 38 | rx_message.unpack(message.pack()) 39 | 40 | assert isinstance(rx_message, message.__class__) 41 | 42 | 43 | def test_send_receive_execute(): 44 | message = rnsh.protocol.ExecuteCommandMesssage(cmdline=["test", "one", "two"], 45 | pipe_stdin=False, 46 | pipe_stdout=True, 47 | pipe_stderr=False, 48 | tcflags=[12, 34, 56, [78, 90]], 49 | term="xtermmmm") 50 | rx_message = message.__class__() 51 | rx_message.unpack(message.pack()) 52 | 53 | assert isinstance(rx_message, message.__class__) 54 | assert rx_message.cmdline == message.cmdline 55 | assert rx_message.pipe_stdin == message.pipe_stdin 56 | assert rx_message.pipe_stdout == message.pipe_stdout 57 | assert rx_message.pipe_stderr == message.pipe_stderr 58 | assert rx_message.tcflags == message.tcflags 59 | assert rx_message.term == message.term 60 | 61 | 62 | def test_send_receive_windowsize(): 63 | message = rnsh.protocol.WindowSizeMessage(1, 2, 3, 4) 64 | rx_message = message.__class__() 65 | rx_message.unpack(message.pack()) 66 | 67 | assert isinstance(rx_message, message.__class__) 68 | assert rx_message.rows == message.rows 69 | assert rx_message.cols == message.cols 70 | assert rx_message.hpix == message.hpix 71 | assert rx_message.vpix == message.vpix 72 | 73 | 74 | def test_send_receive_versioninfo(): 75 | message = rnsh.protocol.VersionInfoMessage(sw_version="1.2.3") 76 | message.protocol_version = 30 77 | rx_message = message.__class__() 78 | rx_message.unpack(message.pack()) 79 | 80 | assert isinstance(rx_message, message.__class__) 81 | assert rx_message.sw_version == message.sw_version 82 | assert rx_message.protocol_version == message.protocol_version 83 | 84 | 85 | def test_send_receive_error(): 86 | message = rnsh.protocol.ErrorMessage(msg="TESTerr", 87 | fatal=True, 88 | data={"one": 2}) 89 | rx_message = message.__class__() 90 | rx_message.unpack(message.pack()) 91 | 92 | assert isinstance(rx_message, message.__class__) 93 | assert rx_message.msg == message.msg 94 | assert rx_message.fatal == message.fatal 95 | assert rx_message.data == message.data 96 | 97 | 98 | def test_send_receive_cmdexit(): 99 | message = rnsh.protocol.CommandExitedMessage(5) 100 | rx_message = message.__class__() 101 | rx_message.unpack(message.pack()) 102 | 103 | assert isinstance(rx_message, message.__class__) 104 | assert rx_message.return_code == message.return_code 105 | 106 | 107 | 108 | 109 | 110 | -------------------------------------------------------------------------------- /tests/test_retry.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | import time 3 | from types import TracebackType 4 | from typing import Type 5 | 6 | import rnsh.retry 7 | from contextlib import AbstractContextManager 8 | import logging 9 | logging.getLogger().setLevel(logging.DEBUG) 10 | 11 | 12 | class State(AbstractContextManager): 13 | def __init__(self, delay: float): 14 | self.delay = delay 15 | self.retry_thread = rnsh.retry.RetryThread(self.delay / 10.0) 16 | self.tries = 0 17 | self.callbacks = 0 18 | self.timed_out = False 19 | self.tag = str(uuid.uuid4()) 20 | self.results = [self.tag, self.tag, self.tag] 21 | self.got_tag = None 22 | assert self.retry_thread.is_alive() 23 | 24 | def cleanup(self): 25 | self.retry_thread.wait() 26 | assert self.tries != 0 27 | self.retry_thread.close() 28 | assert not self.retry_thread.is_alive() 29 | 30 | def retry(self, tag, tries): 31 | self.tries = tries 32 | self.got_tag = tag 33 | self.callbacks += 1 34 | return self.results[tries - 1] 35 | 36 | def timeout(self, tag, tries): 37 | self.tries = tries 38 | self.got_tag = tag 39 | self.timed_out = True 40 | self.callbacks += 1 41 | 42 | def __exit__(self, __exc_type: Type[BaseException], __exc_value: BaseException, 43 | __traceback: TracebackType) -> bool: 44 | self.cleanup() 45 | return False 46 | 47 | 48 | def test_retry_timeout(): 49 | 50 | with State(0.1) as state: 51 | return_tag = state.retry_thread.begin(try_limit=3, 52 | wait_delay=state.delay, 53 | try_callback=state.retry, 54 | timeout_callback=state.timeout) 55 | assert return_tag == state.tag 56 | assert state.tries == 1 57 | assert state.callbacks == 1 58 | assert state.got_tag is None 59 | assert not state.timed_out 60 | time.sleep(state.delay / 2.0) 61 | time.sleep(state.delay) 62 | assert state.tries == 2 63 | assert state.callbacks == 2 64 | assert state.got_tag == state.tag 65 | assert not state.timed_out 66 | time.sleep(state.delay) 67 | assert state.tries == 3 68 | assert state.callbacks == 3 69 | assert state.got_tag == state.tag 70 | assert not state.timed_out 71 | 72 | # check timeout 73 | time.sleep(state.delay) 74 | assert state.tries == 3 75 | assert state.callbacks == 4 76 | assert state.got_tag == state.tag 77 | assert state.timed_out 78 | 79 | # check no more callbacks 80 | time.sleep(state.delay * 3.0) 81 | assert state.callbacks == 4 82 | assert state.tries == 3 83 | 84 | 85 | def test_retry_immediate_complete(): 86 | with State(0.01) as state: 87 | state.results[0] = False 88 | return_tag = state.retry_thread.begin(try_limit=3, 89 | wait_delay=state.delay, 90 | try_callback=state.retry, 91 | timeout_callback=state.timeout) 92 | assert not return_tag 93 | assert state.callbacks == 1 94 | assert not state.got_tag 95 | assert not state.timed_out 96 | time.sleep(state.delay * 3) 97 | assert state.tries == 1 98 | assert state.callbacks == 1 99 | assert not state.got_tag 100 | assert not state.timed_out 101 | 102 | 103 | def test_retry_return_complete(): 104 | with State(0.01) as state: 105 | state.results[1] = False 106 | return_tag = state.retry_thread.begin(try_limit=3, 107 | wait_delay=state.delay, 108 | try_callback=state.retry, 109 | timeout_callback=state.timeout) 110 | assert return_tag == state.tag 111 | assert state.callbacks == 1 112 | assert state.got_tag is None 113 | assert not state.timed_out 114 | time.sleep(state.delay / 2.0) 115 | time.sleep(state.delay) 116 | assert state.tries == 2 117 | assert state.callbacks == 2 118 | assert state.got_tag == state.tag 119 | assert not state.timed_out 120 | 121 | time.sleep(state.delay) 122 | assert state.tries == 2 123 | assert state.callbacks == 2 124 | assert state.got_tag == state.tag 125 | assert not state.timed_out 126 | 127 | # check no more callbacks 128 | time.sleep(state.delay * 3.0) 129 | assert state.callbacks == 2 130 | assert state.tries == 2 131 | 132 | 133 | def test_retry_set_complete(): 134 | with State(0.01) as state: 135 | return_tag = state.retry_thread.begin(try_limit=3, 136 | wait_delay=state.delay, 137 | try_callback=state.retry, 138 | timeout_callback=state.timeout) 139 | assert return_tag == state.tag 140 | assert state.callbacks == 1 141 | assert state.got_tag is None 142 | assert not state.timed_out 143 | time.sleep(state.delay / 2.0) 144 | time.sleep(state.delay) 145 | assert state.tries == 2 146 | assert state.callbacks == 2 147 | assert state.got_tag == state.tag 148 | assert not state.timed_out 149 | 150 | state.retry_thread.complete(state.tag) 151 | 152 | time.sleep(state.delay) 153 | assert state.tries == 2 154 | assert state.callbacks == 2 155 | assert state.got_tag == state.tag 156 | assert not state.timed_out 157 | 158 | # check no more callbacks 159 | time.sleep(state.delay * 3.0) 160 | assert state.callbacks == 2 161 | assert state.tries == 2 162 | 163 | -------------------------------------------------------------------------------- /tests/test_rnsh.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logging.getLogger().setLevel(logging.DEBUG) 3 | 4 | import tests.helpers 5 | import rnsh.rnsh 6 | import rnsh.process 7 | import shlex 8 | import pytest 9 | import time 10 | import asyncio 11 | import re 12 | import os 13 | 14 | 15 | def test_version(): 16 | assert rnsh.__version__ != "0.0.0" 17 | assert rnsh.__version__ != "0.0.1" 18 | 19 | 20 | @pytest.mark.skip_ci 21 | @pytest.mark.asyncio 22 | async def test_wrapper(): 23 | with tests.helpers.tempdir() as td: 24 | with tests.helpers.SubprocessReader(argv=shlex.split(f"date")) as wrapper: 25 | wrapper.start() 26 | assert wrapper.process is not None 27 | assert wrapper.process.running 28 | await asyncio.sleep(1) 29 | text = wrapper.read().decode("utf-8") 30 | assert len(text) > 5 31 | assert not wrapper.process.running 32 | 33 | 34 | 35 | @pytest.mark.skip_ci 36 | @pytest.mark.asyncio 37 | async def test_rnsh_listen_start_stop(): 38 | with tests.helpers.tempdir() as td: 39 | with tests.helpers.SubprocessReader(argv=shlex.split(f"poetry run rnsh -l --config \"{td}\" -n -C -vvvvvv -- /bin/ls")) as wrapper: 40 | wrapper.start() 41 | await asyncio.sleep(0.1) 42 | assert wrapper.process.running 43 | # wait for process to start up 44 | await asyncio.sleep(3) 45 | # read the output 46 | text = wrapper.read().decode("utf-8") 47 | # listener should have printed "listening 48 | assert text.index("listening") is not None 49 | # stop process with SIGINT 50 | wrapper.process.write(rnsh.process.CTRL_C) 51 | # wait for process to wind down 52 | start_time = time.time() 53 | while wrapper.process.running and time.time() - start_time < 5: 54 | await asyncio.sleep(0.1) 55 | assert not wrapper.process.running 56 | 57 | 58 | async def get_listener_id_and_dest(td: str) -> tuple[str, str]: 59 | with tests.helpers.SubprocessReader(name="getid", argv=shlex.split(f"poetry run -- rnsh -l -c \"{td}\" -p")) as wrapper: 60 | wrapper.start() 61 | await asyncio.sleep(0.1) 62 | assert wrapper.process.running 63 | # wait for process to start up 64 | await tests.helpers.wait_for_condition_async(lambda: not wrapper.process.running, 5) 65 | assert not wrapper.process.running 66 | await asyncio.sleep(2) 67 | # read the output 68 | text = wrapper.read().decode("utf-8").replace("\r", "").replace("\n", "") 69 | assert text.index("Using service name \"default\"") is not None 70 | assert text.index("Identity") is not None 71 | match = re.search(r"<([a-f0-9]{32})>[^<]+<([a-f0-9]{32})>", text) 72 | assert match is not None 73 | ih = match.group(1) 74 | assert len(ih) == 32 75 | dh = match.group(2) 76 | assert len(dh) == 32 77 | await asyncio.sleep(0.1) 78 | return ih, dh 79 | 80 | 81 | async def get_initiator_id(td: str) -> str: 82 | with tests.helpers.SubprocessReader(name="getid", argv=shlex.split(f"poetry run -- rnsh -c \"{td}\" -p")) as wrapper: 83 | wrapper.start() 84 | await asyncio.sleep(0.1) 85 | assert wrapper.process.running 86 | # wait for process to start up 87 | await tests.helpers.wait_for_condition_async(lambda: not wrapper.process.running, 5) 88 | assert not wrapper.process.running 89 | # read the output 90 | text = wrapper.read().decode("utf-8").replace("\r", "").replace("\n", "") 91 | assert text.index("Identity") is not None 92 | match = re.search(r"<([a-f0-9]{32})>", text) 93 | assert match is not None 94 | ih = match.group(1) 95 | assert len(ih) == 32 96 | await asyncio.sleep(0.1) 97 | return ih 98 | 99 | 100 | 101 | @pytest.mark.skip_ci 102 | @pytest.mark.asyncio 103 | async def test_rnsh_get_listener_id_and_dest() -> [int]: 104 | with tests.helpers.tempdir() as td: 105 | ih, dh = await get_listener_id_and_dest(td) 106 | assert len(ih) == 32 107 | assert len(dh) == 32 108 | 109 | 110 | @pytest.mark.skip_ci 111 | @pytest.mark.asyncio 112 | async def test_rnsh_get_initiator_id() -> [int]: 113 | with tests.helpers.tempdir() as td: 114 | ih = await get_initiator_id(td) 115 | assert len(ih) == 32 116 | 117 | 118 | async def do_connected_test(listener_args: str, initiator_args: str, test: callable): 119 | with tests.helpers.tempdir() as td: 120 | ih, dh = await get_listener_id_and_dest(td) 121 | iih = await get_initiator_id(td) 122 | assert len(ih) == 32 123 | assert len(dh) == 32 124 | assert len(iih) == 32 125 | assert "dh" in initiator_args 126 | initiator_args = initiator_args.replace("dh", dh) 127 | listener_args = listener_args.replace("iih", iih) 128 | with tests.helpers.SubprocessReader(name="listener", argv=shlex.split(f"poetry run -- rnsh -l -c \"{td}\" {listener_args}")) as listener, \ 129 | tests.helpers.SubprocessReader(name="initiator", argv=shlex.split(f"poetry run -- rnsh -q -c \"{td}\" {initiator_args}")) as initiator: 130 | # listener startup 131 | listener.start() 132 | await asyncio.sleep(0.1) 133 | assert listener.process.running 134 | # wait for process to start up 135 | await asyncio.sleep(5) 136 | # read the output 137 | text = listener.read().decode("utf-8") 138 | assert text.index(dh) is not None 139 | 140 | # initiator run 141 | initiator.start() 142 | assert initiator.process.running 143 | 144 | await test(td, ih, dh, iih, listener, initiator) 145 | 146 | # expect test to shut down initiator 147 | assert not initiator.process.running 148 | 149 | # stop process with SIGINT 150 | listener.process.write(rnsh.process.CTRL_C) 151 | # wait for process to wind down 152 | start_time = time.time() 153 | while listener.process.running and time.time() - start_time < 5: 154 | await asyncio.sleep(0.1) 155 | assert not listener.process.running 156 | 157 | 158 | @pytest.mark.skip_ci 159 | @pytest.mark.asyncio 160 | async def test_rnsh_get_echo_through(): 161 | cwd = os.getcwd() 162 | 163 | async def test(td: str, ih: str, dh: str, iih: str, listener: tests.helpers.SubprocessReader, 164 | initiator: tests.helpers.SubprocessReader): 165 | start_time = time.time() 166 | while initiator.return_code is None and time.time() - start_time < 3: 167 | await asyncio.sleep(0.1) 168 | text = initiator.read().decode("utf-8").replace("\r", "").replace("\n", "") 169 | assert text == cwd 170 | 171 | await do_connected_test("-n -C -- /bin/pwd", "dh", test) 172 | 173 | 174 | @pytest.mark.skip_ci 175 | @pytest.mark.asyncio 176 | async def test_rnsh_no_ident(): 177 | cwd = os.getcwd() 178 | 179 | async def test(td: str, ih: str, dh: str, iih: str, listener: tests.helpers.SubprocessReader, 180 | initiator: tests.helpers.SubprocessReader): 181 | start_time = time.time() 182 | while initiator.return_code is None and time.time() - start_time < 3: 183 | await asyncio.sleep(0.1) 184 | text = initiator.read().decode("utf-8").replace("\r", "").replace("\n", "") 185 | assert text == cwd 186 | 187 | await do_connected_test("-n -C -- /bin/pwd", "-N dh", test) 188 | 189 | 190 | @pytest.mark.skip_ci 191 | @pytest.mark.asyncio 192 | async def test_rnsh_invalid_ident(): 193 | cwd = os.getcwd() 194 | 195 | async def test(td: str, ih: str, dh: str, iih: str, listener: tests.helpers.SubprocessReader, 196 | initiator: tests.helpers.SubprocessReader): 197 | start_time = time.time() 198 | while initiator.return_code is None and time.time() - start_time < 3: 199 | await asyncio.sleep(0.1) 200 | text = initiator.read().decode("utf-8").replace("\r", "").replace("\n", "") 201 | assert "not allowed" in text 202 | 203 | await do_connected_test("-a 12345678901234567890123456789012 -C -- /bin/pwd", "dh", test) 204 | 205 | 206 | @pytest.mark.skip_ci 207 | @pytest.mark.asyncio 208 | async def test_rnsh_valid_ident(): 209 | cwd = os.getcwd() 210 | 211 | async def test(td: str, ih: str, dh: str, iih: str, listener: tests.helpers.SubprocessReader, 212 | initiator: tests.helpers.SubprocessReader): 213 | start_time = time.time() 214 | while initiator.return_code is None and time.time() - start_time < 3: 215 | await asyncio.sleep(0.1) 216 | text = initiator.read().decode("utf-8").replace("\r", "").replace("\n", "") 217 | assert (text == cwd) 218 | 219 | await do_connected_test("-a iih -C -- /bin/pwd", "dh", test) 220 | 221 | 222 | 223 | 224 | -------------------------------------------------------------------------------- /tty_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | for stream in [sys.stdin, sys.stdout, sys.stderr]: 5 | print(f"{stream.name:8s} " + ("tty" if os.isatty(stream.fileno()) else "not tty")) 6 | 7 | print(f"args: {sys.argv}") --------------------------------------------------------------------------------