├── .gitignore ├── LICENSE ├── README.md ├── README.zh_cn.md ├── debug_main.py ├── docs └── README.md ├── naja_atra ├── __init__.py ├── __main__.py ├── app_conf.py ├── http_servers │ ├── __init__.py │ ├── coroutine_http_server.py │ ├── http_server.py │ ├── routing_server.py │ └── threading_http_server.py ├── models.py ├── request_handlers │ ├── __init__.py │ ├── http_controller_handler.py │ ├── http_request_handler.py │ ├── http_session_local_impl.py │ ├── model_bindings.py │ └── websocket_controller_handler.py ├── server.py └── utils │ ├── __init__.py │ ├── http_utils.py │ └── logger.py ├── pyproject.toml ├── setup.py ├── tests ├── __init__.py ├── ctrls │ ├── __init__.py │ ├── my_controllers.py │ ├── my_controllers_model_binding.py │ └── ws_controllers.py ├── static │ ├── a.txt │ ├── inner │ │ ├── b.txt │ │ └── y.ini │ ├── x.ini │ └── 中文.txt └── test_all_ctrls.py └── upload.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | v1/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | 108 | # for development only 109 | tests/certs/ 110 | temp/ 111 | 112 | # vscode 113 | .vscode/ 114 | **/tmp/** 115 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Keijack Wu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Naja-Atra 2 | 3 | As more functions are added to this project, and it's not that simple yet, so it's renamed to `naja-atra` to continue maintainance. You can find the new code [here](https://www.github.com/naja-atra/naja-atra), And also, you can find some new extensions there. 4 | -------------------------------------------------------------------------------- /debug_main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # If you want to run this file, please install following package to run. 4 | # python3 -m pip install werkzeug 'uvicorn[standard]' 5 | # 6 | 7 | from naja_atra import request_map 8 | import naja_atra.server as server 9 | 10 | import os 11 | import signal 12 | 13 | 14 | from naja_atra.http_servers.http_server import HttpServer 15 | from naja_atra.utils.logger import get_logger, set_level 16 | 17 | from naja_atra import get_app_conf 18 | set_level("DEBUG") 19 | 20 | 21 | @request_map("/stop") 22 | def stop(): 23 | server.stop() 24 | return "关闭关闭成功!" 25 | 26 | 27 | _logger = get_logger("http_test") 28 | PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__)) 29 | 30 | _server: HttpServer = None 31 | 32 | app = get_app_conf("2") 33 | 34 | 35 | @app.route("/") 36 | def idx(): 37 | return {"msg": "hello, world!"} 38 | 39 | 40 | def start_via_class(): 41 | global _server 42 | 43 | _server = HttpServer(host=('', 9091), 44 | resources={"/p/**": f"{PROJECT_ROOT}/tests/static"}, 45 | app_conf=app) 46 | _server.start() 47 | 48 | 49 | def start_server(): 50 | _logger.info("start server in background. ") 51 | 52 | server.scan(base_dir="tests/ctrls", regx=r'.*controllers.*') 53 | server.start( 54 | # port=9443, 55 | port=9090, 56 | resources={"/public/*": f"{PROJECT_ROOT}/tests/static", 57 | "/*": f"{PROJECT_ROOT}/tests/static", 58 | '/inn/**': f"{PROJECT_ROOT}/tests/static", 59 | '**.txt': f"{PROJECT_ROOT}/tests/static", 60 | '*.ini': f"{PROJECT_ROOT}/tests/static", 61 | }, 62 | # ssl=True, 63 | # certfile=f"{PROJECT_ROOT}/tests/certs/fullchain.pem", 64 | # keyfile=f"{PROJECT_ROOT}/tests/certs//privkey.pem", 65 | gzip_content_types={"image/x-icon", "text/plain"}, 66 | gzip_compress_level=9, 67 | prefer_coroutine=False) 68 | 69 | 70 | def on_sig_term(signum, frame): 71 | server.stop() 72 | if _server: 73 | _server.shutdown() 74 | 75 | 76 | if __name__ == "__main__": 77 | signal.signal(signal.SIGTERM, on_sig_term) 78 | signal.signal(signal.SIGINT, on_sig_term) 79 | # Thread(target=start_via_class, daemon=True).start() 80 | # sleep(1) 81 | # start_via_class() 82 | # main(sys.argv[1:]) 83 | start_server() 84 | # start_server_wsgi() 85 | # start_server_werkzeug() 86 | # start_server_uvicorn() 87 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # python-simple-http-server 2 | 3 | [![PyPI version](https://badge.fury.io/py/simple-http-server.png)](https://badge.fury.io/py/simple-http-server) 4 | 5 | ## Discription 6 | 7 | This is a simple http server, use MVC like design. 8 | 9 | ## Support Python Version 10 | 11 | Python 3.7+ 12 | 13 | ## Why choose 14 | 15 | * Lightway. 16 | * Functional programing. 17 | * Filter chain support. 18 | * Session support, and can support distributed session by [this extention](https://github.com/keijack/python-simple-http-server-redis-session). 19 | * You can use [this extention](https://github.com/keijack/python-simple-http-server-jinja) to support `jinja` views. 20 | * Spring MVC like request mapping. 21 | * SSL support. 22 | * Websocket support 23 | * Easy to use. 24 | * Free style controller writing. 25 | * Easily integraded with WSGI servers. 26 | * Easily integraded with ASGI servers. Websocket will be supported when ASGI server enable websocket functions. 27 | * Coroutine mode support. 28 | 29 | ## Dependencies 30 | 31 | There are no other dependencies needed to run this project. However, if you want to run the unitests in the `tests` folder, you need to install `websocket` via pip: 32 | 33 | ```shell 34 | python3 -m pip install websocket-client 35 | ``` 36 | 37 | ## How to use 38 | 39 | ### Install 40 | 41 | ```shell 42 | python3 -m pip install simple_http_server 43 | ``` 44 | 45 | ### Minimum code / component requirement setup 46 | 47 | Minimum code to get things started should have at least one controller function,
48 | using the route and server modules from simple_http_server 49 | 50 | ```python 51 | from simple_http_server import route, server 52 | 53 | @route("/") 54 | def index(): 55 | return {"hello": "world"} 56 | 57 | server.start(port=9090) 58 | ``` 59 | 60 | ### Write Controllers 61 | 62 | ```python 63 | 64 | from simple_http_server import request_map 65 | from simple_http_server import Response 66 | from simple_http_server import MultipartFile 67 | from simple_http_server import Parameter 68 | from simple_http_server import Parameters 69 | from simple_http_server import Header 70 | from simple_http_server import JSONBody 71 | from simple_http_server import HttpError 72 | from simple_http_server import StaticFile 73 | from simple_http_server import Headers 74 | from simple_http_server import Cookies 75 | from simple_http_server import Cookie 76 | from simple_http_server import Redirect 77 | from simple_http_server import ModelDict 78 | 79 | # request_map has an alias name `route`, you can select the one you familiar with. 80 | @request_map("/index") 81 | def my_ctrl(): 82 | return {"code": 0, "message": "success"} # You can return a dictionary, a string or a `simple_http_server.simple_http_server.Response` object. 83 | 84 | 85 | @route("/say_hello", method=["GET", "POST"]) 86 | def my_ctrl2(name, name2=Parameter("name", default="KEIJACK"), model=ModelDict()): 87 | """name and name2 is the same""" 88 | name == name2 # True 89 | name == model["name"] # True 90 | return "hello, %s, %s" % (name, name2) 91 | 92 | 93 | @request_map("/error") 94 | def my_ctrl3(): 95 | return Response(status_code=500) 96 | 97 | 98 | @request_map("/exception") 99 | def exception_ctrl(): 100 | raise HttpError(400, "Exception") 101 | 102 | @request_map("/upload", method="GET") 103 | def show_upload(): 104 | root = os.path.dirname(os.path.abspath(__file__)) 105 | return StaticFile("%s/my_dev/my_test_index.html" % root, "text/html; charset=utf-8") 106 | 107 | 108 | @request_map("/upload", method="POST") 109 | def my_upload(img=MultipartFile("img")): 110 | root = os.path.dirname(os.path.abspath(__file__)) 111 | img.save_to_file(root + "/my_dev/imgs/" + img.filename) 112 | return "upload ok!" 113 | 114 | 115 | @request_map("/post_txt", method="POST") 116 | def normal_form_post(txt): 117 | return "hi, %s" % txt 118 | 119 | @request_map("/tuple") 120 | def tuple_results(): 121 | # The order here is not important, we consider the first `int` value as status code, 122 | # All `Headers` object will be sent to the response 123 | # And the first valid object whose type in (str, unicode, dict, StaticFile, bytes) will 124 | # be considered as the body 125 | return 200, Headers({"my-header": "headers"}), {"success": True} 126 | 127 | """ 128 | " Cookie_sc will not be written to response. It's just some kind of default 129 | " value 130 | """ 131 | @request_map("tuple_cookie") 132 | def tuple_with_cookies(all_cookies=Cookies(), cookie_sc=Cookie("sc")): 133 | print("=====> cookies ") 134 | print(all_cookies) 135 | print("=====> cookie sc ") 136 | print(cookie_sc) 137 | print("======<") 138 | import datetime 139 | expires = datetime.datetime(2018, 12, 31) 140 | 141 | cks = Cookies() 142 | # cks = cookies.SimpleCookie() # you could also use the build-in cookie objects 143 | cks["ck1"] = "keijack"request 144 | cks["ck1"]["path"] = "/" 145 | cks["ck1"]["expires"] = expires.strftime(Cookies.EXPIRE_DATE_FORMAT) 146 | # You can ignore status code, headers, cookies even body in this tuple. 147 | return Header({"xx": "yyy"}), cks, "OK" 148 | 149 | """ 150 | " If you visit /a/b/xyz/x,this controller function will be called, and `path_val` will be `xyz` 151 | """ 152 | @request_map("/a/b/{path_val}/x") 153 | def my_path_val_ctr(path_val=PathValue()): 154 | return f"{path_val}" 155 | 156 | @request_map("/star/*") # /star/c will find this controller, but /star/c/d not. 157 | @request_map("*/star") # /c/star will find this controller, but /c/d/star not. 158 | def star_path(path_val=PathValue()): 159 | return f"{path_val}" 160 | 161 | @request_map("/star/**") # Both /star/c and /star/c/d will find this controller. 162 | @request_map("**/star") # Both /c/star and /c/d/stars will find this controller. 163 | def star_path(path_val=PathValue()): 164 | return f"{path_val}" 165 | 166 | @request_map("/redirect") 167 | def redirect(): 168 | return Redirect("/index") 169 | 170 | @request_map("session") 171 | def test_session(session=Session(), invalid=False): 172 | ins = session.get_attribute("in-session") 173 | if not ins: 174 | session.set_attribute("in-session", "Hello, Session!") 175 | 176 | __logger.info("session id: %s" % session.id) 177 | if invalid: 178 | __logger.info("session[%s] is being invalidated. " % session.id) 179 | session.invalidate() 180 | return "%s" % str(ins) 181 | 182 | # use coroutine, these controller functions will work both in a coroutine mode or threading mode. 183 | 184 | async def say(sth: str = ""): 185 | _logger.info(f"Say: {sth}") 186 | return f"Success! {sth}" 187 | 188 | @request_map("/中文/coroutine") 189 | async def coroutine_ctrl(hey: str = "Hey!"): 190 | return await say(hey) 191 | 192 | @route("/res/write/bytes") 193 | def res_writer(response: Response): 194 | response.status_code = 200 195 | response.add_header("Content-Type", "application/octet-stream") 196 | response.write_bytes(b'abcd') 197 | response.write_bytes(bytearray(b'efg')) 198 | response.close() 199 | ``` 200 | 201 | Beside using the default values, you can also use variable annotations to specify your controller function's variables. 202 | 203 | ```python 204 | @request_map("/say_hello/to/{name}", method=["GET", "POST", "PUT"]) 205 | def your_ctroller_function( 206 | user_name: str, # req.parameter["user_name"],400 error will raise when there's no such parameter in the query string. 207 | password: str, # req.parameter["password"],400 error will raise when there's no such parameter in the query string. 208 | skills: list, # req.parameters["skills"],400 error will raise when there's no such parameter in the query string. 209 | all_headers: Headers, # req.headers 210 | user_token: Header, # req.headers["user_token"],400 error will raise when there's no such parameter in the quest headers. 211 | all_cookies: Cookies, # req.cookies, return all cookies 212 | user_info: Cookie, # req.cookies["user_info"],400 error will raise when there's no such parameter in the cookies. 213 | name: PathValue, # req.path_values["name"],get the {name} value from your path. 214 | session: Session # req.getSession(True),get the session, if there is no sessions, create one. 215 | ): 216 | return "Hello, World!" 217 | 218 | # you can use `params` to narrow the controller mapping, the following examples shows only the `params` mapping, ignoring the 219 | # `headers` examples for the usage is almost the same as the `params`. 220 | @request("/exact_params", method="GET", params="a=b") 221 | def exact_params(a: str): 222 | print(f"{a}") # b 223 | return {"result": "ok"} 224 | 225 | @request("/exact_params", method="GET", params="a!=b") 226 | def exact_not_params(a: str): 227 | print(f"{a}") # b 228 | return {"result": "ok"} 229 | 230 | @request("/exact_params", method="GET", params="a^=b") 231 | def exact_startwith_params(a: str): 232 | print(f"{a}") # b 233 | return {"result": "ok"} 234 | 235 | @request("/exact_params", method="GET", params="!a") 236 | def no_params(): 237 | return {"result": "ok"} 238 | 239 | @request("/exact_params", method="GET", params="a") 240 | def must_has_params(): 241 | return {"result": "ok"} 242 | 243 | # If multiple expressions are set, all expressions must be matched to enter this controller function. 244 | @request("/exact_params", method="GET", params=["a=b", "c!=d"]) 245 | def multipul_params(): 246 | return {"result": "ok"} 247 | 248 | # You can set `match_all_params_expressions` to False to make that the url can enter this controller function even only one expression is matched. 249 | @request("/exact_params", method="GET", params=["a=b", "c!=d"], match_all_params_expressions=False) 250 | def multipul_params(): 251 | return {"result": "ok"} 252 | ``` 253 | 254 | We recommend using functional programing to write controller functions. but if you realy want to use Object, you can use `@request_map` in a class method. For doing this, every time a new request comes, a new MyController object will be created. 255 | 256 | ```python 257 | 258 | class MyController: 259 | 260 | def __init__(self) -> None: 261 | self._name = "ctr object" 262 | 263 | @request_map("/obj/say_hello", method="GET") 264 | def my_ctrl_mth(self, name: str): 265 | return {"message": f"hello, {name}, {self._name} says. "} 266 | 267 | ``` 268 | 269 | If you want a singleton, you can add a `@controller` decorator to the class. 270 | 271 | ```python 272 | 273 | @controller 274 | class MyController: 275 | 276 | def __init__(self) -> None: 277 | self._name = "ctr object" 278 | 279 | @request_map("/obj/say_hello", method="GET") 280 | def my_ctrl_mth(self, name: str): 281 | return {"message": f"hello, {name}, {self._name} says. "} 282 | 283 | ``` 284 | 285 | You can also add the `@request_map` to your class, this will be as the part of the url. 286 | 287 | ```python 288 | 289 | @controller 290 | @request_map("/obj", method="GET") 291 | class MyController: 292 | 293 | def __init__(self) -> None: 294 | self._name = "ctr object" 295 | 296 | @request_map 297 | def my_ctrl_default_mth(self, name: str): 298 | return {"message": f"hello, {name}, {self._name} says. "} 299 | 300 | @request_map("/say_hello", method=("GET", "POST")) 301 | def my_ctrl_mth(self, name: str): 302 | return {"message": f"hello, {name}, {self._name} says. "} 303 | 304 | ``` 305 | 306 | You can specify the `init` variables in `@controller` decorator. 307 | 308 | ```python 309 | 310 | @controller(args=["ctr_name"], kwargs={"desc": "this is a key word argument"}) 311 | @request_map("/obj", method="GET") 312 | class MyController: 313 | 314 | def __init__(self, name, desc="") -> None: 315 | self._name = f"ctr[{name}] - {desc}" 316 | 317 | @request_map 318 | def my_ctrl_default_mth(self, name: str): 319 | return {"message": f"hello, {name}, {self._name} says. "} 320 | 321 | @request_map("/say_hello", method=("GET", "POST")) 322 | def my_ctrl_mth(self, name: str): 323 | return {"message": f"hello, {name}, {self._name} says. "} 324 | 325 | ``` 326 | 327 | From `0.7.0`, `@request_map` support regular expression mapping. 328 | 329 | ```python 330 | # url `/reg/abcef/aref/xxx` can map the flowing controller: 331 | @route(regexp="^(reg/(.+))$", method="GET") 332 | def my_reg_ctr(reg_groups: RegGroups, reg_group: RegGroup = RegGroup(1)): 333 | print(reg_groups) # will output ("reg/abcef/aref/xxx", "abcef/aref/xxx") 334 | print(reg_group) # will output "abcef/aref/xxx" 335 | return f"{self._name}, {reg_group.group},{reg_group}" 336 | ``` 337 | Regular expression mapping a class: 338 | 339 | ```python 340 | @controller(args=["ctr_name"], kwargs={"desc": "this is a key word argument"}) 341 | @request_map("/obj", method="GET") # regexp do not work here, method will still available 342 | class MyController: 343 | 344 | def __init__(self, name, desc="") -> None: 345 | self._name = f"ctr[{name}] - {desc}" 346 | 347 | @request_map 348 | def my_ctrl_default_mth(self, name: str): 349 | return {"message": f"hello, {name}, {self._name} says. "} 350 | 351 | @route(regexp="^(reg/(.+))$") # prefix `/obj` from class decorator will be ignored, but `method`(GET in this example) from class decorator will still work. 352 | def my_ctrl_mth(self, name: str): 353 | return {"message": f"hello, {name}, {self._name} says. "} 354 | 355 | ``` 356 | 357 | ### Session 358 | 359 | Defaultly, the session is stored in local, you can extend `SessionFactory` and `Session` classes to implement your own session storage requirement (like store all data in redis or memcache) 360 | 361 | ```python 362 | from simple_http_server import Session, SessionFactory, set_session_factory 363 | 364 | class MySessionImpl(Session): 365 | 366 | def __init__(self): 367 | super().__init__() 368 | # your own implementation 369 | 370 | @property 371 | def id(self) -> str: 372 | # your own implementation 373 | 374 | @property 375 | def creation_time(self) -> float: 376 | # your own implementation 377 | 378 | @property 379 | def last_accessed_time(self) -> float: 380 | # your own implementation 381 | 382 | @property 383 | def is_new(self) -> bool: 384 | # your own implementation 385 | 386 | @property 387 | def attribute_names(self) -> Tuple: 388 | # your own implementation 389 | 390 | def get_attribute(self, name: str) -> Any: 391 | # your own implementation 392 | 393 | def set_attribute(self, name: str, value: Any) -> None: 394 | # your own implementation 395 | 396 | def invalidate(self) -> None: 397 | # your own implementation 398 | 399 | class MySessionFacImpl(SessionFactory): 400 | 401 | def __init__(self): 402 | super().__init__() 403 | # your own implementation 404 | 405 | 406 | def get_session(self, session_id: str, create: bool = False) -> Session: 407 | # your own implementation 408 | return MySessionImpl() 409 | 410 | set_session_factory(MySessionFacImpl()) 411 | 412 | ``` 413 | 414 | There is an offical Redis implementation here: https://github.com/keijack/python-simple-http-server-redis-session.git 415 | 416 | ### Websocket 417 | 418 | To handle a websocket session, you should handle multiple events, so it's more reasonable to use a class rather than functions to do it. 419 | 420 | In this framework, you should use `@websocket_handler` to decorate the class you want to handle websocket session. Specific event listener methods should be defined in a fixed way. However, the easiest way to do it is to inherit `simple_http_server.WebsocketHandler` class, and choose the event you want to implement. But this inheritance is not compulsory. 421 | 422 | You can configure `endpoit` or `regexp` in `@websocket_handler` to setup which url the class should handle. Alongside, there is a `singleton` field, which is set to `True` by default. Which means that all connections are handle by ONE object of this class. If this field is set to `False`, objects will be created when every `WebsocketSession` try to connect. 423 | 424 | ```python 425 | from simple_http_server import WebsocketHandler, WebsocketRequest,WebsocketSession, websocket_handler 426 | 427 | @websocket_handler(endpoint="/ws/{path_val}") 428 | class WSHandler(WebsocketHandler): 429 | 430 | def on_handshake(self, request: WebsocketRequest): 431 | """ 432 | " 433 | " You can get path/headers/path_values/cookies/query_string/query_parameters from request. 434 | " 435 | " You should return a tuple means (http_status_code, headers) 436 | " 437 | " If status code in (0, None, 101), the websocket will be connected, or will return the status you return. 438 | " 439 | " All headers will be send to client 440 | " 441 | """ 442 | _logger.info(f">>{session.id}<< open! {request.path_values}") 443 | return 0, {} 444 | 445 | def on_open(self, session: WebsocketSession): 446 | """ 447 | " 448 | " Will be called when the connection opened. 449 | " 450 | """ 451 | _logger.info(f">>{session.id}<< open! {session.request.path_values}") 452 | 453 | def on_close(self, session: WebsocketSession, reason: str): 454 | """ 455 | " 456 | " Will be called when the connection closed. 457 | " 458 | """ 459 | _logger.info(f">>{session.id}<< close::{reason}") 460 | 461 | def on_ping_message(self, session: WebsocketSession = None, message: bytes = b''): 462 | """ 463 | " 464 | " Will be called when receive a ping message. Will send all the message bytes back to client by default. 465 | " 466 | """ 467 | session.send_pone(message) 468 | 469 | def on_pong_message(self, session: WebsocketSession = None, message: bytes = ""): 470 | """ 471 | " 472 | " Will be called when receive a pong message. 473 | " 474 | """ 475 | pass 476 | 477 | def on_text_message(self, session: WebsocketSession, message: str): 478 | """ 479 | " 480 | " Will be called when receive a text message. 481 | " 482 | """ 483 | _logger.info(f">>{session.id}<< on text message: {message}") 484 | session.send(message) 485 | 486 | def on_binary_message(self, session: WebsocketSession = None, message: bytes = b''): 487 | """ 488 | " 489 | " Will be called when receive a binary message if you have not consumed all the bytes in `on_binary_frame` 490 | " method. 491 | " 492 | """ 493 | pass 494 | 495 | def on_binary_frame(self, session: WebsocketSession = None, fin: bool = False, frame_payload: bytes = b''): 496 | """ 497 | " 498 | " If you are sending a continuation binary message to server, this will be called every time a frame is 499 | " received, you can consumed all the bytes in this method, e.g. save all bytes to a file. By doing so, 500 | " you should not return and value in this method. 501 | " 502 | " If you does not implement this method or return a True in this method, all the bytes will be caced in 503 | " memory and be sent to your `on_binary_message` method. 504 | " 505 | """ 506 | return True 507 | 508 | @websocket_handler(regexp="^/ws-reg/([a-zA-Z0-9]+)$", singleton=False) 509 | class WSHandler(WebsocketHandler): 510 | 511 | """ 512 | " You code here 513 | """ 514 | 515 | ``` 516 | 517 | But if you want to only handle one event, you can also use a function to handle it. 518 | 519 | ```python 520 | 521 | from simple_http_server import WebsocketCloseReason, WebsocketHandler, WebsocketRequest, WebsocketSession, websocket_message, websocket_handshake, websocket_open, websocket_close, WEBSOCKET_MESSAGE_TEXT 522 | 523 | @websocket_handshake(endpoint="/ws-fun/{path_val}") 524 | def ws_handshake(request: WebsocketRequest): 525 | return 0, {} 526 | 527 | 528 | @websocket_open(endpoint="/ws-fun/{path_val}") 529 | def ws_open(session: WebsocketSession): 530 | _logger.info(f">>{session.id}<< open! {session.request.path_values}") 531 | 532 | 533 | @websocket_close(endpoint="/ws-fun/{path_val}") 534 | def ws_close(session: WebsocketSession, reason: WebsocketCloseReason): 535 | _logger.info( 536 | f">>{session.id}<< close::{reason.message}-{reason.code}-{reason.reason}") 537 | 538 | 539 | @websocket_message(endpoint="/ws-fun/{path_val}", message_type=WEBSOCKET_MESSAGE_TEXT) 540 | # You can define a function in a sync or async way. 541 | async def ws_text(session: WebsocketSession, message: str): 542 | _logger.info(f">>{session.id}<< on text message: {message}") 543 | session.send(f"{session.request.path_values['path_val']}-{message}") 544 | if message == "close": 545 | session.close() 546 | ``` 547 | 548 | ### Error pages 549 | 550 | You can use `@error_message` to specify your own error page. See: 551 | 552 | ```python 553 | from simple_http_server import error_message 554 | # map specified codes 555 | @error_message("403", "404") 556 | def my_40x_page(message: str, explain=""): 557 | return f""" 558 | 559 | 560 | 发生错误! 561 | 562 | 563 | message: {message}, explain: {explain} 564 | 565 | 566 | """ 567 | 568 | # map specified code rangs 569 | @error_message("40x", "50x") 570 | def my_error_message(code, message, explain=""): 571 | return f"{code}-{message}-{explain}" 572 | 573 | # map all error page 574 | @error_message 575 | def my_error_message(code, message, explain=""): 576 | return f"{code}-{message}-{explain}" 577 | ``` 578 | 579 | ### Write filters 580 | 581 | This server support filters, you can use `request_filter` decorator to define your filters. 582 | 583 | ```python 584 | from simple_http_server import request_filter 585 | 586 | @request_filter("/tuple/**") # use wildcard 587 | @request_filter(regexp="^/tuple") # use regular expression 588 | def filter_tuple(ctx): 589 | print("---------- through filter ---------------") 590 | # add a header to request header 591 | ctx.request.headers["filter-set"] = "through filter" 592 | if "user_name" not in ctx.request.parameter: 593 | ctx.response.send_redirect("/index") 594 | elif "pass" not in ctx.request.parameter: 595 | ctx.response.send_error(400, "pass should be passed") 596 | # you can also raise a HttpError 597 | # raise HttpError(400, "pass should be passed") 598 | else: 599 | # you should always use do_chain method to go to the next 600 | ctx.do_chain() 601 | ``` 602 | 603 | ### Start your server 604 | 605 | ```python 606 | # If you place the controllers method in the other files, you should import them here. 607 | 608 | import simple_http_server.server as server 609 | import my_test_ctrl 610 | 611 | 612 | def main(*args): 613 | # The following method can import several controller files once. 614 | server.scan("my_ctr_pkg", r".*controller.*") 615 | server.start() 616 | 617 | if __name__ == "__main__": 618 | main() 619 | ``` 620 | 621 | If you want to specify the host and port: 622 | 623 | ```python 624 | server.start(host="", port=8080) 625 | ``` 626 | 627 | If you want to specify the resources path: 628 | 629 | ```python 630 | server.start(resources={"/path_prefix/*", "/absolute/dir/root/path", # Match the files in the given folder with a special path prefix. 631 | "/path_prefix/**", "/absolute/dir/root/path", # Match all the files in the given folder and its sub-folders with a special path prefix. 632 | "*.suffix", "/absolute/dir/root/path", # Match the specific files in the given folder. 633 | "**.suffix", "/absolute/dir/root/path", # Match the specific files in the given folder and its sub-folders. 634 | }) 635 | ``` 636 | 637 | If you want to use ssl: 638 | 639 | ```python 640 | server.start(host="", 641 | port=8443, 642 | ssl=True, 643 | ssl_protocol=ssl.PROTOCOL_TLS_SERVER, # Optional, default is ssl.PROTOCOL_TLS_SERVER, which will auto detect the highted protocol version that both server and client support. 644 | ssl_check_hostname=False, #Optional, if set to True, if the hostname is not match the certificat, it cannot establish the connection, default is False. 645 | keyfile="/path/to/your/keyfile.key", 646 | certfile="/path/to/your/certfile.cert", 647 | keypass="", # Optional, your private key's password 648 | ) 649 | ``` 650 | 651 | ### Coroutine 652 | 653 | From `0.12.0`, you can use coroutine tasks than threads to handle requests, you can set the `prefer_coroutine` parameter in start method to enable the coroutine mode. 654 | 655 | ```python 656 | server.start(prefer_coroutine=True) 657 | ``` 658 | 659 | From `0.13.0`, coroutine mode uses the coroutine server, that means all requests will use the async I/O rather than block I/O. So you can now use `async def` to define all your controllers including the Websocket event callback methods. 660 | 661 | If you call the server starting in a async function, you can all its async version, by doing this, there sever will use the same event loop with your other async functions. 662 | 663 | ```python 664 | await server.start_async(prefer_coroutine=True) 665 | ``` 666 | 667 | ## Logger 668 | 669 | The default logger is try to write logs to the screen, you can specify the logger handler to write it to a file. 670 | 671 | ```python 672 | import simple_http_server.logger as logger 673 | import logging 674 | 675 | _formatter = logging.Formatter(fmt='[%(asctime)s]-[%(name)s]-%(levelname)-4s: %(message)s') 676 | _handler = logging.TimedRotatingFileHandler("/var/log/simple_http_server.log", when="midnight", backupCount=7) 677 | _handler.setFormatter(_formatter) 678 | _handler.setLevel("INFO") 679 | 680 | logger.set_handler(_handler) 681 | ``` 682 | 683 | If you want to add a handler rather than replace the inner one, you can use: 684 | 685 | ```python 686 | logger.add_handler(_handler) 687 | ``` 688 | 689 | If you want to change the logger level: 690 | 691 | ```python 692 | logger.set_level("DEBUG") 693 | ``` 694 | 695 | You can get a stand alone logger which is independent from the framework one via a new class `logger.LoggerFactory`. 696 | 697 | ```python 698 | import simple_http_server.logger as logger 699 | 700 | log = logger.get_logger("my_service", "my_log_fac") 701 | 702 | # If you want to set a different log level to this logger factory: 703 | 704 | log_fac = logger.get_logger_factory("my_log_fac") 705 | log_fac.log_level = "DEBUG" 706 | log = log_fac.get_logger("my_service") 707 | 708 | log.info(...) 709 | 710 | ``` 711 | 712 | 713 | ## WSGI Support 714 | 715 | You can use this module in WSGI apps. 716 | 717 | ```python 718 | import simple_http_server.server as server 719 | import os 720 | from simple_http_server import request_map 721 | 722 | 723 | # scan all your controllers 724 | server.scan("tests/ctrls", r'.*controllers.*') 725 | # or define a new controller function here 726 | @request_map("/hello_wsgi") 727 | def my_controller(name: str): 728 | return 200, "Hello, WSGI!" 729 | # resources is optional 730 | wsgi_proxy = server.init_wsgi_proxy(resources={"/public/*": f"/you/static/files/path"}) 731 | 732 | # wsgi app entrance. 733 | def simple_app(environ, start_response): 734 | return wsgi_proxy.app_proxy(environ, start_response) 735 | 736 | # If your entrance is async: 737 | async def simple_app(envion, start_response): 738 | return await wsgi_proxy.async_app_proxy(environ, start_response) 739 | ``` 740 | 741 | ## ASGI Support 742 | 743 | You can use this module in ASGI server, take `uvicorn` fro example: 744 | 745 | ```python 746 | 747 | import asyncio 748 | import uvicorn 749 | import simple_http_server.server as server 750 | from simple_http_server.server import ASGIProxy 751 | 752 | 753 | asgi_proxy: ASGIProxy = None 754 | init_asgi_proxy_lock: asyncio.Lock = asyncio.Lock() 755 | 756 | 757 | async def init_asgi_proxy(): 758 | global asgi_proxy 759 | if asgi_proxy == None: 760 | async with init_asgi_proxy_lock: 761 | if asgi_proxy == None: 762 | server.scan(base_dir="tests/ctrls", regx=r'.*controllers.*') 763 | asgi_proxy = server.init_asgi_proxy(resources={"/public/*": "tests/static"}) 764 | 765 | async def app(scope, receive, send): 766 | await init_asgi_proxy() 767 | await asgi_proxy.app_proxy(scope, receive, send) 768 | 769 | def main(): 770 | config = uvicorn.Config("main:app", host="0.0.0.0", port=9090, log_level="info") 771 | asgi_server = uvicorn.Server(config) 772 | asgi_server.run() 773 | 774 | if __name__ == "__main__": 775 | main() 776 | 777 | ``` 778 | 779 | ## Thanks 780 | 781 | The code that process websocket comes from the following project: https://github.com/Pithikos/python-websocket-server 782 | -------------------------------------------------------------------------------- /naja_atra/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Copyright (c) 2018 Keijack Wu 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | 25 | from .app_conf import * 26 | from .models import * 27 | 28 | name = "naja-atra" 29 | version = "1.0.2" 30 | -------------------------------------------------------------------------------- /naja_atra/__main__.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | import signal 5 | import getopt 6 | from . import server 7 | from .utils import logger 8 | 9 | _logger = logger.get_logger("naja_atra.__main__") 10 | 11 | 12 | def print_help(): 13 | print(""" 14 | python3 -m naja_atra [options] 15 | 16 | Options: 17 | -h --help Show this message and exit. 18 | -p --port Specify alternate port (default: 9090) 19 | -b --bind Specify alternate bind address (default: all interfaces) 20 | -s --scan Scan this path to find controllers, if absent, will scan the current work directory 21 | --regex Only import the files that macthes this regular expresseion. 22 | -r --resources Specify the resource directory 23 | --loglevel Specify log level (default: info) 24 | """) 25 | 26 | 27 | def on_sig_term(signum, frame): 28 | _logger.info(f"Receive signal [{signum}], stop server now...") 29 | server.stop() 30 | 31 | 32 | def main(argv): 33 | try: 34 | opts = getopt.getopt(argv, "p:b:s:r:h", 35 | ["port=", "bind=", "scan=", "resources=", "regex=", "loglevel=", "help"])[0] 36 | opts = dict(opts) 37 | if "-h" in opts or "--help" in opts: 38 | print_help() 39 | return 40 | port = int(opts.get("-p", opts.get("--port", "9090"))) 41 | scan_path = opts.get("-s", opts.get("--scan", os.getcwd())) 42 | regex = opts.get("--regex", "") 43 | res_dir = opts.get("-r", opts.get("--resources", "")) 44 | binding_host = opts.get("-b", opts.get("--bind", "0.0.0.0")) 45 | log_level = opts.get("--loglevel", "") 46 | 47 | if log_level: 48 | logger.set_level(log_level) 49 | signal.signal(signal.SIGTERM, on_sig_term) 50 | signal.signal(signal.SIGINT, on_sig_term) 51 | server.scan(regx=regex, project_dir=scan_path) 52 | server.start( 53 | host=binding_host, 54 | port=port, 55 | resources={"/**": res_dir}, 56 | keep_alive=False, 57 | prefer_coroutine=False) 58 | except Exception as e: 59 | print(f"Start server error: {e}") 60 | print_help() 61 | 62 | 63 | if __name__ == "__main__": 64 | main(sys.argv[1:]) 65 | -------------------------------------------------------------------------------- /naja_atra/http_servers/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Copyright (c) 2018 Keijack Wu 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | -------------------------------------------------------------------------------- /naja_atra/http_servers/coroutine_http_server.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | """ 5 | Copyright (c) 2018 Keijack Wu 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | """ 25 | 26 | 27 | from .routing_server import RoutingServer 28 | from ..request_handlers.model_bindings import ModelBindingConf 29 | from ..request_handlers.http_request_handler import HttpRequestHandler 30 | 31 | 32 | import asyncio 33 | import threading 34 | from asyncio.base_events import Server 35 | from asyncio.streams import StreamReader, StreamWriter 36 | from ssl import SSLContext 37 | from time import sleep 38 | 39 | from ..utils.logger import get_logger 40 | 41 | 42 | _logger = get_logger("naja_atra.http_servers.coroutine_http_server") 43 | 44 | 45 | class CoroutineHTTPServer(RoutingServer): 46 | 47 | def __init__(self, host: str = '', port: int = 9090, ssl: SSLContext = None, res_conf={}, model_binding_conf: ModelBindingConf = ModelBindingConf()) -> None: 48 | RoutingServer.__init__( 49 | self, res_conf, model_binding_conf=model_binding_conf) 50 | self.host: str = host 51 | self.port: int = port 52 | self.ssl: SSLContext = ssl 53 | self.server: Server = None 54 | self.__thread_local = threading.local() 55 | 56 | async def callback(self, reader: StreamReader, writer: StreamWriter): 57 | handler = HttpRequestHandler(reader, writer, routing_conf=self) 58 | await handler.handle_request() 59 | _logger.debug("Connection ends, close the writer.") 60 | writer.close() 61 | 62 | async def start_async(self): 63 | self.server = await asyncio.start_server( 64 | self.callback, host=self.host, port=self.port, ssl=self.ssl) 65 | async with self.server: 66 | try: 67 | await self.server.serve_forever() 68 | except asyncio.CancelledError: 69 | _logger.debug( 70 | "Some requests are lost for the reason that the server is shutted down.") 71 | finally: 72 | await self.server.wait_closed() 73 | 74 | def _get_event_loop(self) -> asyncio.AbstractEventLoop: 75 | if not hasattr(self.__thread_local, "event_loop"): 76 | try: 77 | self.__thread_local.event_loop = asyncio.new_event_loop() 78 | except: 79 | self.__thread_local.event_loop = asyncio.get_event_loop() 80 | return self.__thread_local.event_loop 81 | 82 | def start(self): 83 | self._get_event_loop().run_until_complete(self.start_async()) 84 | 85 | def _shutdown(self): 86 | _logger.debug("Try to shutdown server.") 87 | self.server.close() 88 | loop = self.server.get_loop() 89 | loop.call_soon_threadsafe(loop.stop) 90 | 91 | def shutdown(self): 92 | wait_time = 3 93 | while wait_time: 94 | sleep(1) 95 | _logger.debug(f"couting to shutdown: {wait_time}") 96 | wait_time = wait_time - 1 97 | if wait_time == 0: 98 | _logger.debug("shutdown server....") 99 | self._shutdown() 100 | -------------------------------------------------------------------------------- /naja_atra/http_servers/http_server.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | """ 5 | Copyright (c) 2018 Keijack Wu 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | """ 25 | 26 | 27 | from ssl import PROTOCOL_TLS_SERVER, SSLContext 28 | 29 | from typing import Dict, Tuple 30 | from .coroutine_http_server import CoroutineHTTPServer 31 | from .threading_http_server import ThreadingHTTPServer 32 | 33 | 34 | from ..app_conf import _ControllerFunction, _WebsocketHandlerClass, AppConf, get_app_conf 35 | 36 | 37 | from ..utils.logger import get_logger 38 | 39 | 40 | _logger = get_logger("naja_atra.http_servers.http_server") 41 | 42 | 43 | class HttpServer: 44 | """Dispatcher Http server""" 45 | 46 | def map_filter(self, filter_conf): 47 | self.server.map_filter(filter_conf) 48 | 49 | def map_controller(self, ctrl: _ControllerFunction): 50 | self.server.map_controller(ctrl) 51 | 52 | def map_websocket_handler(self, handler: _WebsocketHandlerClass): 53 | self.server.map_websocket_handler(handler) 54 | 55 | def map_error_page(self, code, func): 56 | self.server.map_error_page(code, func) 57 | 58 | def __init__(self, 59 | host: Tuple[str, int] = ('', 9090), 60 | ssl: bool = False, 61 | ssl_protocol: int = PROTOCOL_TLS_SERVER, 62 | ssl_check_hostname: bool = False, 63 | keyfile: str = "", 64 | certfile: str = "", 65 | keypass: str = "", 66 | ssl_context: SSLContext = None, 67 | resources: Dict[str, str] = {}, 68 | prefer_corountine=False, 69 | max_workers: int = None, 70 | connection_idle_time=None, 71 | keep_alive=True, 72 | keep_alive_max_request=None, 73 | gzip_content_types=set(), 74 | gzip_compress_level=9, 75 | app_conf: AppConf = None): 76 | self.host = host 77 | self.__ready = False 78 | 79 | self.ssl = ssl 80 | 81 | if ssl: 82 | if ssl_context: 83 | self.ssl_ctx = ssl_context 84 | else: 85 | assert keyfile and certfile, "keyfile and certfile should be provided. " 86 | ssl_ctx = SSLContext(protocol=ssl_protocol) 87 | ssl_ctx.check_hostname = ssl_check_hostname 88 | ssl_ctx.load_cert_chain( 89 | certfile=certfile, keyfile=keyfile, password=keypass) 90 | self.ssl_ctx = ssl_ctx 91 | else: 92 | self.ssl_ctx = None 93 | 94 | appconf = app_conf or get_app_conf() 95 | if prefer_corountine: 96 | _logger.info( 97 | f"Start server in corouting mode, listen to port: {self.host[1]}") 98 | self.server = CoroutineHTTPServer( 99 | self.host[0], self.host[1], self.ssl_ctx, resources, model_binding_conf=appconf.model_binding_conf) 100 | else: 101 | _logger.info( 102 | f"Start server in threading mixed mode, listen to port {self.host[1]}") 103 | self.server = ThreadingHTTPServer( 104 | self.host, resources, model_binding_conf=appconf.model_binding_conf, max_workers=max_workers) 105 | if self.ssl_ctx: 106 | self.server.socket = self.ssl_ctx.wrap_socket( 107 | self.server.socket, server_side=True) 108 | 109 | self.server.gzip_compress_level = gzip_compress_level 110 | self.server.gzip_content_types = gzip_content_types 111 | 112 | filters = appconf._get_filters() 113 | # filter configuration 114 | for ft in filters: 115 | self.map_filter(ft) 116 | 117 | request_mappings = appconf._get_request_mappings() 118 | # request mapping 119 | for ctr in request_mappings: 120 | self.map_controller(ctr) 121 | 122 | ws_handlers = appconf._get_websocket_handlers() 123 | 124 | for wshandler in ws_handlers: 125 | self.map_websocket_handler(wshandler) 126 | 127 | err_pages = appconf._get_error_pages() 128 | for code, func in err_pages.items(): 129 | self.map_error_page(code, func) 130 | self.server.keep_alive = keep_alive 131 | self.server.connection_idle_time = connection_idle_time 132 | self.server.keep_alive_max_request = keep_alive_max_request 133 | self.server.session_factory = appconf.session_factory 134 | 135 | @property 136 | def ready(self): 137 | return self.__ready 138 | 139 | def resources(self, res={}): 140 | self.server.res_conf = res 141 | 142 | def start(self): 143 | try: 144 | self.__ready = True 145 | self.server.start() 146 | except: 147 | self.__ready = False 148 | raise 149 | 150 | async def start_async(self): 151 | try: 152 | self.__ready = True 153 | await self.server.start_async() 154 | except: 155 | self.__ready = False 156 | raise 157 | 158 | def shutdown(self): 159 | # shutdown it in a seperate thread. 160 | self.server.shutdown() 161 | -------------------------------------------------------------------------------- /naja_atra/http_servers/routing_server.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | """ 5 | Copyright (c) 2018 Keijack Wu 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | """ 25 | 26 | 27 | from abc import abstractmethod 28 | import json 29 | import os 30 | import re 31 | 32 | 33 | from urllib.parse import unquote 34 | 35 | from typing import Any, Callable, Dict, List, Set, Tuple, Union 36 | 37 | from ..models import StaticFile, HttpSessionFactory 38 | from ..request_handlers.model_bindings import ModelBindingConf 39 | from ..app_conf import _WebsocketHandlerClass, _ControllerFunction 40 | 41 | from ..utils.http_utils import remove_url_first_slash, get_function_args, get_function_kwargs, get_path_reg_pattern 42 | from ..utils.logger import get_logger 43 | 44 | _logger = get_logger("naja_atra.http_servers.routing_server") 45 | 46 | 47 | _EXT_CONTENT_TYPE = { 48 | ".html": "text/html", 49 | ".htm": "text/html", 50 | ".xhtml": "text/html", 51 | ".css": "text/css", 52 | ".js": "text/javascript", 53 | ".jpg": "image/jpeg", 54 | ".jpeg": "image/jpeg", 55 | ".png": "image/png", 56 | ".ico": "image/x-icon", 57 | ".svg": "image/svg+xml", 58 | ".gif": "image/gif", 59 | ".avif": "image/avif", 60 | ".avifs": "image/avif", 61 | ".webp": "image/webp", 62 | ".pdf": "application/pdf", 63 | ".json": "application/json", 64 | ".mp4": "video/mp4", 65 | ".mp3": "video/mp3", 66 | ".txt": "text/plain" 67 | } 68 | 69 | 70 | class RoutingServer: 71 | 72 | HTTP_METHODS = ["OPTIONS", "GET", "HEAD", 73 | "POST", "PUT", "DELETE", "TRACE", "CONNECT"] 74 | 75 | def __init__(self, res_conf={}, model_binding_conf: ModelBindingConf = ModelBindingConf()): 76 | self.method_url_mapping: Dict[str, 77 | Dict[str, List[_ControllerFunction]]] = {"_": {}} 78 | self.path_val_url_mapping: Dict[str, Dict[str, List[Tuple[_ControllerFunction, List[str]]]]] = { 79 | "_": {}} 80 | self.method_regexp_mapping: Dict[str, Dict[str, List[_ControllerFunction]]] = { 81 | "_": {}} 82 | for mth in self.HTTP_METHODS: 83 | self.method_url_mapping[mth] = {} 84 | self.path_val_url_mapping[mth] = {} 85 | self.method_regexp_mapping[mth] = {} 86 | 87 | self.filter_mapping = {} 88 | self._res_conf = [] 89 | self.add_res_conf(res_conf) 90 | 91 | self.ws_mapping: Dict[str, _ControllerFunction] = {} 92 | self.ws_path_val_mapping: Dict[str, _ControllerFunction] = {} 93 | self.ws_regx_mapping: Dict[str, _ControllerFunction] = {} 94 | 95 | self.error_page_mapping = {} 96 | self.keep_alive = True 97 | self.__connection_idle_time: float = 60 98 | self.__keep_alive_max_request: int = 10 99 | self.session_factory: HttpSessionFactory = None 100 | self.model_binding_conf = model_binding_conf 101 | self.gzip_content_types: Set[str] = set() 102 | self.gzip_compress_level = 9 103 | 104 | @property 105 | def connection_idle_time(self): 106 | return self.__connection_idle_time 107 | 108 | @connection_idle_time.setter 109 | def connection_idle_time(self, val: float): 110 | if (isinstance(val, float) or isinstance(val, int)) and val > 0: 111 | self.__connection_idle_time = val 112 | 113 | @property 114 | def keep_alive_max_request(self): 115 | return self.__keep_alive_max_request 116 | 117 | @keep_alive_max_request.setter 118 | def keep_alive_max_request(self, val): 119 | if isinstance(val, int) and val > 0: 120 | self.__keep_alive_max_request = val 121 | 122 | def extend_gzip_content_types(self, content_types: Union[Set[str], List[str]]): 123 | for ctype in content_types: 124 | self.gzip_content_types.add(ctype.lower()) 125 | 126 | def put_to_method_url_mapping(self, method, url, ctrl): 127 | if url not in self.method_url_mapping[method]: 128 | self.method_url_mapping[method][url] = [] 129 | self.method_url_mapping[method][url].insert(0, ctrl) 130 | 131 | def put_to_path_val_url_mapping(self, method, path_pattern, ctrl, path_names): 132 | if path_pattern not in self.path_val_url_mapping[method]: 133 | self.path_val_url_mapping[method][path_pattern] = [] 134 | self.path_val_url_mapping[method][path_pattern].insert( 135 | 0, (ctrl, path_names)) 136 | 137 | def put_to_method_regexp_mapping(self, method, regexp, ctrl): 138 | if regexp not in self.method_regexp_mapping[method]: 139 | self.method_regexp_mapping[method][regexp] = [] 140 | self.method_regexp_mapping[method][regexp].insert(0, ctrl) 141 | 142 | @property 143 | def res_conf(self): 144 | return self._res_conf 145 | 146 | @res_conf.setter 147 | def res_conf(self, val: Dict[str, str]): 148 | self._res_conf.clear() 149 | self.add_res_conf(val) 150 | 151 | def add_res_conf(self, val: Dict[str, str]): 152 | if not val or not isinstance(val, dict): 153 | return 154 | for k, v in val.items(): 155 | res_k = k 156 | if not res_k.startswith("*") and res_k.endswith('/'): 157 | # xxx/ equals xxx/* 158 | res_k = res_k + "*" 159 | if res_k.startswith('*'): 160 | suffix = res_k[2:] if res_k.startswith("**") else res_k[1:] 161 | assert suffix.find('/') < 0 and suffix.find( 162 | '*') < 0, "If a resource path starts with *, only suffix can be configurated. " 163 | if res_k.startswith('**.'): 164 | # **.xxx 165 | suffix = res_k[3:] 166 | key = f'^[\\w%.\\-@!\\(\\)\\[\\]\\|\\$/]+\\.{suffix}$' 167 | elif res_k.startswith('*.'): 168 | # *.xxx 169 | suffix = res_k[2:] 170 | key = f'^[\\w%.\\-@!\\(\\)\\[\\]\\|\\$]+\\.{suffix}$' 171 | elif res_k.endswith("/**"): 172 | # xx/** 173 | prefix = res_k[0:-2] 174 | while prefix.startswith('/'): 175 | prefix = prefix[1:] 176 | assert prefix.find( 177 | "*") < 0, "You can only config a * or ** at the start or end of a path." 178 | key = f'^{prefix}([\\w%.\\-@!\\(\\)\\[\\]\\|\\$/]+)$' 179 | elif res_k.endswith("/*"): 180 | # xx/* 181 | prefix = res_k[0:-1] 182 | while prefix.startswith('/'): 183 | prefix = prefix[1:] 184 | assert prefix.find( 185 | "*") < 0, "You can only config a * or ** at the start or end of a path." 186 | key = f'^{prefix}([\\w%.\\-@!\\(\\)\\[\\]\\|\\$]+)$' 187 | 188 | if v.endswith(os.path.sep): 189 | val = v 190 | else: 191 | val = v + os.path.sep 192 | self._res_conf.append((key, val)) 193 | 194 | def map_controller(self, ctrl: _ControllerFunction): 195 | url = ctrl.url 196 | regexp = ctrl.regexp 197 | method = ctrl.method 198 | _logger.debug( 199 | f"map url {url}|{regexp} with method::{method}, headers::{ctrl.headers} and params::{ctrl.params} to function {ctrl.func}. ") 200 | assert method is None or method == "" or method.upper() in self.HTTP_METHODS 201 | _method = method.upper() if method is not None and method != "" else "_" 202 | if regexp: 203 | self.put_to_method_regexp_mapping(_method, regexp, ctrl) 204 | else: 205 | _url = remove_url_first_slash(url) 206 | 207 | path_pattern, path_names = get_path_reg_pattern(_url) 208 | if path_pattern is None: 209 | self.put_to_method_url_mapping(_method, _url, ctrl) 210 | else: 211 | self.put_to_path_val_url_mapping( 212 | _method, path_pattern, ctrl, path_names) 213 | 214 | def _res_(self, fpath: str): 215 | fext = os.path.splitext(fpath)[1] 216 | ext = fext.lower() 217 | content_type = _EXT_CONTENT_TYPE.get(ext, "application/octet-stream") 218 | return StaticFile(fpath, content_type) 219 | 220 | def get_url_controllers(self, path: str = "", method: str = "") -> List[Tuple[_ControllerFunction, Dict, List]]: 221 | # explicitly url matching 222 | if path in self.method_url_mapping[method]: 223 | return [(ctrl, {}, ()) for ctrl in self.method_url_mapping[method][path]] 224 | elif path in self.method_url_mapping["_"]: 225 | return [(ctrl, {}, ()) for ctrl in self.method_url_mapping["_"][path]] 226 | 227 | # url with path value matching 228 | path_val_res = self.__try_get_from_path_val(path, method) 229 | if path_val_res is None: 230 | path_val_res = self.__try_get_from_path_val(path, "_") 231 | if path_val_res is not None: 232 | return path_val_res 233 | 234 | # regexp 235 | regexp_res = self.__try_get_ctrl_from_regexp(path, method) 236 | if regexp_res is None: 237 | regexp_res = self.__try_get_ctrl_from_regexp(path, "_") 238 | if regexp_res is not None: 239 | return regexp_res 240 | # static files 241 | for k, v in self.res_conf: 242 | match_static_path_conf = re.match(k, path) 243 | _logger.debug( 244 | f"{path} macth static file conf {k} ? {match_static_path_conf}") 245 | if match_static_path_conf: 246 | if match_static_path_conf.groups(): 247 | fpath = f"{v}{match_static_path_conf.group(1)}" 248 | else: 249 | fpath = f"{v}{path}" 250 | 251 | def static_fun(): 252 | return self._res_(fpath) 253 | return [(_ControllerFunction(func=static_fun), {}, ())] 254 | return [] 255 | 256 | def __try_get_ctrl_from_regexp(self, path, method): 257 | for regex, ctrls in self.method_regexp_mapping[method].items(): 258 | m = re.match(regex, f"/{path}") or re.match(regex, path) 259 | _logger.debug( 260 | f"regexp::pattern::[{regex}] => path::[{path}] match? {m is not None}") 261 | if m: 262 | res = [] 263 | grps = tuple([unquote(v) for v in m.groups()]) 264 | for ctrl in ctrls: 265 | res.append((ctrl, [], grps)) 266 | return res 267 | return None 268 | 269 | def __try_get_from_path_val(self, path, method): 270 | for patterns, val in self.path_val_url_mapping[method].items(): 271 | m = re.match(patterns, path) 272 | _logger.debug( 273 | f"url with path value::pattern::[{patterns}] => path::[{path}] match? {m is not None}") 274 | if m: 275 | res = [] 276 | for ctrl_fun, path_names in val: 277 | path_values = {} 278 | for idx in range(len(path_names)): 279 | key = unquote(path_names[idx]) 280 | path_values[key] = unquote(m.groups()[idx]) 281 | res.append((ctrl_fun, path_values, ())) 282 | return res 283 | return None 284 | 285 | def map_filter(self, filter_conf: Dict[str, Any]): 286 | # {"path": p, "url_pattern": r, "func": filter_fun} 287 | path = filter_conf["path"] if "path" in filter_conf else "" 288 | regexp = filter_conf["url_pattern"] 289 | filter_fun = filter_conf["func"] 290 | if path: 291 | regexp = get_path_reg_pattern(path)[0] 292 | if not regexp: 293 | regexp = f"^{path}$" 294 | _logger.debug( 295 | f"[path: {path}] map url regexp {regexp} to function: {filter_fun}") 296 | self.filter_mapping[regexp] = filter_fun 297 | 298 | def get_matched_filters(self, path): 299 | return self._get_matched_filters(remove_url_first_slash(path)) + self._get_matched_filters(path) 300 | 301 | def _get_matched_filters(self, path): 302 | available_filters = [] 303 | for regexp, val in self.filter_mapping.items(): 304 | m = re.match(regexp, path) 305 | _logger.debug( 306 | f"filter:: [{regexp}], path:: [{path}] match? {m is not None}") 307 | if m: 308 | available_filters.append(val) 309 | return available_filters 310 | 311 | def map_websocket_handler(self, handler: _WebsocketHandlerClass): 312 | url = handler.url 313 | regexp = handler.regexp 314 | _logger.debug( 315 | f"map url {url}|{regexp} to controller class {handler.cls}") 316 | if regexp: 317 | self.ws_regx_mapping[regexp] = handler 318 | else: 319 | url = remove_url_first_slash(url) 320 | path_pattern, path_names = get_path_reg_pattern(url) 321 | if path_pattern is None: 322 | self.ws_mapping[url] = handler 323 | else: 324 | self.ws_path_val_mapping[path_pattern] = (handler, path_names) 325 | 326 | def get_websocket_handler(self, path): 327 | # explicitly mapping 328 | if path in self.ws_mapping: 329 | return self.ws_mapping[path], {}, () 330 | 331 | # path value mapping 332 | handler, path_vals = self.__try_get_ws_handler_from_path_val(path) 333 | if handler is not None: 334 | return handler, path_vals, () 335 | # regexp mapping 336 | return self.__try_get_ws_hanlder_from_regexp(path) 337 | 338 | def __try_get_ws_hanlder_from_regexp(self, path): 339 | for regex, handler in self.ws_regx_mapping.items(): 340 | m = re.match(regex, f"/{path}") or re.match(regex, path) 341 | _logger.debug( 342 | f"regexp::pattern::[{regex}] => path::[{path}] match? {m is not None}") 343 | if m: 344 | return handler, {}, tuple([unquote(v) for v in m.groups()]) 345 | return None, {}, () 346 | 347 | def __try_get_ws_handler_from_path_val(self, path): 348 | for patterns, val in self.ws_path_val_mapping.items(): 349 | m = re.match(patterns, path) 350 | _logger.debug( 351 | f"websocket endpoint with path value::pattern::[{patterns}] => path::[{path}] match? {m is not None}") 352 | if m: 353 | handler, path_names = val 354 | path_values = {} 355 | for idx in range(len(path_names)): 356 | key = unquote(path_names[idx]) 357 | path_values[key] = unquote(m.groups()[idx]) 358 | return handler, path_values 359 | return None, {} 360 | 361 | def map_error_page(self, code: str, error_page_fun: Callable): 362 | if not code: 363 | c = "_" 364 | else: 365 | c = str(code).lower() 366 | self.error_page_mapping[c] = error_page_fun 367 | 368 | def _default_error_page(self, code: int, message: str = "", explain: str = ""): 369 | return json.dumps({ 370 | "code": code, 371 | "message": message, 372 | "explain": explain 373 | }) 374 | 375 | def error_page(self, code: int, message: str = "", explain: str = ""): 376 | c = str(code) 377 | func = None 378 | if c in self.error_page_mapping: 379 | func = self.error_page_mapping[c] 380 | elif code > 200: 381 | c0x = c[0:2] + "x" 382 | if c0x in self.error_page_mapping: 383 | func = self.error_page_mapping[c0x] 384 | elif "_" in self.error_page_mapping: 385 | func = self.error_page_mapping["_"] 386 | 387 | if not func: 388 | func = self._default_error_page 389 | _logger.debug(f"error page function:: {func}") 390 | 391 | co = code 392 | msg = message 393 | exp = explain 394 | 395 | args_def = get_function_args(func, None) 396 | kwargs_def = get_function_kwargs(func, None) 397 | 398 | args = [] 399 | for n, t in args_def: 400 | _logger.debug(f"set value to error_page function -> {n}") 401 | if co is not None: 402 | if t is None or t == int: 403 | args.append(co) 404 | co = None 405 | continue 406 | if msg is not None: 407 | if t is None or t == str: 408 | args.append(msg) 409 | msg = None 410 | continue 411 | if exp is not None: 412 | if t is None or t == str: 413 | args.append(exp) 414 | exp = None 415 | continue 416 | args.append(None) 417 | 418 | kwargs = {} 419 | for n, v, t in kwargs_def: 420 | if co is not None: 421 | if (t is None and isinstance(v, int)) or t == int: 422 | kwargs[n] = co 423 | co = None 424 | continue 425 | if msg is not None: 426 | if (t is None and isinstance(v, str)) or t == str: 427 | kwargs[n] = msg 428 | msg = None 429 | continue 430 | if exp is not None: 431 | if (t is None and isinstance(v, str)) or t == str: 432 | kwargs[n] = exp 433 | exp = None 434 | continue 435 | kwargs[n] = v 436 | 437 | if args and kwargs: 438 | return func(*args, **kwargs) 439 | elif args: 440 | return func(*args) 441 | elif kwargs: 442 | return func(**kwargs) 443 | else: 444 | return func() 445 | 446 | @abstractmethod 447 | def start(self): 448 | pass 449 | 450 | @abstractmethod 451 | async def start_async(self): 452 | pass 453 | 454 | @abstractmethod 455 | def shutdown(self): 456 | pass 457 | -------------------------------------------------------------------------------- /naja_atra/http_servers/threading_http_server.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | """ 5 | Copyright (c) 2018 Keijack Wu 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | """ 25 | 26 | 27 | from .routing_server import RoutingServer 28 | from ..request_handlers.model_bindings import ModelBindingConf 29 | from ..request_handlers.http_request_handler import SocketServerStreamRequestHandlerWraper 30 | 31 | 32 | import socket 33 | import threading 34 | from concurrent.futures import ThreadPoolExecutor 35 | from socketserver import TCPServer 36 | 37 | from ..utils.logger import get_logger 38 | 39 | 40 | _logger = get_logger("naja_atra.http_servers.threading_http_server") 41 | 42 | 43 | class ThreadingHTTPServer(TCPServer, RoutingServer): 44 | 45 | allow_reuse_address = 1 # Seems to make sense in testing environment 46 | 47 | _default_max_workers = 50 48 | 49 | def server_bind(self): 50 | """Override server_bind to store the server name.""" 51 | TCPServer.server_bind(self) 52 | host, port = self.server_address[:2] 53 | self.server_name = socket.getfqdn(host) 54 | self.server_port = port 55 | 56 | def __init__(self, addr, res_conf={}, model_binding_conf: ModelBindingConf = ModelBindingConf(), max_workers: int = None): 57 | RoutingServer.__init__( 58 | self, res_conf, model_binding_conf=model_binding_conf) 59 | self.max_workers = max_workers or self._default_max_workers 60 | self.threadpool: ThreadPoolExecutor = ThreadPoolExecutor( 61 | thread_name_prefix="ReqThread", 62 | max_workers=self.max_workers) 63 | TCPServer.__init__(self, addr, SocketServerStreamRequestHandlerWraper) 64 | 65 | def process_request_thread(self, request, client_address): 66 | try: 67 | self.finish_request(request, client_address) 68 | except Exception: 69 | self.handle_error(request, client_address) 70 | finally: 71 | self.shutdown_request(request) 72 | 73 | # override 74 | def process_request(self, request, client_address): 75 | self.threadpool.submit( 76 | self.process_request_thread, request, client_address) 77 | 78 | def server_close(self): 79 | super().server_close() 80 | self.threadpool.shutdown(True) 81 | 82 | def start(self): 83 | self.serve_forever() 84 | 85 | async def start_async(self): 86 | self.start() 87 | 88 | def _shutdown(self) -> None: 89 | _logger.debug("shutdown http server in a seperate thread..") 90 | super().shutdown() 91 | 92 | def shutdown(self) -> None: 93 | threading.Thread(target=self._shutdown, daemon=False).start() 94 | -------------------------------------------------------------------------------- /naja_atra/models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Copyright (c) 2018 Keijack Wu 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | import http.cookies 25 | import time 26 | from abc import abstractmethod 27 | from typing import Any, Dict, List, Tuple, Union 28 | 29 | 30 | DEFAULT_ENCODING: str = "UTF-8" 31 | 32 | SESSION_COOKIE_NAME: str = "PY_SIM_HTTP_SER_SESSION_ID" 33 | 34 | WEBSOCKET_OPCODE_CONTINUATION: int = 0x0 35 | WEBSOCKET_OPCODE_TEXT: int = 0x1 36 | WEBSOCKET_OPCODE_BINARY: int = 0x2 37 | WEBSOCKET_OPCODE_CLOSE: int = 0x8 38 | WEBSOCKET_OPCODE_PING: int = 0x9 39 | WEBSOCKET_OPCODE_PONG: int = 0xA 40 | 41 | WEBSOCKET_MESSAGE_TEXT: str = "WEBSOCKET_MESSAGE_TEXT" 42 | WEBSOCKET_MESSAGE_BINARY: str = "WEBSOCKET_MESSAGE_BINARY" 43 | WEBSOCKET_MESSAGE_BINARY_FRAME: str = "WEBSOCKET_MESSAGE_BINARY_FRAME" 44 | WEBSOCKET_MESSAGE_PING: str = "WEBSOCKET_MESSAGE_PING" 45 | WEBSOCKET_MESSAGE_PONG: str = "WEBSOCKET_MESSAGE_PONG" 46 | 47 | 48 | class HttpSession: 49 | 50 | def __init__(self): 51 | self.max_inactive_interval: int = 30 * 60 52 | 53 | @property 54 | def id(self) -> str: 55 | return "" 56 | 57 | @property 58 | def creation_time(self) -> float: 59 | return 0 60 | 61 | @property 62 | def last_accessed_time(self) -> float: 63 | return 0 64 | 65 | @property 66 | def attribute_names(self) -> Tuple: 67 | return () 68 | 69 | @property 70 | def is_new(self) -> bool: 71 | return False 72 | 73 | @property 74 | def is_valid(self) -> bool: 75 | return time.time() - self.last_accessed_time < self.max_inactive_interval 76 | 77 | @abstractmethod 78 | def get_attribute(self, name: str) -> Any: 79 | return NotImplemented 80 | 81 | @abstractmethod 82 | def set_attribute(self, name: str, value: str) -> None: 83 | return NotImplemented 84 | 85 | @abstractmethod 86 | def invalidate(self) -> None: 87 | return NotImplemented 88 | 89 | 90 | class HttpSessionFactory: 91 | 92 | @abstractmethod 93 | def get_session(self, session_id: str, create: bool = False) -> HttpSession: 94 | return NotImplemented 95 | 96 | 97 | class Cookies(http.cookies.SimpleCookie): 98 | EXPIRE_DATE_FORMAT = "%a, %d %b %Y %H:%M:%S GMT" 99 | 100 | 101 | class RequestBodyReader: 102 | 103 | @abstractmethod 104 | async def read(self, n: int = -1) -> bytes: 105 | return NotImplemented 106 | 107 | 108 | class Request: 109 | """Request""" 110 | 111 | def __init__(self): 112 | self.method: str = "" # GET, POST, PUT, DELETE, HEAD, etc. 113 | self.headers: Dict[str, str] = {} # Request headers 114 | self.__cookies = Cookies() 115 | self.query_string: str = "" # Query String 116 | self.path_values: Dict[str, str] = {} 117 | self.reg_groups = () # If controller is matched via regexp, then ,all groups are save here 118 | self.path: str = "" # Path 119 | self.__parameters = {} # Parameters, key-value array, merged by query string and request body if the `Content-Type` in request header is `application/x-www-form-urlencoded` or `multipart/form-data` 120 | # Parameters, key-value, if more than one parameters with the same key, only the first one will be stored. 121 | self.__parameter = {} 122 | self._body: bytes = b"" # Request body 123 | # A dictionary if the `Content-Type` in request header is `application/json` 124 | self.json: Dict[str, Any] = None 125 | self.environment = {} 126 | self.reader: RequestBodyReader = None # A stream reader 127 | 128 | @property 129 | def cookies(self) -> Cookies: 130 | return self.__cookies 131 | 132 | @property 133 | def parameters(self) -> Dict[str, List[str]]: 134 | return self.__parameters 135 | 136 | @parameters.setter 137 | def parameters(self, val: Dict[str, List[str]]): 138 | self.__parameters = val 139 | self.__parameter = {} 140 | for k, v in self.__parameters.items(): 141 | self.__parameter[k] = v[0] 142 | 143 | @property 144 | def parameter(self) -> Dict[str, str]: 145 | return self.__parameter 146 | 147 | @property 148 | def host(self) -> str: 149 | if "Host" not in self.headers: 150 | return "" 151 | else: 152 | return self.headers["Host"] 153 | 154 | @property 155 | def body(self) -> bytes: 156 | return self._body 157 | 158 | @property 159 | def content_type(self) -> str: 160 | if "Content-Type" not in self.headers: 161 | return "" 162 | else: 163 | return self.headers["Content-Type"] 164 | 165 | @property 166 | def content_length(self) -> int: 167 | if "Content-Length" not in self.headers: 168 | return None 169 | else: 170 | return int(self.headers["Content-Length"]) 171 | 172 | def get_parameter(self, key: str, default: str = None) -> str: 173 | if key not in self.parameters.keys(): 174 | return default 175 | else: 176 | return self.parameter[key] 177 | 178 | @abstractmethod 179 | def get_session(self, create: bool = False) -> HttpSession: 180 | return NotImplemented 181 | 182 | 183 | class MultipartFile: 184 | """Multipart file""" 185 | 186 | def __init__(self, name: str = "", 187 | required: bool = False, 188 | filename: str = "", 189 | content_type: str = "", 190 | content: bytes = None): 191 | self.__name = name 192 | self.__required = required 193 | self.__filename = filename 194 | self.__content_type = content_type 195 | self.__content = content 196 | 197 | @property 198 | def name(self) -> str: 199 | return self.__name 200 | 201 | @property 202 | def _required(self) -> bool: 203 | return self.__required 204 | 205 | @property 206 | def filename(self) -> str: 207 | return self.__filename 208 | 209 | @property 210 | def content_type(self) -> str: 211 | return self.__content_type 212 | 213 | @property 214 | def content(self) -> bytes: 215 | return self.__content 216 | 217 | @property 218 | def is_empty(self) -> bool: 219 | return self.__content is None or len(self.__content) == 0 220 | 221 | def save_to_file(self, file_path: str) -> None: 222 | if self.__content is not None and len(self.__content) > 0: 223 | with open(file_path, "wb") as f: 224 | f.write(self.__content) 225 | 226 | 227 | class ParamStringValue(str): 228 | 229 | def __init__(self, name: str = "", 230 | default: str = "", 231 | required: bool = False): 232 | self.__name = name 233 | self.__required = required 234 | 235 | @property 236 | def name(self) -> str: 237 | return self.__name 238 | 239 | @property 240 | def _required(self) -> bool: 241 | return self.__required 242 | 243 | def __new__(cls, name="", default="", **kwargs): 244 | assert isinstance(default, str) 245 | obj = super().__new__(cls, default) 246 | return obj 247 | 248 | 249 | class Parameter(ParamStringValue): 250 | pass 251 | 252 | 253 | class PathValue(str): 254 | 255 | def __init__(self, name: str = "", _value: str = ""): 256 | self.__name = name 257 | 258 | @property 259 | def name(self): 260 | return self.__name 261 | 262 | def __new__(cls, name: str = "", _value: str = "", **kwargs): 263 | assert isinstance(_value, str) 264 | obj = super().__new__(cls, _value) 265 | return obj 266 | 267 | 268 | class Parameters(list): 269 | 270 | def __init__(self, name: str = "", default: List[str] = [], required: bool = False): 271 | self.__name = name 272 | self.__required = required 273 | 274 | @property 275 | def name(self) -> str: 276 | return self.__name 277 | 278 | @property 279 | def _required(self) -> bool: 280 | return self.__required 281 | 282 | def __new__(cls, name: str = "", default: List[str] = [], **kwargs): 283 | obj = super().__new__(cls) 284 | obj.extend(default) 285 | return obj 286 | 287 | 288 | class ModelDict(dict): 289 | pass 290 | 291 | 292 | class Environment(dict): 293 | pass 294 | 295 | 296 | class RegGroups(tuple): 297 | pass 298 | 299 | 300 | class RegGroup(str): 301 | 302 | def __init__(self, group=0, **kwargs): 303 | self._group = group 304 | 305 | @property 306 | def group(self) -> int: 307 | return self._group 308 | 309 | def __new__(cls, group=0, **kwargs): 310 | if "_value" not in kwargs: 311 | val = "" 312 | else: 313 | val = kwargs["_value"] 314 | obj = super().__new__(cls, val) 315 | return obj 316 | 317 | 318 | class Header(ParamStringValue): 319 | pass 320 | 321 | 322 | class JSONBody(dict): 323 | pass 324 | 325 | 326 | class BytesBody(bytes): 327 | pass 328 | 329 | 330 | """ 331 | " The folowing beans are used in Response 332 | """ 333 | 334 | 335 | class StaticFile: 336 | 337 | def __init__(self, file_path, content_type="application/octet-stream"): 338 | self.file_path = file_path 339 | self.content_type = content_type 340 | 341 | 342 | class Response: 343 | """Response""" 344 | 345 | def __init__(self, 346 | status_code: int = 200, 347 | headers: Dict[str, str] = None, 348 | body: Union[str, dict, StaticFile, bytes] = ""): 349 | self.status_code = status_code 350 | self.__headers = headers if headers is not None else {} 351 | self.__body = "" 352 | self.__cookies = Cookies() 353 | self.__set_body(body) 354 | 355 | @property 356 | def cookies(self) -> http.cookies.SimpleCookie: 357 | return self.__cookies 358 | 359 | @cookies.setter 360 | def cookies(self, val: http.cookies.SimpleCookie) -> None: 361 | assert isinstance(val, http.cookies.SimpleCookie) 362 | self.__cookies = val 363 | 364 | @property 365 | def body(self): 366 | return self.__body 367 | 368 | @body.setter 369 | def body(self, val: Union[str, dict, StaticFile, bytes]) -> None: 370 | self.__set_body(val) 371 | 372 | def __set_body(self, val): 373 | assert val is None \ 374 | or isinstance(val, str) \ 375 | or isinstance(val, dict) \ 376 | or isinstance(val, StaticFile) \ 377 | or isinstance(val, bytes) \ 378 | or isinstance(val, bytearray), \ 379 | "Body type is not supported." 380 | self.__body = val 381 | 382 | @property 383 | def headers(self) -> Dict[str, Union[list, str]]: 384 | return self.__headers 385 | 386 | def set_header(self, key: str, value: str) -> None: 387 | self.__headers[key] = value 388 | 389 | def add_header(self, key: str, value: Union[str, list]) -> None: 390 | if key not in self.__headers.keys(): 391 | self.__headers[key] = value 392 | return 393 | if not isinstance(self.__headers[key], list): 394 | self.__headers[key] = [self.__headers[key]] 395 | if isinstance(value, list): 396 | self.__headers[key].extend(value) 397 | else: 398 | self.__headers[key].append(value) 399 | 400 | def add_headers(self, headers: Dict[str, Union[str, List[str]]] = {}) -> None: 401 | if headers is not None: 402 | for k, v in headers.items(): 403 | self.add_header(k, v) 404 | 405 | @abstractmethod 406 | def send_error(self, status_code: int, message: str = ""): 407 | return NotImplemented 408 | 409 | @abstractmethod 410 | def send_redirect(self, url: str): 411 | return NotImplemented 412 | 413 | @abstractmethod 414 | def send_response(self): 415 | return NotImplemented 416 | 417 | @abstractmethod 418 | def write_bytes(self, data: bytes): 419 | pass 420 | 421 | @abstractmethod 422 | def close(self): 423 | pass 424 | 425 | 426 | class HttpError(Exception): 427 | 428 | def __init__(self, code: int = 400, message: str = "", explain: str = ""): 429 | super().__init__("HTTP_ERROR[%d] %s" % (code, message)) 430 | self.code: int = code 431 | self.message: str = message 432 | self.explain: str = explain 433 | 434 | 435 | class Redirect: 436 | 437 | def __init__(self, url: str): 438 | self.__url = url 439 | 440 | @property 441 | def url(self) -> str: 442 | return self.__url 443 | 444 | 445 | """ 446 | " Use both in request and response 447 | """ 448 | 449 | 450 | class Headers(dict): 451 | 452 | def __init__(self, headers: Dict[str, Union[str, List[str]]] = {}): 453 | self.update(headers) 454 | 455 | 456 | class Cookie(http.cookies.Morsel): 457 | 458 | def __init__(self, 459 | name: str = "", 460 | default: str = "", 461 | default_options: Dict[str, str] = {}, 462 | required: bool = False): 463 | super().__init__() 464 | self.__name = name 465 | self.__required = required 466 | if name is not None and name != "": 467 | self.set(name, default, default) 468 | self.update(default_options) 469 | 470 | @property 471 | def name(self) -> str: 472 | return self.__name 473 | 474 | @property 475 | def _required(self) -> bool: 476 | return self.__required 477 | 478 | 479 | class FilterContext: 480 | 481 | @property 482 | def request(self) -> Request: 483 | return NotImplemented 484 | 485 | @property 486 | def response(self) -> Response: 487 | return NotImplemented 488 | 489 | @abstractmethod 490 | def do_chain(self): 491 | return NotImplemented 492 | 493 | 494 | class WebsocketRequest: 495 | 496 | def __init__(self): 497 | self.headers: Dict[str, str] = {} # Request headers 498 | self.__cookies = Cookies() 499 | self.query_string: str = "" # Query String 500 | self.path_values: Dict[str, str] = {} 501 | self.reg_groups = () # If controller is matched via regexp, then ,all groups are save here 502 | self.path: str = "" # Path 503 | self.__parameters = {} # Parameters, key-value array, merged by query string and request body if the `Content-Type` in request header is `application/x-www-form-urlencoded` or `multipart/form-data` 504 | # Parameters, key-value, if more than one parameters with the same key, only the first one will be stored. 505 | self.__parameter = {} 506 | 507 | @property 508 | def cookies(self) -> Cookies: 509 | return self.__cookies 510 | 511 | @property 512 | def parameters(self) -> Dict[str, List[str]]: 513 | return self.__parameters 514 | 515 | @parameters.setter 516 | def parameters(self, val: Dict[str, List[str]]): 517 | self.__parameters = val 518 | self.__parameter = {} 519 | for k, v in self.__parameters.items(): 520 | self.__parameter[k] = v[0] 521 | 522 | @property 523 | def parameter(self) -> Dict[str, str]: 524 | return self.__parameter 525 | 526 | def get_parameter(self, key: str, default: str = None) -> str: 527 | if key not in self.parameters.keys(): 528 | return default 529 | else: 530 | return self.parameter[key] 531 | 532 | 533 | class WebsocketSession: 534 | 535 | @property 536 | def id(self) -> str: 537 | return NotImplemented 538 | 539 | @property 540 | def request(self) -> WebsocketRequest: 541 | return NotImplemented 542 | 543 | @property 544 | def is_closed(self) -> bool: 545 | return NotImplemented 546 | 547 | @abstractmethod 548 | def send(self, message: Union[str, bytes], opcode: int = None, chunk_size: int = 0): 549 | return NotImplemented 550 | 551 | @abstractmethod 552 | def send_text(self, message: str, chunk_size: int = 0): 553 | return NotImplemented 554 | 555 | @abstractmethod 556 | def send_binary(self, binary: bytes, chunk_size: int = 0): 557 | return NotImplemented 558 | 559 | @abstractmethod 560 | def send_file(self, path: str, chunk_size: int = 0): 561 | return NotImplemented 562 | 563 | @abstractmethod 564 | def send_pone(self, message: bytes = b''): 565 | return NotImplemented 566 | 567 | @abstractmethod 568 | def send_ping(self, message: bytes = b''): 569 | return NotImplemented 570 | 571 | @abstractmethod 572 | def close(self, reason: str = ""): 573 | return NotImplemented 574 | 575 | 576 | class WebsocketCloseReason(str): 577 | 578 | def __init__(self, 579 | message: str = "", 580 | code: int = None, 581 | reason: str = '') -> None: 582 | self.__message: str = message 583 | self.__code: int = code 584 | self.__reason: str = reason 585 | 586 | @property 587 | def message(self) -> str: 588 | return self.__message 589 | 590 | @property 591 | def code(self) -> int: 592 | return self.__code 593 | 594 | @property 595 | def reason(self) -> str: 596 | return self.__reason 597 | 598 | def __new__(cls, message: str = "", code: int = "", reason: str = '', **kwargs): 599 | obj = super().__new__(cls, message) 600 | return obj 601 | 602 | 603 | class WebsocketHandler: 604 | 605 | @abstractmethod 606 | def on_handshake(self, request: WebsocketRequest = None): 607 | """ 608 | " 609 | " You can get path/headers/path_values/cookies/query_string/query_parameters from request. 610 | " 611 | " You should return a tuple means (http_status_code, headers) 612 | " 613 | " If status code in (0, None, 101), the websocket will be connected, or will return the status you return. 614 | " 615 | " All headers will be send to client 616 | " 617 | """ 618 | return None 619 | 620 | @abstractmethod 621 | def on_open(self, session: WebsocketSession = None): 622 | """ 623 | " 624 | " Will be called when the connection opened. 625 | " 626 | """ 627 | pass 628 | 629 | @abstractmethod 630 | def on_close(self, session: WebsocketSession = None, reason: WebsocketCloseReason = None): 631 | """ 632 | " 633 | " Will be called when the connection closed. 634 | " 635 | """ 636 | pass 637 | 638 | @abstractmethod 639 | def on_ping_message(self, session: WebsocketSession = None, message: bytes = b''): 640 | """ 641 | " 642 | " Will be called when receive a ping message. Will send all the message bytes back to client by default. 643 | " 644 | """ 645 | session.send_pone(message) 646 | 647 | @abstractmethod 648 | def on_pong_message(self, session: WebsocketSession = None, message: bytes = ""): 649 | """ 650 | " 651 | " Will be called when receive a pong message. 652 | " 653 | """ 654 | pass 655 | 656 | @abstractmethod 657 | def on_text_message(self, session: WebsocketSession = None, message: str = ""): 658 | """ 659 | " 660 | " Will be called when receive a text message. 661 | " 662 | """ 663 | pass 664 | 665 | @abstractmethod 666 | def on_binary_message(self, session: WebsocketSession = None, message: bytes = b''): 667 | """ 668 | " 669 | " Will be called when receive a binary message if you have not consumed all the bytes in `on_binary_frame` 670 | " method. 671 | " 672 | """ 673 | pass 674 | 675 | @abstractmethod 676 | def on_binary_frame(self, session: WebsocketSession = None, fin: bool = False, frame_payload: bytes = b''): 677 | """ 678 | " 679 | " When server receive a fragmented message, this method will be called every time when a frame is received, 680 | " you can consume all the bytes in this method, e.g. save all bytes to a file. 681 | " 682 | " If you does not implement this method or return a True in this method, all the bytes will be cached in 683 | " memory and sent to your `on_binary_message` method after all frames are received. 684 | " 685 | """ 686 | return True 687 | -------------------------------------------------------------------------------- /naja_atra/request_handlers/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Copyright (c) 2018 Keijack Wu 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | -------------------------------------------------------------------------------- /naja_atra/request_handlers/http_request_handler.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Copyright (c) 2018 Keijack Wu 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | 25 | import html 26 | import re 27 | import http.client 28 | import email.parser 29 | import email.message 30 | import socketserver 31 | import asyncio 32 | import socket 33 | 34 | 35 | from typing import Any, Dict 36 | from http import HTTPStatus 37 | from urllib.parse import unquote 38 | from asyncio.streams import StreamReader, StreamWriter 39 | from http import HTTPStatus 40 | 41 | from .. import name, version 42 | from ..models import RequestBodyReader 43 | from ..utils import http_utils 44 | from ..utils.logger import get_logger 45 | from ..http_servers.routing_server import RoutingServer 46 | from .http_controller_handler import HTTPControllerHandler 47 | from .websocket_controller_handler import WebsocketControllerHandler 48 | 49 | 50 | _LINE_MAX_BYTES = 65536 51 | _MAXHEADERS = 100 52 | 53 | _logger = get_logger("naja_atra.request_handlers.http_request_handler") 54 | 55 | 56 | class RequestWriter: 57 | 58 | def __init__(self, writer: StreamWriter) -> None: 59 | self.writer: StreamWriter = writer 60 | 61 | def send(self, data: bytes): 62 | self.writer.write(data) 63 | 64 | 65 | class HttpRequestHandler: 66 | 67 | server_version = f"{name}/{version}" 68 | 69 | default_request_version = "HTTP/1.1" 70 | 71 | # The version of the HTTP protocol we support. 72 | # Set this to HTTP/1.1 to enable automatic keepalive 73 | protocol_version = "HTTP/1.1" 74 | 75 | # MessageClass used to parse headers 76 | _message_class = http.client.HTTPMessage 77 | 78 | # hack to maintain backwards compatibility 79 | responses = { 80 | v: (v.phrase, v.description) 81 | for v in HTTPStatus.__members__.values() 82 | } 83 | 84 | def __init__(self, reader: StreamReader, writer: StreamWriter, request_writer=None, routing_conf: RoutingServer = None) -> None: 85 | self.routing_conf: RoutingServer = routing_conf 86 | self.reader: StreamReader = reader 87 | self.writer: StreamWriter = writer 88 | self.request_writer: RequestWriter = request_writer if request_writer else RequestWriter( 89 | writer) 90 | 91 | self.requestline = '' 92 | self.request_version = '' 93 | self.command = '' 94 | self.path = '' 95 | self.request_path = '' 96 | self.query_string = '' 97 | self.query_parameters = {} 98 | self.headers = {} 99 | 100 | self.close_connection = True 101 | self._keep_alive = self.routing_conf.keep_alive 102 | self._connection_idle_time = routing_conf.connection_idle_time 103 | self._keep_alive_max_req = routing_conf.keep_alive_max_request 104 | self.req_count = 0 105 | 106 | async def parse_request(self): 107 | self.req_count += 1 108 | try: 109 | if hasattr(self.reader, "connection"): 110 | # For blocking io. asyncio.wait_for will not raise TimeoutError if the io is blocked. 111 | self.reader.connection.settimeout(self._connection_idle_time) 112 | raw_requestline = await asyncio.wait_for(self.reader.readline(), self._connection_idle_time) 113 | if hasattr(self.reader, "connection") and hasattr(self.reader, "timeout"): 114 | # For blocking io. Set the Original timeout to the connection. 115 | self.reader.connection.settimeout(self.reader.timeout) 116 | except asyncio.TimeoutError: 117 | _logger.warn("Wait for reading request line timeout. ") 118 | return False 119 | if len(raw_requestline) > _LINE_MAX_BYTES: 120 | self.requestline = '' 121 | self.request_version = '' 122 | self.command = '' 123 | self.send_error(HTTPStatus.REQUEST_URI_TOO_LONG) 124 | return False 125 | if not raw_requestline: 126 | self.close_connection = True 127 | return False 128 | self.command = None 129 | self.request_version = version = self.default_request_version 130 | self.close_connection = True 131 | requestline = str(raw_requestline, 'iso-8859-1') 132 | requestline = requestline.rstrip('\r\n') 133 | self.requestline = requestline 134 | words = requestline.split() 135 | if len(words) == 0: 136 | return False 137 | 138 | if len(words) >= 3: # Enough to determine protocol version 139 | version = words[-1] 140 | try: 141 | if not version.startswith('HTTP/'): 142 | raise ValueError 143 | base_version_number = version.split('/', 1)[1] 144 | version_number = base_version_number.split(".") 145 | # RFC 2145 section 3.1 says there can be only one "." and 146 | # - major and minor numbers MUST be treated as 147 | # separate integers; 148 | # - HTTP/2.4 is a lower version than HTTP/2.13, which in 149 | # turn is lower than HTTP/12.3; 150 | # - Leading zeros MUST be ignored by recipients. 151 | if len(version_number) != 2: 152 | raise ValueError 153 | version_number = int(version_number[0]), int(version_number[1]) 154 | except (ValueError, IndexError): 155 | self.send_error( 156 | HTTPStatus.BAD_REQUEST, 157 | f"Bad request version {version}") 158 | return False 159 | 160 | if version_number >= (2, 0): 161 | self.send_error( 162 | HTTPStatus.HTTP_VERSION_NOT_SUPPORTED, 163 | f"Invalid HTTP version {base_version_number}") 164 | return False 165 | self.request_version = version 166 | _logger.info(f"request version: {self.request_version}") 167 | if not 2 <= len(words) <= 3: 168 | self.send_error( 169 | HTTPStatus.BAD_REQUEST, 170 | "Bad request syntax (%r)" % requestline) 171 | return False 172 | command, path = words[:2] 173 | if len(words) == 2: 174 | self.close_connection = True 175 | if command != 'GET': 176 | self.send_error( 177 | HTTPStatus.BAD_REQUEST, 178 | "Bad HTTP/0.9 request type (%r)" % command) 179 | return False 180 | self.command, self.path = command, path 181 | 182 | self.request_path = self._get_request_path(self.path) 183 | 184 | self.query_string = self.__get_query_string(self.path) 185 | 186 | self.query_parameters = http_utils.decode_query_string( 187 | self.query_string) 188 | 189 | # Examine the headers and look for a Connection directive. 190 | try: 191 | self.headers = await self.parse_headers() 192 | except http.client.LineTooLong as err: 193 | self.send_error( 194 | HTTPStatus.REQUEST_HEADER_FIELDS_TOO_LARGE, 195 | "Line too long", 196 | str(err)) 197 | return False 198 | except http.client.HTTPException as err: 199 | self.send_error( 200 | HTTPStatus.REQUEST_HEADER_FIELDS_TOO_LARGE, 201 | "Too many headers", 202 | str(err) 203 | ) 204 | return False 205 | 206 | conntype = self.headers.get('Connection', '') 207 | 208 | self.close_connection = not self._keep_alive or conntype.lower( 209 | ) != 'keep-alive' or self.protocol_version != "HTTP/1.1" 210 | 211 | # Examine the headers and look for an Expect directive 212 | expect = self.headers.get('Expect', "") 213 | if (expect.lower() == "100-continue" and 214 | self.protocol_version >= "HTTP/1.1" and 215 | self.request_version >= "HTTP/1.1"): 216 | if not self.handle_expect_100(): 217 | return False 218 | return True 219 | 220 | async def parse_headers(self): 221 | """Parses only RFC2822 headers from a file pointer. 222 | 223 | email Parser wants to see strings rather than bytes. 224 | But a TextIOWrapper around self.rfile would buffer too many bytes 225 | from the stream, bytes which we later need to read as bytes. 226 | So we read the correct bytes here, as bytes, for email Parser 227 | to parse. 228 | 229 | """ 230 | headers = [] 231 | while True: 232 | line = await self.reader.readline() 233 | if len(line) > _LINE_MAX_BYTES: 234 | raise http.client.LineTooLong("header line") 235 | headers.append(line) 236 | if len(headers) > _MAXHEADERS: 237 | raise http.client.HTTPException( 238 | f"got more than {_MAXHEADERS} headers") 239 | if line in (b'\r\n', b'\n', b''): 240 | break 241 | hstring = b''.join(headers).decode('iso-8859-1') 242 | 243 | return email.parser.Parser(_class=self._message_class).parsestr(hstring) 244 | 245 | def handle_expect_100(self): 246 | """Decide what to do with an "Expect: 100-continue" header. 247 | 248 | If the client is expecting a 100 Continue response, we must 249 | respond with either a 100 Continue or a final response before 250 | waiting for the request body. The default is to always respond 251 | with a 100 Continue. You can behave differently (for example, 252 | reject unauthorized requests) by overriding this method. 253 | 254 | This method should either return True (possibly after sending 255 | a 100 Continue response) or send an error response and return 256 | False. 257 | 258 | """ 259 | self.send_response_only(HTTPStatus.CONTINUE) 260 | self.end_headers() 261 | return True 262 | 263 | def __get_query_string(self, ori_path: str): 264 | parts = ori_path.split('?') 265 | if len(parts) == 2: 266 | return parts[1] 267 | else: 268 | return "" 269 | 270 | def _get_request_path(self, ori_path: str): 271 | path = ori_path.split('?', 1)[0] 272 | path = path.split('#', 1)[0] 273 | path = http_utils.remove_url_first_slash(path) 274 | path = unquote(path) 275 | return path 276 | 277 | def send_error(self, code: int, message: str = None, explain: str = None, headers: Dict[str, str] = {}): 278 | try: 279 | shortmsg, longmsg = self.responses[code] 280 | except KeyError: 281 | shortmsg, longmsg = '???', '???' 282 | if message is None: 283 | message = shortmsg 284 | if explain is None: 285 | explain = longmsg 286 | self.log_error(f"code {code}, message {message}") 287 | self.send_response(code, message) 288 | self.send_header('Connection', 'close') 289 | 290 | # Message body is omitted for cases described in: 291 | # - RFC7230: 3.3. 1xx, 204(No Content), 304(Not Modified) 292 | # - RFC7231: 6.3.6. 205(Reset Content) 293 | body = None 294 | if (code >= 200 and 295 | code not in (HTTPStatus.NO_CONTENT, 296 | HTTPStatus.RESET_CONTENT, 297 | HTTPStatus.NOT_MODIFIED)): 298 | try: 299 | content: Any = self.routing_conf.error_page(code, html.escape( 300 | message, quote=False), html.escape(explain, quote=False)) 301 | except: 302 | content: str = html.escape( 303 | message, quote=False) + ":" + html.escape(explain, quote=False) 304 | content_type, body = http_utils.decode_response_body_to_bytes( 305 | content) 306 | 307 | self.send_header("Content-Type", content_type) 308 | self.send_header('Content-Length', str(len(body))) 309 | if headers: 310 | for h_name, h_val in headers.items(): 311 | self.send_header(h_name, h_val) 312 | self.end_headers() 313 | 314 | if self.command != 'HEAD' and body: 315 | self.writer.write(body) 316 | 317 | def send_response(self, code, message=None): 318 | """Add the response header to the headers buffer and log the 319 | response code. 320 | 321 | Also send two standard headers with the server software 322 | version and the current date. 323 | 324 | """ 325 | self.log_request(code) 326 | self.send_response_only(code, message) 327 | self.send_header('Server', self.server_version) 328 | self.send_header('Date', http_utils.date_time_string()) 329 | 330 | def send_header(self, keyword: str, value: str): 331 | """Send a MIME header to the headers buffer.""" 332 | if keyword.lower() == 'connection': 333 | if value.lower() == 'close': 334 | self.close_connection = True 335 | elif value.lower() == 'keep-alive': 336 | if self._keep_alive: 337 | self.close_connection = False 338 | else: 339 | _logger.warning( 340 | f"Keep Alive configuration is set to False, won't send keep-alive header.") 341 | return 342 | 343 | if self.request_version != 'HTTP/0.9': 344 | if not hasattr(self, '_headers_buffer'): 345 | self._headers_buffer = [] 346 | self._headers_buffer.append( 347 | f"{keyword}: {value}\r\n".encode('latin-1', errors='strict')) 348 | 349 | def end_headers(self): 350 | """Send the blank line ending the MIME headers.""" 351 | if self.request_version != 'HTTP/0.9': 352 | self._headers_buffer.append(b"\r\n") 353 | self.flush_headers() 354 | 355 | def flush_headers(self): 356 | if hasattr(self, '_headers_buffer'): 357 | self.writer.write(b"".join(self._headers_buffer)) 358 | self._headers_buffer = [] 359 | 360 | def send_response_only(self, code, message: str = None): 361 | """Send the response header only.""" 362 | if self.request_version != 'HTTP/0.9': 363 | if message is None: 364 | if code in self.responses: 365 | message = self.responses[code][0] 366 | else: 367 | message = '' 368 | if not hasattr(self, '_headers_buffer'): 369 | self._headers_buffer = [] 370 | self._headers_buffer \ 371 | .append(f"{self.protocol_version} {code} {message}\r\n".encode('latin-1', errors='strict')) 372 | 373 | def log_request(self, code='-', size='-'): 374 | if isinstance(code, HTTPStatus): 375 | code = code.value 376 | self.log_message('"%s" %s %s', 377 | self.requestline, str(code), str(size)) 378 | 379 | def log_error(self, format, *args): 380 | self.log_message(format, *args) 381 | 382 | def log_message(self, format, *args): 383 | _logger.info(f"{format % args}") 384 | 385 | def set_prefer_keep_alive_params(self): 386 | pass 387 | 388 | def set_alive_params(self): 389 | if "Keep-Alive" in self.headers: 390 | ka_header = self.headers["Keep-Alive"] 391 | timeout_match = re.match(r"^.*timeout=(\d+).*$", ka_header) 392 | if timeout_match: 393 | self._connection_idle_time = int(timeout_match.group(1)) 394 | max_match = re.match(r"^.*max=(\d+).*$", ka_header) 395 | if max_match: 396 | self._keep_alive_max_req = int(max_match.group(1)) 397 | 398 | async def handle_request(self): 399 | parse_request_success = await self.parse_request() 400 | if not parse_request_success: 401 | return 402 | self.set_alive_params() 403 | 404 | if self.request_version == "HTTP/1.1" and self.command == "GET" and "Upgrade" in self.headers and self.headers["Upgrade"] == "websocket": 405 | _logger.debug("This is a websocket connection. ") 406 | ws_handler = WebsocketControllerHandler(self) 407 | await ws_handler.handle_request() 408 | self.writer.write_eof() 409 | return 410 | 411 | await self.handle_http_request() 412 | while not self.close_connection: 413 | _logger.debug("Keep-Alive, read next request. ") 414 | parse_request_success = await self.parse_request() 415 | if not parse_request_success: 416 | _logger.debug("parse request fails, return. ") 417 | return 418 | if self.req_count >= self._keep_alive_max_req: 419 | self.send_response("Connection", "close") 420 | await self.handle_http_request() 421 | _logger.debug("Handle a keep-alive request successfully!") 422 | 423 | async def handle_http_request(self): 424 | try: 425 | http_handler = HTTPControllerHandler(self) 426 | await http_handler.handle_request() 427 | if self.writer.can_write_eof(): 428 | self.writer.write_eof() 429 | except socket.timeout as e: 430 | # a read or a write timed out. Discard this connection 431 | self.log_error("Request timed out: %r", e) 432 | self.close_connection = True 433 | return 434 | 435 | 436 | class SocketServerStreamRequestHandlerWraper(socketserver.StreamRequestHandler, RequestBodyReader): 437 | 438 | server_version = HttpRequestHandler.server_version 439 | 440 | # Wrapper method for readline 441 | async def readline(self): 442 | return self.rfile.readline(_LINE_MAX_BYTES) 443 | 444 | async def read(self, n: int = -1): 445 | return self.rfile.read(n) 446 | 447 | def write(self, data: bytes): 448 | self.wfile.write(data) 449 | 450 | def can_write_eof(self) -> bool: 451 | return True 452 | 453 | def write_eof(self): 454 | self.wfile.flush() 455 | 456 | def close(self): 457 | self.wfile.close() 458 | 459 | def handle(self) -> None: 460 | handler: HttpRequestHandler = HttpRequestHandler( 461 | self, self, request_writer=self.request, routing_conf=self.server) 462 | asyncio.run(handler.handle_request()) 463 | 464 | def finish(self) -> None: 465 | _logger.debug("Finish a socket connection.") 466 | return super().finish() 467 | -------------------------------------------------------------------------------- /naja_atra/request_handlers/http_session_local_impl.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Copyright (c) 2018 Keijack Wu 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | 25 | import threading 26 | import uuid 27 | import time 28 | 29 | from typing import Any, Dict, List, Tuple 30 | from threading import RLock 31 | from ..models import HttpSession, HttpSessionFactory 32 | from ..utils.logger import get_logger 33 | 34 | _logger = get_logger("naja_atra.request_handlers.http_session_local_impl") 35 | 36 | _SESSION_TIME_CLEANING_INTERVAL = 60 37 | 38 | 39 | def _get_from_dict(adict: Dict[str, Any], key: str) -> Any: 40 | if key not in adict: 41 | return None 42 | try: 43 | return adict[key] 44 | except KeyError: 45 | _logger.debug("key %s was deleted in other thread.") 46 | return None 47 | 48 | 49 | class LocalSessionHolder: 50 | 51 | def __init__(self): 52 | self.__sessions: Dict[str, HttpSession] = {} 53 | self.__session_lock = RLock() 54 | self.__started = False 55 | self.__clearing_thread = threading.Thread(target=self._clear_time_out_session_in_bg, daemon=True) 56 | 57 | def _start_cleaning(self): 58 | if not self.__started: 59 | self.__started = True 60 | self.__clearing_thread.start() 61 | 62 | def _clear_time_out_session_in_bg(self): 63 | while True: 64 | time.sleep(_SESSION_TIME_CLEANING_INTERVAL) 65 | self._clear_time_out_session() 66 | 67 | def _clear_time_out_session(self): 68 | timeout_sessions: List[HttpSession] = [] 69 | for v in self.__sessions.values(): 70 | if not v.is_valid: 71 | timeout_sessions.append(v) 72 | for session in timeout_sessions: 73 | session.invalidate() 74 | 75 | def clean_session(self, session_id: str): 76 | if session_id in self.__sessions: 77 | try: 78 | _logger.debug("session[#%s] is being cleaned" % session_id) 79 | sess = self.__sessions[session_id] 80 | if not sess.is_valid: 81 | del self.__sessions[session_id] 82 | except KeyError: 83 | _logger.debug("Session[#%s] in session cache is already deleted. " % session_id) 84 | 85 | def get_session(self, session_id: str) -> HttpSession: 86 | if not session_id: 87 | return None 88 | sess: HttpSession = _get_from_dict(self.__sessions, session_id) 89 | if sess and sess.is_valid: 90 | return sess 91 | else: 92 | return None 93 | 94 | def cache_session(self, session: HttpSession): 95 | if not session: 96 | return None 97 | sess: HttpSession = _get_from_dict(self.__sessions, session.id) 98 | 99 | if sess: 100 | if session is sess: 101 | return 102 | sess.invalidate() 103 | with self.__session_lock: 104 | self.__sessions[session.id] = session 105 | self._start_cleaning() 106 | 107 | return session 108 | 109 | 110 | class LocalSessionImpl(HttpSession): 111 | 112 | def __init__(self, id: str, creation_time: float, session_holder: LocalSessionHolder): 113 | super().__init__() 114 | self.__id = id 115 | self.__creation_time = creation_time 116 | self.__last_accessed_time = creation_time 117 | self.__is_new = True 118 | self.__attr_lock = RLock() 119 | self.__attrs = {} 120 | self.__session_holder = session_holder 121 | 122 | @property 123 | def id(self) -> str: 124 | return self.__id 125 | 126 | @property 127 | def creation_time(self) -> float: 128 | return self.__creation_time 129 | 130 | @property 131 | def last_accessed_time(self) -> float: 132 | return self.__last_accessed_time 133 | 134 | @property 135 | def is_new(self) -> bool: 136 | return self.__is_new 137 | 138 | def _set_last_accessed_time(self, last_acessed_time: float): 139 | self.__last_accessed_time = last_acessed_time 140 | self.__is_new = False 141 | 142 | @property 143 | def attribute_names(self) -> Tuple: 144 | return tuple(self.__attrs.keys()) 145 | 146 | def get_attribute(self, name: str) -> Any: 147 | return _get_from_dict(self.__attrs, name) 148 | 149 | def set_attribute(self, name: str, value: Any) -> None: 150 | with self.__attr_lock: 151 | self.__attrs[name] = value 152 | 153 | def invalidate(self) -> None: 154 | self._set_last_accessed_time(0) 155 | self.__session_holder.clean_session(session_id=self.id) 156 | 157 | 158 | class LocalSessionFactory(HttpSessionFactory): 159 | 160 | def __init__(self): 161 | self.__session_holder = LocalSessionHolder() 162 | self.__session_lock = RLock() 163 | 164 | def _create_local_session(self, session_id: str) -> HttpSession: 165 | if session_id: 166 | sid = session_id 167 | else: 168 | sid = uuid.uuid4().hex 169 | return LocalSessionImpl(sid, time.time(), self.__session_holder) 170 | 171 | def get_session(self, session_id: str, create: bool = False) -> HttpSession: 172 | sess: LocalSessionImpl = self.__session_holder.get_session(session_id) 173 | if sess: 174 | sess._set_last_accessed_time(time.time()) 175 | return sess 176 | if not create: 177 | return None 178 | with self.__session_lock: 179 | session = self._create_local_session(session_id) 180 | self.__session_holder.cache_session(session) 181 | return session 182 | 183 | -------------------------------------------------------------------------------- /naja_atra/request_handlers/model_bindings.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Copyright (c) 2018 Keijack Wu 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | 25 | import json 26 | from typing import Any, Dict, List, Type 27 | from http.cookies import BaseCookie, SimpleCookie 28 | from ..models import ModelDict, Environment, RegGroup, RegGroups, HttpError, RequestBodyReader, \ 29 | Headers, Response, Cookies, Cookie, JSONBody, BytesBody, Header, Parameters, PathValue, Parameter, \ 30 | MultipartFile, Request, HttpSession 31 | from ..utils.logger import get_logger 32 | 33 | _logger = get_logger("naja_atra.models.model_bindings") 34 | 35 | 36 | class ModelBinding: 37 | 38 | def __init__(self, request: Request, 39 | response: Response, 40 | arg: str, 41 | arg_type, 42 | val=None 43 | ) -> None: 44 | self.request = request 45 | self.response = response 46 | self.arg_name = arg 47 | self.arg_type = arg_type 48 | self.default_value = val 49 | self._kws = {} 50 | if self.default_value is not None: 51 | self._kws["val"] = self.default_value 52 | 53 | async def bind(self) -> Any: 54 | pass 55 | 56 | 57 | class RequestModelBinding(ModelBinding): 58 | 59 | async def bind(self) -> Request: 60 | return self.request 61 | 62 | 63 | class SessionModelBinding(ModelBinding): 64 | 65 | async def bind(self) -> HttpSession: 66 | return self.request.get_session(True) 67 | 68 | 69 | class ResponseModelBinding(ModelBinding): 70 | 71 | async def bind(self) -> Response: 72 | return self.response 73 | 74 | 75 | class HeadersModelBinding(ModelBinding): 76 | 77 | async def bind(self) -> Headers: 78 | return Headers(self.request.headers) 79 | 80 | 81 | class RegGroupsModelBinding(ModelBinding): 82 | 83 | async def bind(self) -> RegGroups: 84 | return RegGroups(self.request.reg_groups) 85 | 86 | 87 | class EnvironmentModelBinding(ModelBinding): 88 | 89 | async def bind(self) -> Environment: 90 | return Environment(self.request.environment) 91 | 92 | 93 | class HeaderModelBinding(ModelBinding): 94 | 95 | async def bind(self) -> Header: 96 | return self.__build_header(self.arg_name, **self._kws) 97 | 98 | def __build_header(self, key, val=Header()): 99 | name = val.name if val.name is not None and val.name != "" else key 100 | if val._required and name not in self.request.headers: 101 | raise HttpError(400, "Missing Header", 102 | f"Header[{name}] is required.") 103 | if name in self.request.headers: 104 | v = self.request.headers[name] 105 | return Header(name=name, default=v, required=val._required) 106 | else: 107 | return val 108 | 109 | 110 | class CookiesModelBinding(ModelBinding): 111 | 112 | async def bind(self) -> Cookies: 113 | return self.request.cookies 114 | 115 | 116 | class CookieModelBinding(ModelBinding): 117 | 118 | async def bind(self) -> Cookie: 119 | return self.__build_cookie(self.arg_name, **self._kws) 120 | 121 | def __build_cookie(self, key, val=None): 122 | name = val.name if val.name is not None and val.name != "" else key 123 | if val._required and name not in self.request.cookies: 124 | raise HttpError(400, "Missing Cookie", 125 | f"Cookie[{name}] is required.") 126 | if name in self.request.cookies: 127 | morsel = self.request.cookies[name] 128 | cookie = Cookie() 129 | cookie.set(morsel.key, morsel.value, morsel.coded_value) 130 | cookie.update(morsel) 131 | return cookie 132 | else: 133 | return val 134 | 135 | 136 | class MultipartFileModelBinding(ModelBinding): 137 | 138 | async def bind(self) -> MultipartFile: 139 | return self.__build_multipart(self.arg_name, **self._kws) 140 | 141 | def __build_multipart(self, key, val=MultipartFile()): 142 | name = val.name if val.name is not None and val.name != "" else key 143 | if val._required and name not in self.request.parameter.keys(): 144 | raise HttpError(400, "Missing Parameter", 145 | f"Parameter[{name}] is required.") 146 | if name in self.request.parameter.keys(): 147 | v = self.request.parameter[name] 148 | if isinstance(v, MultipartFile): 149 | return v 150 | else: 151 | raise HttpError( 152 | 400, None, f"Parameter[{name}] should be a file.") 153 | else: 154 | return val 155 | 156 | 157 | class ParameterModelBinding(ModelBinding): 158 | 159 | async def bind(self) -> Parameter: 160 | return self.__build_param(self.arg_name, **self._kws) 161 | 162 | def __build_param(self, key, val=Parameter()): 163 | name = val.name if val.name is not None and val.name != "" else key 164 | if val._required and name not in self.request.parameter: 165 | raise HttpError(400, "Missing Parameter", 166 | f"Parameter[{name}] is required.") 167 | if name in self.request.parameter: 168 | v = self.request.parameter[name] 169 | return Parameter(name=name, default=v, required=val._required) 170 | else: 171 | return val 172 | 173 | 174 | class PathValueModelBinding(ModelBinding): 175 | 176 | async def bind(self) -> PathValue: 177 | return self.__build_path_value(self.arg_name, **self._kws) 178 | 179 | def __build_path_value(self, key, val=PathValue()): 180 | # wildcard value 181 | if len(self.request.path_values) == 1 and "__path_wildcard" in self.request.path_values: 182 | if val.name: 183 | _logger.warning( 184 | f"Wildcard value, `name` of the PathValue:: [{val.name}] will be ignored. ") 185 | return self.request.path_values["__path_wildcard"] 186 | 187 | # brace values 188 | name = val.name if val.name is not None and val.name != "" else key 189 | if name in self.request.path_values: 190 | return PathValue(name=name, _value=self.request.path_values[name]) 191 | else: 192 | raise HttpError( 193 | 500, None, f"path name[{name}] not in your url mapping!") 194 | 195 | 196 | class ParametersModelBinding(ModelBinding): 197 | 198 | async def bind(self) -> Parameters: 199 | return self.__build_params(self.arg_name, **self._kws) 200 | 201 | def __build_params(self, key, val=Parameters()): 202 | name = val.name if val.name is not None and val.name != "" else key 203 | if val._required and name not in self.request.parameters: 204 | raise HttpError(400, "Missing Parameter", 205 | f"Parameter[{name}] is required.") 206 | if name in self.request.parameters: 207 | v = self.request.parameters[name] 208 | return Parameters(name=name, default=v, required=val._required) 209 | else: 210 | return val 211 | 212 | 213 | class RegGroupModelBinding(ModelBinding): 214 | 215 | async def bind(self) -> RegGroup: 216 | return self.__build_reg_group(**self._kws) 217 | 218 | def __build_reg_group(self, val: RegGroup = RegGroup(group=0)): 219 | if val.group >= len(self.request.reg_groups): 220 | raise HttpError( 221 | 400, None, f"RegGroup required an element at {val.group}, but the reg length is only {len(self.request.reg_groups)}") 222 | return RegGroup(group=val.group, _value=self.request.reg_groups[val.group]) 223 | 224 | 225 | class JSONBodyModelBinding(ModelBinding): 226 | 227 | async def bind(self) -> Any: 228 | return self.__build_json_body() 229 | 230 | def __build_json_body(self): 231 | if "content-type" not in self.request._headers_keys_in_lowcase.keys() or \ 232 | not self.request._headers_keys_in_lowcase["content-type"].lower().startswith("application/json"): 233 | raise HttpError( 234 | 400, None, 'The content type of this request must be "application/json"') 235 | return JSONBody(self.request.json) 236 | 237 | 238 | class RequestBodyReaderModelBinding(ModelBinding): 239 | 240 | async def bind(self) -> Any: 241 | return self.request.reader 242 | 243 | 244 | class BytesBodyModelBinding(ModelBinding): 245 | 246 | async def bind(self) -> Any: 247 | if not self.request._body: 248 | self.request._body = await self.request.reader.read() 249 | return BytesBody(self.request._body) 250 | 251 | 252 | class StrModelBinding(ModelBinding): 253 | 254 | async def bind(self) -> Any: 255 | return self.__build_str(self.arg_name, **self._kws) 256 | 257 | def __build_str(self, key, val=None): 258 | if key in self.request.parameter.keys(): 259 | return Parameter(name=key, default=self.request.parameter[key], required=False) 260 | elif val is None: 261 | return None 262 | else: 263 | return Parameter(name=key, default=val, required=False) 264 | 265 | 266 | class BoolModelBinding(ModelBinding): 267 | 268 | async def bind(self) -> Any: 269 | return self.__build_bool(self.arg_name, **self._kws) 270 | 271 | def __build_bool(self, key, val=None): 272 | if key in self.request.parameter.keys(): 273 | v = self.request.parameter[key] 274 | return v.lower() not in ("0", "false", "") 275 | else: 276 | return val 277 | 278 | 279 | class IntModelBinding(ModelBinding): 280 | 281 | async def bind(self) -> Any: 282 | return self.__build_int(self.arg_name, **self._kws) 283 | 284 | def __build_int(self, key, val=None): 285 | if key in self.request.parameter.keys(): 286 | try: 287 | return int(self.request.parameter[key]) 288 | except: 289 | raise HttpError( 290 | 400, None, f"Parameter[{key}] should be an int. ") 291 | else: 292 | return val 293 | 294 | 295 | class FloatModelBinding(ModelBinding): 296 | 297 | async def bind(self) -> Any: 298 | return self.__build_float(self.arg_name, **self._kws) 299 | 300 | def __build_float(self, key, val=None): 301 | if key in self.request.parameter.keys(): 302 | try: 303 | return float(self.request.parameter[key]) 304 | except: 305 | raise HttpError( 306 | 400, None, f"Parameter[{key}] should be an float. ") 307 | else: 308 | return val 309 | 310 | 311 | class ListModelBinding(ModelBinding): 312 | 313 | async def bind(self) -> Any: 314 | return self.__build_list(self.arg_name, **self._kws) 315 | 316 | def __build_list(self, key, target_type=list, val=[]): 317 | if key in self.request.parameters.keys(): 318 | ori_list = self.request.parameters[key] 319 | else: 320 | ori_list = val 321 | 322 | if target_type == List[int]: 323 | try: 324 | return [int(p) for p in ori_list] 325 | except: 326 | raise HttpError( 327 | 400, None, f"One of the parameter[{key}] is not int. ") 328 | elif target_type == List[float]: 329 | try: 330 | return [float(p) for p in ori_list] 331 | except: 332 | raise HttpError( 333 | 400, None, f"One of the parameter[{key}] is not float. ") 334 | elif target_type == List[bool]: 335 | return [p.lower() not in ("0", "false", "") for p in ori_list] 336 | elif target_type in (List[dict], List[Dict]): 337 | try: 338 | return [json.loads(p) for p in ori_list] 339 | except: 340 | raise HttpError( 341 | 400, None, f"One of the parameter[{key}] is not JSON string. ") 342 | elif target_type == List[Parameter]: 343 | return [Parameter(name=key, default=p, required=False) for p in ori_list] 344 | else: 345 | return ori_list 346 | 347 | 348 | class ModelDictModelBinding(ModelBinding): 349 | 350 | async def bind(self) -> Any: 351 | return self.__build_model_dict() 352 | 353 | def __build_model_dict(self): 354 | mdict = ModelDict() 355 | for k, v in self.request.parameters.items(): 356 | if len(v) == 1: 357 | mdict[k] = v[0] 358 | else: 359 | mdict[k] = v 360 | return mdict 361 | 362 | 363 | class DictModelBinding(ModelBinding): 364 | 365 | async def bind(self) -> Any: 366 | return self.__build_dict(self.arg_name, **self._kws) 367 | 368 | def __build_dict(self, key, val={}): 369 | if key in self.request.parameter.keys(): 370 | try: 371 | return json.loads(self.request.parameter[key]) 372 | except: 373 | raise HttpError( 374 | 400, None, f"Parameter[{key}] should be a JSON string.") 375 | else: 376 | return val 377 | 378 | 379 | class DefaultModelBinding(ModelBinding): 380 | 381 | async def bind(self) -> Any: 382 | return self.default_value 383 | 384 | 385 | class ModelBindingConf: 386 | 387 | def __init__(self) -> None: 388 | self.default_model_binding_type = DefaultModelBinding 389 | self.model_bingding_types: Dict[Type, Type[ModelBinding]] = { 390 | Request: RequestModelBinding, 391 | HttpSession: SessionModelBinding, 392 | Response: ResponseModelBinding, 393 | Headers: HeadersModelBinding, 394 | RegGroups: RegGroupsModelBinding, 395 | Environment: EnvironmentModelBinding, 396 | Header: HeaderModelBinding, 397 | Cookies: CookiesModelBinding, 398 | BaseCookie: CookiesModelBinding, 399 | SimpleCookie: CookiesModelBinding, 400 | Cookie: CookieModelBinding, 401 | MultipartFile: MultipartFileModelBinding, 402 | Parameter: ParameterModelBinding, 403 | PathValue: PathValueModelBinding, 404 | Parameters: ParametersModelBinding, 405 | RegGroup: RegGroupModelBinding, 406 | JSONBody: JSONBodyModelBinding, 407 | RequestBodyReader: RequestBodyReaderModelBinding, 408 | BytesBody: BytesBodyModelBinding, 409 | str: StrModelBinding, 410 | bool: BoolModelBinding, 411 | int: IntModelBinding, 412 | float: FloatModelBinding, 413 | list: ListModelBinding, 414 | List: ListModelBinding, 415 | List[str]: ListModelBinding, 416 | List[Parameter]: ListModelBinding, 417 | List[int]: ListModelBinding, List[float]: ListModelBinding, 418 | List[bool]: ListModelBinding, 419 | List[dict]: ListModelBinding, 420 | List[Dict]: ListModelBinding, 421 | ModelDict: ModelDictModelBinding, 422 | dict: DictModelBinding, 423 | Dict: DictModelBinding 424 | } 425 | -------------------------------------------------------------------------------- /naja_atra/request_handlers/websocket_controller_handler.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Copyright (c) 2018 Keijack Wu 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | 25 | 26 | import asyncio 27 | import os 28 | import struct 29 | import errno 30 | 31 | from threading import Lock 32 | from base64 import b64encode 33 | from hashlib import sha1 34 | from typing import Dict, Tuple, Union, List 35 | from uuid import uuid4 36 | from socket import error as SocketError 37 | 38 | from ..utils.logger import get_logger 39 | from ..models import Headers, WebsocketCloseReason, WebsocketRequest, WebsocketSession 40 | from ..models import WEBSOCKET_OPCODE_BINARY, WEBSOCKET_OPCODE_CLOSE, WEBSOCKET_OPCODE_CONTINUATION, WEBSOCKET_OPCODE_PING, WEBSOCKET_OPCODE_PONG, WEBSOCKET_OPCODE_TEXT 41 | from ..models import DEFAULT_ENCODING 42 | 43 | 44 | _logger = get_logger("naja_atra.request_handlers.websocket_request_handler") 45 | 46 | 47 | ''' 48 | https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API/Writing_WebSocket_servers 49 | 50 | https://datatracker.ietf.org/doc/html/rfc6455 51 | 52 | +-+-+-+-+-------+-+-------------+-------------------------------+ 53 | 0 1 2 3 54 | 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 55 | +-+-+-+-+-------+-+-------------+-------------------------------+ 56 | |F|R|R|R| opcode|M| Payload len | Extended payload length | 57 | |I|S|S|S| (4) |A| (7) | (16/64) | 58 | |N|V|V|V| |S| | (if payload len==126/127) | 59 | | |1|2|3| |K| | | 60 | +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + 61 | | Extended payload length continued, if payload len == 127 | 62 | + - - - - - - - - - - - - - - - +-------------------------------+ 63 | | |Masking-key, if MASK set to 1 | 64 | +-------------------------------+-------------------------------+ 65 | | Masking-key (continued) | Payload Data | 66 | +-------------------------------- - - - - - - - - - - - - - - - + 67 | : Payload Data continued ... : 68 | + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + 69 | | Payload Data continued ... | 70 | +---------------------------------------------------------------+ 71 | ''' 72 | 73 | FIN = 0x80 74 | OPCODE = 0x0f 75 | MASKED = 0x80 76 | PAYLOAD_LEN = 0x7f 77 | PAYLOAD_LEN_EXT16 = 0x7e 78 | PAYLOAD_LEN_EXT64 = 0x7f 79 | GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' 80 | _BUFFER_SIZE = 1024 * 1024 81 | 82 | OPTYPES = { 83 | WEBSOCKET_OPCODE_CONTINUATION: "CONTINUATION", 84 | WEBSOCKET_OPCODE_TEXT: "TEXT", 85 | WEBSOCKET_OPCODE_BINARY: "BINARY", 86 | WEBSOCKET_OPCODE_CLOSE: "CLOSE", 87 | WEBSOCKET_OPCODE_PING: "PING", 88 | WEBSOCKET_OPCODE_PONG: "PONE", 89 | } 90 | 91 | 92 | class _ContinuationMessageCache: 93 | 94 | def __init__(self, opcode: int) -> None: 95 | self.opcode: int = opcode 96 | self.message_bytes: bytearray = bytearray() 97 | 98 | 99 | class WebsocketException(Exception): 100 | 101 | def __init__(self, reason: WebsocketCloseReason = None, graceful: bool = False) -> None: 102 | super().__init__(reason) 103 | self.__graceful: bool = graceful 104 | self.__reason: WebsocketCloseReason = reason 105 | 106 | @property 107 | def is_graceful(self) -> bool: 108 | return self.__graceful 109 | 110 | @property 111 | def reason(self) -> WebsocketCloseReason: 112 | return self.__reason 113 | 114 | 115 | class WebsocketControllerHandler: 116 | 117 | def __init__(self, http_request_handler) -> None: 118 | self.http_request_handler = http_request_handler 119 | self.request_writer = http_request_handler.request_writer 120 | self.routing_conf = http_request_handler.routing_conf 121 | self.send_response = http_request_handler.send_response_only 122 | self.send_header = http_request_handler.send_header 123 | self.reader = http_request_handler.reader 124 | self.keep_alive = True 125 | self.handshake_done = False 126 | 127 | handler_class, path_values, regroups = self.routing_conf.get_websocket_handler( 128 | http_request_handler.request_path) 129 | self.handler = handler_class.ctrl_object if handler_class else None 130 | self.ws_request = WebsocketRequest() 131 | self.ws_request.headers = http_request_handler.headers 132 | self.ws_request.path = http_request_handler.request_path 133 | self.ws_request.query_string = http_request_handler.query_string 134 | self.ws_request.parameters = http_request_handler.query_parameters 135 | self.ws_request.path_values = path_values 136 | self.ws_request.reg_groups = regroups 137 | if "cookie" in self.ws_request.headers: 138 | self.ws_request.cookies.load(self.ws_request.headers["cookie"]) 139 | elif "Cookie" in self.ws_request.headers: 140 | self.ws_request.cookies.load(self.ws_request.headers["Cookie"]) 141 | self.session = WebsocketSessionImpl(self, self.ws_request) 142 | self.close_reason: WebsocketCloseReason = None 143 | 144 | self._continution_cache: _ContinuationMessageCache = None 145 | self._send_msg_lock = Lock() 146 | self._send_frame_lock = Lock() 147 | 148 | @property 149 | def response_headers(self): 150 | if hasattr(self.http_request_handler, "_headers_buffer"): 151 | return self.http_request_handler._headers_buffer 152 | else: 153 | return [] 154 | 155 | async def await_func(self, obj): 156 | if asyncio.iscoroutine(obj): 157 | return await obj 158 | return obj 159 | 160 | async def on_handshake(self) -> Tuple[int, Dict[str, List[str]]]: 161 | try: 162 | if not hasattr(self.handler, "on_handshake") or not callable(self.handler.on_handshake): 163 | return None, {} 164 | res = await self.await_func(self.handler.on_handshake(self.ws_request)) 165 | http_status_code = None 166 | headers = {} 167 | if not res: 168 | pass 169 | elif isinstance(res, int): 170 | http_status_code = res 171 | elif isinstance(res, dict) or isinstance(res, Headers): 172 | headers = res 173 | elif isinstance(res, tuple): 174 | for item in res: 175 | if isinstance(item, int) and not http_status_code: 176 | http_status_code = item 177 | elif isinstance(item, dict) or isinstance(item, Headers): 178 | headers.update(item) 179 | else: 180 | _logger.warn(f"Endpoint[{self.ws_request.path}]") 181 | return http_status_code, headers 182 | except Exception as e: 183 | _logger.error(f"Error occurs when handshake. ") 184 | return 500, {} 185 | 186 | async def on_message(self, opcode: int, message_bytes: bytearray): 187 | try: 188 | if opcode == WEBSOCKET_OPCODE_CLOSE: 189 | _logger.info("Client asked to close connection.") 190 | if len(message_bytes) >= 2: 191 | code = struct.unpack(">H", message_bytes[0:2])[0] 192 | reason = message_bytes[2:].decode( 193 | 'UTF-8', errors="replace") 194 | else: 195 | code = None 196 | reason = '' 197 | raise WebsocketException(graceful=True, reason=WebsocketCloseReason( 198 | "Client asked to close connection.", code=code, reason=reason)) 199 | elif opcode == WEBSOCKET_OPCODE_TEXT and hasattr(self.handler, "on_text_message") and callable(self.handler.on_text_message): 200 | await self.await_func(self.handler.on_text_message(self.session, message_bytes.decode("UTF-8", errors="replace"))) 201 | elif opcode == WEBSOCKET_OPCODE_PING and hasattr(self.handler, "on_ping_message") and callable(self.handler.on_ping_message): 202 | await self.await_func(self.handler.on_ping_message(self.session, bytes(message_bytes))) 203 | elif opcode == WEBSOCKET_OPCODE_PONG and hasattr(self.handler, "on_pong_message") and callable(self.handler.on_pong_message): 204 | await self.await_func(self.handler.on_pong_message(self.session, bytes(message_bytes))) 205 | elif opcode == WEBSOCKET_OPCODE_BINARY and self._continution_cache.message_bytes and hasattr(self.handler, "on_binary_message") and callable(self.handler.on_binary_message): 206 | await self.await_func(self.handler.on_binary_message(self.session, bytes(message_bytes))) 207 | except Exception as e: 208 | _logger.error(f"Error occurs when on message!") 209 | self.close(f"Error occurs when on_message. {e}") 210 | 211 | async def on_continuation_frame(self, first_frame_opcode: int, fin: int, message_frame: bytearray): 212 | try: 213 | if first_frame_opcode == WEBSOCKET_OPCODE_BINARY and hasattr(self.handler, "on_binary_frame") and callable(self.handler.on_binary_frame): 214 | should_append_to_cache = await self.await_func(self.handler.on_binary_frame(self.session, bool(fin), bytes(message_frame))) 215 | if should_append_to_cache == True: 216 | self._continution_cache.message_bytes.extend(message_frame) 217 | else: 218 | self._continution_cache.message_bytes.extend(message_frame) 219 | except Exception as e: 220 | _logger.error(f"Error occurs when on message!") 221 | self.close(f"Error occurs when on_message. {e}") 222 | 223 | async def on_open(self): 224 | try: 225 | if hasattr(self.handler, "on_open") and callable(self.handler.on_open): 226 | await self.await_func(self.handler.on_open(self.session)) 227 | except Exception as e: 228 | _logger.error(f"Error occurs when on open!") 229 | self.close(f"Error occurs when on_open. {e}") 230 | 231 | async def on_close(self): 232 | try: 233 | if hasattr(self.handler, "on_close") and callable(self.handler.on_close): 234 | await self.await_func(self.handler.on_close(self.session, self.close_reason)) 235 | except Exception as e: 236 | _logger.error(f"Error occurs when on close!") 237 | 238 | async def handle_request(self): 239 | while self.keep_alive: 240 | try: 241 | if not self.handshake_done: 242 | await self.handshake() 243 | else: 244 | await self.read_next_message() 245 | except WebsocketException as e: 246 | if not e.is_graceful: 247 | _logger.warning( 248 | f"Something's wrong, close connection: {e.reason}") 249 | else: 250 | _logger.info(f"Close connection: {e.reason}") 251 | self.keep_alive = False 252 | self.close_reason = e.reason 253 | except: 254 | _logger.exception("Errors occur when handling message!") 255 | self.keep_alive = False 256 | self.close_reason = WebsocketCloseReason( 257 | "Errors occur when handling message!") 258 | 259 | await self.on_close() 260 | 261 | async def handshake(self): 262 | if self.handler: 263 | code, headers = await self.on_handshake() 264 | if code and code != 101: 265 | self.keep_alive = False 266 | self.send_response(code) 267 | else: 268 | self.send_response(101, "Switching Protocols") 269 | self.send_header("Upgrade", "websocket") 270 | self.send_header("Connection", "Upgrade") 271 | self.send_header("Sec-WebSocket-Accept", 272 | self.calculate_response_key()) 273 | if headers: 274 | for h_name, h_val in headers.items(): 275 | self.send_header(h_name, h_val) 276 | else: 277 | self.keep_alive = False 278 | self.send_response(404) 279 | 280 | ws_res_headers = b"".join(self.response_headers) + b"\r\n" 281 | _logger.debug(ws_res_headers) 282 | self.request_writer.send(ws_res_headers) 283 | self.handshake_done = True 284 | if self.keep_alive == True: 285 | await self.on_open() 286 | 287 | def calculate_response_key(self): 288 | key: str = self.ws_request.headers["Sec-WebSocket-Key"] if "Sec-WebSocket-Key" in self.ws_request.headers else self.ws_request.headers["Sec-Websocket-Key"] 289 | _logger.debug( 290 | f"Sec-WebSocket-Key: {key}") 291 | key_hash = sha1(key.encode(errors="replace") + 292 | GUID.encode(errors="replace")) 293 | response_key = b64encode(key_hash.digest()).strip() 294 | return response_key.decode('ASCII', errors="replace") 295 | 296 | async def read_bytes(self, num): 297 | return await self.reader.read(num) 298 | 299 | async def _read_message_content(self) -> Tuple[int, int, bytearray]: 300 | _logger.debug(f"Read next websocket[{self.ws_request.path}] message") 301 | try: 302 | b1, b2 = await self.read_bytes(2) 303 | except ConnectionResetError as e: 304 | raise WebsocketException( 305 | graceful=True, reason=WebsocketCloseReason("Client closed connection.")) 306 | except SocketError as e: 307 | if e.errno == errno.ECONNRESET: 308 | raise WebsocketException( 309 | graceful=True, reason=WebsocketCloseReason("Client closed connection.")) 310 | b1, b2 = 0, 0 311 | except ValueError as e: 312 | b1, b2 = 0, 0 313 | 314 | fin = b1 & FIN 315 | opcode = b1 & OPCODE 316 | masked = b2 & MASKED 317 | payload_length = b2 & PAYLOAD_LEN 318 | 319 | if not masked: 320 | raise WebsocketException( 321 | reason=WebsocketCloseReason("Client is not masked.")) 322 | 323 | if opcode not in OPTYPES.keys(): 324 | raise WebsocketException( 325 | reason=WebsocketCloseReason(f"Unknown opcode {opcode}.")) 326 | 327 | if opcode in (WEBSOCKET_OPCODE_PING, WEBSOCKET_OPCODE_PONG) and payload_length > 125: 328 | raise WebsocketException(reason=WebsocketCloseReason( 329 | f"Ping/Pong message payload is too large! The max length of the Ping/Pong messages is 125. but now is {payload_length}")) 330 | 331 | if payload_length == 126: 332 | hb = await self.reader.read(2) 333 | payload_length = struct.unpack(">H", hb)[0] 334 | elif payload_length == 127: 335 | qb = await self.reader.read(8) 336 | payload_length = struct.unpack(">Q", qb)[0] 337 | 338 | frame_bytes = bytearray() 339 | if payload_length > 0: 340 | masks = await self.read_bytes(4) 341 | payload = await self.read_bytes(payload_length) 342 | for encoded_byte in payload: 343 | frame_bytes.append(encoded_byte ^ masks[len(frame_bytes) % 4]) 344 | 345 | return fin, opcode, frame_bytes 346 | 347 | async def read_next_message(self): 348 | 349 | fin, opcode, frame_bytes = await self._read_message_content() 350 | 351 | if fin and opcode != WEBSOCKET_OPCODE_CONTINUATION: 352 | # A normal frame, handle message. 353 | await self.on_message(opcode, frame_bytes) 354 | return 355 | 356 | if not fin and opcode != WEBSOCKET_OPCODE_CONTINUATION: 357 | # Fragment message: first frame, try to create a cache object. 358 | if opcode not in (WEBSOCKET_OPCODE_TEXT, WEBSOCKET_OPCODE_BINARY): 359 | raise WebsocketException(reason=WebsocketCloseReason( 360 | f"Control({OPTYPES[opcode]}) frames MUST NOT be fragmented")) 361 | 362 | if self._continution_cache is not None: 363 | # Check if another fragment message is being read. 364 | raise WebsocketException(reason=WebsocketCloseReason( 365 | "Another continution message is not yet finished. Close connection for this error!")) 366 | 367 | self._continution_cache = _ContinuationMessageCache(opcode) 368 | 369 | if self._continution_cache is None: 370 | # When the first frame is not send, close connection. 371 | raise WebsocketException(reason=WebsocketCloseReason( 372 | "A continuation fragment frame is received, but the start fragment is not yet received. ")) 373 | 374 | await self.on_continuation_frame(self._continution_cache.opcode, fin, frame_bytes) 375 | 376 | if fin: 377 | # Fragment message: end of this message. 378 | await self.on_message(self._continution_cache.opcode, self._continution_cache.message_bytes) 379 | self._continution_cache = None 380 | 381 | def send_message(self, message: Union[bytes, str], chunk_size: int = 0): 382 | if isinstance(message, bytes): 383 | self.send_bytes(WEBSOCKET_OPCODE_TEXT, 384 | message, chunk_size=chunk_size) 385 | elif isinstance(message, str): 386 | self.send_bytes(WEBSOCKET_OPCODE_TEXT, message.encode( 387 | DEFAULT_ENCODING, errors="replace"), chunk_size=chunk_size) 388 | else: 389 | _logger.error(f"Cannot send message[{message}. ") 390 | 391 | def send_ping(self, message: Union[str, bytes]): 392 | if isinstance(message, bytes): 393 | self.send_bytes(WEBSOCKET_OPCODE_PING, message) 394 | elif isinstance(message, str): 395 | self.send_bytes(WEBSOCKET_OPCODE_PING, message.encode( 396 | DEFAULT_ENCODING, errors="replace")) 397 | 398 | def send_pong(self, message: Union[str, bytes]): 399 | if isinstance(message, bytes): 400 | self.send_bytes(WEBSOCKET_OPCODE_PONG, message) 401 | elif isinstance(message, str): 402 | self.send_bytes(WEBSOCKET_OPCODE_PONG, message.encode( 403 | DEFAULT_ENCODING, errors="replace")) 404 | 405 | def send_bytes(self, opcode: int, payload: bytes, chunk_size: int = 0): 406 | if opcode not in OPTYPES.keys() or opcode == WEBSOCKET_OPCODE_CONTINUATION: 407 | raise WebsocketException(reason=WebsocketCloseReason( 408 | f"Cannot send message in a opcode {opcode}. ")) 409 | 410 | # Control frames MUST NOT be fragmented. 411 | c_size = chunk_size if opcode in ( 412 | WEBSOCKET_OPCODE_BINARY, WEBSOCKET_OPCODE_TEXT) else 0 413 | 414 | if c_size and c_size > 0: 415 | with self._send_msg_lock: 416 | # Make sure a fragmented message is sent completely. 417 | self._send_bytes_no_lock(opcode, payload, chunk_size=c_size) 418 | else: 419 | self._send_bytes_no_lock(opcode, payload) 420 | 421 | def _send_bytes_no_lock(self, opcode: int, payload: bytes, chunk_size: int = 0): 422 | frame_size = chunk_size if chunk_size and chunk_size > 0 else None 423 | all_payloads = payload 424 | frame_bytes = b'' 425 | while all_payloads: 426 | op = WEBSOCKET_OPCODE_CONTINUATION if frame_bytes else opcode 427 | frame_bytes = all_payloads[0: frame_size] 428 | all_payloads = all_payloads[frame_size:] if frame_size else b'' 429 | fin = 0 if all_payloads else FIN 430 | self._send_frame(fin, op, frame_bytes) 431 | 432 | def _send_frame(self, fin: int, opcode: int, payload: bytes): 433 | with self._send_frame_lock: 434 | self.request_writer.send( 435 | self._create_frame_header(fin, opcode, len(payload))) 436 | self.request_writer.send(payload) 437 | 438 | def _create_frame_header(self, fin: int, opcode: int, payload_length: int) -> bytes: 439 | header = bytearray() 440 | # Normal payload 441 | if payload_length <= 125: 442 | header.append(fin | opcode) 443 | header.append(payload_length) 444 | 445 | # Extended payload 446 | elif payload_length >= 126 and payload_length <= 65535: 447 | header.append(fin | opcode) 448 | header.append(PAYLOAD_LEN_EXT16) 449 | header.extend(struct.pack(">H", payload_length)) 450 | 451 | # Huge extended payload 452 | elif payload_length < 18446744073709551616: 453 | header.append(fin | opcode) 454 | header.append(PAYLOAD_LEN_EXT64) 455 | header.extend(struct.pack(">Q", payload_length)) 456 | else: 457 | raise Exception( 458 | "Message is too big. Consider breaking it into chunks.") 459 | 460 | return header 461 | 462 | def send_file(self, path: str, chunk_size: int = 0): 463 | try: 464 | file_size = os.path.getsize(path) 465 | if not chunk_size or chunk_size < 0 or chunk_size > file_size: 466 | self._send_file_no_lock(path, file_size, file_size) 467 | else: 468 | with self._send_msg_lock: 469 | self._send_file_no_lock(path, file_size, chunk_size) 470 | except (OSError, ValueError): 471 | raise WebsocketException(reason=WebsocketCloseReason( 472 | f"File in {path} does not exist or is not accessible.")) 473 | 474 | def _send_file_no_lock(self, path: str, file_size: int, chunk_size: int): 475 | with open(path, 'rb') as in_file: 476 | remain_bytes = file_size 477 | opcode = WEBSOCKET_OPCODE_BINARY 478 | while remain_bytes > 0: 479 | with self._send_frame_lock: 480 | frame_size = min(remain_bytes, chunk_size) 481 | remain_bytes -= frame_size 482 | 483 | fin = 0 if remain_bytes > 0 else FIN 484 | 485 | self.request_writer.send( 486 | self._create_frame_header(fin, opcode, frame_size)) 487 | while frame_size > 0: 488 | buff_size = min(_BUFFER_SIZE, frame_size) 489 | frame_size -= buff_size 490 | 491 | data = in_file.read(buff_size) 492 | self.request_writer.send(data) 493 | # After the first frame, the opcode of other frames is continuation forever. 494 | opcode = WEBSOCKET_OPCODE_CONTINUATION 495 | 496 | def close(self, reason: str = ""): 497 | self.send_bytes(WEBSOCKET_OPCODE_CLOSE, reason.encode( 498 | DEFAULT_ENCODING, errors="replace")) 499 | self.keep_alive = False 500 | self.close_reason = WebsocketCloseReason( 501 | "Server asked to close connection.") 502 | 503 | 504 | class WebsocketSessionImpl(WebsocketSession): 505 | 506 | def __init__(self, handler: WebsocketControllerHandler, request: WebsocketRequest) -> None: 507 | self.__id = uuid4().hex 508 | self.__handler = handler 509 | self.__request = request 510 | 511 | @ property 512 | def id(self) -> str: 513 | return self.__id 514 | 515 | @ property 516 | def request(self) -> WebsocketRequest: 517 | return self.__request 518 | 519 | @ property 520 | def is_closed(self) -> bool: 521 | return not self.__handler.keep_alive 522 | 523 | def send_ping(self, message: bytes = b''): 524 | self.__handler.send_ping(message) 525 | 526 | def send_pone(self, message: bytes = b''): 527 | self.__handler.send_pong(message) 528 | 529 | def send(self, message: Union[str, bytes], opcode: int = WEBSOCKET_OPCODE_TEXT, chunk_size: int = 0): 530 | if isinstance(message, bytes): 531 | msg = message 532 | elif isinstance(message, str): 533 | msg = message.encode(DEFAULT_ENCODING, errors="replace") 534 | else: 535 | raise WebsocketException(reason=WebsocketCloseReason( 536 | f"message {message} is not a string nor a bytes object, cannot send it to client. ")) 537 | self.__handler.send_bytes( 538 | opcode if opcode is None else WEBSOCKET_OPCODE_TEXT, msg, chunk_size=chunk_size) 539 | 540 | def send_text(self, message: str, chunk_size: int = 0): 541 | self.__handler.send_message(message, chunk_size=chunk_size) 542 | 543 | def send_binary(self, binary: bytes, chunk_size: int = 0): 544 | self.__handler.send_bytes( 545 | WEBSOCKET_OPCODE_BINARY, binary, chunk_size=chunk_size) 546 | 547 | def send_file(self, path: str, chunk_size: int = 0): 548 | self.__handler.send_file(path, chunk_size=chunk_size) 549 | 550 | def close(self, reason: str = ""): 551 | self.__handler.close(reason) 552 | -------------------------------------------------------------------------------- /naja_atra/server.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Copyright (c) 2018 Keijack Wu 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | 25 | import os 26 | import sys 27 | import threading 28 | import importlib 29 | import re 30 | 31 | from ssl import PROTOCOL_TLS_SERVER, SSLContext 32 | from typing import Dict 33 | 34 | from .http_servers.http_server import HttpServer 35 | 36 | from .app_conf import AppConf 37 | from .utils.logger import get_logger 38 | 39 | 40 | _logger = get_logger("naja_atra.server") 41 | __lock = threading.Lock() 42 | _server: HttpServer = None 43 | 44 | 45 | def _is_match(string="", regx=r""): 46 | if not regx: 47 | return True 48 | pattern = re.compile(regx) 49 | match = pattern.match(string) 50 | return True if match else False 51 | 52 | 53 | def _to_module_name(fpath="", regx=r""): 54 | fname, fext = os.path.splitext(fpath) 55 | 56 | if fext != ".py": 57 | return 58 | mname = fname.replace(os.path.sep, '.') 59 | if _is_match(fpath, regx) or _is_match(fname, regx) or _is_match(mname, regx): 60 | return mname 61 | 62 | 63 | def _load_all_modules(work_dir, pkg, regx): 64 | abs_folder = os.path.join(work_dir, pkg) 65 | if os.path.isfile(abs_folder): 66 | return [_to_module_name(pkg, regx)] 67 | if not os.path.exists(abs_folder): 68 | if abs_folder.endswith(".py"): 69 | _logger.warning( 70 | f"Cannot find package {pkg}, file [{abs_folder}] is not exist") 71 | return [] 72 | if os.path.isfile(abs_folder + ".py"): 73 | return [_to_module_name(pkg + ".py", regx)] 74 | else: 75 | _logger.warning( 76 | f"Cannot find package {pkg}, file [{abs_folder}] is not exist") 77 | return [] 78 | 79 | modules = [] 80 | folders = [] 81 | all_files = os.listdir(abs_folder) 82 | for f in all_files: 83 | if os.path.isfile(os.path.join(abs_folder, f)): 84 | mname = _to_module_name(os.path.join(pkg, f), regx) 85 | if mname: 86 | modules.append(mname) 87 | elif f != "__pycache__": 88 | folders.append(os.path.join(pkg, f)) 89 | 90 | for folder in folders: 91 | modules += _load_all_modules(work_dir, folder, regx) 92 | return modules 93 | 94 | 95 | def _import_module(mname): 96 | try: 97 | importlib.import_module(mname) 98 | except Exception as e: 99 | _logger.warning(f"Import moudle [{mname}] error! {e}") 100 | 101 | 102 | def scan(base_dir: str = "", regx: str = r"", project_dir: str = "") -> None: 103 | """ 104 | Scan the given directory to import controllers. 105 | 106 | - base_dir: the directory to scan. 107 | - regx: only include the modules that match this regular expression, if absent, all files will be included. 108 | - project_dir: the project directory, default to the entrance file directory. 109 | 110 | """ 111 | if project_dir: 112 | work_dir = project_dir 113 | else: 114 | mpath = os.path.dirname(sys.modules['__main__'].__file__) 115 | _logger.debug( 116 | f"Project directory is not set, use the directory where the main module is in. {mpath}") 117 | work_dir = mpath 118 | modules = _load_all_modules(work_dir, base_dir, regx) 119 | 120 | for mname in modules: 121 | _logger.info(f"Import controllers from module: {mname}") 122 | _import_module(mname) 123 | 124 | 125 | def _prepare_server(host: str = "", 126 | port: int = 9090, 127 | ssl: bool = False, 128 | ssl_protocol: int = PROTOCOL_TLS_SERVER, 129 | ssl_check_hostname: bool = False, 130 | keyfile: str = "", 131 | certfile: str = "", 132 | keypass: str = "", 133 | ssl_context: SSLContext = None, 134 | resources: Dict[str, str] = {}, 135 | connection_idle_time=None, 136 | keep_alive=True, 137 | keep_alive_max_request=None, 138 | gzip_content_types=set(), 139 | gzip_compress_level=9, 140 | prefer_coroutine=False, 141 | app_conf: AppConf = None 142 | ) -> None: 143 | with __lock: 144 | global _server 145 | if _server is not None: 146 | _server.shutdown() 147 | _server = HttpServer(host=(host, port), 148 | ssl=ssl, 149 | ssl_protocol=ssl_protocol, 150 | ssl_check_hostname=ssl_check_hostname, 151 | keyfile=keyfile, 152 | certfile=certfile, 153 | keypass=keypass, 154 | ssl_context=ssl_context, 155 | resources=resources, 156 | connection_idle_time=connection_idle_time, 157 | keep_alive=keep_alive, 158 | keep_alive_max_request=keep_alive_max_request, 159 | gzip_content_types=gzip_content_types, 160 | gzip_compress_level=gzip_compress_level, 161 | prefer_corountine=prefer_coroutine, 162 | app_conf=app_conf) 163 | 164 | 165 | def start(host: str = "", 166 | port: int = 9090, 167 | ssl: bool = False, 168 | ssl_protocol: int = PROTOCOL_TLS_SERVER, 169 | ssl_check_hostname: bool = False, 170 | keyfile: str = "", 171 | certfile: str = "", 172 | keypass: str = "", 173 | ssl_context: SSLContext = None, 174 | resources: Dict[str, str] = {}, 175 | connection_idle_time=None, 176 | keep_alive=True, 177 | keep_alive_max_request=None, 178 | gzip_content_types=set(), 179 | gzip_compress_level=9, 180 | prefer_coroutine=False, 181 | app_conf: AppConf = None) -> None: 182 | _prepare_server( 183 | host=host, 184 | port=port, 185 | ssl=ssl, 186 | ssl_protocol=ssl_protocol, 187 | ssl_check_hostname=ssl_check_hostname, 188 | keyfile=keyfile, 189 | certfile=certfile, 190 | keypass=keypass, 191 | ssl_context=ssl_context, 192 | resources=resources, 193 | connection_idle_time=connection_idle_time, 194 | keep_alive=keep_alive, 195 | keep_alive_max_request=keep_alive_max_request, 196 | gzip_content_types=gzip_content_types, 197 | gzip_compress_level=gzip_compress_level, 198 | prefer_coroutine=prefer_coroutine, 199 | app_conf=app_conf 200 | ) 201 | # start the server 202 | _server.start() 203 | 204 | 205 | async def start_async(host: str = "", 206 | port: int = 9090, 207 | ssl: bool = False, 208 | ssl_protocol: int = PROTOCOL_TLS_SERVER, 209 | ssl_check_hostname: bool = False, 210 | keyfile: str = "", 211 | certfile: str = "", 212 | keypass: str = "", 213 | ssl_context: SSLContext = None, 214 | resources: Dict[str, str] = {}, 215 | connection_idle_time=None, 216 | keep_alive=True, 217 | keep_alive_max_request=None, 218 | gzip_content_types=set(), 219 | gzip_compress_level=9, 220 | prefer_coroutine=True, 221 | app_conf: AppConf = None) -> None: 222 | _prepare_server( 223 | host=host, 224 | port=port, 225 | ssl=ssl, 226 | ssl_protocol=ssl_protocol, 227 | ssl_check_hostname=ssl_check_hostname, 228 | keyfile=keyfile, 229 | certfile=certfile, 230 | keypass=keypass, 231 | ssl_context=ssl_context, 232 | resources=resources, 233 | connection_idle_time=connection_idle_time, 234 | keep_alive=keep_alive, 235 | keep_alive_max_request=keep_alive_max_request, 236 | gzip_content_types=gzip_content_types, 237 | gzip_compress_level=gzip_compress_level, 238 | prefer_coroutine=prefer_coroutine, 239 | app_conf=app_conf 240 | ) 241 | 242 | # start the server 243 | await _server.start_async() 244 | 245 | 246 | def is_ready() -> bool: 247 | return _server and _server.ready 248 | 249 | 250 | def stop() -> None: 251 | with __lock: 252 | global _server 253 | if _server is not None: 254 | _logger.info("Shutting down server...") 255 | _server.shutdown() 256 | _server = None 257 | else: 258 | _logger.warning("Server is not ready yet.") 259 | -------------------------------------------------------------------------------- /naja_atra/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Copyright (c) 2018 Keijack Wu 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | -------------------------------------------------------------------------------- /naja_atra/utils/http_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | """ 5 | Copyright (c) 2018 Keijack Wu 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | """ 25 | 26 | import inspect 27 | import time 28 | import os 29 | import email 30 | import json 31 | import re 32 | from collections import OrderedDict 33 | from typing import Any, Tuple, Union 34 | from urllib.parse import unquote, quote 35 | from ..models import HttpError, StaticFile, DEFAULT_ENCODING 36 | 37 | from .logger import get_logger 38 | 39 | 40 | _logger = get_logger("naja_atra.utils.http_utils") 41 | 42 | 43 | def remove_url_first_slash(url: str): 44 | _url = url 45 | while _url.startswith("/"): 46 | _url = url[1:] 47 | return _url 48 | 49 | 50 | def get_function_args(func, default_type=str): 51 | argspec = inspect.getfullargspec(func) 52 | # ignore first argument like `self` or `clz` in object methods or class methods 53 | start = 1 if inspect.ismethod(func) else 0 54 | if argspec.defaults is None: 55 | args = argspec.args[start:] 56 | else: 57 | args = argspec.args[start: len(argspec.args) - len(argspec.defaults)] 58 | arg_turples = [] 59 | for arg in args: 60 | if arg in argspec.annotations: 61 | ty = argspec.annotations[arg] 62 | else: 63 | ty = default_type 64 | arg_turples.append((arg, ty)) 65 | return arg_turples 66 | 67 | 68 | def get_function_kwargs(func, default_type=str): 69 | argspec = inspect.getfullargspec(func) 70 | if argspec.defaults is None: 71 | return [] 72 | 73 | kwargs = OrderedDict(zip(argspec.args[-len(argspec.defaults):], argspec.defaults)) 74 | kwarg_turples = [] 75 | for k, v in kwargs.items(): 76 | if k in argspec.annotations: 77 | k_anno = argspec.annotations[k] 78 | else: 79 | k_anno = default_type 80 | kwarg_turples.append((k, v, k_anno)) 81 | return kwarg_turples 82 | 83 | 84 | def break_into(txt: str, separator: str): 85 | try: 86 | idx = txt.index(separator) 87 | return txt[0: idx], txt[idx + len(separator):] 88 | except ValueError: 89 | return txt, None 90 | 91 | 92 | def put_to(params, key, val): 93 | if key not in params.keys(): 94 | params[key] = [val] 95 | else: 96 | params[key].append(val) 97 | 98 | 99 | def decode_query_string(query_string: str): 100 | params = {} 101 | if not query_string: 102 | return params 103 | pairs = query_string.split("&") 104 | for item in pairs: 105 | key, val = break_into(item, "=") 106 | if val is None: 107 | val = "" 108 | put_to(params, unquote(key), unquote(val)) 109 | 110 | return params 111 | 112 | 113 | def date_time_string(timestamp=None): 114 | if timestamp is None: 115 | timestamp = time.time() 116 | return email.utils.formatdate(timestamp, usegmt=True) 117 | 118 | 119 | def decode_response_body(raw_body: Any) -> Tuple[str, Union[str, bytes, StaticFile]]: 120 | content_type = "text/plain; chartset=utf8" 121 | if raw_body is None: 122 | body = "" 123 | elif isinstance(raw_body, dict): 124 | content_type = "application/json; charset=utf8" 125 | body = json.dumps(raw_body, ensure_ascii=False) 126 | elif isinstance(raw_body, str): 127 | body = raw_body.strip() 128 | if body.startswith(""): 129 | content_type = "text/xml; charset=utf8" 130 | elif body.lower().startswith(""): 131 | content_type = "text/html; charset=utf8" 132 | elif body.lower().startswith(""): 133 | content_type = "text/html; charset=utf8" 134 | else: 135 | content_type = "text/plain; charset=utf8" 136 | elif isinstance(raw_body, StaticFile): 137 | if not os.path.isfile(raw_body.file_path): 138 | _logger.error(f"Cannot find file[{raw_body.file_path}] specified in StaticFile body.") 139 | raise HttpError(404, explain="Cannot find file for this url.") 140 | else: 141 | body = raw_body 142 | content_type = body.content_type 143 | elif isinstance(raw_body, bytes): 144 | body = raw_body 145 | content_type = "application/octet-stream" 146 | else: 147 | body = raw_body 148 | return content_type, body 149 | 150 | 151 | def decode_response_body_to_bytes(raw_body: Any) -> Tuple[str, bytes]: 152 | content_type, body = decode_response_body(raw_body) 153 | if body is None: 154 | byte_body = b'' 155 | elif isinstance(body, str): 156 | byte_body = body.encode(DEFAULT_ENCODING, 'replace') 157 | elif isinstance(body, bytes): 158 | byte_body = body 159 | elif isinstance(body, StaticFile): 160 | with open(body.file_path, "rb") as in_file: 161 | byte_body = in_file.read() 162 | else: 163 | raise HttpError(400, explain="Cannot read body into bytes!") 164 | return content_type, byte_body 165 | 166 | def get_path_reg_pattern(url): 167 | _url: str = url 168 | path_names = re.findall("(?u)\\{\\w+\\}", _url) 169 | if len(path_names) == 0: 170 | if _url.startswith("**"): 171 | _url = _url[2: ] 172 | assert _url.find("*") < 0, "You can only config a * or ** at the start or end of a path." 173 | _url = f'^([\\w%.\\-@!\\(\\)\\[\\]\\|\\$/]+){_url}$' 174 | return _url, [quote("__path_wildcard")] 175 | elif _url.startswith("*"): 176 | _url = _url[1: ] 177 | assert _url.find("*") < 0, "You can only config a * or ** at the start or end of a path." 178 | _url = f'^([\\w%.\\-@!\\(\\)\\[\\]\\|\\$]+){_url}$' 179 | return _url, [quote("__path_wildcard")] 180 | elif _url.endswith("**"): 181 | _url = _url[0: -2] 182 | assert _url.find("*") < 0, "You can only config a * or ** at the start or end of a path." 183 | _url = f'^{_url}([\\w%.\\-@!\\(\\)\\[\\]\\|\\$/]+)$' 184 | return _url, [quote("__path_wildcard")] 185 | elif _url.endswith("*"): 186 | _url = _url[0: -1] 187 | assert _url.find("*") < 0, "You can only config a * or ** at the start or end of a path." 188 | _url = f'^{_url}([\\w%.\\-@!\\(\\)\\[\\]\\|\\$]+)$' 189 | return _url, [quote("__path_wildcard")] 190 | else: 191 | # normal url 192 | return None, path_names 193 | for name in path_names: 194 | _url = _url.replace(name, "([\\w%.\\-@!\\(\\)\\[\\]\\|\\$]+)") 195 | _url = f"^{_url}$" 196 | 197 | quoted_names = [] 198 | for name in path_names: 199 | name = name[1: -1] 200 | quoted_names.append(quote(name)) 201 | return _url, quoted_names -------------------------------------------------------------------------------- /naja_atra/utils/logger.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Copyright (c) 2018 Keijack Wu 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | 25 | from abc import abstractmethod 26 | import sys 27 | import time 28 | import logging 29 | import asyncio 30 | from threading import Thread 31 | from typing import Dict, List, Tuple 32 | 33 | 34 | class LazyCalledLogger(logging.Logger): 35 | 36 | def do_call_handlers(self, record): 37 | super().callHandlers(record) 38 | 39 | @abstractmethod 40 | def callHandlers(self, record): 41 | pass 42 | 43 | 44 | class LazyCalledLoggerThread: 45 | 46 | daemon_threads = True 47 | 48 | def __init__(self) -> None: 49 | self.coroutine_loop = None 50 | self.coroutine_thread = None 51 | 52 | def coroutine_main(self): 53 | self.coroutine_loop = loop = asyncio.new_event_loop() 54 | try: 55 | loop.run_forever() 56 | finally: 57 | loop.run_until_complete(loop.shutdown_asyncgens()) 58 | loop.close() 59 | 60 | def start(self): 61 | if self.coroutine_thread is not None: 62 | while not self.coroutine_loop: 63 | # wait for the loop ready 64 | time.sleep(0.1) 65 | return 66 | self.coroutine_thread = Thread(target=self.coroutine_main, name="logger-thread", daemon=self.daemon_threads) 67 | self.coroutine_thread.start() 68 | 69 | while not self.coroutine_loop: 70 | # wait for the loop ready 71 | time.sleep(0.1) 72 | 73 | def stop(self): 74 | if self.coroutine_loop: 75 | self.coroutine_loop.call_soon_threadsafe(self.coroutine_loop.stop) 76 | self.coroutine_thread.join() 77 | self.coroutine_loop = None 78 | self.coroutine_thread = None 79 | 80 | async def _call(self, logger: LazyCalledLogger, record): 81 | logger.do_call_handlers(record) 82 | 83 | def call_logger_handler(self, logger: LazyCalledLogger, record): 84 | self.start() 85 | asyncio.run_coroutine_threadsafe(self._call(logger, record), self.coroutine_loop) 86 | 87 | 88 | class CachingLogger(LazyCalledLogger): 89 | 90 | logger_thread: LazyCalledLoggerThread = LazyCalledLoggerThread() 91 | 92 | def callHandlers(self, record): 93 | CachingLogger.logger_thread.call_logger_handler(self, record) 94 | 95 | 96 | class LoggerFactory: 97 | 98 | DEFAULT_LOG_FORMAT: str = '[%(asctime)s]-[%(threadName)s]-[%(name)s:%(lineno)d] %(levelname)-4s: %(message)s' 99 | 100 | DEFAULT_DATE_FORMAT: str = '%Y-%m-%d %H:%M:%S' 101 | 102 | _LOG_LVELS: Tuple[str] = ("DEBUG", "INFO", "WARN", "ERROR") 103 | 104 | def __init__(self, log_level: str = "INFO", log_format: str = DEFAULT_LOG_FORMAT, date_format: str = DEFAULT_DATE_FORMAT) -> None: 105 | self.__cache_loggers: Dict[str, CachingLogger] = {} 106 | self._log_level = log_level.upper() if log_level and log_level.upper() in self._LOG_LVELS else "INFO" 107 | self.log_format = log_format 108 | self.date_format = date_format 109 | 110 | self._handlers: List[logging.Handler] = [] 111 | 112 | @property 113 | def handlers(self) -> List[logging.Handler]: 114 | if self._handlers: 115 | return self._handlers 116 | _handler = logging.StreamHandler(sys.stdout) 117 | _formatter_ = logging.Formatter(fmt=self.log_format, datefmt=self.date_format) 118 | _handler.setFormatter(_formatter_) 119 | _handler.setLevel(self._log_level) 120 | self._handlers.append(_handler) 121 | return self._handlers 122 | 123 | @property 124 | def log_level(self): 125 | return self._log_level 126 | 127 | @log_level.setter 128 | def log_level(self, log_level: str): 129 | self._log_level = log_level.upper() if log_level and log_level.upper() in self._LOG_LVELS else "INFO" 130 | _logger_ = self.get_logger("Logger") 131 | _logger_.info(f"global logger set to {self._log_level}") 132 | for h in self._handlers: 133 | h.setLevel(self._log_level) 134 | for l in self.__cache_loggers.values(): 135 | l.setLevel(self._log_level) 136 | 137 | def add_handler(self, handler: logging.Handler) -> None: 138 | if self.__cache_loggers: 139 | self._handlers.append(handler) 140 | else: 141 | self.handlers.append(handler) 142 | for l in self.__cache_loggers.values(): 143 | l.addHandler(handler) 144 | 145 | def remove_handler(self, handler: logging.Handler) -> None: 146 | if handler in self._handlers: 147 | self._handlers.remove(handler) 148 | for l in self.__cache_loggers.values(): 149 | l.removeHandler(handler) 150 | 151 | def set_handler(self, handler: logging.Handler) -> None: 152 | self._handlers.clear() 153 | self._handlers.append(handler) 154 | for l in self.__cache_loggers.values(): 155 | for hdlr in l.handlers: 156 | l.removeHandler(hdlr) 157 | l.addHandler(handler) 158 | 159 | def get_logger(self, tag: str = "naja_atra") -> logging.Logger: 160 | if tag not in self.__cache_loggers: 161 | self.__cache_loggers[tag] = CachingLogger(tag, self._log_level) 162 | for hdlr in self.handlers: 163 | self.__cache_loggers[tag].addHandler(hdlr) 164 | return self.__cache_loggers[tag] 165 | 166 | 167 | _default_logger_factory: LoggerFactory = LoggerFactory() 168 | 169 | _logger_factories: Dict[str, LoggerFactory] = {} 170 | 171 | 172 | def get_logger_factory(tag: str = "") -> LoggerFactory: 173 | if not tag: 174 | return _default_logger_factory 175 | if tag not in _logger_factories: 176 | _logger_factories[tag] = LoggerFactory() 177 | return _logger_factories[tag] 178 | 179 | 180 | def set_level(level) -> None: 181 | _default_logger_factory.log_level = level 182 | 183 | 184 | def add_handler(handler: logging.Handler) -> None: 185 | _default_logger_factory.add_handler(handler) 186 | 187 | 188 | def remove_handler(handler: logging.Handler) -> None: 189 | _default_logger_factory.remove_handler(handler) 190 | 191 | 192 | def set_handler(handler: logging.Handler) -> None: 193 | _default_logger_factory.set_handler(handler) 194 | 195 | 196 | def get_logger(tag: str = "naja_atra", factory: str = "") -> logging.Logger: 197 | return get_logger_factory(factory).get_logger(tag) 198 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=65", 4 | "wheel" 5 | ] 6 | build-backend = "setuptools.build_meta" 7 | 8 | [project] 9 | name = "naja-atra" 10 | description = "This is a simple http server, use MVC like design." 11 | readme = "README.md" 12 | authors = [ 13 | { name = "keijack", email = "keijack.wu@gmail.com" } 14 | ] 15 | requires-python = ">=3.7" 16 | keywords = ["http-server", "websocket", "http", "web", "web-server"] 17 | license = {file = "LICENSE"} 18 | classifiers = [ 19 | "Programming Language :: Python :: 3", 20 | "License :: OSI Approved :: MIT License", 21 | "Operating System :: OS Independent", 22 | ] 23 | dynamic = ["version"] 24 | 25 | [project.optional-dependencies] 26 | test = ["websocket-client", "pytest"] 27 | dev = ["websocket-client"] 28 | 29 | [tool.setuptools.packages.find] 30 | include=["naja_atra*"] 31 | 32 | [tool.setuptools.dynamic] 33 | version = {attr = "naja_atra.version"} 34 | 35 | [project.urls] 36 | homepage = "https://github.com/naja-atra/naja-atra" 37 | repository = "https://github.com/naja-atra/naja-atra" -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from setuptools import setup 3 | 4 | setup() 5 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keijack/python-simple-http-server/3bb7859bc849d790b5e8b48c6cfb7680ba96a77c/tests/__init__.py -------------------------------------------------------------------------------- /tests/ctrls/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keijack/python-simple-http-server/3bb7859bc849d790b5e8b48c6cfb7680ba96a77c/tests/ctrls/__init__.py -------------------------------------------------------------------------------- /tests/ctrls/my_controllers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | import time 5 | from typing import List, OrderedDict 6 | 7 | from naja_atra import BytesBody, FilterContext, ModelDict, Redirect, RegGroup, RequestBodyReader, request_filter 8 | from naja_atra import Headers 9 | from naja_atra import HttpError 10 | from naja_atra import JSONBody 11 | from naja_atra import Header 12 | from naja_atra import Parameters 13 | from naja_atra import Cookie 14 | from naja_atra import Cookies 15 | from naja_atra import PathValue 16 | from naja_atra import Parameter 17 | from naja_atra import MultipartFile 18 | from naja_atra import Response 19 | from naja_atra import Request 20 | from naja_atra import HttpSession 21 | from naja_atra import request_map, route 22 | from naja_atra import controller 23 | from naja_atra import error_message 24 | from naja_atra.app_conf import get_app_conf 25 | import os 26 | import naja_atra.utils.logger as logger 27 | 28 | 29 | _logger = logger.get_logger("my_test_main") 30 | 31 | 32 | _logger = logger.get_logger("controller") 33 | 34 | _app = get_app_conf("2") 35 | 36 | 37 | @request_map("/") 38 | @request_map("/index") 39 | @_app.route("/") 40 | def my_ctrl(): 41 | # You can return a dictionary, a string or a `simple_http_server.simple_http_server.Response` object. 42 | return {"code": 0, "message": "success"} 43 | 44 | 45 | @request_map("/say_hello", method=["GET", "POST"]) 46 | def my_ctrl2(name, name2=Parameter("name", default="KEIJACK")): 47 | """name and name2 is the same""" 48 | return f"hello, {name}, {name2}" 49 | 50 | 51 | @request_map("/error") 52 | def my_ctrl3(): 53 | raise HttpError(400, "Parameter Error!", "Test Parameter Error!") 54 | 55 | 56 | @request_map("/sleep") 57 | def sleep_secs(secs: int = 10): 58 | _logger.info(f"Sleep {secs} secondes...") 59 | time.sleep(secs) 60 | return { 61 | "message": "OK" 62 | } 63 | 64 | 65 | async def say(sth: str = ""): 66 | _logger.info(f"Say: {sth}") 67 | return f"Success! {sth}" 68 | 69 | 70 | @request_map("/中文/coroutine") 71 | async def coroutine_ctrl(hey: str = "Hey!"): 72 | # raise RuntimeError 73 | return await say(hey) 74 | 75 | 76 | @request_map("/exception") 77 | def exception_ctrl(): 78 | raise Exception("some error occurs!") 79 | 80 | 81 | @request_map("/upload", method="POST") 82 | def my_upload(img=MultipartFile("img"), txt=Parameter("中文text", required=False, default="DEFAULT"), req=Request()): 83 | for k, v in req.parameter.items(): 84 | print("%s (%s)====> %s " % (k, str(type(k)), v)) 85 | print(txt) 86 | 87 | root = os.path.dirname(os.path.abspath(__file__)) 88 | img.save_to_file(root + "/imgs/" + img.filename) 89 | return f"upload ok! {txt} " 90 | 91 | 92 | @request_map("/post_txt", method=["GET", "POST"]) 93 | def normal_form_post(txt=Parameter("中文txt", required=False, default="DEFAULT"), req=Request(), bd=BytesBody()): 94 | for k, v in req.parameter.items(): 95 | print("%s ====> %s " % (k, v)) 96 | print(req.body) 97 | print(bd) 98 | return f"hi, {txt}" 99 | 100 | 101 | @request_map("/post_json", method="POST") 102 | def post_json(host=Header("Host"), json=JSONBody()): 103 | print(f"Host: {host}") 104 | print(json) 105 | return OrderedDict(json) 106 | 107 | 108 | @request_map("/cookies") 109 | def set_headers(res: Response, headers: Headers, cookies: Cookies, cookie=Cookie("sc")): 110 | print("==================cookies==========") 111 | print(cookies) 112 | print("==================cookies==========") 113 | print(cookie) 114 | res.add_header( 115 | "Set-Cookie", "sc=keijack; Expires=Web, 31 Oct 2018 00:00:00 GMT;") 116 | res.add_header("Set-Cookie", "sc=keijack2;") 117 | res.body = "OK!" 118 | 119 | 120 | @request_map("tuple") 121 | async def tuple_results(): 122 | return 200, Headers({"MyHeader": "my header"}), "hello tuple result!" 123 | 124 | 125 | @request_map("session") 126 | def test_session(session: HttpSession, invalid=False): 127 | ins = session.get_attribute("in-session") 128 | if not ins: 129 | session.set_attribute("in-session", "Hello, Session!") 130 | 131 | _logger.info("session id: %s" % session.id) 132 | if invalid: 133 | _logger.info("session[%s] is being invalidated. " % session.id) 134 | session.invalidate() 135 | return "%s" % str(ins) 136 | 137 | 138 | @request_map("tuple_cookie") 139 | def tuple_with_cookies(headers=Headers(), all_cookies=Cookies(), cookie_sc=Cookie("sc")): 140 | print("=====>headers") 141 | print(headers) 142 | print("=====> cookies ") 143 | print(all_cookies) 144 | print("=====> cookie sc ") 145 | print(cookie_sc) 146 | print("======<") 147 | import datetime 148 | expires = datetime.datetime(2018, 12, 31) 149 | 150 | cks = Cookies() 151 | # cks = cookies.SimpleCookie() # you could also use the build-in cookie objects 152 | cks["ck1"] = "keijack" 153 | cks["ck1"]["path"] = "/" 154 | cks["ck1"]["expires"] = expires.strftime(Cookies.EXPIRE_DATE_FORMAT) 155 | 156 | return 200, Header({"xx": "yyy"}), cks, "OK" 157 | 158 | 159 | @request_map("header_echo") 160 | def header_echo(headers: Headers): 161 | return 200, headers, "" 162 | 163 | 164 | @request_filter("/abcde/**") 165 | def fil(ctx: FilterContext): 166 | print("---------- through filter ---------------") 167 | ctx.do_chain() 168 | 169 | 170 | @request_filter(regexp="^/abcd") 171 | @request_filter("/tuple") 172 | async def filter_tuple(ctx: FilterContext): 173 | print("---------- through filter async ---------------") 174 | # add a header to request header 175 | ctx.request.headers["filter-set"] = "through filter" 176 | if "user_name" not in ctx.request.parameter: 177 | ctx.response.send_redirect("/index") 178 | elif "pass" not in ctx.request.parameter: 179 | ctx.response.send_error(400, "pass should be passed") 180 | # you can also raise a HttpError 181 | # raise HttpError(400, "pass should be passed") 182 | else: 183 | # you should always use do_chain method to go to the next 184 | res: Response = ctx.response 185 | res.add_header("Access-Control-Allow-Origin", "*") 186 | res.add_header("Access-Control-Allow-Methods", "*") 187 | res.add_header("Access-Control-Allow-Headers", "*") 188 | res.add_header("Res-Filter-Header", "from-filter") 189 | ctx.do_chain() 190 | 191 | 192 | @request_map("/redirect") 193 | def redirect(): 194 | return Redirect("/index") 195 | 196 | 197 | @request_map("/params") 198 | def my_ctrl4(user_name, 199 | password=Parameter(name="passwd", required=True), 200 | remember_me=True, 201 | locations=[], 202 | json_param={}, 203 | lcs=Parameters(name="locals", required=True), 204 | content_type=Header("Content-Type", default="application/json"), 205 | connection=Header("Connection"), 206 | ua=Header("User-Agent"), 207 | headers=Headers() 208 | ): 209 | return f""" 210 | 211 | 212 | show all params! 213 | 214 | 215 |

user_name: {user_name}

216 |

password: {password}

217 |

remember_me: {remember_me}

218 |

locations: {locations}

219 |

json_param: {json_param}

220 |

locals: {lcs}

221 |

conent_type: {content_type}

222 |

user_agent: {ua}

223 |

connection: {connection}

224 |

all Headers: {headers}

225 | 226 | 227 | """ 228 | 229 | 230 | @request_map("/int_status_code") 231 | def return_int(status_code=200): 232 | return status_code 233 | 234 | 235 | @request_map("/path_values/{pval}/{path_val}/x") 236 | def my_path_val_ctr(pval: PathValue, path_val=PathValue()): 237 | return f"{pval}, {path_val}" 238 | 239 | 240 | @controller(args=["my-ctr"], kwargs={"desc": "desc"}) 241 | @route("/obj") 242 | class MyController: 243 | 244 | def __init__(self, name, desc="") -> None: 245 | self._name = f"ctr object[#{name}]:{desc}" 246 | 247 | @route("/hello", method="GET, POST") 248 | @request_map 249 | def my_ctrl_mth(self, model: ModelDict): 250 | return {"message": f"hello, {model['name']}, {self._name} says. "} 251 | 252 | @request_map("/hello2", method=("GET", "POST")) 253 | def my_ctr_mth2(self, name: str, i: List[int]): 254 | return f"{self._name}{self._name}: {name}, {i}" 255 | 256 | @route(regexp="^(reg/(.+))$", method="GET") 257 | def my_reg_ctr(self, reg_group: RegGroup = RegGroup(1)): 258 | return f"{self._name}, {reg_group.group},{reg_group}" 259 | 260 | 261 | @error_message("400") 262 | def my_40x_page(message: str, explain=""): 263 | return f"code:400, message: {message}, explain: {explain}" 264 | 265 | 266 | @error_message 267 | def my_other_error_page(code, message, explain=""): 268 | return f"{code}-{message}-{explain}" 269 | 270 | 271 | @request_map("abcde/**") 272 | def wildcard_match(path_val=PathValue()): 273 | return f"path values{path_val}" 274 | 275 | 276 | """ 277 | curl -X PUT --data-binary "@/data1/clamav/scan/trojans/000.exe" \ 278 | -H "Content-Type: application/octet-stream" \ 279 | http://10.0.2.16:9090/put/file 280 | """ 281 | 282 | 283 | @request_map("/put/file", method="PUT") 284 | async def reader_test( 285 | content_type: Header = Header("Content-Type"), 286 | reader: RequestBodyReader = None): 287 | buf = 1024 * 1024 288 | folder = os.path.dirname(os.path.abspath(__file__)) + "/tmp" 289 | if not os.path.isdir(folder): 290 | os.mkdir(folder) 291 | _logger.info(f"content-type:: {content_type}") 292 | with open(f"{folder}/target_file", "wb") as outfile: 293 | while True: 294 | _logger.info("read file") 295 | data = await reader.read(buf) 296 | _logger.info(f"read data {len(data)} and write") 297 | if data == b'': 298 | break 299 | outfile.write(data) 300 | return None 301 | 302 | 303 | @route("/res/write/bytes") 304 | def res_writer(response: Response): 305 | response.status_code = 200 306 | response.add_header("Content-Type", "application/octet-stream") 307 | response.write_bytes(b'abcd') 308 | response.write_bytes(bytearray(b'efg')) 309 | response.close() 310 | 311 | 312 | @route("/param/narrowing", params="a=b") 313 | def params_narrowing(): 314 | return "a=b" 315 | 316 | 317 | @route("/param/narrowing", params="a!=b") 318 | def params_narrowing2(): 319 | return "a!=b" 320 | 321 | 322 | @controller 323 | @request_map(url="/page", params=("a=b", )) 324 | class IndexPage: 325 | 326 | @request_map("/index", method='GET', params="x=y", match_all_params_expressions=False) 327 | def index_page(self): 328 | return "你好你好,世界!" 329 | 330 | 331 | @route(url="/header_narrowing", method="POST", headers="Content-Type^=text/") 332 | def header_narrowing(): 333 | return "a^=b" 334 | -------------------------------------------------------------------------------- /tests/ctrls/my_controllers_model_binding.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Any 3 | from naja_atra.request_handlers.model_bindings import ModelBinding 4 | from naja_atra import model_binding, default_model_binding 5 | from naja_atra import HttpError, route 6 | from naja_atra.utils.logger import get_logger 7 | 8 | _logger = get_logger("my_ctrls_model_bindings") 9 | 10 | 11 | class Person: 12 | 13 | def __init__(self, name: str = "", sex: int = "", age: int = 0) -> None: 14 | self.name = name 15 | self.sex = sex 16 | self.age = age 17 | 18 | 19 | class Dog: 20 | 21 | def __init__(self, name="a dog") -> None: 22 | self.name = name 23 | 24 | def wang(self): 25 | name = self.name 26 | if hasattr(self, "name"): 27 | name = self.name 28 | _logger.info(f"wang, I'm {name}, with attrs {self.__dict__}") 29 | return name 30 | 31 | 32 | @model_binding(Person) 33 | class PersonModelBinding(ModelBinding): 34 | 35 | async def bind(self) -> Any: 36 | name = self.request.get_parameter("name", "no-one") 37 | sex = self.request.get_parameter("sex", "secret") 38 | try: 39 | age = int(self.request.get_parameter("age", "")) 40 | except: 41 | raise HttpError(400, "Age is required, and must be an integer") 42 | return Person(name, sex, age) 43 | 44 | 45 | @default_model_binding 46 | class SetAttrModelBinding(ModelBinding): 47 | 48 | def bind(self) -> Any: 49 | try: 50 | obj = self.arg_type() 51 | for k, v in self.request.parameter.items(): 52 | setattr(obj, k, v) 53 | return obj 54 | except Exception as e: 55 | _logger.warning( 56 | f"Cannot create Object with given type {self.arg_type}. ", stack_info=True) 57 | return self.default_value 58 | 59 | 60 | @route("/model_binding/person") 61 | def test_model_binding(person: Person): 62 | return { 63 | "name": person.name, 64 | "sex": person.sex, 65 | "age": person.age, 66 | } 67 | 68 | 69 | @route("/model_binding/dog") 70 | def test_model_binding_dog(dog: Dog): 71 | return { 72 | "name": dog.wang() 73 | } 74 | -------------------------------------------------------------------------------- /tests/ctrls/ws_controllers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | import os 5 | import shutil 6 | from uuid import uuid4 7 | from naja_atra import WebsocketCloseReason, WebsocketHandler, WebsocketRequest, WebsocketSession, websocket_handler, websocket_message, websocket_handshake, websocket_open, websocket_close, WEBSOCKET_MESSAGE_TEXT 8 | import naja_atra.utils.logger as logger 9 | 10 | _logger = logger.get_logger("ws_test") 11 | 12 | 13 | @websocket_handler(endpoint="/ws/{path_val}") 14 | class WSHandler(WebsocketHandler): 15 | 16 | def __init__(self) -> None: 17 | self.uuid: str = uuid4().hex 18 | 19 | def on_handshake(self, request: WebsocketRequest): 20 | return 0, {} 21 | 22 | async def on_open(self, session: WebsocketSession): 23 | 24 | _logger.info(f">>{session.id}<< open! {session.request.path_values}") 25 | 26 | def on_text_message(self, session: WebsocketSession, message: str): 27 | _logger.info( 28 | f">>{session.id}::{self.uuid}<< on text message: {message}") 29 | session.send(f"{session.request.path_values['path_val']}-{message}") 30 | if message == "close": 31 | session.close() 32 | 33 | def on_close(self, session: WebsocketSession, reason: WebsocketCloseReason): 34 | _logger.info( 35 | f">>{session.id}<< close::{reason.message}-{reason.code}-{reason.reason}") 36 | 37 | def on_binary_frame(self, session: WebsocketSession = None, fin: bool = False, frame_data: bytes = b''): 38 | _logger.info(f"Fin => {fin}, Data: {frame_data}") 39 | return True 40 | 41 | def on_binary_message(self, session: WebsocketSession = None, message: bytes = b''): 42 | _logger.info(f'Binary Message:: {message}') 43 | tmp_folder = os.path.dirname(os.path.abspath(__file__)) + "/tmp" 44 | is_dir_tmp_folder = os.path.isdir(tmp_folder) 45 | if not is_dir_tmp_folder: 46 | os.mkdir(tmp_folder) 47 | session.send( 48 | 'binary-message-received, and this is some message for the long size.', chunk_size=10) 49 | tmp_file_path = f"{tmp_folder}/ws_bi_re.tmp" 50 | with open(tmp_file_path, 'wb') as out_file: 51 | out_file.write(message) 52 | 53 | session.send_file(tmp_file_path) 54 | session.send_file(tmp_file_path, chunk_size=10) 55 | if is_dir_tmp_folder: 56 | os.remove(tmp_file_path) 57 | else: 58 | shutil.rmtree(tmp_folder) 59 | 60 | 61 | @websocket_handler(regexp="^/ws-reg/([a-zA-Z0-9]+)$", singleton=False) 62 | class WSRegHander(WebsocketHandler): 63 | 64 | def __init__(self) -> None: 65 | self.uuid: str = uuid4().hex 66 | 67 | def on_text_message(self, session: WebsocketSession, message: str): 68 | _logger.info( 69 | f">>{session.id}::{self.uuid}<< on text message: {message}") 70 | _logger.info(f"{session.request.reg_groups}") 71 | session.send(f"{session.request.reg_groups[0]}-{message}") 72 | 73 | 74 | @websocket_handshake(endpoint="/ws-fun/{path_val}") 75 | def ws_handshake(request: WebsocketRequest): 76 | return 0, {} 77 | 78 | 79 | @websocket_open(endpoint="/ws-fun/{path_val}") 80 | def ws_open(session: WebsocketSession): 81 | _logger.info(f">>{session.id}<< open! {session.request.path_values}") 82 | 83 | 84 | @websocket_close(endpoint="/ws-fun/{path_val}") 85 | def ws_close(session: WebsocketSession, reason: WebsocketCloseReason): 86 | _logger.info( 87 | f">>{session.id}<< close::{reason.message}-{reason.code}-{reason.reason}") 88 | 89 | 90 | @websocket_message(endpoint="/ws-fun/{path_val}", message_type=WEBSOCKET_MESSAGE_TEXT) 91 | async def ws_text(session: WebsocketSession, message: str): 92 | _logger.info(f">>{session.id}<< on text message: {message}") 93 | session.send(f"{session.request.path_values['path_val']}-{message}") 94 | if message == "close": 95 | session.close() 96 | -------------------------------------------------------------------------------- /tests/static/a.txt: -------------------------------------------------------------------------------- 1 | hello world! -------------------------------------------------------------------------------- /tests/static/inner/b.txt: -------------------------------------------------------------------------------- 1 | hello, inner! -------------------------------------------------------------------------------- /tests/static/inner/y.ini: -------------------------------------------------------------------------------- 1 | [my] 2 | hello = inner -------------------------------------------------------------------------------- /tests/static/x.ini: -------------------------------------------------------------------------------- 1 | [conf] 2 | hello = world -------------------------------------------------------------------------------- /tests/static/中文.txt: -------------------------------------------------------------------------------- 1 | hello world! -------------------------------------------------------------------------------- /tests/test_all_ctrls.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from gzip import GzipFile 4 | import os 5 | import json 6 | import websocket 7 | import unittest 8 | import urllib.request 9 | import urllib.error 10 | import http.client 11 | from typing import Dict 12 | from threading import Thread 13 | from time import sleep 14 | 15 | from naja_atra.utils.logger import get_logger, set_level 16 | import naja_atra.server as server 17 | 18 | set_level("DEBUG") 19 | 20 | _logger = get_logger("http_test") 21 | 22 | 23 | class ThreadingServerTest(unittest.TestCase): 24 | 25 | PORT = 9090 26 | 27 | WAIT_COUNT = 10 28 | 29 | COROUTINE = False 30 | 31 | @classmethod 32 | def start_server(cls): 33 | cls.tearDownClass() 34 | _logger.info("start server in background. ") 35 | root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 36 | server.scan(project_dir=root, base_dir="tests/ctrls", 37 | regx=r'.*controllers.*') 38 | server.start( 39 | port=cls.PORT, 40 | resources={"/public/*": f"{root}/tests/static"}, 41 | gzip_content_types={"text/plain"}, 42 | prefer_coroutine=cls.COROUTINE) 43 | 44 | @classmethod 45 | def setUpClass(cls): 46 | Thread(target=cls.start_server, daemon=True, name="t").start() 47 | retry = 0 48 | while not server.is_ready(): 49 | sleep(1) 50 | retry = retry + 1 51 | _logger.info( 52 | f"server is not ready wait. {retry}/{cls.WAIT_COUNT} ") 53 | if retry >= cls.WAIT_COUNT: 54 | raise Exception("Server start wait timeout.") 55 | 56 | @classmethod 57 | def tearDownClass(cls): 58 | try: 59 | server.stop() 60 | except: 61 | pass 62 | 63 | @classmethod 64 | def visit(cls, ctx_path, headers: Dict[str, str] = {}, data=None, return_type: str = "TEXT"): 65 | req: urllib.request.Request = urllib.request.Request( 66 | f"http://127.0.0.1:{cls.PORT}/{ctx_path}") 67 | for k, v in headers.items(): 68 | req.add_header(k, v) 69 | res: http.client.HTTPResponse = urllib.request.urlopen(req, data=data) 70 | 71 | if return_type == "RESPONSE": 72 | return res 73 | elif return_type == "HEADERS": 74 | headers = res.headers 75 | res.close() 76 | return headers 77 | elif return_type == "JSON": 78 | txt = res.read().decode("utf-8") 79 | res.close() 80 | return json.loads(txt) 81 | else: 82 | txt = res.read().decode("utf-8") 83 | res.close() 84 | return txt 85 | 86 | def test_header_echo(self): 87 | res: http.client.HTTPResponse = self.visit( 88 | f"header_echo", headers={"X-KJ-ABC": "my-headers"}, return_type="RESPONSE") 89 | assert "X-Kj-Abc" in res.headers 90 | assert res.headers["X-Kj-Abc"] == "my-headers" 91 | 92 | def test_static(self): 93 | txt = self.visit("public/a.txt") 94 | assert txt == "hello world!" 95 | 96 | def test_gzip(self): 97 | res: http.client.HTTPResponse = self.visit( 98 | f"public/a.txt", headers={"Accept-Encoding": "gzip ,deflate"}, return_type="RESPONSE") 99 | content_encoding = res.info().get("Content-Encoding") 100 | assert "gzip" == content_encoding 101 | f = GzipFile(fileobj=res) 102 | txt = f.read().decode() 103 | assert txt == "hello world!" 104 | 105 | def test_path_value(self): 106 | pval = "abc" 107 | path_val = "xyz" 108 | txt = self.visit(f"path_values/{pval}/{path_val}/x") 109 | assert txt == f"{pval}, {path_val}" 110 | 111 | def test_error(self): 112 | try: 113 | self.visit("error") 114 | except urllib.error.HTTPError as err: 115 | assert err.code == 400 116 | error_msg = err.read().decode("utf-8") 117 | _logger.info(error_msg) 118 | assert error_msg == "code:400, message: Parameter Error!, explain: Test Parameter Error!" 119 | 120 | def test_coroutine(self): 121 | txt = self.visit(f"%E4%B8%AD%E6%96%87/coroutine?hey=KJ2") 122 | assert txt == "Success! KJ2" 123 | 124 | def test_post_json(self): 125 | data_dict = { 126 | "code": 0, 127 | "msg": "xxx" 128 | } 129 | res: str = self.visit(f"post_json", headers={ 130 | "Content-Type": "application/json"}, data=json.dumps(data_dict).encode(errors="replace")) 131 | res_dict: dict = json.loads(res) 132 | assert data_dict["code"] == res_dict["code"] 133 | assert data_dict["msg"] == res_dict["msg"] 134 | 135 | def test_filter(self): 136 | res: http.client.HTTPResponse = self.visit( 137 | f"tuple?user_name=kj&pass=wu", return_type="RESPONSE") 138 | assert "Res-Filter-Header" in res.headers 139 | assert res.headers["Res-Filter-Header"] == "from-filter" 140 | 141 | def test_exception(self): 142 | try: 143 | self.visit("exception") 144 | except urllib.error.HTTPError as err: 145 | assert err.code == 500 146 | error_msg = err.read().decode("utf-8") 147 | _logger.info(error_msg) 148 | assert error_msg == '500-Internal Server Error-some error occurs!' 149 | 150 | def test_res_write_bytes(self): 151 | body = self.visit("res/write/bytes") 152 | assert body == 'abcdefg' 153 | 154 | def test_ws(self): 155 | ws = websocket.WebSocket() 156 | path_val = "test-ws" 157 | msg = "hello websocket!" 158 | ws.connect(f"ws://127.0.0.1:{self.PORT}/ws/{path_val}") 159 | ws.send(msg) 160 | txt = ws.recv() 161 | ws.close() 162 | assert txt == f"{path_val}-{msg}" 163 | 164 | def test_ws_fun(self): 165 | ws = websocket.WebSocket() 166 | path_val = "test-ws-fun" 167 | msg = "hello websocket!" 168 | ws.connect(f"ws://127.0.0.1:{self.PORT}/ws-fun/{path_val}") 169 | ws.send(msg) 170 | txt = ws.recv() 171 | ws.close() 172 | assert txt == f"{path_val}-{msg}" 173 | 174 | def test_ws_continuation(self): 175 | ws = websocket.WebSocket() 176 | path_val = "test-ws" 177 | 178 | ws.connect(f"ws://127.0.0.1:{self.PORT}/ws/{path_val}") 179 | msg0 = "Hello " 180 | frame0 = websocket.ABNF.create_frame( 181 | msg0, websocket.ABNF.OPCODE_TEXT, 0) 182 | ws.send_frame(frame0) 183 | msg1 = "Websocket " 184 | frame1 = websocket.ABNF.create_frame( 185 | msg1, websocket.ABNF.OPCODE_CONT, 0) 186 | ws.send_frame(frame1) 187 | msg2 = "Frames!" 188 | frame2 = websocket.ABNF.create_frame( 189 | msg2, websocket.ABNF.OPCODE_CONT, 1) 190 | ws.send_frame(frame2) 191 | 192 | txt = ws.recv() 193 | ws.close() 194 | assert txt == f"{path_val}-{msg0 + msg1 + msg2}" 195 | 196 | def test_ws_bytes_continuation(self): 197 | ws = websocket.WebSocket() 198 | path_val = "test-ws" 199 | 200 | ws.connect(f"ws://127.0.0.1:{self.PORT}/ws/{path_val}") 201 | msg0 = "Hello " 202 | frame0 = websocket.ABNF.create_frame( 203 | msg0, websocket.ABNF.OPCODE_BINARY, 0) 204 | ws.send_frame(frame0) 205 | msg1 = "Websocket " 206 | frame1 = websocket.ABNF.create_frame( 207 | msg1, websocket.ABNF.OPCODE_CONT, 0) 208 | ws.send_frame(frame1) 209 | msg2 = "Frames!" 210 | frame2 = websocket.ABNF.create_frame( 211 | msg2, websocket.ABNF.OPCODE_CONT, 1) 212 | ws.send_frame(frame2) 213 | 214 | txt: str = ws.recv() 215 | bs: bytes = ws.recv() 216 | bs2: bytes = ws.recv() 217 | ws.close() 218 | assert txt == "binary-message-received, and this is some message for the long size." 219 | assert bs.decode() == bs2.decode() == msg0 + msg1 + msg2 220 | 221 | def test_ws_regexp(self): 222 | ws = websocket.WebSocket() 223 | path_val = "wstest" 224 | msg = 'hello, reg' 225 | 226 | ws.connect(f"ws://127.0.0.1:{self.PORT}/ws-reg/{path_val}") 227 | ws.send(msg) 228 | 229 | txt: str = ws.recv() 230 | print(txt) 231 | ws.close() 232 | assert txt == f"{path_val}-{msg}" 233 | 234 | def test_params_narrowing(self): 235 | body = self.visit("param/narrowing?a=b") 236 | assert body == 'a=b' 237 | body = self.visit("param/narrowing?a=c") 238 | assert body == 'a!=b' 239 | 240 | def test_model_binding(self): 241 | name = "keijack" 242 | sex = "male" 243 | age = 18 244 | res: Dict = self.visit( 245 | f"model_binding/person?name={name}&sex={sex}&age={age}", return_type="JSON") 246 | assert res["name"] == name 247 | assert res["sex"] == sex 248 | assert res["age"] == age 249 | 250 | 251 | class CoroutineServerTest(ThreadingServerTest): 252 | 253 | PORT = 9091 254 | 255 | COROUTINE = True 256 | -------------------------------------------------------------------------------- /upload.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | rm dist/* 4 | 5 | python3 -m build 6 | 7 | python3 -m twine upload dist/* --------------------------------------------------------------------------------