├── .gitignore
├── LICENSE
├── README.md
├── README.zh_cn.md
├── debug_main.py
├── docs
└── README.md
├── naja_atra
├── __init__.py
├── __main__.py
├── app_conf.py
├── http_servers
│ ├── __init__.py
│ ├── coroutine_http_server.py
│ ├── http_server.py
│ ├── routing_server.py
│ └── threading_http_server.py
├── models.py
├── request_handlers
│ ├── __init__.py
│ ├── http_controller_handler.py
│ ├── http_request_handler.py
│ ├── http_session_local_impl.py
│ ├── model_bindings.py
│ └── websocket_controller_handler.py
├── server.py
└── utils
│ ├── __init__.py
│ ├── http_utils.py
│ └── logger.py
├── pyproject.toml
├── setup.py
├── tests
├── __init__.py
├── ctrls
│ ├── __init__.py
│ ├── my_controllers.py
│ ├── my_controllers_model_binding.py
│ └── ws_controllers.py
├── static
│ ├── a.txt
│ ├── inner
│ │ ├── b.txt
│ │ └── y.ini
│ ├── x.ini
│ └── 中文.txt
└── test_all_ctrls.py
└── upload.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | v1/
94 |
95 | # Spyder project settings
96 | .spyderproject
97 | .spyproject
98 |
99 | # Rope project settings
100 | .ropeproject
101 |
102 | # mkdocs documentation
103 | /site
104 |
105 | # mypy
106 | .mypy_cache/
107 |
108 | # for development only
109 | tests/certs/
110 | temp/
111 |
112 | # vscode
113 | .vscode/
114 | **/tmp/**
115 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Keijack Wu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Naja-Atra
2 |
3 | As more functions are added to this project, and it's not that simple yet, so it's renamed to `naja-atra` to continue maintainance. You can find the new code [here](https://www.github.com/naja-atra/naja-atra), And also, you can find some new extensions there.
4 |
--------------------------------------------------------------------------------
/debug_main.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # If you want to run this file, please install following package to run.
4 | # python3 -m pip install werkzeug 'uvicorn[standard]'
5 | #
6 |
7 | from naja_atra import request_map
8 | import naja_atra.server as server
9 |
10 | import os
11 | import signal
12 |
13 |
14 | from naja_atra.http_servers.http_server import HttpServer
15 | from naja_atra.utils.logger import get_logger, set_level
16 |
17 | from naja_atra import get_app_conf
18 | set_level("DEBUG")
19 |
20 |
21 | @request_map("/stop")
22 | def stop():
23 | server.stop()
24 | return "
关闭关闭成功!"
25 |
26 |
27 | _logger = get_logger("http_test")
28 | PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
29 |
30 | _server: HttpServer = None
31 |
32 | app = get_app_conf("2")
33 |
34 |
35 | @app.route("/")
36 | def idx():
37 | return {"msg": "hello, world!"}
38 |
39 |
40 | def start_via_class():
41 | global _server
42 |
43 | _server = HttpServer(host=('', 9091),
44 | resources={"/p/**": f"{PROJECT_ROOT}/tests/static"},
45 | app_conf=app)
46 | _server.start()
47 |
48 |
49 | def start_server():
50 | _logger.info("start server in background. ")
51 |
52 | server.scan(base_dir="tests/ctrls", regx=r'.*controllers.*')
53 | server.start(
54 | # port=9443,
55 | port=9090,
56 | resources={"/public/*": f"{PROJECT_ROOT}/tests/static",
57 | "/*": f"{PROJECT_ROOT}/tests/static",
58 | '/inn/**': f"{PROJECT_ROOT}/tests/static",
59 | '**.txt': f"{PROJECT_ROOT}/tests/static",
60 | '*.ini': f"{PROJECT_ROOT}/tests/static",
61 | },
62 | # ssl=True,
63 | # certfile=f"{PROJECT_ROOT}/tests/certs/fullchain.pem",
64 | # keyfile=f"{PROJECT_ROOT}/tests/certs//privkey.pem",
65 | gzip_content_types={"image/x-icon", "text/plain"},
66 | gzip_compress_level=9,
67 | prefer_coroutine=False)
68 |
69 |
70 | def on_sig_term(signum, frame):
71 | server.stop()
72 | if _server:
73 | _server.shutdown()
74 |
75 |
76 | if __name__ == "__main__":
77 | signal.signal(signal.SIGTERM, on_sig_term)
78 | signal.signal(signal.SIGINT, on_sig_term)
79 | # Thread(target=start_via_class, daemon=True).start()
80 | # sleep(1)
81 | # start_via_class()
82 | # main(sys.argv[1:])
83 | start_server()
84 | # start_server_wsgi()
85 | # start_server_werkzeug()
86 | # start_server_uvicorn()
87 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | # python-simple-http-server
2 |
3 | [](https://badge.fury.io/py/simple-http-server)
4 |
5 | ## Discription
6 |
7 | This is a simple http server, use MVC like design.
8 |
9 | ## Support Python Version
10 |
11 | Python 3.7+
12 |
13 | ## Why choose
14 |
15 | * Lightway.
16 | * Functional programing.
17 | * Filter chain support.
18 | * Session support, and can support distributed session by [this extention](https://github.com/keijack/python-simple-http-server-redis-session).
19 | * You can use [this extention](https://github.com/keijack/python-simple-http-server-jinja) to support `jinja` views.
20 | * Spring MVC like request mapping.
21 | * SSL support.
22 | * Websocket support
23 | * Easy to use.
24 | * Free style controller writing.
25 | * Easily integraded with WSGI servers.
26 | * Easily integraded with ASGI servers. Websocket will be supported when ASGI server enable websocket functions.
27 | * Coroutine mode support.
28 |
29 | ## Dependencies
30 |
31 | There are no other dependencies needed to run this project. However, if you want to run the unitests in the `tests` folder, you need to install `websocket` via pip:
32 |
33 | ```shell
34 | python3 -m pip install websocket-client
35 | ```
36 |
37 | ## How to use
38 |
39 | ### Install
40 |
41 | ```shell
42 | python3 -m pip install simple_http_server
43 | ```
44 |
45 | ### Minimum code / component requirement setup
46 |
47 | Minimum code to get things started should have at least one controller function,
48 | using the route and server modules from simple_http_server
49 |
50 | ```python
51 | from simple_http_server import route, server
52 |
53 | @route("/")
54 | def index():
55 | return {"hello": "world"}
56 |
57 | server.start(port=9090)
58 | ```
59 |
60 | ### Write Controllers
61 |
62 | ```python
63 |
64 | from simple_http_server import request_map
65 | from simple_http_server import Response
66 | from simple_http_server import MultipartFile
67 | from simple_http_server import Parameter
68 | from simple_http_server import Parameters
69 | from simple_http_server import Header
70 | from simple_http_server import JSONBody
71 | from simple_http_server import HttpError
72 | from simple_http_server import StaticFile
73 | from simple_http_server import Headers
74 | from simple_http_server import Cookies
75 | from simple_http_server import Cookie
76 | from simple_http_server import Redirect
77 | from simple_http_server import ModelDict
78 |
79 | # request_map has an alias name `route`, you can select the one you familiar with.
80 | @request_map("/index")
81 | def my_ctrl():
82 | return {"code": 0, "message": "success"} # You can return a dictionary, a string or a `simple_http_server.simple_http_server.Response` object.
83 |
84 |
85 | @route("/say_hello", method=["GET", "POST"])
86 | def my_ctrl2(name, name2=Parameter("name", default="KEIJACK"), model=ModelDict()):
87 | """name and name2 is the same"""
88 | name == name2 # True
89 | name == model["name"] # True
90 | return "hello, %s, %s" % (name, name2)
91 |
92 |
93 | @request_map("/error")
94 | def my_ctrl3():
95 | return Response(status_code=500)
96 |
97 |
98 | @request_map("/exception")
99 | def exception_ctrl():
100 | raise HttpError(400, "Exception")
101 |
102 | @request_map("/upload", method="GET")
103 | def show_upload():
104 | root = os.path.dirname(os.path.abspath(__file__))
105 | return StaticFile("%s/my_dev/my_test_index.html" % root, "text/html; charset=utf-8")
106 |
107 |
108 | @request_map("/upload", method="POST")
109 | def my_upload(img=MultipartFile("img")):
110 | root = os.path.dirname(os.path.abspath(__file__))
111 | img.save_to_file(root + "/my_dev/imgs/" + img.filename)
112 | return "upload ok!"
113 |
114 |
115 | @request_map("/post_txt", method="POST")
116 | def normal_form_post(txt):
117 | return "hi, %s" % txt
118 |
119 | @request_map("/tuple")
120 | def tuple_results():
121 | # The order here is not important, we consider the first `int` value as status code,
122 | # All `Headers` object will be sent to the response
123 | # And the first valid object whose type in (str, unicode, dict, StaticFile, bytes) will
124 | # be considered as the body
125 | return 200, Headers({"my-header": "headers"}), {"success": True}
126 |
127 | """
128 | " Cookie_sc will not be written to response. It's just some kind of default
129 | " value
130 | """
131 | @request_map("tuple_cookie")
132 | def tuple_with_cookies(all_cookies=Cookies(), cookie_sc=Cookie("sc")):
133 | print("=====> cookies ")
134 | print(all_cookies)
135 | print("=====> cookie sc ")
136 | print(cookie_sc)
137 | print("======<")
138 | import datetime
139 | expires = datetime.datetime(2018, 12, 31)
140 |
141 | cks = Cookies()
142 | # cks = cookies.SimpleCookie() # you could also use the build-in cookie objects
143 | cks["ck1"] = "keijack"request
144 | cks["ck1"]["path"] = "/"
145 | cks["ck1"]["expires"] = expires.strftime(Cookies.EXPIRE_DATE_FORMAT)
146 | # You can ignore status code, headers, cookies even body in this tuple.
147 | return Header({"xx": "yyy"}), cks, "OK"
148 |
149 | """
150 | " If you visit /a/b/xyz/x,this controller function will be called, and `path_val` will be `xyz`
151 | """
152 | @request_map("/a/b/{path_val}/x")
153 | def my_path_val_ctr(path_val=PathValue()):
154 | return f"{path_val}"
155 |
156 | @request_map("/star/*") # /star/c will find this controller, but /star/c/d not.
157 | @request_map("*/star") # /c/star will find this controller, but /c/d/star not.
158 | def star_path(path_val=PathValue()):
159 | return f"{path_val}"
160 |
161 | @request_map("/star/**") # Both /star/c and /star/c/d will find this controller.
162 | @request_map("**/star") # Both /c/star and /c/d/stars will find this controller.
163 | def star_path(path_val=PathValue()):
164 | return f"{path_val}"
165 |
166 | @request_map("/redirect")
167 | def redirect():
168 | return Redirect("/index")
169 |
170 | @request_map("session")
171 | def test_session(session=Session(), invalid=False):
172 | ins = session.get_attribute("in-session")
173 | if not ins:
174 | session.set_attribute("in-session", "Hello, Session!")
175 |
176 | __logger.info("session id: %s" % session.id)
177 | if invalid:
178 | __logger.info("session[%s] is being invalidated. " % session.id)
179 | session.invalidate()
180 | return "%s" % str(ins)
181 |
182 | # use coroutine, these controller functions will work both in a coroutine mode or threading mode.
183 |
184 | async def say(sth: str = ""):
185 | _logger.info(f"Say: {sth}")
186 | return f"Success! {sth}"
187 |
188 | @request_map("/中文/coroutine")
189 | async def coroutine_ctrl(hey: str = "Hey!"):
190 | return await say(hey)
191 |
192 | @route("/res/write/bytes")
193 | def res_writer(response: Response):
194 | response.status_code = 200
195 | response.add_header("Content-Type", "application/octet-stream")
196 | response.write_bytes(b'abcd')
197 | response.write_bytes(bytearray(b'efg'))
198 | response.close()
199 | ```
200 |
201 | Beside using the default values, you can also use variable annotations to specify your controller function's variables.
202 |
203 | ```python
204 | @request_map("/say_hello/to/{name}", method=["GET", "POST", "PUT"])
205 | def your_ctroller_function(
206 | user_name: str, # req.parameter["user_name"],400 error will raise when there's no such parameter in the query string.
207 | password: str, # req.parameter["password"],400 error will raise when there's no such parameter in the query string.
208 | skills: list, # req.parameters["skills"],400 error will raise when there's no such parameter in the query string.
209 | all_headers: Headers, # req.headers
210 | user_token: Header, # req.headers["user_token"],400 error will raise when there's no such parameter in the quest headers.
211 | all_cookies: Cookies, # req.cookies, return all cookies
212 | user_info: Cookie, # req.cookies["user_info"],400 error will raise when there's no such parameter in the cookies.
213 | name: PathValue, # req.path_values["name"],get the {name} value from your path.
214 | session: Session # req.getSession(True),get the session, if there is no sessions, create one.
215 | ):
216 | return "Hello, World!"
217 |
218 | # you can use `params` to narrow the controller mapping, the following examples shows only the `params` mapping, ignoring the
219 | # `headers` examples for the usage is almost the same as the `params`.
220 | @request("/exact_params", method="GET", params="a=b")
221 | def exact_params(a: str):
222 | print(f"{a}") # b
223 | return {"result": "ok"}
224 |
225 | @request("/exact_params", method="GET", params="a!=b")
226 | def exact_not_params(a: str):
227 | print(f"{a}") # b
228 | return {"result": "ok"}
229 |
230 | @request("/exact_params", method="GET", params="a^=b")
231 | def exact_startwith_params(a: str):
232 | print(f"{a}") # b
233 | return {"result": "ok"}
234 |
235 | @request("/exact_params", method="GET", params="!a")
236 | def no_params():
237 | return {"result": "ok"}
238 |
239 | @request("/exact_params", method="GET", params="a")
240 | def must_has_params():
241 | return {"result": "ok"}
242 |
243 | # If multiple expressions are set, all expressions must be matched to enter this controller function.
244 | @request("/exact_params", method="GET", params=["a=b", "c!=d"])
245 | def multipul_params():
246 | return {"result": "ok"}
247 |
248 | # You can set `match_all_params_expressions` to False to make that the url can enter this controller function even only one expression is matched.
249 | @request("/exact_params", method="GET", params=["a=b", "c!=d"], match_all_params_expressions=False)
250 | def multipul_params():
251 | return {"result": "ok"}
252 | ```
253 |
254 | We recommend using functional programing to write controller functions. but if you realy want to use Object, you can use `@request_map` in a class method. For doing this, every time a new request comes, a new MyController object will be created.
255 |
256 | ```python
257 |
258 | class MyController:
259 |
260 | def __init__(self) -> None:
261 | self._name = "ctr object"
262 |
263 | @request_map("/obj/say_hello", method="GET")
264 | def my_ctrl_mth(self, name: str):
265 | return {"message": f"hello, {name}, {self._name} says. "}
266 |
267 | ```
268 |
269 | If you want a singleton, you can add a `@controller` decorator to the class.
270 |
271 | ```python
272 |
273 | @controller
274 | class MyController:
275 |
276 | def __init__(self) -> None:
277 | self._name = "ctr object"
278 |
279 | @request_map("/obj/say_hello", method="GET")
280 | def my_ctrl_mth(self, name: str):
281 | return {"message": f"hello, {name}, {self._name} says. "}
282 |
283 | ```
284 |
285 | You can also add the `@request_map` to your class, this will be as the part of the url.
286 |
287 | ```python
288 |
289 | @controller
290 | @request_map("/obj", method="GET")
291 | class MyController:
292 |
293 | def __init__(self) -> None:
294 | self._name = "ctr object"
295 |
296 | @request_map
297 | def my_ctrl_default_mth(self, name: str):
298 | return {"message": f"hello, {name}, {self._name} says. "}
299 |
300 | @request_map("/say_hello", method=("GET", "POST"))
301 | def my_ctrl_mth(self, name: str):
302 | return {"message": f"hello, {name}, {self._name} says. "}
303 |
304 | ```
305 |
306 | You can specify the `init` variables in `@controller` decorator.
307 |
308 | ```python
309 |
310 | @controller(args=["ctr_name"], kwargs={"desc": "this is a key word argument"})
311 | @request_map("/obj", method="GET")
312 | class MyController:
313 |
314 | def __init__(self, name, desc="") -> None:
315 | self._name = f"ctr[{name}] - {desc}"
316 |
317 | @request_map
318 | def my_ctrl_default_mth(self, name: str):
319 | return {"message": f"hello, {name}, {self._name} says. "}
320 |
321 | @request_map("/say_hello", method=("GET", "POST"))
322 | def my_ctrl_mth(self, name: str):
323 | return {"message": f"hello, {name}, {self._name} says. "}
324 |
325 | ```
326 |
327 | From `0.7.0`, `@request_map` support regular expression mapping.
328 |
329 | ```python
330 | # url `/reg/abcef/aref/xxx` can map the flowing controller:
331 | @route(regexp="^(reg/(.+))$", method="GET")
332 | def my_reg_ctr(reg_groups: RegGroups, reg_group: RegGroup = RegGroup(1)):
333 | print(reg_groups) # will output ("reg/abcef/aref/xxx", "abcef/aref/xxx")
334 | print(reg_group) # will output "abcef/aref/xxx"
335 | return f"{self._name}, {reg_group.group},{reg_group}"
336 | ```
337 | Regular expression mapping a class:
338 |
339 | ```python
340 | @controller(args=["ctr_name"], kwargs={"desc": "this is a key word argument"})
341 | @request_map("/obj", method="GET") # regexp do not work here, method will still available
342 | class MyController:
343 |
344 | def __init__(self, name, desc="") -> None:
345 | self._name = f"ctr[{name}] - {desc}"
346 |
347 | @request_map
348 | def my_ctrl_default_mth(self, name: str):
349 | return {"message": f"hello, {name}, {self._name} says. "}
350 |
351 | @route(regexp="^(reg/(.+))$") # prefix `/obj` from class decorator will be ignored, but `method`(GET in this example) from class decorator will still work.
352 | def my_ctrl_mth(self, name: str):
353 | return {"message": f"hello, {name}, {self._name} says. "}
354 |
355 | ```
356 |
357 | ### Session
358 |
359 | Defaultly, the session is stored in local, you can extend `SessionFactory` and `Session` classes to implement your own session storage requirement (like store all data in redis or memcache)
360 |
361 | ```python
362 | from simple_http_server import Session, SessionFactory, set_session_factory
363 |
364 | class MySessionImpl(Session):
365 |
366 | def __init__(self):
367 | super().__init__()
368 | # your own implementation
369 |
370 | @property
371 | def id(self) -> str:
372 | # your own implementation
373 |
374 | @property
375 | def creation_time(self) -> float:
376 | # your own implementation
377 |
378 | @property
379 | def last_accessed_time(self) -> float:
380 | # your own implementation
381 |
382 | @property
383 | def is_new(self) -> bool:
384 | # your own implementation
385 |
386 | @property
387 | def attribute_names(self) -> Tuple:
388 | # your own implementation
389 |
390 | def get_attribute(self, name: str) -> Any:
391 | # your own implementation
392 |
393 | def set_attribute(self, name: str, value: Any) -> None:
394 | # your own implementation
395 |
396 | def invalidate(self) -> None:
397 | # your own implementation
398 |
399 | class MySessionFacImpl(SessionFactory):
400 |
401 | def __init__(self):
402 | super().__init__()
403 | # your own implementation
404 |
405 |
406 | def get_session(self, session_id: str, create: bool = False) -> Session:
407 | # your own implementation
408 | return MySessionImpl()
409 |
410 | set_session_factory(MySessionFacImpl())
411 |
412 | ```
413 |
414 | There is an offical Redis implementation here: https://github.com/keijack/python-simple-http-server-redis-session.git
415 |
416 | ### Websocket
417 |
418 | To handle a websocket session, you should handle multiple events, so it's more reasonable to use a class rather than functions to do it.
419 |
420 | In this framework, you should use `@websocket_handler` to decorate the class you want to handle websocket session. Specific event listener methods should be defined in a fixed way. However, the easiest way to do it is to inherit `simple_http_server.WebsocketHandler` class, and choose the event you want to implement. But this inheritance is not compulsory.
421 |
422 | You can configure `endpoit` or `regexp` in `@websocket_handler` to setup which url the class should handle. Alongside, there is a `singleton` field, which is set to `True` by default. Which means that all connections are handle by ONE object of this class. If this field is set to `False`, objects will be created when every `WebsocketSession` try to connect.
423 |
424 | ```python
425 | from simple_http_server import WebsocketHandler, WebsocketRequest,WebsocketSession, websocket_handler
426 |
427 | @websocket_handler(endpoint="/ws/{path_val}")
428 | class WSHandler(WebsocketHandler):
429 |
430 | def on_handshake(self, request: WebsocketRequest):
431 | """
432 | "
433 | " You can get path/headers/path_values/cookies/query_string/query_parameters from request.
434 | "
435 | " You should return a tuple means (http_status_code, headers)
436 | "
437 | " If status code in (0, None, 101), the websocket will be connected, or will return the status you return.
438 | "
439 | " All headers will be send to client
440 | "
441 | """
442 | _logger.info(f">>{session.id}<< open! {request.path_values}")
443 | return 0, {}
444 |
445 | def on_open(self, session: WebsocketSession):
446 | """
447 | "
448 | " Will be called when the connection opened.
449 | "
450 | """
451 | _logger.info(f">>{session.id}<< open! {session.request.path_values}")
452 |
453 | def on_close(self, session: WebsocketSession, reason: str):
454 | """
455 | "
456 | " Will be called when the connection closed.
457 | "
458 | """
459 | _logger.info(f">>{session.id}<< close::{reason}")
460 |
461 | def on_ping_message(self, session: WebsocketSession = None, message: bytes = b''):
462 | """
463 | "
464 | " Will be called when receive a ping message. Will send all the message bytes back to client by default.
465 | "
466 | """
467 | session.send_pone(message)
468 |
469 | def on_pong_message(self, session: WebsocketSession = None, message: bytes = ""):
470 | """
471 | "
472 | " Will be called when receive a pong message.
473 | "
474 | """
475 | pass
476 |
477 | def on_text_message(self, session: WebsocketSession, message: str):
478 | """
479 | "
480 | " Will be called when receive a text message.
481 | "
482 | """
483 | _logger.info(f">>{session.id}<< on text message: {message}")
484 | session.send(message)
485 |
486 | def on_binary_message(self, session: WebsocketSession = None, message: bytes = b''):
487 | """
488 | "
489 | " Will be called when receive a binary message if you have not consumed all the bytes in `on_binary_frame`
490 | " method.
491 | "
492 | """
493 | pass
494 |
495 | def on_binary_frame(self, session: WebsocketSession = None, fin: bool = False, frame_payload: bytes = b''):
496 | """
497 | "
498 | " If you are sending a continuation binary message to server, this will be called every time a frame is
499 | " received, you can consumed all the bytes in this method, e.g. save all bytes to a file. By doing so,
500 | " you should not return and value in this method.
501 | "
502 | " If you does not implement this method or return a True in this method, all the bytes will be caced in
503 | " memory and be sent to your `on_binary_message` method.
504 | "
505 | """
506 | return True
507 |
508 | @websocket_handler(regexp="^/ws-reg/([a-zA-Z0-9]+)$", singleton=False)
509 | class WSHandler(WebsocketHandler):
510 |
511 | """
512 | " You code here
513 | """
514 |
515 | ```
516 |
517 | But if you want to only handle one event, you can also use a function to handle it.
518 |
519 | ```python
520 |
521 | from simple_http_server import WebsocketCloseReason, WebsocketHandler, WebsocketRequest, WebsocketSession, websocket_message, websocket_handshake, websocket_open, websocket_close, WEBSOCKET_MESSAGE_TEXT
522 |
523 | @websocket_handshake(endpoint="/ws-fun/{path_val}")
524 | def ws_handshake(request: WebsocketRequest):
525 | return 0, {}
526 |
527 |
528 | @websocket_open(endpoint="/ws-fun/{path_val}")
529 | def ws_open(session: WebsocketSession):
530 | _logger.info(f">>{session.id}<< open! {session.request.path_values}")
531 |
532 |
533 | @websocket_close(endpoint="/ws-fun/{path_val}")
534 | def ws_close(session: WebsocketSession, reason: WebsocketCloseReason):
535 | _logger.info(
536 | f">>{session.id}<< close::{reason.message}-{reason.code}-{reason.reason}")
537 |
538 |
539 | @websocket_message(endpoint="/ws-fun/{path_val}", message_type=WEBSOCKET_MESSAGE_TEXT)
540 | # You can define a function in a sync or async way.
541 | async def ws_text(session: WebsocketSession, message: str):
542 | _logger.info(f">>{session.id}<< on text message: {message}")
543 | session.send(f"{session.request.path_values['path_val']}-{message}")
544 | if message == "close":
545 | session.close()
546 | ```
547 |
548 | ### Error pages
549 |
550 | You can use `@error_message` to specify your own error page. See:
551 |
552 | ```python
553 | from simple_http_server import error_message
554 | # map specified codes
555 | @error_message("403", "404")
556 | def my_40x_page(message: str, explain=""):
557 | return f"""
558 |
559 |
560 | 发生错误!
561 |
562 |
563 | message: {message}, explain: {explain}
564 |
565 |
566 | """
567 |
568 | # map specified code rangs
569 | @error_message("40x", "50x")
570 | def my_error_message(code, message, explain=""):
571 | return f"{code}-{message}-{explain}"
572 |
573 | # map all error page
574 | @error_message
575 | def my_error_message(code, message, explain=""):
576 | return f"{code}-{message}-{explain}"
577 | ```
578 |
579 | ### Write filters
580 |
581 | This server support filters, you can use `request_filter` decorator to define your filters.
582 |
583 | ```python
584 | from simple_http_server import request_filter
585 |
586 | @request_filter("/tuple/**") # use wildcard
587 | @request_filter(regexp="^/tuple") # use regular expression
588 | def filter_tuple(ctx):
589 | print("---------- through filter ---------------")
590 | # add a header to request header
591 | ctx.request.headers["filter-set"] = "through filter"
592 | if "user_name" not in ctx.request.parameter:
593 | ctx.response.send_redirect("/index")
594 | elif "pass" not in ctx.request.parameter:
595 | ctx.response.send_error(400, "pass should be passed")
596 | # you can also raise a HttpError
597 | # raise HttpError(400, "pass should be passed")
598 | else:
599 | # you should always use do_chain method to go to the next
600 | ctx.do_chain()
601 | ```
602 |
603 | ### Start your server
604 |
605 | ```python
606 | # If you place the controllers method in the other files, you should import them here.
607 |
608 | import simple_http_server.server as server
609 | import my_test_ctrl
610 |
611 |
612 | def main(*args):
613 | # The following method can import several controller files once.
614 | server.scan("my_ctr_pkg", r".*controller.*")
615 | server.start()
616 |
617 | if __name__ == "__main__":
618 | main()
619 | ```
620 |
621 | If you want to specify the host and port:
622 |
623 | ```python
624 | server.start(host="", port=8080)
625 | ```
626 |
627 | If you want to specify the resources path:
628 |
629 | ```python
630 | server.start(resources={"/path_prefix/*", "/absolute/dir/root/path", # Match the files in the given folder with a special path prefix.
631 | "/path_prefix/**", "/absolute/dir/root/path", # Match all the files in the given folder and its sub-folders with a special path prefix.
632 | "*.suffix", "/absolute/dir/root/path", # Match the specific files in the given folder.
633 | "**.suffix", "/absolute/dir/root/path", # Match the specific files in the given folder and its sub-folders.
634 | })
635 | ```
636 |
637 | If you want to use ssl:
638 |
639 | ```python
640 | server.start(host="",
641 | port=8443,
642 | ssl=True,
643 | ssl_protocol=ssl.PROTOCOL_TLS_SERVER, # Optional, default is ssl.PROTOCOL_TLS_SERVER, which will auto detect the highted protocol version that both server and client support.
644 | ssl_check_hostname=False, #Optional, if set to True, if the hostname is not match the certificat, it cannot establish the connection, default is False.
645 | keyfile="/path/to/your/keyfile.key",
646 | certfile="/path/to/your/certfile.cert",
647 | keypass="", # Optional, your private key's password
648 | )
649 | ```
650 |
651 | ### Coroutine
652 |
653 | From `0.12.0`, you can use coroutine tasks than threads to handle requests, you can set the `prefer_coroutine` parameter in start method to enable the coroutine mode.
654 |
655 | ```python
656 | server.start(prefer_coroutine=True)
657 | ```
658 |
659 | From `0.13.0`, coroutine mode uses the coroutine server, that means all requests will use the async I/O rather than block I/O. So you can now use `async def` to define all your controllers including the Websocket event callback methods.
660 |
661 | If you call the server starting in a async function, you can all its async version, by doing this, there sever will use the same event loop with your other async functions.
662 |
663 | ```python
664 | await server.start_async(prefer_coroutine=True)
665 | ```
666 |
667 | ## Logger
668 |
669 | The default logger is try to write logs to the screen, you can specify the logger handler to write it to a file.
670 |
671 | ```python
672 | import simple_http_server.logger as logger
673 | import logging
674 |
675 | _formatter = logging.Formatter(fmt='[%(asctime)s]-[%(name)s]-%(levelname)-4s: %(message)s')
676 | _handler = logging.TimedRotatingFileHandler("/var/log/simple_http_server.log", when="midnight", backupCount=7)
677 | _handler.setFormatter(_formatter)
678 | _handler.setLevel("INFO")
679 |
680 | logger.set_handler(_handler)
681 | ```
682 |
683 | If you want to add a handler rather than replace the inner one, you can use:
684 |
685 | ```python
686 | logger.add_handler(_handler)
687 | ```
688 |
689 | If you want to change the logger level:
690 |
691 | ```python
692 | logger.set_level("DEBUG")
693 | ```
694 |
695 | You can get a stand alone logger which is independent from the framework one via a new class `logger.LoggerFactory`.
696 |
697 | ```python
698 | import simple_http_server.logger as logger
699 |
700 | log = logger.get_logger("my_service", "my_log_fac")
701 |
702 | # If you want to set a different log level to this logger factory:
703 |
704 | log_fac = logger.get_logger_factory("my_log_fac")
705 | log_fac.log_level = "DEBUG"
706 | log = log_fac.get_logger("my_service")
707 |
708 | log.info(...)
709 |
710 | ```
711 |
712 |
713 | ## WSGI Support
714 |
715 | You can use this module in WSGI apps.
716 |
717 | ```python
718 | import simple_http_server.server as server
719 | import os
720 | from simple_http_server import request_map
721 |
722 |
723 | # scan all your controllers
724 | server.scan("tests/ctrls", r'.*controllers.*')
725 | # or define a new controller function here
726 | @request_map("/hello_wsgi")
727 | def my_controller(name: str):
728 | return 200, "Hello, WSGI!"
729 | # resources is optional
730 | wsgi_proxy = server.init_wsgi_proxy(resources={"/public/*": f"/you/static/files/path"})
731 |
732 | # wsgi app entrance.
733 | def simple_app(environ, start_response):
734 | return wsgi_proxy.app_proxy(environ, start_response)
735 |
736 | # If your entrance is async:
737 | async def simple_app(envion, start_response):
738 | return await wsgi_proxy.async_app_proxy(environ, start_response)
739 | ```
740 |
741 | ## ASGI Support
742 |
743 | You can use this module in ASGI server, take `uvicorn` fro example:
744 |
745 | ```python
746 |
747 | import asyncio
748 | import uvicorn
749 | import simple_http_server.server as server
750 | from simple_http_server.server import ASGIProxy
751 |
752 |
753 | asgi_proxy: ASGIProxy = None
754 | init_asgi_proxy_lock: asyncio.Lock = asyncio.Lock()
755 |
756 |
757 | async def init_asgi_proxy():
758 | global asgi_proxy
759 | if asgi_proxy == None:
760 | async with init_asgi_proxy_lock:
761 | if asgi_proxy == None:
762 | server.scan(base_dir="tests/ctrls", regx=r'.*controllers.*')
763 | asgi_proxy = server.init_asgi_proxy(resources={"/public/*": "tests/static"})
764 |
765 | async def app(scope, receive, send):
766 | await init_asgi_proxy()
767 | await asgi_proxy.app_proxy(scope, receive, send)
768 |
769 | def main():
770 | config = uvicorn.Config("main:app", host="0.0.0.0", port=9090, log_level="info")
771 | asgi_server = uvicorn.Server(config)
772 | asgi_server.run()
773 |
774 | if __name__ == "__main__":
775 | main()
776 |
777 | ```
778 |
779 | ## Thanks
780 |
781 | The code that process websocket comes from the following project: https://github.com/Pithikos/python-websocket-server
782 |
--------------------------------------------------------------------------------
/naja_atra/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | Copyright (c) 2018 Keijack Wu
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 |
25 | from .app_conf import *
26 | from .models import *
27 |
28 | name = "naja-atra"
29 | version = "1.0.2"
30 |
--------------------------------------------------------------------------------
/naja_atra/__main__.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import sys
4 | import signal
5 | import getopt
6 | from . import server
7 | from .utils import logger
8 |
9 | _logger = logger.get_logger("naja_atra.__main__")
10 |
11 |
12 | def print_help():
13 | print("""
14 | python3 -m naja_atra [options]
15 |
16 | Options:
17 | -h --help Show this message and exit.
18 | -p --port Specify alternate port (default: 9090)
19 | -b --bind Specify alternate bind address (default: all interfaces)
20 | -s --scan Scan this path to find controllers, if absent, will scan the current work directory
21 | --regex Only import the files that macthes this regular expresseion.
22 | -r --resources Specify the resource directory
23 | --loglevel Specify log level (default: info)
24 | """)
25 |
26 |
27 | def on_sig_term(signum, frame):
28 | _logger.info(f"Receive signal [{signum}], stop server now...")
29 | server.stop()
30 |
31 |
32 | def main(argv):
33 | try:
34 | opts = getopt.getopt(argv, "p:b:s:r:h",
35 | ["port=", "bind=", "scan=", "resources=", "regex=", "loglevel=", "help"])[0]
36 | opts = dict(opts)
37 | if "-h" in opts or "--help" in opts:
38 | print_help()
39 | return
40 | port = int(opts.get("-p", opts.get("--port", "9090")))
41 | scan_path = opts.get("-s", opts.get("--scan", os.getcwd()))
42 | regex = opts.get("--regex", "")
43 | res_dir = opts.get("-r", opts.get("--resources", ""))
44 | binding_host = opts.get("-b", opts.get("--bind", "0.0.0.0"))
45 | log_level = opts.get("--loglevel", "")
46 |
47 | if log_level:
48 | logger.set_level(log_level)
49 | signal.signal(signal.SIGTERM, on_sig_term)
50 | signal.signal(signal.SIGINT, on_sig_term)
51 | server.scan(regx=regex, project_dir=scan_path)
52 | server.start(
53 | host=binding_host,
54 | port=port,
55 | resources={"/**": res_dir},
56 | keep_alive=False,
57 | prefer_coroutine=False)
58 | except Exception as e:
59 | print(f"Start server error: {e}")
60 | print_help()
61 |
62 |
63 | if __name__ == "__main__":
64 | main(sys.argv[1:])
65 |
--------------------------------------------------------------------------------
/naja_atra/http_servers/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | Copyright (c) 2018 Keijack Wu
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 |
--------------------------------------------------------------------------------
/naja_atra/http_servers/coroutine_http_server.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | Copyright (c) 2018 Keijack Wu
6 |
7 | Permission is hereby granted, free of charge, to any person obtaining a copy
8 | of this software and associated documentation files (the "Software"), to deal
9 | in the Software without restriction, including without limitation the rights
10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | copies of the Software, and to permit persons to whom the Software is
12 | furnished to do so, subject to the following conditions:
13 |
14 | The above copyright notice and this permission notice shall be included in all
15 | copies or substantial portions of the Software.
16 |
17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | SOFTWARE.
24 | """
25 |
26 |
27 | from .routing_server import RoutingServer
28 | from ..request_handlers.model_bindings import ModelBindingConf
29 | from ..request_handlers.http_request_handler import HttpRequestHandler
30 |
31 |
32 | import asyncio
33 | import threading
34 | from asyncio.base_events import Server
35 | from asyncio.streams import StreamReader, StreamWriter
36 | from ssl import SSLContext
37 | from time import sleep
38 |
39 | from ..utils.logger import get_logger
40 |
41 |
42 | _logger = get_logger("naja_atra.http_servers.coroutine_http_server")
43 |
44 |
45 | class CoroutineHTTPServer(RoutingServer):
46 |
47 | def __init__(self, host: str = '', port: int = 9090, ssl: SSLContext = None, res_conf={}, model_binding_conf: ModelBindingConf = ModelBindingConf()) -> None:
48 | RoutingServer.__init__(
49 | self, res_conf, model_binding_conf=model_binding_conf)
50 | self.host: str = host
51 | self.port: int = port
52 | self.ssl: SSLContext = ssl
53 | self.server: Server = None
54 | self.__thread_local = threading.local()
55 |
56 | async def callback(self, reader: StreamReader, writer: StreamWriter):
57 | handler = HttpRequestHandler(reader, writer, routing_conf=self)
58 | await handler.handle_request()
59 | _logger.debug("Connection ends, close the writer.")
60 | writer.close()
61 |
62 | async def start_async(self):
63 | self.server = await asyncio.start_server(
64 | self.callback, host=self.host, port=self.port, ssl=self.ssl)
65 | async with self.server:
66 | try:
67 | await self.server.serve_forever()
68 | except asyncio.CancelledError:
69 | _logger.debug(
70 | "Some requests are lost for the reason that the server is shutted down.")
71 | finally:
72 | await self.server.wait_closed()
73 |
74 | def _get_event_loop(self) -> asyncio.AbstractEventLoop:
75 | if not hasattr(self.__thread_local, "event_loop"):
76 | try:
77 | self.__thread_local.event_loop = asyncio.new_event_loop()
78 | except:
79 | self.__thread_local.event_loop = asyncio.get_event_loop()
80 | return self.__thread_local.event_loop
81 |
82 | def start(self):
83 | self._get_event_loop().run_until_complete(self.start_async())
84 |
85 | def _shutdown(self):
86 | _logger.debug("Try to shutdown server.")
87 | self.server.close()
88 | loop = self.server.get_loop()
89 | loop.call_soon_threadsafe(loop.stop)
90 |
91 | def shutdown(self):
92 | wait_time = 3
93 | while wait_time:
94 | sleep(1)
95 | _logger.debug(f"couting to shutdown: {wait_time}")
96 | wait_time = wait_time - 1
97 | if wait_time == 0:
98 | _logger.debug("shutdown server....")
99 | self._shutdown()
100 |
--------------------------------------------------------------------------------
/naja_atra/http_servers/http_server.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | Copyright (c) 2018 Keijack Wu
6 |
7 | Permission is hereby granted, free of charge, to any person obtaining a copy
8 | of this software and associated documentation files (the "Software"), to deal
9 | in the Software without restriction, including without limitation the rights
10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | copies of the Software, and to permit persons to whom the Software is
12 | furnished to do so, subject to the following conditions:
13 |
14 | The above copyright notice and this permission notice shall be included in all
15 | copies or substantial portions of the Software.
16 |
17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | SOFTWARE.
24 | """
25 |
26 |
27 | from ssl import PROTOCOL_TLS_SERVER, SSLContext
28 |
29 | from typing import Dict, Tuple
30 | from .coroutine_http_server import CoroutineHTTPServer
31 | from .threading_http_server import ThreadingHTTPServer
32 |
33 |
34 | from ..app_conf import _ControllerFunction, _WebsocketHandlerClass, AppConf, get_app_conf
35 |
36 |
37 | from ..utils.logger import get_logger
38 |
39 |
40 | _logger = get_logger("naja_atra.http_servers.http_server")
41 |
42 |
43 | class HttpServer:
44 | """Dispatcher Http server"""
45 |
46 | def map_filter(self, filter_conf):
47 | self.server.map_filter(filter_conf)
48 |
49 | def map_controller(self, ctrl: _ControllerFunction):
50 | self.server.map_controller(ctrl)
51 |
52 | def map_websocket_handler(self, handler: _WebsocketHandlerClass):
53 | self.server.map_websocket_handler(handler)
54 |
55 | def map_error_page(self, code, func):
56 | self.server.map_error_page(code, func)
57 |
58 | def __init__(self,
59 | host: Tuple[str, int] = ('', 9090),
60 | ssl: bool = False,
61 | ssl_protocol: int = PROTOCOL_TLS_SERVER,
62 | ssl_check_hostname: bool = False,
63 | keyfile: str = "",
64 | certfile: str = "",
65 | keypass: str = "",
66 | ssl_context: SSLContext = None,
67 | resources: Dict[str, str] = {},
68 | prefer_corountine=False,
69 | max_workers: int = None,
70 | connection_idle_time=None,
71 | keep_alive=True,
72 | keep_alive_max_request=None,
73 | gzip_content_types=set(),
74 | gzip_compress_level=9,
75 | app_conf: AppConf = None):
76 | self.host = host
77 | self.__ready = False
78 |
79 | self.ssl = ssl
80 |
81 | if ssl:
82 | if ssl_context:
83 | self.ssl_ctx = ssl_context
84 | else:
85 | assert keyfile and certfile, "keyfile and certfile should be provided. "
86 | ssl_ctx = SSLContext(protocol=ssl_protocol)
87 | ssl_ctx.check_hostname = ssl_check_hostname
88 | ssl_ctx.load_cert_chain(
89 | certfile=certfile, keyfile=keyfile, password=keypass)
90 | self.ssl_ctx = ssl_ctx
91 | else:
92 | self.ssl_ctx = None
93 |
94 | appconf = app_conf or get_app_conf()
95 | if prefer_corountine:
96 | _logger.info(
97 | f"Start server in corouting mode, listen to port: {self.host[1]}")
98 | self.server = CoroutineHTTPServer(
99 | self.host[0], self.host[1], self.ssl_ctx, resources, model_binding_conf=appconf.model_binding_conf)
100 | else:
101 | _logger.info(
102 | f"Start server in threading mixed mode, listen to port {self.host[1]}")
103 | self.server = ThreadingHTTPServer(
104 | self.host, resources, model_binding_conf=appconf.model_binding_conf, max_workers=max_workers)
105 | if self.ssl_ctx:
106 | self.server.socket = self.ssl_ctx.wrap_socket(
107 | self.server.socket, server_side=True)
108 |
109 | self.server.gzip_compress_level = gzip_compress_level
110 | self.server.gzip_content_types = gzip_content_types
111 |
112 | filters = appconf._get_filters()
113 | # filter configuration
114 | for ft in filters:
115 | self.map_filter(ft)
116 |
117 | request_mappings = appconf._get_request_mappings()
118 | # request mapping
119 | for ctr in request_mappings:
120 | self.map_controller(ctr)
121 |
122 | ws_handlers = appconf._get_websocket_handlers()
123 |
124 | for wshandler in ws_handlers:
125 | self.map_websocket_handler(wshandler)
126 |
127 | err_pages = appconf._get_error_pages()
128 | for code, func in err_pages.items():
129 | self.map_error_page(code, func)
130 | self.server.keep_alive = keep_alive
131 | self.server.connection_idle_time = connection_idle_time
132 | self.server.keep_alive_max_request = keep_alive_max_request
133 | self.server.session_factory = appconf.session_factory
134 |
135 | @property
136 | def ready(self):
137 | return self.__ready
138 |
139 | def resources(self, res={}):
140 | self.server.res_conf = res
141 |
142 | def start(self):
143 | try:
144 | self.__ready = True
145 | self.server.start()
146 | except:
147 | self.__ready = False
148 | raise
149 |
150 | async def start_async(self):
151 | try:
152 | self.__ready = True
153 | await self.server.start_async()
154 | except:
155 | self.__ready = False
156 | raise
157 |
158 | def shutdown(self):
159 | # shutdown it in a seperate thread.
160 | self.server.shutdown()
161 |
--------------------------------------------------------------------------------
/naja_atra/http_servers/routing_server.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | Copyright (c) 2018 Keijack Wu
6 |
7 | Permission is hereby granted, free of charge, to any person obtaining a copy
8 | of this software and associated documentation files (the "Software"), to deal
9 | in the Software without restriction, including without limitation the rights
10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | copies of the Software, and to permit persons to whom the Software is
12 | furnished to do so, subject to the following conditions:
13 |
14 | The above copyright notice and this permission notice shall be included in all
15 | copies or substantial portions of the Software.
16 |
17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | SOFTWARE.
24 | """
25 |
26 |
27 | from abc import abstractmethod
28 | import json
29 | import os
30 | import re
31 |
32 |
33 | from urllib.parse import unquote
34 |
35 | from typing import Any, Callable, Dict, List, Set, Tuple, Union
36 |
37 | from ..models import StaticFile, HttpSessionFactory
38 | from ..request_handlers.model_bindings import ModelBindingConf
39 | from ..app_conf import _WebsocketHandlerClass, _ControllerFunction
40 |
41 | from ..utils.http_utils import remove_url_first_slash, get_function_args, get_function_kwargs, get_path_reg_pattern
42 | from ..utils.logger import get_logger
43 |
44 | _logger = get_logger("naja_atra.http_servers.routing_server")
45 |
46 |
47 | _EXT_CONTENT_TYPE = {
48 | ".html": "text/html",
49 | ".htm": "text/html",
50 | ".xhtml": "text/html",
51 | ".css": "text/css",
52 | ".js": "text/javascript",
53 | ".jpg": "image/jpeg",
54 | ".jpeg": "image/jpeg",
55 | ".png": "image/png",
56 | ".ico": "image/x-icon",
57 | ".svg": "image/svg+xml",
58 | ".gif": "image/gif",
59 | ".avif": "image/avif",
60 | ".avifs": "image/avif",
61 | ".webp": "image/webp",
62 | ".pdf": "application/pdf",
63 | ".json": "application/json",
64 | ".mp4": "video/mp4",
65 | ".mp3": "video/mp3",
66 | ".txt": "text/plain"
67 | }
68 |
69 |
70 | class RoutingServer:
71 |
72 | HTTP_METHODS = ["OPTIONS", "GET", "HEAD",
73 | "POST", "PUT", "DELETE", "TRACE", "CONNECT"]
74 |
75 | def __init__(self, res_conf={}, model_binding_conf: ModelBindingConf = ModelBindingConf()):
76 | self.method_url_mapping: Dict[str,
77 | Dict[str, List[_ControllerFunction]]] = {"_": {}}
78 | self.path_val_url_mapping: Dict[str, Dict[str, List[Tuple[_ControllerFunction, List[str]]]]] = {
79 | "_": {}}
80 | self.method_regexp_mapping: Dict[str, Dict[str, List[_ControllerFunction]]] = {
81 | "_": {}}
82 | for mth in self.HTTP_METHODS:
83 | self.method_url_mapping[mth] = {}
84 | self.path_val_url_mapping[mth] = {}
85 | self.method_regexp_mapping[mth] = {}
86 |
87 | self.filter_mapping = {}
88 | self._res_conf = []
89 | self.add_res_conf(res_conf)
90 |
91 | self.ws_mapping: Dict[str, _ControllerFunction] = {}
92 | self.ws_path_val_mapping: Dict[str, _ControllerFunction] = {}
93 | self.ws_regx_mapping: Dict[str, _ControllerFunction] = {}
94 |
95 | self.error_page_mapping = {}
96 | self.keep_alive = True
97 | self.__connection_idle_time: float = 60
98 | self.__keep_alive_max_request: int = 10
99 | self.session_factory: HttpSessionFactory = None
100 | self.model_binding_conf = model_binding_conf
101 | self.gzip_content_types: Set[str] = set()
102 | self.gzip_compress_level = 9
103 |
104 | @property
105 | def connection_idle_time(self):
106 | return self.__connection_idle_time
107 |
108 | @connection_idle_time.setter
109 | def connection_idle_time(self, val: float):
110 | if (isinstance(val, float) or isinstance(val, int)) and val > 0:
111 | self.__connection_idle_time = val
112 |
113 | @property
114 | def keep_alive_max_request(self):
115 | return self.__keep_alive_max_request
116 |
117 | @keep_alive_max_request.setter
118 | def keep_alive_max_request(self, val):
119 | if isinstance(val, int) and val > 0:
120 | self.__keep_alive_max_request = val
121 |
122 | def extend_gzip_content_types(self, content_types: Union[Set[str], List[str]]):
123 | for ctype in content_types:
124 | self.gzip_content_types.add(ctype.lower())
125 |
126 | def put_to_method_url_mapping(self, method, url, ctrl):
127 | if url not in self.method_url_mapping[method]:
128 | self.method_url_mapping[method][url] = []
129 | self.method_url_mapping[method][url].insert(0, ctrl)
130 |
131 | def put_to_path_val_url_mapping(self, method, path_pattern, ctrl, path_names):
132 | if path_pattern not in self.path_val_url_mapping[method]:
133 | self.path_val_url_mapping[method][path_pattern] = []
134 | self.path_val_url_mapping[method][path_pattern].insert(
135 | 0, (ctrl, path_names))
136 |
137 | def put_to_method_regexp_mapping(self, method, regexp, ctrl):
138 | if regexp not in self.method_regexp_mapping[method]:
139 | self.method_regexp_mapping[method][regexp] = []
140 | self.method_regexp_mapping[method][regexp].insert(0, ctrl)
141 |
142 | @property
143 | def res_conf(self):
144 | return self._res_conf
145 |
146 | @res_conf.setter
147 | def res_conf(self, val: Dict[str, str]):
148 | self._res_conf.clear()
149 | self.add_res_conf(val)
150 |
151 | def add_res_conf(self, val: Dict[str, str]):
152 | if not val or not isinstance(val, dict):
153 | return
154 | for k, v in val.items():
155 | res_k = k
156 | if not res_k.startswith("*") and res_k.endswith('/'):
157 | # xxx/ equals xxx/*
158 | res_k = res_k + "*"
159 | if res_k.startswith('*'):
160 | suffix = res_k[2:] if res_k.startswith("**") else res_k[1:]
161 | assert suffix.find('/') < 0 and suffix.find(
162 | '*') < 0, "If a resource path starts with *, only suffix can be configurated. "
163 | if res_k.startswith('**.'):
164 | # **.xxx
165 | suffix = res_k[3:]
166 | key = f'^[\\w%.\\-@!\\(\\)\\[\\]\\|\\$/]+\\.{suffix}$'
167 | elif res_k.startswith('*.'):
168 | # *.xxx
169 | suffix = res_k[2:]
170 | key = f'^[\\w%.\\-@!\\(\\)\\[\\]\\|\\$]+\\.{suffix}$'
171 | elif res_k.endswith("/**"):
172 | # xx/**
173 | prefix = res_k[0:-2]
174 | while prefix.startswith('/'):
175 | prefix = prefix[1:]
176 | assert prefix.find(
177 | "*") < 0, "You can only config a * or ** at the start or end of a path."
178 | key = f'^{prefix}([\\w%.\\-@!\\(\\)\\[\\]\\|\\$/]+)$'
179 | elif res_k.endswith("/*"):
180 | # xx/*
181 | prefix = res_k[0:-1]
182 | while prefix.startswith('/'):
183 | prefix = prefix[1:]
184 | assert prefix.find(
185 | "*") < 0, "You can only config a * or ** at the start or end of a path."
186 | key = f'^{prefix}([\\w%.\\-@!\\(\\)\\[\\]\\|\\$]+)$'
187 |
188 | if v.endswith(os.path.sep):
189 | val = v
190 | else:
191 | val = v + os.path.sep
192 | self._res_conf.append((key, val))
193 |
194 | def map_controller(self, ctrl: _ControllerFunction):
195 | url = ctrl.url
196 | regexp = ctrl.regexp
197 | method = ctrl.method
198 | _logger.debug(
199 | f"map url {url}|{regexp} with method::{method}, headers::{ctrl.headers} and params::{ctrl.params} to function {ctrl.func}. ")
200 | assert method is None or method == "" or method.upper() in self.HTTP_METHODS
201 | _method = method.upper() if method is not None and method != "" else "_"
202 | if regexp:
203 | self.put_to_method_regexp_mapping(_method, regexp, ctrl)
204 | else:
205 | _url = remove_url_first_slash(url)
206 |
207 | path_pattern, path_names = get_path_reg_pattern(_url)
208 | if path_pattern is None:
209 | self.put_to_method_url_mapping(_method, _url, ctrl)
210 | else:
211 | self.put_to_path_val_url_mapping(
212 | _method, path_pattern, ctrl, path_names)
213 |
214 | def _res_(self, fpath: str):
215 | fext = os.path.splitext(fpath)[1]
216 | ext = fext.lower()
217 | content_type = _EXT_CONTENT_TYPE.get(ext, "application/octet-stream")
218 | return StaticFile(fpath, content_type)
219 |
220 | def get_url_controllers(self, path: str = "", method: str = "") -> List[Tuple[_ControllerFunction, Dict, List]]:
221 | # explicitly url matching
222 | if path in self.method_url_mapping[method]:
223 | return [(ctrl, {}, ()) for ctrl in self.method_url_mapping[method][path]]
224 | elif path in self.method_url_mapping["_"]:
225 | return [(ctrl, {}, ()) for ctrl in self.method_url_mapping["_"][path]]
226 |
227 | # url with path value matching
228 | path_val_res = self.__try_get_from_path_val(path, method)
229 | if path_val_res is None:
230 | path_val_res = self.__try_get_from_path_val(path, "_")
231 | if path_val_res is not None:
232 | return path_val_res
233 |
234 | # regexp
235 | regexp_res = self.__try_get_ctrl_from_regexp(path, method)
236 | if regexp_res is None:
237 | regexp_res = self.__try_get_ctrl_from_regexp(path, "_")
238 | if regexp_res is not None:
239 | return regexp_res
240 | # static files
241 | for k, v in self.res_conf:
242 | match_static_path_conf = re.match(k, path)
243 | _logger.debug(
244 | f"{path} macth static file conf {k} ? {match_static_path_conf}")
245 | if match_static_path_conf:
246 | if match_static_path_conf.groups():
247 | fpath = f"{v}{match_static_path_conf.group(1)}"
248 | else:
249 | fpath = f"{v}{path}"
250 |
251 | def static_fun():
252 | return self._res_(fpath)
253 | return [(_ControllerFunction(func=static_fun), {}, ())]
254 | return []
255 |
256 | def __try_get_ctrl_from_regexp(self, path, method):
257 | for regex, ctrls in self.method_regexp_mapping[method].items():
258 | m = re.match(regex, f"/{path}") or re.match(regex, path)
259 | _logger.debug(
260 | f"regexp::pattern::[{regex}] => path::[{path}] match? {m is not None}")
261 | if m:
262 | res = []
263 | grps = tuple([unquote(v) for v in m.groups()])
264 | for ctrl in ctrls:
265 | res.append((ctrl, [], grps))
266 | return res
267 | return None
268 |
269 | def __try_get_from_path_val(self, path, method):
270 | for patterns, val in self.path_val_url_mapping[method].items():
271 | m = re.match(patterns, path)
272 | _logger.debug(
273 | f"url with path value::pattern::[{patterns}] => path::[{path}] match? {m is not None}")
274 | if m:
275 | res = []
276 | for ctrl_fun, path_names in val:
277 | path_values = {}
278 | for idx in range(len(path_names)):
279 | key = unquote(path_names[idx])
280 | path_values[key] = unquote(m.groups()[idx])
281 | res.append((ctrl_fun, path_values, ()))
282 | return res
283 | return None
284 |
285 | def map_filter(self, filter_conf: Dict[str, Any]):
286 | # {"path": p, "url_pattern": r, "func": filter_fun}
287 | path = filter_conf["path"] if "path" in filter_conf else ""
288 | regexp = filter_conf["url_pattern"]
289 | filter_fun = filter_conf["func"]
290 | if path:
291 | regexp = get_path_reg_pattern(path)[0]
292 | if not regexp:
293 | regexp = f"^{path}$"
294 | _logger.debug(
295 | f"[path: {path}] map url regexp {regexp} to function: {filter_fun}")
296 | self.filter_mapping[regexp] = filter_fun
297 |
298 | def get_matched_filters(self, path):
299 | return self._get_matched_filters(remove_url_first_slash(path)) + self._get_matched_filters(path)
300 |
301 | def _get_matched_filters(self, path):
302 | available_filters = []
303 | for regexp, val in self.filter_mapping.items():
304 | m = re.match(regexp, path)
305 | _logger.debug(
306 | f"filter:: [{regexp}], path:: [{path}] match? {m is not None}")
307 | if m:
308 | available_filters.append(val)
309 | return available_filters
310 |
311 | def map_websocket_handler(self, handler: _WebsocketHandlerClass):
312 | url = handler.url
313 | regexp = handler.regexp
314 | _logger.debug(
315 | f"map url {url}|{regexp} to controller class {handler.cls}")
316 | if regexp:
317 | self.ws_regx_mapping[regexp] = handler
318 | else:
319 | url = remove_url_first_slash(url)
320 | path_pattern, path_names = get_path_reg_pattern(url)
321 | if path_pattern is None:
322 | self.ws_mapping[url] = handler
323 | else:
324 | self.ws_path_val_mapping[path_pattern] = (handler, path_names)
325 |
326 | def get_websocket_handler(self, path):
327 | # explicitly mapping
328 | if path in self.ws_mapping:
329 | return self.ws_mapping[path], {}, ()
330 |
331 | # path value mapping
332 | handler, path_vals = self.__try_get_ws_handler_from_path_val(path)
333 | if handler is not None:
334 | return handler, path_vals, ()
335 | # regexp mapping
336 | return self.__try_get_ws_hanlder_from_regexp(path)
337 |
338 | def __try_get_ws_hanlder_from_regexp(self, path):
339 | for regex, handler in self.ws_regx_mapping.items():
340 | m = re.match(regex, f"/{path}") or re.match(regex, path)
341 | _logger.debug(
342 | f"regexp::pattern::[{regex}] => path::[{path}] match? {m is not None}")
343 | if m:
344 | return handler, {}, tuple([unquote(v) for v in m.groups()])
345 | return None, {}, ()
346 |
347 | def __try_get_ws_handler_from_path_val(self, path):
348 | for patterns, val in self.ws_path_val_mapping.items():
349 | m = re.match(patterns, path)
350 | _logger.debug(
351 | f"websocket endpoint with path value::pattern::[{patterns}] => path::[{path}] match? {m is not None}")
352 | if m:
353 | handler, path_names = val
354 | path_values = {}
355 | for idx in range(len(path_names)):
356 | key = unquote(path_names[idx])
357 | path_values[key] = unquote(m.groups()[idx])
358 | return handler, path_values
359 | return None, {}
360 |
361 | def map_error_page(self, code: str, error_page_fun: Callable):
362 | if not code:
363 | c = "_"
364 | else:
365 | c = str(code).lower()
366 | self.error_page_mapping[c] = error_page_fun
367 |
368 | def _default_error_page(self, code: int, message: str = "", explain: str = ""):
369 | return json.dumps({
370 | "code": code,
371 | "message": message,
372 | "explain": explain
373 | })
374 |
375 | def error_page(self, code: int, message: str = "", explain: str = ""):
376 | c = str(code)
377 | func = None
378 | if c in self.error_page_mapping:
379 | func = self.error_page_mapping[c]
380 | elif code > 200:
381 | c0x = c[0:2] + "x"
382 | if c0x in self.error_page_mapping:
383 | func = self.error_page_mapping[c0x]
384 | elif "_" in self.error_page_mapping:
385 | func = self.error_page_mapping["_"]
386 |
387 | if not func:
388 | func = self._default_error_page
389 | _logger.debug(f"error page function:: {func}")
390 |
391 | co = code
392 | msg = message
393 | exp = explain
394 |
395 | args_def = get_function_args(func, None)
396 | kwargs_def = get_function_kwargs(func, None)
397 |
398 | args = []
399 | for n, t in args_def:
400 | _logger.debug(f"set value to error_page function -> {n}")
401 | if co is not None:
402 | if t is None or t == int:
403 | args.append(co)
404 | co = None
405 | continue
406 | if msg is not None:
407 | if t is None or t == str:
408 | args.append(msg)
409 | msg = None
410 | continue
411 | if exp is not None:
412 | if t is None or t == str:
413 | args.append(exp)
414 | exp = None
415 | continue
416 | args.append(None)
417 |
418 | kwargs = {}
419 | for n, v, t in kwargs_def:
420 | if co is not None:
421 | if (t is None and isinstance(v, int)) or t == int:
422 | kwargs[n] = co
423 | co = None
424 | continue
425 | if msg is not None:
426 | if (t is None and isinstance(v, str)) or t == str:
427 | kwargs[n] = msg
428 | msg = None
429 | continue
430 | if exp is not None:
431 | if (t is None and isinstance(v, str)) or t == str:
432 | kwargs[n] = exp
433 | exp = None
434 | continue
435 | kwargs[n] = v
436 |
437 | if args and kwargs:
438 | return func(*args, **kwargs)
439 | elif args:
440 | return func(*args)
441 | elif kwargs:
442 | return func(**kwargs)
443 | else:
444 | return func()
445 |
446 | @abstractmethod
447 | def start(self):
448 | pass
449 |
450 | @abstractmethod
451 | async def start_async(self):
452 | pass
453 |
454 | @abstractmethod
455 | def shutdown(self):
456 | pass
457 |
--------------------------------------------------------------------------------
/naja_atra/http_servers/threading_http_server.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | Copyright (c) 2018 Keijack Wu
6 |
7 | Permission is hereby granted, free of charge, to any person obtaining a copy
8 | of this software and associated documentation files (the "Software"), to deal
9 | in the Software without restriction, including without limitation the rights
10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | copies of the Software, and to permit persons to whom the Software is
12 | furnished to do so, subject to the following conditions:
13 |
14 | The above copyright notice and this permission notice shall be included in all
15 | copies or substantial portions of the Software.
16 |
17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | SOFTWARE.
24 | """
25 |
26 |
27 | from .routing_server import RoutingServer
28 | from ..request_handlers.model_bindings import ModelBindingConf
29 | from ..request_handlers.http_request_handler import SocketServerStreamRequestHandlerWraper
30 |
31 |
32 | import socket
33 | import threading
34 | from concurrent.futures import ThreadPoolExecutor
35 | from socketserver import TCPServer
36 |
37 | from ..utils.logger import get_logger
38 |
39 |
40 | _logger = get_logger("naja_atra.http_servers.threading_http_server")
41 |
42 |
43 | class ThreadingHTTPServer(TCPServer, RoutingServer):
44 |
45 | allow_reuse_address = 1 # Seems to make sense in testing environment
46 |
47 | _default_max_workers = 50
48 |
49 | def server_bind(self):
50 | """Override server_bind to store the server name."""
51 | TCPServer.server_bind(self)
52 | host, port = self.server_address[:2]
53 | self.server_name = socket.getfqdn(host)
54 | self.server_port = port
55 |
56 | def __init__(self, addr, res_conf={}, model_binding_conf: ModelBindingConf = ModelBindingConf(), max_workers: int = None):
57 | RoutingServer.__init__(
58 | self, res_conf, model_binding_conf=model_binding_conf)
59 | self.max_workers = max_workers or self._default_max_workers
60 | self.threadpool: ThreadPoolExecutor = ThreadPoolExecutor(
61 | thread_name_prefix="ReqThread",
62 | max_workers=self.max_workers)
63 | TCPServer.__init__(self, addr, SocketServerStreamRequestHandlerWraper)
64 |
65 | def process_request_thread(self, request, client_address):
66 | try:
67 | self.finish_request(request, client_address)
68 | except Exception:
69 | self.handle_error(request, client_address)
70 | finally:
71 | self.shutdown_request(request)
72 |
73 | # override
74 | def process_request(self, request, client_address):
75 | self.threadpool.submit(
76 | self.process_request_thread, request, client_address)
77 |
78 | def server_close(self):
79 | super().server_close()
80 | self.threadpool.shutdown(True)
81 |
82 | def start(self):
83 | self.serve_forever()
84 |
85 | async def start_async(self):
86 | self.start()
87 |
88 | def _shutdown(self) -> None:
89 | _logger.debug("shutdown http server in a seperate thread..")
90 | super().shutdown()
91 |
92 | def shutdown(self) -> None:
93 | threading.Thread(target=self._shutdown, daemon=False).start()
94 |
--------------------------------------------------------------------------------
/naja_atra/models.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | Copyright (c) 2018 Keijack Wu
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 | import http.cookies
25 | import time
26 | from abc import abstractmethod
27 | from typing import Any, Dict, List, Tuple, Union
28 |
29 |
30 | DEFAULT_ENCODING: str = "UTF-8"
31 |
32 | SESSION_COOKIE_NAME: str = "PY_SIM_HTTP_SER_SESSION_ID"
33 |
34 | WEBSOCKET_OPCODE_CONTINUATION: int = 0x0
35 | WEBSOCKET_OPCODE_TEXT: int = 0x1
36 | WEBSOCKET_OPCODE_BINARY: int = 0x2
37 | WEBSOCKET_OPCODE_CLOSE: int = 0x8
38 | WEBSOCKET_OPCODE_PING: int = 0x9
39 | WEBSOCKET_OPCODE_PONG: int = 0xA
40 |
41 | WEBSOCKET_MESSAGE_TEXT: str = "WEBSOCKET_MESSAGE_TEXT"
42 | WEBSOCKET_MESSAGE_BINARY: str = "WEBSOCKET_MESSAGE_BINARY"
43 | WEBSOCKET_MESSAGE_BINARY_FRAME: str = "WEBSOCKET_MESSAGE_BINARY_FRAME"
44 | WEBSOCKET_MESSAGE_PING: str = "WEBSOCKET_MESSAGE_PING"
45 | WEBSOCKET_MESSAGE_PONG: str = "WEBSOCKET_MESSAGE_PONG"
46 |
47 |
48 | class HttpSession:
49 |
50 | def __init__(self):
51 | self.max_inactive_interval: int = 30 * 60
52 |
53 | @property
54 | def id(self) -> str:
55 | return ""
56 |
57 | @property
58 | def creation_time(self) -> float:
59 | return 0
60 |
61 | @property
62 | def last_accessed_time(self) -> float:
63 | return 0
64 |
65 | @property
66 | def attribute_names(self) -> Tuple:
67 | return ()
68 |
69 | @property
70 | def is_new(self) -> bool:
71 | return False
72 |
73 | @property
74 | def is_valid(self) -> bool:
75 | return time.time() - self.last_accessed_time < self.max_inactive_interval
76 |
77 | @abstractmethod
78 | def get_attribute(self, name: str) -> Any:
79 | return NotImplemented
80 |
81 | @abstractmethod
82 | def set_attribute(self, name: str, value: str) -> None:
83 | return NotImplemented
84 |
85 | @abstractmethod
86 | def invalidate(self) -> None:
87 | return NotImplemented
88 |
89 |
90 | class HttpSessionFactory:
91 |
92 | @abstractmethod
93 | def get_session(self, session_id: str, create: bool = False) -> HttpSession:
94 | return NotImplemented
95 |
96 |
97 | class Cookies(http.cookies.SimpleCookie):
98 | EXPIRE_DATE_FORMAT = "%a, %d %b %Y %H:%M:%S GMT"
99 |
100 |
101 | class RequestBodyReader:
102 |
103 | @abstractmethod
104 | async def read(self, n: int = -1) -> bytes:
105 | return NotImplemented
106 |
107 |
108 | class Request:
109 | """Request"""
110 |
111 | def __init__(self):
112 | self.method: str = "" # GET, POST, PUT, DELETE, HEAD, etc.
113 | self.headers: Dict[str, str] = {} # Request headers
114 | self.__cookies = Cookies()
115 | self.query_string: str = "" # Query String
116 | self.path_values: Dict[str, str] = {}
117 | self.reg_groups = () # If controller is matched via regexp, then ,all groups are save here
118 | self.path: str = "" # Path
119 | self.__parameters = {} # Parameters, key-value array, merged by query string and request body if the `Content-Type` in request header is `application/x-www-form-urlencoded` or `multipart/form-data`
120 | # Parameters, key-value, if more than one parameters with the same key, only the first one will be stored.
121 | self.__parameter = {}
122 | self._body: bytes = b"" # Request body
123 | # A dictionary if the `Content-Type` in request header is `application/json`
124 | self.json: Dict[str, Any] = None
125 | self.environment = {}
126 | self.reader: RequestBodyReader = None # A stream reader
127 |
128 | @property
129 | def cookies(self) -> Cookies:
130 | return self.__cookies
131 |
132 | @property
133 | def parameters(self) -> Dict[str, List[str]]:
134 | return self.__parameters
135 |
136 | @parameters.setter
137 | def parameters(self, val: Dict[str, List[str]]):
138 | self.__parameters = val
139 | self.__parameter = {}
140 | for k, v in self.__parameters.items():
141 | self.__parameter[k] = v[0]
142 |
143 | @property
144 | def parameter(self) -> Dict[str, str]:
145 | return self.__parameter
146 |
147 | @property
148 | def host(self) -> str:
149 | if "Host" not in self.headers:
150 | return ""
151 | else:
152 | return self.headers["Host"]
153 |
154 | @property
155 | def body(self) -> bytes:
156 | return self._body
157 |
158 | @property
159 | def content_type(self) -> str:
160 | if "Content-Type" not in self.headers:
161 | return ""
162 | else:
163 | return self.headers["Content-Type"]
164 |
165 | @property
166 | def content_length(self) -> int:
167 | if "Content-Length" not in self.headers:
168 | return None
169 | else:
170 | return int(self.headers["Content-Length"])
171 |
172 | def get_parameter(self, key: str, default: str = None) -> str:
173 | if key not in self.parameters.keys():
174 | return default
175 | else:
176 | return self.parameter[key]
177 |
178 | @abstractmethod
179 | def get_session(self, create: bool = False) -> HttpSession:
180 | return NotImplemented
181 |
182 |
183 | class MultipartFile:
184 | """Multipart file"""
185 |
186 | def __init__(self, name: str = "",
187 | required: bool = False,
188 | filename: str = "",
189 | content_type: str = "",
190 | content: bytes = None):
191 | self.__name = name
192 | self.__required = required
193 | self.__filename = filename
194 | self.__content_type = content_type
195 | self.__content = content
196 |
197 | @property
198 | def name(self) -> str:
199 | return self.__name
200 |
201 | @property
202 | def _required(self) -> bool:
203 | return self.__required
204 |
205 | @property
206 | def filename(self) -> str:
207 | return self.__filename
208 |
209 | @property
210 | def content_type(self) -> str:
211 | return self.__content_type
212 |
213 | @property
214 | def content(self) -> bytes:
215 | return self.__content
216 |
217 | @property
218 | def is_empty(self) -> bool:
219 | return self.__content is None or len(self.__content) == 0
220 |
221 | def save_to_file(self, file_path: str) -> None:
222 | if self.__content is not None and len(self.__content) > 0:
223 | with open(file_path, "wb") as f:
224 | f.write(self.__content)
225 |
226 |
227 | class ParamStringValue(str):
228 |
229 | def __init__(self, name: str = "",
230 | default: str = "",
231 | required: bool = False):
232 | self.__name = name
233 | self.__required = required
234 |
235 | @property
236 | def name(self) -> str:
237 | return self.__name
238 |
239 | @property
240 | def _required(self) -> bool:
241 | return self.__required
242 |
243 | def __new__(cls, name="", default="", **kwargs):
244 | assert isinstance(default, str)
245 | obj = super().__new__(cls, default)
246 | return obj
247 |
248 |
249 | class Parameter(ParamStringValue):
250 | pass
251 |
252 |
253 | class PathValue(str):
254 |
255 | def __init__(self, name: str = "", _value: str = ""):
256 | self.__name = name
257 |
258 | @property
259 | def name(self):
260 | return self.__name
261 |
262 | def __new__(cls, name: str = "", _value: str = "", **kwargs):
263 | assert isinstance(_value, str)
264 | obj = super().__new__(cls, _value)
265 | return obj
266 |
267 |
268 | class Parameters(list):
269 |
270 | def __init__(self, name: str = "", default: List[str] = [], required: bool = False):
271 | self.__name = name
272 | self.__required = required
273 |
274 | @property
275 | def name(self) -> str:
276 | return self.__name
277 |
278 | @property
279 | def _required(self) -> bool:
280 | return self.__required
281 |
282 | def __new__(cls, name: str = "", default: List[str] = [], **kwargs):
283 | obj = super().__new__(cls)
284 | obj.extend(default)
285 | return obj
286 |
287 |
288 | class ModelDict(dict):
289 | pass
290 |
291 |
292 | class Environment(dict):
293 | pass
294 |
295 |
296 | class RegGroups(tuple):
297 | pass
298 |
299 |
300 | class RegGroup(str):
301 |
302 | def __init__(self, group=0, **kwargs):
303 | self._group = group
304 |
305 | @property
306 | def group(self) -> int:
307 | return self._group
308 |
309 | def __new__(cls, group=0, **kwargs):
310 | if "_value" not in kwargs:
311 | val = ""
312 | else:
313 | val = kwargs["_value"]
314 | obj = super().__new__(cls, val)
315 | return obj
316 |
317 |
318 | class Header(ParamStringValue):
319 | pass
320 |
321 |
322 | class JSONBody(dict):
323 | pass
324 |
325 |
326 | class BytesBody(bytes):
327 | pass
328 |
329 |
330 | """
331 | " The folowing beans are used in Response
332 | """
333 |
334 |
335 | class StaticFile:
336 |
337 | def __init__(self, file_path, content_type="application/octet-stream"):
338 | self.file_path = file_path
339 | self.content_type = content_type
340 |
341 |
342 | class Response:
343 | """Response"""
344 |
345 | def __init__(self,
346 | status_code: int = 200,
347 | headers: Dict[str, str] = None,
348 | body: Union[str, dict, StaticFile, bytes] = ""):
349 | self.status_code = status_code
350 | self.__headers = headers if headers is not None else {}
351 | self.__body = ""
352 | self.__cookies = Cookies()
353 | self.__set_body(body)
354 |
355 | @property
356 | def cookies(self) -> http.cookies.SimpleCookie:
357 | return self.__cookies
358 |
359 | @cookies.setter
360 | def cookies(self, val: http.cookies.SimpleCookie) -> None:
361 | assert isinstance(val, http.cookies.SimpleCookie)
362 | self.__cookies = val
363 |
364 | @property
365 | def body(self):
366 | return self.__body
367 |
368 | @body.setter
369 | def body(self, val: Union[str, dict, StaticFile, bytes]) -> None:
370 | self.__set_body(val)
371 |
372 | def __set_body(self, val):
373 | assert val is None \
374 | or isinstance(val, str) \
375 | or isinstance(val, dict) \
376 | or isinstance(val, StaticFile) \
377 | or isinstance(val, bytes) \
378 | or isinstance(val, bytearray), \
379 | "Body type is not supported."
380 | self.__body = val
381 |
382 | @property
383 | def headers(self) -> Dict[str, Union[list, str]]:
384 | return self.__headers
385 |
386 | def set_header(self, key: str, value: str) -> None:
387 | self.__headers[key] = value
388 |
389 | def add_header(self, key: str, value: Union[str, list]) -> None:
390 | if key not in self.__headers.keys():
391 | self.__headers[key] = value
392 | return
393 | if not isinstance(self.__headers[key], list):
394 | self.__headers[key] = [self.__headers[key]]
395 | if isinstance(value, list):
396 | self.__headers[key].extend(value)
397 | else:
398 | self.__headers[key].append(value)
399 |
400 | def add_headers(self, headers: Dict[str, Union[str, List[str]]] = {}) -> None:
401 | if headers is not None:
402 | for k, v in headers.items():
403 | self.add_header(k, v)
404 |
405 | @abstractmethod
406 | def send_error(self, status_code: int, message: str = ""):
407 | return NotImplemented
408 |
409 | @abstractmethod
410 | def send_redirect(self, url: str):
411 | return NotImplemented
412 |
413 | @abstractmethod
414 | def send_response(self):
415 | return NotImplemented
416 |
417 | @abstractmethod
418 | def write_bytes(self, data: bytes):
419 | pass
420 |
421 | @abstractmethod
422 | def close(self):
423 | pass
424 |
425 |
426 | class HttpError(Exception):
427 |
428 | def __init__(self, code: int = 400, message: str = "", explain: str = ""):
429 | super().__init__("HTTP_ERROR[%d] %s" % (code, message))
430 | self.code: int = code
431 | self.message: str = message
432 | self.explain: str = explain
433 |
434 |
435 | class Redirect:
436 |
437 | def __init__(self, url: str):
438 | self.__url = url
439 |
440 | @property
441 | def url(self) -> str:
442 | return self.__url
443 |
444 |
445 | """
446 | " Use both in request and response
447 | """
448 |
449 |
450 | class Headers(dict):
451 |
452 | def __init__(self, headers: Dict[str, Union[str, List[str]]] = {}):
453 | self.update(headers)
454 |
455 |
456 | class Cookie(http.cookies.Morsel):
457 |
458 | def __init__(self,
459 | name: str = "",
460 | default: str = "",
461 | default_options: Dict[str, str] = {},
462 | required: bool = False):
463 | super().__init__()
464 | self.__name = name
465 | self.__required = required
466 | if name is not None and name != "":
467 | self.set(name, default, default)
468 | self.update(default_options)
469 |
470 | @property
471 | def name(self) -> str:
472 | return self.__name
473 |
474 | @property
475 | def _required(self) -> bool:
476 | return self.__required
477 |
478 |
479 | class FilterContext:
480 |
481 | @property
482 | def request(self) -> Request:
483 | return NotImplemented
484 |
485 | @property
486 | def response(self) -> Response:
487 | return NotImplemented
488 |
489 | @abstractmethod
490 | def do_chain(self):
491 | return NotImplemented
492 |
493 |
494 | class WebsocketRequest:
495 |
496 | def __init__(self):
497 | self.headers: Dict[str, str] = {} # Request headers
498 | self.__cookies = Cookies()
499 | self.query_string: str = "" # Query String
500 | self.path_values: Dict[str, str] = {}
501 | self.reg_groups = () # If controller is matched via regexp, then ,all groups are save here
502 | self.path: str = "" # Path
503 | self.__parameters = {} # Parameters, key-value array, merged by query string and request body if the `Content-Type` in request header is `application/x-www-form-urlencoded` or `multipart/form-data`
504 | # Parameters, key-value, if more than one parameters with the same key, only the first one will be stored.
505 | self.__parameter = {}
506 |
507 | @property
508 | def cookies(self) -> Cookies:
509 | return self.__cookies
510 |
511 | @property
512 | def parameters(self) -> Dict[str, List[str]]:
513 | return self.__parameters
514 |
515 | @parameters.setter
516 | def parameters(self, val: Dict[str, List[str]]):
517 | self.__parameters = val
518 | self.__parameter = {}
519 | for k, v in self.__parameters.items():
520 | self.__parameter[k] = v[0]
521 |
522 | @property
523 | def parameter(self) -> Dict[str, str]:
524 | return self.__parameter
525 |
526 | def get_parameter(self, key: str, default: str = None) -> str:
527 | if key not in self.parameters.keys():
528 | return default
529 | else:
530 | return self.parameter[key]
531 |
532 |
533 | class WebsocketSession:
534 |
535 | @property
536 | def id(self) -> str:
537 | return NotImplemented
538 |
539 | @property
540 | def request(self) -> WebsocketRequest:
541 | return NotImplemented
542 |
543 | @property
544 | def is_closed(self) -> bool:
545 | return NotImplemented
546 |
547 | @abstractmethod
548 | def send(self, message: Union[str, bytes], opcode: int = None, chunk_size: int = 0):
549 | return NotImplemented
550 |
551 | @abstractmethod
552 | def send_text(self, message: str, chunk_size: int = 0):
553 | return NotImplemented
554 |
555 | @abstractmethod
556 | def send_binary(self, binary: bytes, chunk_size: int = 0):
557 | return NotImplemented
558 |
559 | @abstractmethod
560 | def send_file(self, path: str, chunk_size: int = 0):
561 | return NotImplemented
562 |
563 | @abstractmethod
564 | def send_pone(self, message: bytes = b''):
565 | return NotImplemented
566 |
567 | @abstractmethod
568 | def send_ping(self, message: bytes = b''):
569 | return NotImplemented
570 |
571 | @abstractmethod
572 | def close(self, reason: str = ""):
573 | return NotImplemented
574 |
575 |
576 | class WebsocketCloseReason(str):
577 |
578 | def __init__(self,
579 | message: str = "",
580 | code: int = None,
581 | reason: str = '') -> None:
582 | self.__message: str = message
583 | self.__code: int = code
584 | self.__reason: str = reason
585 |
586 | @property
587 | def message(self) -> str:
588 | return self.__message
589 |
590 | @property
591 | def code(self) -> int:
592 | return self.__code
593 |
594 | @property
595 | def reason(self) -> str:
596 | return self.__reason
597 |
598 | def __new__(cls, message: str = "", code: int = "", reason: str = '', **kwargs):
599 | obj = super().__new__(cls, message)
600 | return obj
601 |
602 |
603 | class WebsocketHandler:
604 |
605 | @abstractmethod
606 | def on_handshake(self, request: WebsocketRequest = None):
607 | """
608 | "
609 | " You can get path/headers/path_values/cookies/query_string/query_parameters from request.
610 | "
611 | " You should return a tuple means (http_status_code, headers)
612 | "
613 | " If status code in (0, None, 101), the websocket will be connected, or will return the status you return.
614 | "
615 | " All headers will be send to client
616 | "
617 | """
618 | return None
619 |
620 | @abstractmethod
621 | def on_open(self, session: WebsocketSession = None):
622 | """
623 | "
624 | " Will be called when the connection opened.
625 | "
626 | """
627 | pass
628 |
629 | @abstractmethod
630 | def on_close(self, session: WebsocketSession = None, reason: WebsocketCloseReason = None):
631 | """
632 | "
633 | " Will be called when the connection closed.
634 | "
635 | """
636 | pass
637 |
638 | @abstractmethod
639 | def on_ping_message(self, session: WebsocketSession = None, message: bytes = b''):
640 | """
641 | "
642 | " Will be called when receive a ping message. Will send all the message bytes back to client by default.
643 | "
644 | """
645 | session.send_pone(message)
646 |
647 | @abstractmethod
648 | def on_pong_message(self, session: WebsocketSession = None, message: bytes = ""):
649 | """
650 | "
651 | " Will be called when receive a pong message.
652 | "
653 | """
654 | pass
655 |
656 | @abstractmethod
657 | def on_text_message(self, session: WebsocketSession = None, message: str = ""):
658 | """
659 | "
660 | " Will be called when receive a text message.
661 | "
662 | """
663 | pass
664 |
665 | @abstractmethod
666 | def on_binary_message(self, session: WebsocketSession = None, message: bytes = b''):
667 | """
668 | "
669 | " Will be called when receive a binary message if you have not consumed all the bytes in `on_binary_frame`
670 | " method.
671 | "
672 | """
673 | pass
674 |
675 | @abstractmethod
676 | def on_binary_frame(self, session: WebsocketSession = None, fin: bool = False, frame_payload: bytes = b''):
677 | """
678 | "
679 | " When server receive a fragmented message, this method will be called every time when a frame is received,
680 | " you can consume all the bytes in this method, e.g. save all bytes to a file.
681 | "
682 | " If you does not implement this method or return a True in this method, all the bytes will be cached in
683 | " memory and sent to your `on_binary_message` method after all frames are received.
684 | "
685 | """
686 | return True
687 |
--------------------------------------------------------------------------------
/naja_atra/request_handlers/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | Copyright (c) 2018 Keijack Wu
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 |
--------------------------------------------------------------------------------
/naja_atra/request_handlers/http_request_handler.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | Copyright (c) 2018 Keijack Wu
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 |
25 | import html
26 | import re
27 | import http.client
28 | import email.parser
29 | import email.message
30 | import socketserver
31 | import asyncio
32 | import socket
33 |
34 |
35 | from typing import Any, Dict
36 | from http import HTTPStatus
37 | from urllib.parse import unquote
38 | from asyncio.streams import StreamReader, StreamWriter
39 | from http import HTTPStatus
40 |
41 | from .. import name, version
42 | from ..models import RequestBodyReader
43 | from ..utils import http_utils
44 | from ..utils.logger import get_logger
45 | from ..http_servers.routing_server import RoutingServer
46 | from .http_controller_handler import HTTPControllerHandler
47 | from .websocket_controller_handler import WebsocketControllerHandler
48 |
49 |
50 | _LINE_MAX_BYTES = 65536
51 | _MAXHEADERS = 100
52 |
53 | _logger = get_logger("naja_atra.request_handlers.http_request_handler")
54 |
55 |
56 | class RequestWriter:
57 |
58 | def __init__(self, writer: StreamWriter) -> None:
59 | self.writer: StreamWriter = writer
60 |
61 | def send(self, data: bytes):
62 | self.writer.write(data)
63 |
64 |
65 | class HttpRequestHandler:
66 |
67 | server_version = f"{name}/{version}"
68 |
69 | default_request_version = "HTTP/1.1"
70 |
71 | # The version of the HTTP protocol we support.
72 | # Set this to HTTP/1.1 to enable automatic keepalive
73 | protocol_version = "HTTP/1.1"
74 |
75 | # MessageClass used to parse headers
76 | _message_class = http.client.HTTPMessage
77 |
78 | # hack to maintain backwards compatibility
79 | responses = {
80 | v: (v.phrase, v.description)
81 | for v in HTTPStatus.__members__.values()
82 | }
83 |
84 | def __init__(self, reader: StreamReader, writer: StreamWriter, request_writer=None, routing_conf: RoutingServer = None) -> None:
85 | self.routing_conf: RoutingServer = routing_conf
86 | self.reader: StreamReader = reader
87 | self.writer: StreamWriter = writer
88 | self.request_writer: RequestWriter = request_writer if request_writer else RequestWriter(
89 | writer)
90 |
91 | self.requestline = ''
92 | self.request_version = ''
93 | self.command = ''
94 | self.path = ''
95 | self.request_path = ''
96 | self.query_string = ''
97 | self.query_parameters = {}
98 | self.headers = {}
99 |
100 | self.close_connection = True
101 | self._keep_alive = self.routing_conf.keep_alive
102 | self._connection_idle_time = routing_conf.connection_idle_time
103 | self._keep_alive_max_req = routing_conf.keep_alive_max_request
104 | self.req_count = 0
105 |
106 | async def parse_request(self):
107 | self.req_count += 1
108 | try:
109 | if hasattr(self.reader, "connection"):
110 | # For blocking io. asyncio.wait_for will not raise TimeoutError if the io is blocked.
111 | self.reader.connection.settimeout(self._connection_idle_time)
112 | raw_requestline = await asyncio.wait_for(self.reader.readline(), self._connection_idle_time)
113 | if hasattr(self.reader, "connection") and hasattr(self.reader, "timeout"):
114 | # For blocking io. Set the Original timeout to the connection.
115 | self.reader.connection.settimeout(self.reader.timeout)
116 | except asyncio.TimeoutError:
117 | _logger.warn("Wait for reading request line timeout. ")
118 | return False
119 | if len(raw_requestline) > _LINE_MAX_BYTES:
120 | self.requestline = ''
121 | self.request_version = ''
122 | self.command = ''
123 | self.send_error(HTTPStatus.REQUEST_URI_TOO_LONG)
124 | return False
125 | if not raw_requestline:
126 | self.close_connection = True
127 | return False
128 | self.command = None
129 | self.request_version = version = self.default_request_version
130 | self.close_connection = True
131 | requestline = str(raw_requestline, 'iso-8859-1')
132 | requestline = requestline.rstrip('\r\n')
133 | self.requestline = requestline
134 | words = requestline.split()
135 | if len(words) == 0:
136 | return False
137 |
138 | if len(words) >= 3: # Enough to determine protocol version
139 | version = words[-1]
140 | try:
141 | if not version.startswith('HTTP/'):
142 | raise ValueError
143 | base_version_number = version.split('/', 1)[1]
144 | version_number = base_version_number.split(".")
145 | # RFC 2145 section 3.1 says there can be only one "." and
146 | # - major and minor numbers MUST be treated as
147 | # separate integers;
148 | # - HTTP/2.4 is a lower version than HTTP/2.13, which in
149 | # turn is lower than HTTP/12.3;
150 | # - Leading zeros MUST be ignored by recipients.
151 | if len(version_number) != 2:
152 | raise ValueError
153 | version_number = int(version_number[0]), int(version_number[1])
154 | except (ValueError, IndexError):
155 | self.send_error(
156 | HTTPStatus.BAD_REQUEST,
157 | f"Bad request version {version}")
158 | return False
159 |
160 | if version_number >= (2, 0):
161 | self.send_error(
162 | HTTPStatus.HTTP_VERSION_NOT_SUPPORTED,
163 | f"Invalid HTTP version {base_version_number}")
164 | return False
165 | self.request_version = version
166 | _logger.info(f"request version: {self.request_version}")
167 | if not 2 <= len(words) <= 3:
168 | self.send_error(
169 | HTTPStatus.BAD_REQUEST,
170 | "Bad request syntax (%r)" % requestline)
171 | return False
172 | command, path = words[:2]
173 | if len(words) == 2:
174 | self.close_connection = True
175 | if command != 'GET':
176 | self.send_error(
177 | HTTPStatus.BAD_REQUEST,
178 | "Bad HTTP/0.9 request type (%r)" % command)
179 | return False
180 | self.command, self.path = command, path
181 |
182 | self.request_path = self._get_request_path(self.path)
183 |
184 | self.query_string = self.__get_query_string(self.path)
185 |
186 | self.query_parameters = http_utils.decode_query_string(
187 | self.query_string)
188 |
189 | # Examine the headers and look for a Connection directive.
190 | try:
191 | self.headers = await self.parse_headers()
192 | except http.client.LineTooLong as err:
193 | self.send_error(
194 | HTTPStatus.REQUEST_HEADER_FIELDS_TOO_LARGE,
195 | "Line too long",
196 | str(err))
197 | return False
198 | except http.client.HTTPException as err:
199 | self.send_error(
200 | HTTPStatus.REQUEST_HEADER_FIELDS_TOO_LARGE,
201 | "Too many headers",
202 | str(err)
203 | )
204 | return False
205 |
206 | conntype = self.headers.get('Connection', '')
207 |
208 | self.close_connection = not self._keep_alive or conntype.lower(
209 | ) != 'keep-alive' or self.protocol_version != "HTTP/1.1"
210 |
211 | # Examine the headers and look for an Expect directive
212 | expect = self.headers.get('Expect', "")
213 | if (expect.lower() == "100-continue" and
214 | self.protocol_version >= "HTTP/1.1" and
215 | self.request_version >= "HTTP/1.1"):
216 | if not self.handle_expect_100():
217 | return False
218 | return True
219 |
220 | async def parse_headers(self):
221 | """Parses only RFC2822 headers from a file pointer.
222 |
223 | email Parser wants to see strings rather than bytes.
224 | But a TextIOWrapper around self.rfile would buffer too many bytes
225 | from the stream, bytes which we later need to read as bytes.
226 | So we read the correct bytes here, as bytes, for email Parser
227 | to parse.
228 |
229 | """
230 | headers = []
231 | while True:
232 | line = await self.reader.readline()
233 | if len(line) > _LINE_MAX_BYTES:
234 | raise http.client.LineTooLong("header line")
235 | headers.append(line)
236 | if len(headers) > _MAXHEADERS:
237 | raise http.client.HTTPException(
238 | f"got more than {_MAXHEADERS} headers")
239 | if line in (b'\r\n', b'\n', b''):
240 | break
241 | hstring = b''.join(headers).decode('iso-8859-1')
242 |
243 | return email.parser.Parser(_class=self._message_class).parsestr(hstring)
244 |
245 | def handle_expect_100(self):
246 | """Decide what to do with an "Expect: 100-continue" header.
247 |
248 | If the client is expecting a 100 Continue response, we must
249 | respond with either a 100 Continue or a final response before
250 | waiting for the request body. The default is to always respond
251 | with a 100 Continue. You can behave differently (for example,
252 | reject unauthorized requests) by overriding this method.
253 |
254 | This method should either return True (possibly after sending
255 | a 100 Continue response) or send an error response and return
256 | False.
257 |
258 | """
259 | self.send_response_only(HTTPStatus.CONTINUE)
260 | self.end_headers()
261 | return True
262 |
263 | def __get_query_string(self, ori_path: str):
264 | parts = ori_path.split('?')
265 | if len(parts) == 2:
266 | return parts[1]
267 | else:
268 | return ""
269 |
270 | def _get_request_path(self, ori_path: str):
271 | path = ori_path.split('?', 1)[0]
272 | path = path.split('#', 1)[0]
273 | path = http_utils.remove_url_first_slash(path)
274 | path = unquote(path)
275 | return path
276 |
277 | def send_error(self, code: int, message: str = None, explain: str = None, headers: Dict[str, str] = {}):
278 | try:
279 | shortmsg, longmsg = self.responses[code]
280 | except KeyError:
281 | shortmsg, longmsg = '???', '???'
282 | if message is None:
283 | message = shortmsg
284 | if explain is None:
285 | explain = longmsg
286 | self.log_error(f"code {code}, message {message}")
287 | self.send_response(code, message)
288 | self.send_header('Connection', 'close')
289 |
290 | # Message body is omitted for cases described in:
291 | # - RFC7230: 3.3. 1xx, 204(No Content), 304(Not Modified)
292 | # - RFC7231: 6.3.6. 205(Reset Content)
293 | body = None
294 | if (code >= 200 and
295 | code not in (HTTPStatus.NO_CONTENT,
296 | HTTPStatus.RESET_CONTENT,
297 | HTTPStatus.NOT_MODIFIED)):
298 | try:
299 | content: Any = self.routing_conf.error_page(code, html.escape(
300 | message, quote=False), html.escape(explain, quote=False))
301 | except:
302 | content: str = html.escape(
303 | message, quote=False) + ":" + html.escape(explain, quote=False)
304 | content_type, body = http_utils.decode_response_body_to_bytes(
305 | content)
306 |
307 | self.send_header("Content-Type", content_type)
308 | self.send_header('Content-Length', str(len(body)))
309 | if headers:
310 | for h_name, h_val in headers.items():
311 | self.send_header(h_name, h_val)
312 | self.end_headers()
313 |
314 | if self.command != 'HEAD' and body:
315 | self.writer.write(body)
316 |
317 | def send_response(self, code, message=None):
318 | """Add the response header to the headers buffer and log the
319 | response code.
320 |
321 | Also send two standard headers with the server software
322 | version and the current date.
323 |
324 | """
325 | self.log_request(code)
326 | self.send_response_only(code, message)
327 | self.send_header('Server', self.server_version)
328 | self.send_header('Date', http_utils.date_time_string())
329 |
330 | def send_header(self, keyword: str, value: str):
331 | """Send a MIME header to the headers buffer."""
332 | if keyword.lower() == 'connection':
333 | if value.lower() == 'close':
334 | self.close_connection = True
335 | elif value.lower() == 'keep-alive':
336 | if self._keep_alive:
337 | self.close_connection = False
338 | else:
339 | _logger.warning(
340 | f"Keep Alive configuration is set to False, won't send keep-alive header.")
341 | return
342 |
343 | if self.request_version != 'HTTP/0.9':
344 | if not hasattr(self, '_headers_buffer'):
345 | self._headers_buffer = []
346 | self._headers_buffer.append(
347 | f"{keyword}: {value}\r\n".encode('latin-1', errors='strict'))
348 |
349 | def end_headers(self):
350 | """Send the blank line ending the MIME headers."""
351 | if self.request_version != 'HTTP/0.9':
352 | self._headers_buffer.append(b"\r\n")
353 | self.flush_headers()
354 |
355 | def flush_headers(self):
356 | if hasattr(self, '_headers_buffer'):
357 | self.writer.write(b"".join(self._headers_buffer))
358 | self._headers_buffer = []
359 |
360 | def send_response_only(self, code, message: str = None):
361 | """Send the response header only."""
362 | if self.request_version != 'HTTP/0.9':
363 | if message is None:
364 | if code in self.responses:
365 | message = self.responses[code][0]
366 | else:
367 | message = ''
368 | if not hasattr(self, '_headers_buffer'):
369 | self._headers_buffer = []
370 | self._headers_buffer \
371 | .append(f"{self.protocol_version} {code} {message}\r\n".encode('latin-1', errors='strict'))
372 |
373 | def log_request(self, code='-', size='-'):
374 | if isinstance(code, HTTPStatus):
375 | code = code.value
376 | self.log_message('"%s" %s %s',
377 | self.requestline, str(code), str(size))
378 |
379 | def log_error(self, format, *args):
380 | self.log_message(format, *args)
381 |
382 | def log_message(self, format, *args):
383 | _logger.info(f"{format % args}")
384 |
385 | def set_prefer_keep_alive_params(self):
386 | pass
387 |
388 | def set_alive_params(self):
389 | if "Keep-Alive" in self.headers:
390 | ka_header = self.headers["Keep-Alive"]
391 | timeout_match = re.match(r"^.*timeout=(\d+).*$", ka_header)
392 | if timeout_match:
393 | self._connection_idle_time = int(timeout_match.group(1))
394 | max_match = re.match(r"^.*max=(\d+).*$", ka_header)
395 | if max_match:
396 | self._keep_alive_max_req = int(max_match.group(1))
397 |
398 | async def handle_request(self):
399 | parse_request_success = await self.parse_request()
400 | if not parse_request_success:
401 | return
402 | self.set_alive_params()
403 |
404 | if self.request_version == "HTTP/1.1" and self.command == "GET" and "Upgrade" in self.headers and self.headers["Upgrade"] == "websocket":
405 | _logger.debug("This is a websocket connection. ")
406 | ws_handler = WebsocketControllerHandler(self)
407 | await ws_handler.handle_request()
408 | self.writer.write_eof()
409 | return
410 |
411 | await self.handle_http_request()
412 | while not self.close_connection:
413 | _logger.debug("Keep-Alive, read next request. ")
414 | parse_request_success = await self.parse_request()
415 | if not parse_request_success:
416 | _logger.debug("parse request fails, return. ")
417 | return
418 | if self.req_count >= self._keep_alive_max_req:
419 | self.send_response("Connection", "close")
420 | await self.handle_http_request()
421 | _logger.debug("Handle a keep-alive request successfully!")
422 |
423 | async def handle_http_request(self):
424 | try:
425 | http_handler = HTTPControllerHandler(self)
426 | await http_handler.handle_request()
427 | if self.writer.can_write_eof():
428 | self.writer.write_eof()
429 | except socket.timeout as e:
430 | # a read or a write timed out. Discard this connection
431 | self.log_error("Request timed out: %r", e)
432 | self.close_connection = True
433 | return
434 |
435 |
436 | class SocketServerStreamRequestHandlerWraper(socketserver.StreamRequestHandler, RequestBodyReader):
437 |
438 | server_version = HttpRequestHandler.server_version
439 |
440 | # Wrapper method for readline
441 | async def readline(self):
442 | return self.rfile.readline(_LINE_MAX_BYTES)
443 |
444 | async def read(self, n: int = -1):
445 | return self.rfile.read(n)
446 |
447 | def write(self, data: bytes):
448 | self.wfile.write(data)
449 |
450 | def can_write_eof(self) -> bool:
451 | return True
452 |
453 | def write_eof(self):
454 | self.wfile.flush()
455 |
456 | def close(self):
457 | self.wfile.close()
458 |
459 | def handle(self) -> None:
460 | handler: HttpRequestHandler = HttpRequestHandler(
461 | self, self, request_writer=self.request, routing_conf=self.server)
462 | asyncio.run(handler.handle_request())
463 |
464 | def finish(self) -> None:
465 | _logger.debug("Finish a socket connection.")
466 | return super().finish()
467 |
--------------------------------------------------------------------------------
/naja_atra/request_handlers/http_session_local_impl.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | Copyright (c) 2018 Keijack Wu
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 |
25 | import threading
26 | import uuid
27 | import time
28 |
29 | from typing import Any, Dict, List, Tuple
30 | from threading import RLock
31 | from ..models import HttpSession, HttpSessionFactory
32 | from ..utils.logger import get_logger
33 |
34 | _logger = get_logger("naja_atra.request_handlers.http_session_local_impl")
35 |
36 | _SESSION_TIME_CLEANING_INTERVAL = 60
37 |
38 |
39 | def _get_from_dict(adict: Dict[str, Any], key: str) -> Any:
40 | if key not in adict:
41 | return None
42 | try:
43 | return adict[key]
44 | except KeyError:
45 | _logger.debug("key %s was deleted in other thread.")
46 | return None
47 |
48 |
49 | class LocalSessionHolder:
50 |
51 | def __init__(self):
52 | self.__sessions: Dict[str, HttpSession] = {}
53 | self.__session_lock = RLock()
54 | self.__started = False
55 | self.__clearing_thread = threading.Thread(target=self._clear_time_out_session_in_bg, daemon=True)
56 |
57 | def _start_cleaning(self):
58 | if not self.__started:
59 | self.__started = True
60 | self.__clearing_thread.start()
61 |
62 | def _clear_time_out_session_in_bg(self):
63 | while True:
64 | time.sleep(_SESSION_TIME_CLEANING_INTERVAL)
65 | self._clear_time_out_session()
66 |
67 | def _clear_time_out_session(self):
68 | timeout_sessions: List[HttpSession] = []
69 | for v in self.__sessions.values():
70 | if not v.is_valid:
71 | timeout_sessions.append(v)
72 | for session in timeout_sessions:
73 | session.invalidate()
74 |
75 | def clean_session(self, session_id: str):
76 | if session_id in self.__sessions:
77 | try:
78 | _logger.debug("session[#%s] is being cleaned" % session_id)
79 | sess = self.__sessions[session_id]
80 | if not sess.is_valid:
81 | del self.__sessions[session_id]
82 | except KeyError:
83 | _logger.debug("Session[#%s] in session cache is already deleted. " % session_id)
84 |
85 | def get_session(self, session_id: str) -> HttpSession:
86 | if not session_id:
87 | return None
88 | sess: HttpSession = _get_from_dict(self.__sessions, session_id)
89 | if sess and sess.is_valid:
90 | return sess
91 | else:
92 | return None
93 |
94 | def cache_session(self, session: HttpSession):
95 | if not session:
96 | return None
97 | sess: HttpSession = _get_from_dict(self.__sessions, session.id)
98 |
99 | if sess:
100 | if session is sess:
101 | return
102 | sess.invalidate()
103 | with self.__session_lock:
104 | self.__sessions[session.id] = session
105 | self._start_cleaning()
106 |
107 | return session
108 |
109 |
110 | class LocalSessionImpl(HttpSession):
111 |
112 | def __init__(self, id: str, creation_time: float, session_holder: LocalSessionHolder):
113 | super().__init__()
114 | self.__id = id
115 | self.__creation_time = creation_time
116 | self.__last_accessed_time = creation_time
117 | self.__is_new = True
118 | self.__attr_lock = RLock()
119 | self.__attrs = {}
120 | self.__session_holder = session_holder
121 |
122 | @property
123 | def id(self) -> str:
124 | return self.__id
125 |
126 | @property
127 | def creation_time(self) -> float:
128 | return self.__creation_time
129 |
130 | @property
131 | def last_accessed_time(self) -> float:
132 | return self.__last_accessed_time
133 |
134 | @property
135 | def is_new(self) -> bool:
136 | return self.__is_new
137 |
138 | def _set_last_accessed_time(self, last_acessed_time: float):
139 | self.__last_accessed_time = last_acessed_time
140 | self.__is_new = False
141 |
142 | @property
143 | def attribute_names(self) -> Tuple:
144 | return tuple(self.__attrs.keys())
145 |
146 | def get_attribute(self, name: str) -> Any:
147 | return _get_from_dict(self.__attrs, name)
148 |
149 | def set_attribute(self, name: str, value: Any) -> None:
150 | with self.__attr_lock:
151 | self.__attrs[name] = value
152 |
153 | def invalidate(self) -> None:
154 | self._set_last_accessed_time(0)
155 | self.__session_holder.clean_session(session_id=self.id)
156 |
157 |
158 | class LocalSessionFactory(HttpSessionFactory):
159 |
160 | def __init__(self):
161 | self.__session_holder = LocalSessionHolder()
162 | self.__session_lock = RLock()
163 |
164 | def _create_local_session(self, session_id: str) -> HttpSession:
165 | if session_id:
166 | sid = session_id
167 | else:
168 | sid = uuid.uuid4().hex
169 | return LocalSessionImpl(sid, time.time(), self.__session_holder)
170 |
171 | def get_session(self, session_id: str, create: bool = False) -> HttpSession:
172 | sess: LocalSessionImpl = self.__session_holder.get_session(session_id)
173 | if sess:
174 | sess._set_last_accessed_time(time.time())
175 | return sess
176 | if not create:
177 | return None
178 | with self.__session_lock:
179 | session = self._create_local_session(session_id)
180 | self.__session_holder.cache_session(session)
181 | return session
182 |
183 |
--------------------------------------------------------------------------------
/naja_atra/request_handlers/model_bindings.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | Copyright (c) 2018 Keijack Wu
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 |
25 | import json
26 | from typing import Any, Dict, List, Type
27 | from http.cookies import BaseCookie, SimpleCookie
28 | from ..models import ModelDict, Environment, RegGroup, RegGroups, HttpError, RequestBodyReader, \
29 | Headers, Response, Cookies, Cookie, JSONBody, BytesBody, Header, Parameters, PathValue, Parameter, \
30 | MultipartFile, Request, HttpSession
31 | from ..utils.logger import get_logger
32 |
33 | _logger = get_logger("naja_atra.models.model_bindings")
34 |
35 |
36 | class ModelBinding:
37 |
38 | def __init__(self, request: Request,
39 | response: Response,
40 | arg: str,
41 | arg_type,
42 | val=None
43 | ) -> None:
44 | self.request = request
45 | self.response = response
46 | self.arg_name = arg
47 | self.arg_type = arg_type
48 | self.default_value = val
49 | self._kws = {}
50 | if self.default_value is not None:
51 | self._kws["val"] = self.default_value
52 |
53 | async def bind(self) -> Any:
54 | pass
55 |
56 |
57 | class RequestModelBinding(ModelBinding):
58 |
59 | async def bind(self) -> Request:
60 | return self.request
61 |
62 |
63 | class SessionModelBinding(ModelBinding):
64 |
65 | async def bind(self) -> HttpSession:
66 | return self.request.get_session(True)
67 |
68 |
69 | class ResponseModelBinding(ModelBinding):
70 |
71 | async def bind(self) -> Response:
72 | return self.response
73 |
74 |
75 | class HeadersModelBinding(ModelBinding):
76 |
77 | async def bind(self) -> Headers:
78 | return Headers(self.request.headers)
79 |
80 |
81 | class RegGroupsModelBinding(ModelBinding):
82 |
83 | async def bind(self) -> RegGroups:
84 | return RegGroups(self.request.reg_groups)
85 |
86 |
87 | class EnvironmentModelBinding(ModelBinding):
88 |
89 | async def bind(self) -> Environment:
90 | return Environment(self.request.environment)
91 |
92 |
93 | class HeaderModelBinding(ModelBinding):
94 |
95 | async def bind(self) -> Header:
96 | return self.__build_header(self.arg_name, **self._kws)
97 |
98 | def __build_header(self, key, val=Header()):
99 | name = val.name if val.name is not None and val.name != "" else key
100 | if val._required and name not in self.request.headers:
101 | raise HttpError(400, "Missing Header",
102 | f"Header[{name}] is required.")
103 | if name in self.request.headers:
104 | v = self.request.headers[name]
105 | return Header(name=name, default=v, required=val._required)
106 | else:
107 | return val
108 |
109 |
110 | class CookiesModelBinding(ModelBinding):
111 |
112 | async def bind(self) -> Cookies:
113 | return self.request.cookies
114 |
115 |
116 | class CookieModelBinding(ModelBinding):
117 |
118 | async def bind(self) -> Cookie:
119 | return self.__build_cookie(self.arg_name, **self._kws)
120 |
121 | def __build_cookie(self, key, val=None):
122 | name = val.name if val.name is not None and val.name != "" else key
123 | if val._required and name not in self.request.cookies:
124 | raise HttpError(400, "Missing Cookie",
125 | f"Cookie[{name}] is required.")
126 | if name in self.request.cookies:
127 | morsel = self.request.cookies[name]
128 | cookie = Cookie()
129 | cookie.set(morsel.key, morsel.value, morsel.coded_value)
130 | cookie.update(morsel)
131 | return cookie
132 | else:
133 | return val
134 |
135 |
136 | class MultipartFileModelBinding(ModelBinding):
137 |
138 | async def bind(self) -> MultipartFile:
139 | return self.__build_multipart(self.arg_name, **self._kws)
140 |
141 | def __build_multipart(self, key, val=MultipartFile()):
142 | name = val.name if val.name is not None and val.name != "" else key
143 | if val._required and name not in self.request.parameter.keys():
144 | raise HttpError(400, "Missing Parameter",
145 | f"Parameter[{name}] is required.")
146 | if name in self.request.parameter.keys():
147 | v = self.request.parameter[name]
148 | if isinstance(v, MultipartFile):
149 | return v
150 | else:
151 | raise HttpError(
152 | 400, None, f"Parameter[{name}] should be a file.")
153 | else:
154 | return val
155 |
156 |
157 | class ParameterModelBinding(ModelBinding):
158 |
159 | async def bind(self) -> Parameter:
160 | return self.__build_param(self.arg_name, **self._kws)
161 |
162 | def __build_param(self, key, val=Parameter()):
163 | name = val.name if val.name is not None and val.name != "" else key
164 | if val._required and name not in self.request.parameter:
165 | raise HttpError(400, "Missing Parameter",
166 | f"Parameter[{name}] is required.")
167 | if name in self.request.parameter:
168 | v = self.request.parameter[name]
169 | return Parameter(name=name, default=v, required=val._required)
170 | else:
171 | return val
172 |
173 |
174 | class PathValueModelBinding(ModelBinding):
175 |
176 | async def bind(self) -> PathValue:
177 | return self.__build_path_value(self.arg_name, **self._kws)
178 |
179 | def __build_path_value(self, key, val=PathValue()):
180 | # wildcard value
181 | if len(self.request.path_values) == 1 and "__path_wildcard" in self.request.path_values:
182 | if val.name:
183 | _logger.warning(
184 | f"Wildcard value, `name` of the PathValue:: [{val.name}] will be ignored. ")
185 | return self.request.path_values["__path_wildcard"]
186 |
187 | # brace values
188 | name = val.name if val.name is not None and val.name != "" else key
189 | if name in self.request.path_values:
190 | return PathValue(name=name, _value=self.request.path_values[name])
191 | else:
192 | raise HttpError(
193 | 500, None, f"path name[{name}] not in your url mapping!")
194 |
195 |
196 | class ParametersModelBinding(ModelBinding):
197 |
198 | async def bind(self) -> Parameters:
199 | return self.__build_params(self.arg_name, **self._kws)
200 |
201 | def __build_params(self, key, val=Parameters()):
202 | name = val.name if val.name is not None and val.name != "" else key
203 | if val._required and name not in self.request.parameters:
204 | raise HttpError(400, "Missing Parameter",
205 | f"Parameter[{name}] is required.")
206 | if name in self.request.parameters:
207 | v = self.request.parameters[name]
208 | return Parameters(name=name, default=v, required=val._required)
209 | else:
210 | return val
211 |
212 |
213 | class RegGroupModelBinding(ModelBinding):
214 |
215 | async def bind(self) -> RegGroup:
216 | return self.__build_reg_group(**self._kws)
217 |
218 | def __build_reg_group(self, val: RegGroup = RegGroup(group=0)):
219 | if val.group >= len(self.request.reg_groups):
220 | raise HttpError(
221 | 400, None, f"RegGroup required an element at {val.group}, but the reg length is only {len(self.request.reg_groups)}")
222 | return RegGroup(group=val.group, _value=self.request.reg_groups[val.group])
223 |
224 |
225 | class JSONBodyModelBinding(ModelBinding):
226 |
227 | async def bind(self) -> Any:
228 | return self.__build_json_body()
229 |
230 | def __build_json_body(self):
231 | if "content-type" not in self.request._headers_keys_in_lowcase.keys() or \
232 | not self.request._headers_keys_in_lowcase["content-type"].lower().startswith("application/json"):
233 | raise HttpError(
234 | 400, None, 'The content type of this request must be "application/json"')
235 | return JSONBody(self.request.json)
236 |
237 |
238 | class RequestBodyReaderModelBinding(ModelBinding):
239 |
240 | async def bind(self) -> Any:
241 | return self.request.reader
242 |
243 |
244 | class BytesBodyModelBinding(ModelBinding):
245 |
246 | async def bind(self) -> Any:
247 | if not self.request._body:
248 | self.request._body = await self.request.reader.read()
249 | return BytesBody(self.request._body)
250 |
251 |
252 | class StrModelBinding(ModelBinding):
253 |
254 | async def bind(self) -> Any:
255 | return self.__build_str(self.arg_name, **self._kws)
256 |
257 | def __build_str(self, key, val=None):
258 | if key in self.request.parameter.keys():
259 | return Parameter(name=key, default=self.request.parameter[key], required=False)
260 | elif val is None:
261 | return None
262 | else:
263 | return Parameter(name=key, default=val, required=False)
264 |
265 |
266 | class BoolModelBinding(ModelBinding):
267 |
268 | async def bind(self) -> Any:
269 | return self.__build_bool(self.arg_name, **self._kws)
270 |
271 | def __build_bool(self, key, val=None):
272 | if key in self.request.parameter.keys():
273 | v = self.request.parameter[key]
274 | return v.lower() not in ("0", "false", "")
275 | else:
276 | return val
277 |
278 |
279 | class IntModelBinding(ModelBinding):
280 |
281 | async def bind(self) -> Any:
282 | return self.__build_int(self.arg_name, **self._kws)
283 |
284 | def __build_int(self, key, val=None):
285 | if key in self.request.parameter.keys():
286 | try:
287 | return int(self.request.parameter[key])
288 | except:
289 | raise HttpError(
290 | 400, None, f"Parameter[{key}] should be an int. ")
291 | else:
292 | return val
293 |
294 |
295 | class FloatModelBinding(ModelBinding):
296 |
297 | async def bind(self) -> Any:
298 | return self.__build_float(self.arg_name, **self._kws)
299 |
300 | def __build_float(self, key, val=None):
301 | if key in self.request.parameter.keys():
302 | try:
303 | return float(self.request.parameter[key])
304 | except:
305 | raise HttpError(
306 | 400, None, f"Parameter[{key}] should be an float. ")
307 | else:
308 | return val
309 |
310 |
311 | class ListModelBinding(ModelBinding):
312 |
313 | async def bind(self) -> Any:
314 | return self.__build_list(self.arg_name, **self._kws)
315 |
316 | def __build_list(self, key, target_type=list, val=[]):
317 | if key in self.request.parameters.keys():
318 | ori_list = self.request.parameters[key]
319 | else:
320 | ori_list = val
321 |
322 | if target_type == List[int]:
323 | try:
324 | return [int(p) for p in ori_list]
325 | except:
326 | raise HttpError(
327 | 400, None, f"One of the parameter[{key}] is not int. ")
328 | elif target_type == List[float]:
329 | try:
330 | return [float(p) for p in ori_list]
331 | except:
332 | raise HttpError(
333 | 400, None, f"One of the parameter[{key}] is not float. ")
334 | elif target_type == List[bool]:
335 | return [p.lower() not in ("0", "false", "") for p in ori_list]
336 | elif target_type in (List[dict], List[Dict]):
337 | try:
338 | return [json.loads(p) for p in ori_list]
339 | except:
340 | raise HttpError(
341 | 400, None, f"One of the parameter[{key}] is not JSON string. ")
342 | elif target_type == List[Parameter]:
343 | return [Parameter(name=key, default=p, required=False) for p in ori_list]
344 | else:
345 | return ori_list
346 |
347 |
348 | class ModelDictModelBinding(ModelBinding):
349 |
350 | async def bind(self) -> Any:
351 | return self.__build_model_dict()
352 |
353 | def __build_model_dict(self):
354 | mdict = ModelDict()
355 | for k, v in self.request.parameters.items():
356 | if len(v) == 1:
357 | mdict[k] = v[0]
358 | else:
359 | mdict[k] = v
360 | return mdict
361 |
362 |
363 | class DictModelBinding(ModelBinding):
364 |
365 | async def bind(self) -> Any:
366 | return self.__build_dict(self.arg_name, **self._kws)
367 |
368 | def __build_dict(self, key, val={}):
369 | if key in self.request.parameter.keys():
370 | try:
371 | return json.loads(self.request.parameter[key])
372 | except:
373 | raise HttpError(
374 | 400, None, f"Parameter[{key}] should be a JSON string.")
375 | else:
376 | return val
377 |
378 |
379 | class DefaultModelBinding(ModelBinding):
380 |
381 | async def bind(self) -> Any:
382 | return self.default_value
383 |
384 |
385 | class ModelBindingConf:
386 |
387 | def __init__(self) -> None:
388 | self.default_model_binding_type = DefaultModelBinding
389 | self.model_bingding_types: Dict[Type, Type[ModelBinding]] = {
390 | Request: RequestModelBinding,
391 | HttpSession: SessionModelBinding,
392 | Response: ResponseModelBinding,
393 | Headers: HeadersModelBinding,
394 | RegGroups: RegGroupsModelBinding,
395 | Environment: EnvironmentModelBinding,
396 | Header: HeaderModelBinding,
397 | Cookies: CookiesModelBinding,
398 | BaseCookie: CookiesModelBinding,
399 | SimpleCookie: CookiesModelBinding,
400 | Cookie: CookieModelBinding,
401 | MultipartFile: MultipartFileModelBinding,
402 | Parameter: ParameterModelBinding,
403 | PathValue: PathValueModelBinding,
404 | Parameters: ParametersModelBinding,
405 | RegGroup: RegGroupModelBinding,
406 | JSONBody: JSONBodyModelBinding,
407 | RequestBodyReader: RequestBodyReaderModelBinding,
408 | BytesBody: BytesBodyModelBinding,
409 | str: StrModelBinding,
410 | bool: BoolModelBinding,
411 | int: IntModelBinding,
412 | float: FloatModelBinding,
413 | list: ListModelBinding,
414 | List: ListModelBinding,
415 | List[str]: ListModelBinding,
416 | List[Parameter]: ListModelBinding,
417 | List[int]: ListModelBinding, List[float]: ListModelBinding,
418 | List[bool]: ListModelBinding,
419 | List[dict]: ListModelBinding,
420 | List[Dict]: ListModelBinding,
421 | ModelDict: ModelDictModelBinding,
422 | dict: DictModelBinding,
423 | Dict: DictModelBinding
424 | }
425 |
--------------------------------------------------------------------------------
/naja_atra/request_handlers/websocket_controller_handler.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | Copyright (c) 2018 Keijack Wu
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 |
25 |
26 | import asyncio
27 | import os
28 | import struct
29 | import errno
30 |
31 | from threading import Lock
32 | from base64 import b64encode
33 | from hashlib import sha1
34 | from typing import Dict, Tuple, Union, List
35 | from uuid import uuid4
36 | from socket import error as SocketError
37 |
38 | from ..utils.logger import get_logger
39 | from ..models import Headers, WebsocketCloseReason, WebsocketRequest, WebsocketSession
40 | from ..models import WEBSOCKET_OPCODE_BINARY, WEBSOCKET_OPCODE_CLOSE, WEBSOCKET_OPCODE_CONTINUATION, WEBSOCKET_OPCODE_PING, WEBSOCKET_OPCODE_PONG, WEBSOCKET_OPCODE_TEXT
41 | from ..models import DEFAULT_ENCODING
42 |
43 |
44 | _logger = get_logger("naja_atra.request_handlers.websocket_request_handler")
45 |
46 |
47 | '''
48 | https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API/Writing_WebSocket_servers
49 |
50 | https://datatracker.ietf.org/doc/html/rfc6455
51 |
52 | +-+-+-+-+-------+-+-------------+-------------------------------+
53 | 0 1 2 3
54 | 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
55 | +-+-+-+-+-------+-+-------------+-------------------------------+
56 | |F|R|R|R| opcode|M| Payload len | Extended payload length |
57 | |I|S|S|S| (4) |A| (7) | (16/64) |
58 | |N|V|V|V| |S| | (if payload len==126/127) |
59 | | |1|2|3| |K| | |
60 | +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
61 | | Extended payload length continued, if payload len == 127 |
62 | + - - - - - - - - - - - - - - - +-------------------------------+
63 | | |Masking-key, if MASK set to 1 |
64 | +-------------------------------+-------------------------------+
65 | | Masking-key (continued) | Payload Data |
66 | +-------------------------------- - - - - - - - - - - - - - - - +
67 | : Payload Data continued ... :
68 | + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
69 | | Payload Data continued ... |
70 | +---------------------------------------------------------------+
71 | '''
72 |
73 | FIN = 0x80
74 | OPCODE = 0x0f
75 | MASKED = 0x80
76 | PAYLOAD_LEN = 0x7f
77 | PAYLOAD_LEN_EXT16 = 0x7e
78 | PAYLOAD_LEN_EXT64 = 0x7f
79 | GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
80 | _BUFFER_SIZE = 1024 * 1024
81 |
82 | OPTYPES = {
83 | WEBSOCKET_OPCODE_CONTINUATION: "CONTINUATION",
84 | WEBSOCKET_OPCODE_TEXT: "TEXT",
85 | WEBSOCKET_OPCODE_BINARY: "BINARY",
86 | WEBSOCKET_OPCODE_CLOSE: "CLOSE",
87 | WEBSOCKET_OPCODE_PING: "PING",
88 | WEBSOCKET_OPCODE_PONG: "PONE",
89 | }
90 |
91 |
92 | class _ContinuationMessageCache:
93 |
94 | def __init__(self, opcode: int) -> None:
95 | self.opcode: int = opcode
96 | self.message_bytes: bytearray = bytearray()
97 |
98 |
99 | class WebsocketException(Exception):
100 |
101 | def __init__(self, reason: WebsocketCloseReason = None, graceful: bool = False) -> None:
102 | super().__init__(reason)
103 | self.__graceful: bool = graceful
104 | self.__reason: WebsocketCloseReason = reason
105 |
106 | @property
107 | def is_graceful(self) -> bool:
108 | return self.__graceful
109 |
110 | @property
111 | def reason(self) -> WebsocketCloseReason:
112 | return self.__reason
113 |
114 |
115 | class WebsocketControllerHandler:
116 |
117 | def __init__(self, http_request_handler) -> None:
118 | self.http_request_handler = http_request_handler
119 | self.request_writer = http_request_handler.request_writer
120 | self.routing_conf = http_request_handler.routing_conf
121 | self.send_response = http_request_handler.send_response_only
122 | self.send_header = http_request_handler.send_header
123 | self.reader = http_request_handler.reader
124 | self.keep_alive = True
125 | self.handshake_done = False
126 |
127 | handler_class, path_values, regroups = self.routing_conf.get_websocket_handler(
128 | http_request_handler.request_path)
129 | self.handler = handler_class.ctrl_object if handler_class else None
130 | self.ws_request = WebsocketRequest()
131 | self.ws_request.headers = http_request_handler.headers
132 | self.ws_request.path = http_request_handler.request_path
133 | self.ws_request.query_string = http_request_handler.query_string
134 | self.ws_request.parameters = http_request_handler.query_parameters
135 | self.ws_request.path_values = path_values
136 | self.ws_request.reg_groups = regroups
137 | if "cookie" in self.ws_request.headers:
138 | self.ws_request.cookies.load(self.ws_request.headers["cookie"])
139 | elif "Cookie" in self.ws_request.headers:
140 | self.ws_request.cookies.load(self.ws_request.headers["Cookie"])
141 | self.session = WebsocketSessionImpl(self, self.ws_request)
142 | self.close_reason: WebsocketCloseReason = None
143 |
144 | self._continution_cache: _ContinuationMessageCache = None
145 | self._send_msg_lock = Lock()
146 | self._send_frame_lock = Lock()
147 |
148 | @property
149 | def response_headers(self):
150 | if hasattr(self.http_request_handler, "_headers_buffer"):
151 | return self.http_request_handler._headers_buffer
152 | else:
153 | return []
154 |
155 | async def await_func(self, obj):
156 | if asyncio.iscoroutine(obj):
157 | return await obj
158 | return obj
159 |
160 | async def on_handshake(self) -> Tuple[int, Dict[str, List[str]]]:
161 | try:
162 | if not hasattr(self.handler, "on_handshake") or not callable(self.handler.on_handshake):
163 | return None, {}
164 | res = await self.await_func(self.handler.on_handshake(self.ws_request))
165 | http_status_code = None
166 | headers = {}
167 | if not res:
168 | pass
169 | elif isinstance(res, int):
170 | http_status_code = res
171 | elif isinstance(res, dict) or isinstance(res, Headers):
172 | headers = res
173 | elif isinstance(res, tuple):
174 | for item in res:
175 | if isinstance(item, int) and not http_status_code:
176 | http_status_code = item
177 | elif isinstance(item, dict) or isinstance(item, Headers):
178 | headers.update(item)
179 | else:
180 | _logger.warn(f"Endpoint[{self.ws_request.path}]")
181 | return http_status_code, headers
182 | except Exception as e:
183 | _logger.error(f"Error occurs when handshake. ")
184 | return 500, {}
185 |
186 | async def on_message(self, opcode: int, message_bytes: bytearray):
187 | try:
188 | if opcode == WEBSOCKET_OPCODE_CLOSE:
189 | _logger.info("Client asked to close connection.")
190 | if len(message_bytes) >= 2:
191 | code = struct.unpack(">H", message_bytes[0:2])[0]
192 | reason = message_bytes[2:].decode(
193 | 'UTF-8', errors="replace")
194 | else:
195 | code = None
196 | reason = ''
197 | raise WebsocketException(graceful=True, reason=WebsocketCloseReason(
198 | "Client asked to close connection.", code=code, reason=reason))
199 | elif opcode == WEBSOCKET_OPCODE_TEXT and hasattr(self.handler, "on_text_message") and callable(self.handler.on_text_message):
200 | await self.await_func(self.handler.on_text_message(self.session, message_bytes.decode("UTF-8", errors="replace")))
201 | elif opcode == WEBSOCKET_OPCODE_PING and hasattr(self.handler, "on_ping_message") and callable(self.handler.on_ping_message):
202 | await self.await_func(self.handler.on_ping_message(self.session, bytes(message_bytes)))
203 | elif opcode == WEBSOCKET_OPCODE_PONG and hasattr(self.handler, "on_pong_message") and callable(self.handler.on_pong_message):
204 | await self.await_func(self.handler.on_pong_message(self.session, bytes(message_bytes)))
205 | elif opcode == WEBSOCKET_OPCODE_BINARY and self._continution_cache.message_bytes and hasattr(self.handler, "on_binary_message") and callable(self.handler.on_binary_message):
206 | await self.await_func(self.handler.on_binary_message(self.session, bytes(message_bytes)))
207 | except Exception as e:
208 | _logger.error(f"Error occurs when on message!")
209 | self.close(f"Error occurs when on_message. {e}")
210 |
211 | async def on_continuation_frame(self, first_frame_opcode: int, fin: int, message_frame: bytearray):
212 | try:
213 | if first_frame_opcode == WEBSOCKET_OPCODE_BINARY and hasattr(self.handler, "on_binary_frame") and callable(self.handler.on_binary_frame):
214 | should_append_to_cache = await self.await_func(self.handler.on_binary_frame(self.session, bool(fin), bytes(message_frame)))
215 | if should_append_to_cache == True:
216 | self._continution_cache.message_bytes.extend(message_frame)
217 | else:
218 | self._continution_cache.message_bytes.extend(message_frame)
219 | except Exception as e:
220 | _logger.error(f"Error occurs when on message!")
221 | self.close(f"Error occurs when on_message. {e}")
222 |
223 | async def on_open(self):
224 | try:
225 | if hasattr(self.handler, "on_open") and callable(self.handler.on_open):
226 | await self.await_func(self.handler.on_open(self.session))
227 | except Exception as e:
228 | _logger.error(f"Error occurs when on open!")
229 | self.close(f"Error occurs when on_open. {e}")
230 |
231 | async def on_close(self):
232 | try:
233 | if hasattr(self.handler, "on_close") and callable(self.handler.on_close):
234 | await self.await_func(self.handler.on_close(self.session, self.close_reason))
235 | except Exception as e:
236 | _logger.error(f"Error occurs when on close!")
237 |
238 | async def handle_request(self):
239 | while self.keep_alive:
240 | try:
241 | if not self.handshake_done:
242 | await self.handshake()
243 | else:
244 | await self.read_next_message()
245 | except WebsocketException as e:
246 | if not e.is_graceful:
247 | _logger.warning(
248 | f"Something's wrong, close connection: {e.reason}")
249 | else:
250 | _logger.info(f"Close connection: {e.reason}")
251 | self.keep_alive = False
252 | self.close_reason = e.reason
253 | except:
254 | _logger.exception("Errors occur when handling message!")
255 | self.keep_alive = False
256 | self.close_reason = WebsocketCloseReason(
257 | "Errors occur when handling message!")
258 |
259 | await self.on_close()
260 |
261 | async def handshake(self):
262 | if self.handler:
263 | code, headers = await self.on_handshake()
264 | if code and code != 101:
265 | self.keep_alive = False
266 | self.send_response(code)
267 | else:
268 | self.send_response(101, "Switching Protocols")
269 | self.send_header("Upgrade", "websocket")
270 | self.send_header("Connection", "Upgrade")
271 | self.send_header("Sec-WebSocket-Accept",
272 | self.calculate_response_key())
273 | if headers:
274 | for h_name, h_val in headers.items():
275 | self.send_header(h_name, h_val)
276 | else:
277 | self.keep_alive = False
278 | self.send_response(404)
279 |
280 | ws_res_headers = b"".join(self.response_headers) + b"\r\n"
281 | _logger.debug(ws_res_headers)
282 | self.request_writer.send(ws_res_headers)
283 | self.handshake_done = True
284 | if self.keep_alive == True:
285 | await self.on_open()
286 |
287 | def calculate_response_key(self):
288 | key: str = self.ws_request.headers["Sec-WebSocket-Key"] if "Sec-WebSocket-Key" in self.ws_request.headers else self.ws_request.headers["Sec-Websocket-Key"]
289 | _logger.debug(
290 | f"Sec-WebSocket-Key: {key}")
291 | key_hash = sha1(key.encode(errors="replace") +
292 | GUID.encode(errors="replace"))
293 | response_key = b64encode(key_hash.digest()).strip()
294 | return response_key.decode('ASCII', errors="replace")
295 |
296 | async def read_bytes(self, num):
297 | return await self.reader.read(num)
298 |
299 | async def _read_message_content(self) -> Tuple[int, int, bytearray]:
300 | _logger.debug(f"Read next websocket[{self.ws_request.path}] message")
301 | try:
302 | b1, b2 = await self.read_bytes(2)
303 | except ConnectionResetError as e:
304 | raise WebsocketException(
305 | graceful=True, reason=WebsocketCloseReason("Client closed connection."))
306 | except SocketError as e:
307 | if e.errno == errno.ECONNRESET:
308 | raise WebsocketException(
309 | graceful=True, reason=WebsocketCloseReason("Client closed connection."))
310 | b1, b2 = 0, 0
311 | except ValueError as e:
312 | b1, b2 = 0, 0
313 |
314 | fin = b1 & FIN
315 | opcode = b1 & OPCODE
316 | masked = b2 & MASKED
317 | payload_length = b2 & PAYLOAD_LEN
318 |
319 | if not masked:
320 | raise WebsocketException(
321 | reason=WebsocketCloseReason("Client is not masked."))
322 |
323 | if opcode not in OPTYPES.keys():
324 | raise WebsocketException(
325 | reason=WebsocketCloseReason(f"Unknown opcode {opcode}."))
326 |
327 | if opcode in (WEBSOCKET_OPCODE_PING, WEBSOCKET_OPCODE_PONG) and payload_length > 125:
328 | raise WebsocketException(reason=WebsocketCloseReason(
329 | f"Ping/Pong message payload is too large! The max length of the Ping/Pong messages is 125. but now is {payload_length}"))
330 |
331 | if payload_length == 126:
332 | hb = await self.reader.read(2)
333 | payload_length = struct.unpack(">H", hb)[0]
334 | elif payload_length == 127:
335 | qb = await self.reader.read(8)
336 | payload_length = struct.unpack(">Q", qb)[0]
337 |
338 | frame_bytes = bytearray()
339 | if payload_length > 0:
340 | masks = await self.read_bytes(4)
341 | payload = await self.read_bytes(payload_length)
342 | for encoded_byte in payload:
343 | frame_bytes.append(encoded_byte ^ masks[len(frame_bytes) % 4])
344 |
345 | return fin, opcode, frame_bytes
346 |
347 | async def read_next_message(self):
348 |
349 | fin, opcode, frame_bytes = await self._read_message_content()
350 |
351 | if fin and opcode != WEBSOCKET_OPCODE_CONTINUATION:
352 | # A normal frame, handle message.
353 | await self.on_message(opcode, frame_bytes)
354 | return
355 |
356 | if not fin and opcode != WEBSOCKET_OPCODE_CONTINUATION:
357 | # Fragment message: first frame, try to create a cache object.
358 | if opcode not in (WEBSOCKET_OPCODE_TEXT, WEBSOCKET_OPCODE_BINARY):
359 | raise WebsocketException(reason=WebsocketCloseReason(
360 | f"Control({OPTYPES[opcode]}) frames MUST NOT be fragmented"))
361 |
362 | if self._continution_cache is not None:
363 | # Check if another fragment message is being read.
364 | raise WebsocketException(reason=WebsocketCloseReason(
365 | "Another continution message is not yet finished. Close connection for this error!"))
366 |
367 | self._continution_cache = _ContinuationMessageCache(opcode)
368 |
369 | if self._continution_cache is None:
370 | # When the first frame is not send, close connection.
371 | raise WebsocketException(reason=WebsocketCloseReason(
372 | "A continuation fragment frame is received, but the start fragment is not yet received. "))
373 |
374 | await self.on_continuation_frame(self._continution_cache.opcode, fin, frame_bytes)
375 |
376 | if fin:
377 | # Fragment message: end of this message.
378 | await self.on_message(self._continution_cache.opcode, self._continution_cache.message_bytes)
379 | self._continution_cache = None
380 |
381 | def send_message(self, message: Union[bytes, str], chunk_size: int = 0):
382 | if isinstance(message, bytes):
383 | self.send_bytes(WEBSOCKET_OPCODE_TEXT,
384 | message, chunk_size=chunk_size)
385 | elif isinstance(message, str):
386 | self.send_bytes(WEBSOCKET_OPCODE_TEXT, message.encode(
387 | DEFAULT_ENCODING, errors="replace"), chunk_size=chunk_size)
388 | else:
389 | _logger.error(f"Cannot send message[{message}. ")
390 |
391 | def send_ping(self, message: Union[str, bytes]):
392 | if isinstance(message, bytes):
393 | self.send_bytes(WEBSOCKET_OPCODE_PING, message)
394 | elif isinstance(message, str):
395 | self.send_bytes(WEBSOCKET_OPCODE_PING, message.encode(
396 | DEFAULT_ENCODING, errors="replace"))
397 |
398 | def send_pong(self, message: Union[str, bytes]):
399 | if isinstance(message, bytes):
400 | self.send_bytes(WEBSOCKET_OPCODE_PONG, message)
401 | elif isinstance(message, str):
402 | self.send_bytes(WEBSOCKET_OPCODE_PONG, message.encode(
403 | DEFAULT_ENCODING, errors="replace"))
404 |
405 | def send_bytes(self, opcode: int, payload: bytes, chunk_size: int = 0):
406 | if opcode not in OPTYPES.keys() or opcode == WEBSOCKET_OPCODE_CONTINUATION:
407 | raise WebsocketException(reason=WebsocketCloseReason(
408 | f"Cannot send message in a opcode {opcode}. "))
409 |
410 | # Control frames MUST NOT be fragmented.
411 | c_size = chunk_size if opcode in (
412 | WEBSOCKET_OPCODE_BINARY, WEBSOCKET_OPCODE_TEXT) else 0
413 |
414 | if c_size and c_size > 0:
415 | with self._send_msg_lock:
416 | # Make sure a fragmented message is sent completely.
417 | self._send_bytes_no_lock(opcode, payload, chunk_size=c_size)
418 | else:
419 | self._send_bytes_no_lock(opcode, payload)
420 |
421 | def _send_bytes_no_lock(self, opcode: int, payload: bytes, chunk_size: int = 0):
422 | frame_size = chunk_size if chunk_size and chunk_size > 0 else None
423 | all_payloads = payload
424 | frame_bytes = b''
425 | while all_payloads:
426 | op = WEBSOCKET_OPCODE_CONTINUATION if frame_bytes else opcode
427 | frame_bytes = all_payloads[0: frame_size]
428 | all_payloads = all_payloads[frame_size:] if frame_size else b''
429 | fin = 0 if all_payloads else FIN
430 | self._send_frame(fin, op, frame_bytes)
431 |
432 | def _send_frame(self, fin: int, opcode: int, payload: bytes):
433 | with self._send_frame_lock:
434 | self.request_writer.send(
435 | self._create_frame_header(fin, opcode, len(payload)))
436 | self.request_writer.send(payload)
437 |
438 | def _create_frame_header(self, fin: int, opcode: int, payload_length: int) -> bytes:
439 | header = bytearray()
440 | # Normal payload
441 | if payload_length <= 125:
442 | header.append(fin | opcode)
443 | header.append(payload_length)
444 |
445 | # Extended payload
446 | elif payload_length >= 126 and payload_length <= 65535:
447 | header.append(fin | opcode)
448 | header.append(PAYLOAD_LEN_EXT16)
449 | header.extend(struct.pack(">H", payload_length))
450 |
451 | # Huge extended payload
452 | elif payload_length < 18446744073709551616:
453 | header.append(fin | opcode)
454 | header.append(PAYLOAD_LEN_EXT64)
455 | header.extend(struct.pack(">Q", payload_length))
456 | else:
457 | raise Exception(
458 | "Message is too big. Consider breaking it into chunks.")
459 |
460 | return header
461 |
462 | def send_file(self, path: str, chunk_size: int = 0):
463 | try:
464 | file_size = os.path.getsize(path)
465 | if not chunk_size or chunk_size < 0 or chunk_size > file_size:
466 | self._send_file_no_lock(path, file_size, file_size)
467 | else:
468 | with self._send_msg_lock:
469 | self._send_file_no_lock(path, file_size, chunk_size)
470 | except (OSError, ValueError):
471 | raise WebsocketException(reason=WebsocketCloseReason(
472 | f"File in {path} does not exist or is not accessible."))
473 |
474 | def _send_file_no_lock(self, path: str, file_size: int, chunk_size: int):
475 | with open(path, 'rb') as in_file:
476 | remain_bytes = file_size
477 | opcode = WEBSOCKET_OPCODE_BINARY
478 | while remain_bytes > 0:
479 | with self._send_frame_lock:
480 | frame_size = min(remain_bytes, chunk_size)
481 | remain_bytes -= frame_size
482 |
483 | fin = 0 if remain_bytes > 0 else FIN
484 |
485 | self.request_writer.send(
486 | self._create_frame_header(fin, opcode, frame_size))
487 | while frame_size > 0:
488 | buff_size = min(_BUFFER_SIZE, frame_size)
489 | frame_size -= buff_size
490 |
491 | data = in_file.read(buff_size)
492 | self.request_writer.send(data)
493 | # After the first frame, the opcode of other frames is continuation forever.
494 | opcode = WEBSOCKET_OPCODE_CONTINUATION
495 |
496 | def close(self, reason: str = ""):
497 | self.send_bytes(WEBSOCKET_OPCODE_CLOSE, reason.encode(
498 | DEFAULT_ENCODING, errors="replace"))
499 | self.keep_alive = False
500 | self.close_reason = WebsocketCloseReason(
501 | "Server asked to close connection.")
502 |
503 |
504 | class WebsocketSessionImpl(WebsocketSession):
505 |
506 | def __init__(self, handler: WebsocketControllerHandler, request: WebsocketRequest) -> None:
507 | self.__id = uuid4().hex
508 | self.__handler = handler
509 | self.__request = request
510 |
511 | @ property
512 | def id(self) -> str:
513 | return self.__id
514 |
515 | @ property
516 | def request(self) -> WebsocketRequest:
517 | return self.__request
518 |
519 | @ property
520 | def is_closed(self) -> bool:
521 | return not self.__handler.keep_alive
522 |
523 | def send_ping(self, message: bytes = b''):
524 | self.__handler.send_ping(message)
525 |
526 | def send_pone(self, message: bytes = b''):
527 | self.__handler.send_pong(message)
528 |
529 | def send(self, message: Union[str, bytes], opcode: int = WEBSOCKET_OPCODE_TEXT, chunk_size: int = 0):
530 | if isinstance(message, bytes):
531 | msg = message
532 | elif isinstance(message, str):
533 | msg = message.encode(DEFAULT_ENCODING, errors="replace")
534 | else:
535 | raise WebsocketException(reason=WebsocketCloseReason(
536 | f"message {message} is not a string nor a bytes object, cannot send it to client. "))
537 | self.__handler.send_bytes(
538 | opcode if opcode is None else WEBSOCKET_OPCODE_TEXT, msg, chunk_size=chunk_size)
539 |
540 | def send_text(self, message: str, chunk_size: int = 0):
541 | self.__handler.send_message(message, chunk_size=chunk_size)
542 |
543 | def send_binary(self, binary: bytes, chunk_size: int = 0):
544 | self.__handler.send_bytes(
545 | WEBSOCKET_OPCODE_BINARY, binary, chunk_size=chunk_size)
546 |
547 | def send_file(self, path: str, chunk_size: int = 0):
548 | self.__handler.send_file(path, chunk_size=chunk_size)
549 |
550 | def close(self, reason: str = ""):
551 | self.__handler.close(reason)
552 |
--------------------------------------------------------------------------------
/naja_atra/server.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | Copyright (c) 2018 Keijack Wu
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 |
25 | import os
26 | import sys
27 | import threading
28 | import importlib
29 | import re
30 |
31 | from ssl import PROTOCOL_TLS_SERVER, SSLContext
32 | from typing import Dict
33 |
34 | from .http_servers.http_server import HttpServer
35 |
36 | from .app_conf import AppConf
37 | from .utils.logger import get_logger
38 |
39 |
40 | _logger = get_logger("naja_atra.server")
41 | __lock = threading.Lock()
42 | _server: HttpServer = None
43 |
44 |
45 | def _is_match(string="", regx=r""):
46 | if not regx:
47 | return True
48 | pattern = re.compile(regx)
49 | match = pattern.match(string)
50 | return True if match else False
51 |
52 |
53 | def _to_module_name(fpath="", regx=r""):
54 | fname, fext = os.path.splitext(fpath)
55 |
56 | if fext != ".py":
57 | return
58 | mname = fname.replace(os.path.sep, '.')
59 | if _is_match(fpath, regx) or _is_match(fname, regx) or _is_match(mname, regx):
60 | return mname
61 |
62 |
63 | def _load_all_modules(work_dir, pkg, regx):
64 | abs_folder = os.path.join(work_dir, pkg)
65 | if os.path.isfile(abs_folder):
66 | return [_to_module_name(pkg, regx)]
67 | if not os.path.exists(abs_folder):
68 | if abs_folder.endswith(".py"):
69 | _logger.warning(
70 | f"Cannot find package {pkg}, file [{abs_folder}] is not exist")
71 | return []
72 | if os.path.isfile(abs_folder + ".py"):
73 | return [_to_module_name(pkg + ".py", regx)]
74 | else:
75 | _logger.warning(
76 | f"Cannot find package {pkg}, file [{abs_folder}] is not exist")
77 | return []
78 |
79 | modules = []
80 | folders = []
81 | all_files = os.listdir(abs_folder)
82 | for f in all_files:
83 | if os.path.isfile(os.path.join(abs_folder, f)):
84 | mname = _to_module_name(os.path.join(pkg, f), regx)
85 | if mname:
86 | modules.append(mname)
87 | elif f != "__pycache__":
88 | folders.append(os.path.join(pkg, f))
89 |
90 | for folder in folders:
91 | modules += _load_all_modules(work_dir, folder, regx)
92 | return modules
93 |
94 |
95 | def _import_module(mname):
96 | try:
97 | importlib.import_module(mname)
98 | except Exception as e:
99 | _logger.warning(f"Import moudle [{mname}] error! {e}")
100 |
101 |
102 | def scan(base_dir: str = "", regx: str = r"", project_dir: str = "") -> None:
103 | """
104 | Scan the given directory to import controllers.
105 |
106 | - base_dir: the directory to scan.
107 | - regx: only include the modules that match this regular expression, if absent, all files will be included.
108 | - project_dir: the project directory, default to the entrance file directory.
109 |
110 | """
111 | if project_dir:
112 | work_dir = project_dir
113 | else:
114 | mpath = os.path.dirname(sys.modules['__main__'].__file__)
115 | _logger.debug(
116 | f"Project directory is not set, use the directory where the main module is in. {mpath}")
117 | work_dir = mpath
118 | modules = _load_all_modules(work_dir, base_dir, regx)
119 |
120 | for mname in modules:
121 | _logger.info(f"Import controllers from module: {mname}")
122 | _import_module(mname)
123 |
124 |
125 | def _prepare_server(host: str = "",
126 | port: int = 9090,
127 | ssl: bool = False,
128 | ssl_protocol: int = PROTOCOL_TLS_SERVER,
129 | ssl_check_hostname: bool = False,
130 | keyfile: str = "",
131 | certfile: str = "",
132 | keypass: str = "",
133 | ssl_context: SSLContext = None,
134 | resources: Dict[str, str] = {},
135 | connection_idle_time=None,
136 | keep_alive=True,
137 | keep_alive_max_request=None,
138 | gzip_content_types=set(),
139 | gzip_compress_level=9,
140 | prefer_coroutine=False,
141 | app_conf: AppConf = None
142 | ) -> None:
143 | with __lock:
144 | global _server
145 | if _server is not None:
146 | _server.shutdown()
147 | _server = HttpServer(host=(host, port),
148 | ssl=ssl,
149 | ssl_protocol=ssl_protocol,
150 | ssl_check_hostname=ssl_check_hostname,
151 | keyfile=keyfile,
152 | certfile=certfile,
153 | keypass=keypass,
154 | ssl_context=ssl_context,
155 | resources=resources,
156 | connection_idle_time=connection_idle_time,
157 | keep_alive=keep_alive,
158 | keep_alive_max_request=keep_alive_max_request,
159 | gzip_content_types=gzip_content_types,
160 | gzip_compress_level=gzip_compress_level,
161 | prefer_corountine=prefer_coroutine,
162 | app_conf=app_conf)
163 |
164 |
165 | def start(host: str = "",
166 | port: int = 9090,
167 | ssl: bool = False,
168 | ssl_protocol: int = PROTOCOL_TLS_SERVER,
169 | ssl_check_hostname: bool = False,
170 | keyfile: str = "",
171 | certfile: str = "",
172 | keypass: str = "",
173 | ssl_context: SSLContext = None,
174 | resources: Dict[str, str] = {},
175 | connection_idle_time=None,
176 | keep_alive=True,
177 | keep_alive_max_request=None,
178 | gzip_content_types=set(),
179 | gzip_compress_level=9,
180 | prefer_coroutine=False,
181 | app_conf: AppConf = None) -> None:
182 | _prepare_server(
183 | host=host,
184 | port=port,
185 | ssl=ssl,
186 | ssl_protocol=ssl_protocol,
187 | ssl_check_hostname=ssl_check_hostname,
188 | keyfile=keyfile,
189 | certfile=certfile,
190 | keypass=keypass,
191 | ssl_context=ssl_context,
192 | resources=resources,
193 | connection_idle_time=connection_idle_time,
194 | keep_alive=keep_alive,
195 | keep_alive_max_request=keep_alive_max_request,
196 | gzip_content_types=gzip_content_types,
197 | gzip_compress_level=gzip_compress_level,
198 | prefer_coroutine=prefer_coroutine,
199 | app_conf=app_conf
200 | )
201 | # start the server
202 | _server.start()
203 |
204 |
205 | async def start_async(host: str = "",
206 | port: int = 9090,
207 | ssl: bool = False,
208 | ssl_protocol: int = PROTOCOL_TLS_SERVER,
209 | ssl_check_hostname: bool = False,
210 | keyfile: str = "",
211 | certfile: str = "",
212 | keypass: str = "",
213 | ssl_context: SSLContext = None,
214 | resources: Dict[str, str] = {},
215 | connection_idle_time=None,
216 | keep_alive=True,
217 | keep_alive_max_request=None,
218 | gzip_content_types=set(),
219 | gzip_compress_level=9,
220 | prefer_coroutine=True,
221 | app_conf: AppConf = None) -> None:
222 | _prepare_server(
223 | host=host,
224 | port=port,
225 | ssl=ssl,
226 | ssl_protocol=ssl_protocol,
227 | ssl_check_hostname=ssl_check_hostname,
228 | keyfile=keyfile,
229 | certfile=certfile,
230 | keypass=keypass,
231 | ssl_context=ssl_context,
232 | resources=resources,
233 | connection_idle_time=connection_idle_time,
234 | keep_alive=keep_alive,
235 | keep_alive_max_request=keep_alive_max_request,
236 | gzip_content_types=gzip_content_types,
237 | gzip_compress_level=gzip_compress_level,
238 | prefer_coroutine=prefer_coroutine,
239 | app_conf=app_conf
240 | )
241 |
242 | # start the server
243 | await _server.start_async()
244 |
245 |
246 | def is_ready() -> bool:
247 | return _server and _server.ready
248 |
249 |
250 | def stop() -> None:
251 | with __lock:
252 | global _server
253 | if _server is not None:
254 | _logger.info("Shutting down server...")
255 | _server.shutdown()
256 | _server = None
257 | else:
258 | _logger.warning("Server is not ready yet.")
259 |
--------------------------------------------------------------------------------
/naja_atra/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | Copyright (c) 2018 Keijack Wu
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 |
--------------------------------------------------------------------------------
/naja_atra/utils/http_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | Copyright (c) 2018 Keijack Wu
6 |
7 | Permission is hereby granted, free of charge, to any person obtaining a copy
8 | of this software and associated documentation files (the "Software"), to deal
9 | in the Software without restriction, including without limitation the rights
10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | copies of the Software, and to permit persons to whom the Software is
12 | furnished to do so, subject to the following conditions:
13 |
14 | The above copyright notice and this permission notice shall be included in all
15 | copies or substantial portions of the Software.
16 |
17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | SOFTWARE.
24 | """
25 |
26 | import inspect
27 | import time
28 | import os
29 | import email
30 | import json
31 | import re
32 | from collections import OrderedDict
33 | from typing import Any, Tuple, Union
34 | from urllib.parse import unquote, quote
35 | from ..models import HttpError, StaticFile, DEFAULT_ENCODING
36 |
37 | from .logger import get_logger
38 |
39 |
40 | _logger = get_logger("naja_atra.utils.http_utils")
41 |
42 |
43 | def remove_url_first_slash(url: str):
44 | _url = url
45 | while _url.startswith("/"):
46 | _url = url[1:]
47 | return _url
48 |
49 |
50 | def get_function_args(func, default_type=str):
51 | argspec = inspect.getfullargspec(func)
52 | # ignore first argument like `self` or `clz` in object methods or class methods
53 | start = 1 if inspect.ismethod(func) else 0
54 | if argspec.defaults is None:
55 | args = argspec.args[start:]
56 | else:
57 | args = argspec.args[start: len(argspec.args) - len(argspec.defaults)]
58 | arg_turples = []
59 | for arg in args:
60 | if arg in argspec.annotations:
61 | ty = argspec.annotations[arg]
62 | else:
63 | ty = default_type
64 | arg_turples.append((arg, ty))
65 | return arg_turples
66 |
67 |
68 | def get_function_kwargs(func, default_type=str):
69 | argspec = inspect.getfullargspec(func)
70 | if argspec.defaults is None:
71 | return []
72 |
73 | kwargs = OrderedDict(zip(argspec.args[-len(argspec.defaults):], argspec.defaults))
74 | kwarg_turples = []
75 | for k, v in kwargs.items():
76 | if k in argspec.annotations:
77 | k_anno = argspec.annotations[k]
78 | else:
79 | k_anno = default_type
80 | kwarg_turples.append((k, v, k_anno))
81 | return kwarg_turples
82 |
83 |
84 | def break_into(txt: str, separator: str):
85 | try:
86 | idx = txt.index(separator)
87 | return txt[0: idx], txt[idx + len(separator):]
88 | except ValueError:
89 | return txt, None
90 |
91 |
92 | def put_to(params, key, val):
93 | if key not in params.keys():
94 | params[key] = [val]
95 | else:
96 | params[key].append(val)
97 |
98 |
99 | def decode_query_string(query_string: str):
100 | params = {}
101 | if not query_string:
102 | return params
103 | pairs = query_string.split("&")
104 | for item in pairs:
105 | key, val = break_into(item, "=")
106 | if val is None:
107 | val = ""
108 | put_to(params, unquote(key), unquote(val))
109 |
110 | return params
111 |
112 |
113 | def date_time_string(timestamp=None):
114 | if timestamp is None:
115 | timestamp = time.time()
116 | return email.utils.formatdate(timestamp, usegmt=True)
117 |
118 |
119 | def decode_response_body(raw_body: Any) -> Tuple[str, Union[str, bytes, StaticFile]]:
120 | content_type = "text/plain; chartset=utf8"
121 | if raw_body is None:
122 | body = ""
123 | elif isinstance(raw_body, dict):
124 | content_type = "application/json; charset=utf8"
125 | body = json.dumps(raw_body, ensure_ascii=False)
126 | elif isinstance(raw_body, str):
127 | body = raw_body.strip()
128 | if body.startswith(""):
129 | content_type = "text/xml; charset=utf8"
130 | elif body.lower().startswith(""):
131 | content_type = "text/html; charset=utf8"
132 | elif body.lower().startswith(""):
133 | content_type = "text/html; charset=utf8"
134 | else:
135 | content_type = "text/plain; charset=utf8"
136 | elif isinstance(raw_body, StaticFile):
137 | if not os.path.isfile(raw_body.file_path):
138 | _logger.error(f"Cannot find file[{raw_body.file_path}] specified in StaticFile body.")
139 | raise HttpError(404, explain="Cannot find file for this url.")
140 | else:
141 | body = raw_body
142 | content_type = body.content_type
143 | elif isinstance(raw_body, bytes):
144 | body = raw_body
145 | content_type = "application/octet-stream"
146 | else:
147 | body = raw_body
148 | return content_type, body
149 |
150 |
151 | def decode_response_body_to_bytes(raw_body: Any) -> Tuple[str, bytes]:
152 | content_type, body = decode_response_body(raw_body)
153 | if body is None:
154 | byte_body = b''
155 | elif isinstance(body, str):
156 | byte_body = body.encode(DEFAULT_ENCODING, 'replace')
157 | elif isinstance(body, bytes):
158 | byte_body = body
159 | elif isinstance(body, StaticFile):
160 | with open(body.file_path, "rb") as in_file:
161 | byte_body = in_file.read()
162 | else:
163 | raise HttpError(400, explain="Cannot read body into bytes!")
164 | return content_type, byte_body
165 |
166 | def get_path_reg_pattern(url):
167 | _url: str = url
168 | path_names = re.findall("(?u)\\{\\w+\\}", _url)
169 | if len(path_names) == 0:
170 | if _url.startswith("**"):
171 | _url = _url[2: ]
172 | assert _url.find("*") < 0, "You can only config a * or ** at the start or end of a path."
173 | _url = f'^([\\w%.\\-@!\\(\\)\\[\\]\\|\\$/]+){_url}$'
174 | return _url, [quote("__path_wildcard")]
175 | elif _url.startswith("*"):
176 | _url = _url[1: ]
177 | assert _url.find("*") < 0, "You can only config a * or ** at the start or end of a path."
178 | _url = f'^([\\w%.\\-@!\\(\\)\\[\\]\\|\\$]+){_url}$'
179 | return _url, [quote("__path_wildcard")]
180 | elif _url.endswith("**"):
181 | _url = _url[0: -2]
182 | assert _url.find("*") < 0, "You can only config a * or ** at the start or end of a path."
183 | _url = f'^{_url}([\\w%.\\-@!\\(\\)\\[\\]\\|\\$/]+)$'
184 | return _url, [quote("__path_wildcard")]
185 | elif _url.endswith("*"):
186 | _url = _url[0: -1]
187 | assert _url.find("*") < 0, "You can only config a * or ** at the start or end of a path."
188 | _url = f'^{_url}([\\w%.\\-@!\\(\\)\\[\\]\\|\\$]+)$'
189 | return _url, [quote("__path_wildcard")]
190 | else:
191 | # normal url
192 | return None, path_names
193 | for name in path_names:
194 | _url = _url.replace(name, "([\\w%.\\-@!\\(\\)\\[\\]\\|\\$]+)")
195 | _url = f"^{_url}$"
196 |
197 | quoted_names = []
198 | for name in path_names:
199 | name = name[1: -1]
200 | quoted_names.append(quote(name))
201 | return _url, quoted_names
--------------------------------------------------------------------------------
/naja_atra/utils/logger.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | Copyright (c) 2018 Keijack Wu
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 |
25 | from abc import abstractmethod
26 | import sys
27 | import time
28 | import logging
29 | import asyncio
30 | from threading import Thread
31 | from typing import Dict, List, Tuple
32 |
33 |
34 | class LazyCalledLogger(logging.Logger):
35 |
36 | def do_call_handlers(self, record):
37 | super().callHandlers(record)
38 |
39 | @abstractmethod
40 | def callHandlers(self, record):
41 | pass
42 |
43 |
44 | class LazyCalledLoggerThread:
45 |
46 | daemon_threads = True
47 |
48 | def __init__(self) -> None:
49 | self.coroutine_loop = None
50 | self.coroutine_thread = None
51 |
52 | def coroutine_main(self):
53 | self.coroutine_loop = loop = asyncio.new_event_loop()
54 | try:
55 | loop.run_forever()
56 | finally:
57 | loop.run_until_complete(loop.shutdown_asyncgens())
58 | loop.close()
59 |
60 | def start(self):
61 | if self.coroutine_thread is not None:
62 | while not self.coroutine_loop:
63 | # wait for the loop ready
64 | time.sleep(0.1)
65 | return
66 | self.coroutine_thread = Thread(target=self.coroutine_main, name="logger-thread", daemon=self.daemon_threads)
67 | self.coroutine_thread.start()
68 |
69 | while not self.coroutine_loop:
70 | # wait for the loop ready
71 | time.sleep(0.1)
72 |
73 | def stop(self):
74 | if self.coroutine_loop:
75 | self.coroutine_loop.call_soon_threadsafe(self.coroutine_loop.stop)
76 | self.coroutine_thread.join()
77 | self.coroutine_loop = None
78 | self.coroutine_thread = None
79 |
80 | async def _call(self, logger: LazyCalledLogger, record):
81 | logger.do_call_handlers(record)
82 |
83 | def call_logger_handler(self, logger: LazyCalledLogger, record):
84 | self.start()
85 | asyncio.run_coroutine_threadsafe(self._call(logger, record), self.coroutine_loop)
86 |
87 |
88 | class CachingLogger(LazyCalledLogger):
89 |
90 | logger_thread: LazyCalledLoggerThread = LazyCalledLoggerThread()
91 |
92 | def callHandlers(self, record):
93 | CachingLogger.logger_thread.call_logger_handler(self, record)
94 |
95 |
96 | class LoggerFactory:
97 |
98 | DEFAULT_LOG_FORMAT: str = '[%(asctime)s]-[%(threadName)s]-[%(name)s:%(lineno)d] %(levelname)-4s: %(message)s'
99 |
100 | DEFAULT_DATE_FORMAT: str = '%Y-%m-%d %H:%M:%S'
101 |
102 | _LOG_LVELS: Tuple[str] = ("DEBUG", "INFO", "WARN", "ERROR")
103 |
104 | def __init__(self, log_level: str = "INFO", log_format: str = DEFAULT_LOG_FORMAT, date_format: str = DEFAULT_DATE_FORMAT) -> None:
105 | self.__cache_loggers: Dict[str, CachingLogger] = {}
106 | self._log_level = log_level.upper() if log_level and log_level.upper() in self._LOG_LVELS else "INFO"
107 | self.log_format = log_format
108 | self.date_format = date_format
109 |
110 | self._handlers: List[logging.Handler] = []
111 |
112 | @property
113 | def handlers(self) -> List[logging.Handler]:
114 | if self._handlers:
115 | return self._handlers
116 | _handler = logging.StreamHandler(sys.stdout)
117 | _formatter_ = logging.Formatter(fmt=self.log_format, datefmt=self.date_format)
118 | _handler.setFormatter(_formatter_)
119 | _handler.setLevel(self._log_level)
120 | self._handlers.append(_handler)
121 | return self._handlers
122 |
123 | @property
124 | def log_level(self):
125 | return self._log_level
126 |
127 | @log_level.setter
128 | def log_level(self, log_level: str):
129 | self._log_level = log_level.upper() if log_level and log_level.upper() in self._LOG_LVELS else "INFO"
130 | _logger_ = self.get_logger("Logger")
131 | _logger_.info(f"global logger set to {self._log_level}")
132 | for h in self._handlers:
133 | h.setLevel(self._log_level)
134 | for l in self.__cache_loggers.values():
135 | l.setLevel(self._log_level)
136 |
137 | def add_handler(self, handler: logging.Handler) -> None:
138 | if self.__cache_loggers:
139 | self._handlers.append(handler)
140 | else:
141 | self.handlers.append(handler)
142 | for l in self.__cache_loggers.values():
143 | l.addHandler(handler)
144 |
145 | def remove_handler(self, handler: logging.Handler) -> None:
146 | if handler in self._handlers:
147 | self._handlers.remove(handler)
148 | for l in self.__cache_loggers.values():
149 | l.removeHandler(handler)
150 |
151 | def set_handler(self, handler: logging.Handler) -> None:
152 | self._handlers.clear()
153 | self._handlers.append(handler)
154 | for l in self.__cache_loggers.values():
155 | for hdlr in l.handlers:
156 | l.removeHandler(hdlr)
157 | l.addHandler(handler)
158 |
159 | def get_logger(self, tag: str = "naja_atra") -> logging.Logger:
160 | if tag not in self.__cache_loggers:
161 | self.__cache_loggers[tag] = CachingLogger(tag, self._log_level)
162 | for hdlr in self.handlers:
163 | self.__cache_loggers[tag].addHandler(hdlr)
164 | return self.__cache_loggers[tag]
165 |
166 |
167 | _default_logger_factory: LoggerFactory = LoggerFactory()
168 |
169 | _logger_factories: Dict[str, LoggerFactory] = {}
170 |
171 |
172 | def get_logger_factory(tag: str = "") -> LoggerFactory:
173 | if not tag:
174 | return _default_logger_factory
175 | if tag not in _logger_factories:
176 | _logger_factories[tag] = LoggerFactory()
177 | return _logger_factories[tag]
178 |
179 |
180 | def set_level(level) -> None:
181 | _default_logger_factory.log_level = level
182 |
183 |
184 | def add_handler(handler: logging.Handler) -> None:
185 | _default_logger_factory.add_handler(handler)
186 |
187 |
188 | def remove_handler(handler: logging.Handler) -> None:
189 | _default_logger_factory.remove_handler(handler)
190 |
191 |
192 | def set_handler(handler: logging.Handler) -> None:
193 | _default_logger_factory.set_handler(handler)
194 |
195 |
196 | def get_logger(tag: str = "naja_atra", factory: str = "") -> logging.Logger:
197 | return get_logger_factory(factory).get_logger(tag)
198 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = [
3 | "setuptools>=65",
4 | "wheel"
5 | ]
6 | build-backend = "setuptools.build_meta"
7 |
8 | [project]
9 | name = "naja-atra"
10 | description = "This is a simple http server, use MVC like design."
11 | readme = "README.md"
12 | authors = [
13 | { name = "keijack", email = "keijack.wu@gmail.com" }
14 | ]
15 | requires-python = ">=3.7"
16 | keywords = ["http-server", "websocket", "http", "web", "web-server"]
17 | license = {file = "LICENSE"}
18 | classifiers = [
19 | "Programming Language :: Python :: 3",
20 | "License :: OSI Approved :: MIT License",
21 | "Operating System :: OS Independent",
22 | ]
23 | dynamic = ["version"]
24 |
25 | [project.optional-dependencies]
26 | test = ["websocket-client", "pytest"]
27 | dev = ["websocket-client"]
28 |
29 | [tool.setuptools.packages.find]
30 | include=["naja_atra*"]
31 |
32 | [tool.setuptools.dynamic]
33 | version = {attr = "naja_atra.version"}
34 |
35 | [project.urls]
36 | homepage = "https://github.com/naja-atra/naja-atra"
37 | repository = "https://github.com/naja-atra/naja-atra"
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from setuptools import setup
3 |
4 | setup()
5 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keijack/python-simple-http-server/3bb7859bc849d790b5e8b48c6cfb7680ba96a77c/tests/__init__.py
--------------------------------------------------------------------------------
/tests/ctrls/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keijack/python-simple-http-server/3bb7859bc849d790b5e8b48c6cfb7680ba96a77c/tests/ctrls/__init__.py
--------------------------------------------------------------------------------
/tests/ctrls/my_controllers.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | import time
5 | from typing import List, OrderedDict
6 |
7 | from naja_atra import BytesBody, FilterContext, ModelDict, Redirect, RegGroup, RequestBodyReader, request_filter
8 | from naja_atra import Headers
9 | from naja_atra import HttpError
10 | from naja_atra import JSONBody
11 | from naja_atra import Header
12 | from naja_atra import Parameters
13 | from naja_atra import Cookie
14 | from naja_atra import Cookies
15 | from naja_atra import PathValue
16 | from naja_atra import Parameter
17 | from naja_atra import MultipartFile
18 | from naja_atra import Response
19 | from naja_atra import Request
20 | from naja_atra import HttpSession
21 | from naja_atra import request_map, route
22 | from naja_atra import controller
23 | from naja_atra import error_message
24 | from naja_atra.app_conf import get_app_conf
25 | import os
26 | import naja_atra.utils.logger as logger
27 |
28 |
29 | _logger = logger.get_logger("my_test_main")
30 |
31 |
32 | _logger = logger.get_logger("controller")
33 |
34 | _app = get_app_conf("2")
35 |
36 |
37 | @request_map("/")
38 | @request_map("/index")
39 | @_app.route("/")
40 | def my_ctrl():
41 | # You can return a dictionary, a string or a `simple_http_server.simple_http_server.Response` object.
42 | return {"code": 0, "message": "success"}
43 |
44 |
45 | @request_map("/say_hello", method=["GET", "POST"])
46 | def my_ctrl2(name, name2=Parameter("name", default="KEIJACK")):
47 | """name and name2 is the same"""
48 | return f"hello, {name}, {name2}"
49 |
50 |
51 | @request_map("/error")
52 | def my_ctrl3():
53 | raise HttpError(400, "Parameter Error!", "Test Parameter Error!")
54 |
55 |
56 | @request_map("/sleep")
57 | def sleep_secs(secs: int = 10):
58 | _logger.info(f"Sleep {secs} secondes...")
59 | time.sleep(secs)
60 | return {
61 | "message": "OK"
62 | }
63 |
64 |
65 | async def say(sth: str = ""):
66 | _logger.info(f"Say: {sth}")
67 | return f"Success! {sth}"
68 |
69 |
70 | @request_map("/中文/coroutine")
71 | async def coroutine_ctrl(hey: str = "Hey!"):
72 | # raise RuntimeError
73 | return await say(hey)
74 |
75 |
76 | @request_map("/exception")
77 | def exception_ctrl():
78 | raise Exception("some error occurs!")
79 |
80 |
81 | @request_map("/upload", method="POST")
82 | def my_upload(img=MultipartFile("img"), txt=Parameter("中文text", required=False, default="DEFAULT"), req=Request()):
83 | for k, v in req.parameter.items():
84 | print("%s (%s)====> %s " % (k, str(type(k)), v))
85 | print(txt)
86 |
87 | root = os.path.dirname(os.path.abspath(__file__))
88 | img.save_to_file(root + "/imgs/" + img.filename)
89 | return f"upload ok! {txt} "
90 |
91 |
92 | @request_map("/post_txt", method=["GET", "POST"])
93 | def normal_form_post(txt=Parameter("中文txt", required=False, default="DEFAULT"), req=Request(), bd=BytesBody()):
94 | for k, v in req.parameter.items():
95 | print("%s ====> %s " % (k, v))
96 | print(req.body)
97 | print(bd)
98 | return f"hi, {txt}"
99 |
100 |
101 | @request_map("/post_json", method="POST")
102 | def post_json(host=Header("Host"), json=JSONBody()):
103 | print(f"Host: {host}")
104 | print(json)
105 | return OrderedDict(json)
106 |
107 |
108 | @request_map("/cookies")
109 | def set_headers(res: Response, headers: Headers, cookies: Cookies, cookie=Cookie("sc")):
110 | print("==================cookies==========")
111 | print(cookies)
112 | print("==================cookies==========")
113 | print(cookie)
114 | res.add_header(
115 | "Set-Cookie", "sc=keijack; Expires=Web, 31 Oct 2018 00:00:00 GMT;")
116 | res.add_header("Set-Cookie", "sc=keijack2;")
117 | res.body = "OK!"
118 |
119 |
120 | @request_map("tuple")
121 | async def tuple_results():
122 | return 200, Headers({"MyHeader": "my header"}), "hello tuple result!"
123 |
124 |
125 | @request_map("session")
126 | def test_session(session: HttpSession, invalid=False):
127 | ins = session.get_attribute("in-session")
128 | if not ins:
129 | session.set_attribute("in-session", "Hello, Session!")
130 |
131 | _logger.info("session id: %s" % session.id)
132 | if invalid:
133 | _logger.info("session[%s] is being invalidated. " % session.id)
134 | session.invalidate()
135 | return "%s" % str(ins)
136 |
137 |
138 | @request_map("tuple_cookie")
139 | def tuple_with_cookies(headers=Headers(), all_cookies=Cookies(), cookie_sc=Cookie("sc")):
140 | print("=====>headers")
141 | print(headers)
142 | print("=====> cookies ")
143 | print(all_cookies)
144 | print("=====> cookie sc ")
145 | print(cookie_sc)
146 | print("======<")
147 | import datetime
148 | expires = datetime.datetime(2018, 12, 31)
149 |
150 | cks = Cookies()
151 | # cks = cookies.SimpleCookie() # you could also use the build-in cookie objects
152 | cks["ck1"] = "keijack"
153 | cks["ck1"]["path"] = "/"
154 | cks["ck1"]["expires"] = expires.strftime(Cookies.EXPIRE_DATE_FORMAT)
155 |
156 | return 200, Header({"xx": "yyy"}), cks, "OK"
157 |
158 |
159 | @request_map("header_echo")
160 | def header_echo(headers: Headers):
161 | return 200, headers, ""
162 |
163 |
164 | @request_filter("/abcde/**")
165 | def fil(ctx: FilterContext):
166 | print("---------- through filter ---------------")
167 | ctx.do_chain()
168 |
169 |
170 | @request_filter(regexp="^/abcd")
171 | @request_filter("/tuple")
172 | async def filter_tuple(ctx: FilterContext):
173 | print("---------- through filter async ---------------")
174 | # add a header to request header
175 | ctx.request.headers["filter-set"] = "through filter"
176 | if "user_name" not in ctx.request.parameter:
177 | ctx.response.send_redirect("/index")
178 | elif "pass" not in ctx.request.parameter:
179 | ctx.response.send_error(400, "pass should be passed")
180 | # you can also raise a HttpError
181 | # raise HttpError(400, "pass should be passed")
182 | else:
183 | # you should always use do_chain method to go to the next
184 | res: Response = ctx.response
185 | res.add_header("Access-Control-Allow-Origin", "*")
186 | res.add_header("Access-Control-Allow-Methods", "*")
187 | res.add_header("Access-Control-Allow-Headers", "*")
188 | res.add_header("Res-Filter-Header", "from-filter")
189 | ctx.do_chain()
190 |
191 |
192 | @request_map("/redirect")
193 | def redirect():
194 | return Redirect("/index")
195 |
196 |
197 | @request_map("/params")
198 | def my_ctrl4(user_name,
199 | password=Parameter(name="passwd", required=True),
200 | remember_me=True,
201 | locations=[],
202 | json_param={},
203 | lcs=Parameters(name="locals", required=True),
204 | content_type=Header("Content-Type", default="application/json"),
205 | connection=Header("Connection"),
206 | ua=Header("User-Agent"),
207 | headers=Headers()
208 | ):
209 | return f"""
210 |
211 |
212 | show all params!
213 |
214 |
215 | user_name: {user_name}
216 | password: {password}
217 | remember_me: {remember_me}
218 | locations: {locations}
219 | json_param: {json_param}
220 | locals: {lcs}
221 | conent_type: {content_type}
222 | user_agent: {ua}
223 | connection: {connection}
224 | all Headers: {headers}
225 |
226 |
227 | """
228 |
229 |
230 | @request_map("/int_status_code")
231 | def return_int(status_code=200):
232 | return status_code
233 |
234 |
235 | @request_map("/path_values/{pval}/{path_val}/x")
236 | def my_path_val_ctr(pval: PathValue, path_val=PathValue()):
237 | return f"{pval}, {path_val}"
238 |
239 |
240 | @controller(args=["my-ctr"], kwargs={"desc": "desc"})
241 | @route("/obj")
242 | class MyController:
243 |
244 | def __init__(self, name, desc="") -> None:
245 | self._name = f"ctr object[#{name}]:{desc}"
246 |
247 | @route("/hello", method="GET, POST")
248 | @request_map
249 | def my_ctrl_mth(self, model: ModelDict):
250 | return {"message": f"hello, {model['name']}, {self._name} says. "}
251 |
252 | @request_map("/hello2", method=("GET", "POST"))
253 | def my_ctr_mth2(self, name: str, i: List[int]):
254 | return f"{self._name}{self._name}: {name}, {i}"
255 |
256 | @route(regexp="^(reg/(.+))$", method="GET")
257 | def my_reg_ctr(self, reg_group: RegGroup = RegGroup(1)):
258 | return f"{self._name}, {reg_group.group},{reg_group}"
259 |
260 |
261 | @error_message("400")
262 | def my_40x_page(message: str, explain=""):
263 | return f"code:400, message: {message}, explain: {explain}"
264 |
265 |
266 | @error_message
267 | def my_other_error_page(code, message, explain=""):
268 | return f"{code}-{message}-{explain}"
269 |
270 |
271 | @request_map("abcde/**")
272 | def wildcard_match(path_val=PathValue()):
273 | return f"path values{path_val}"
274 |
275 |
276 | """
277 | curl -X PUT --data-binary "@/data1/clamav/scan/trojans/000.exe" \
278 | -H "Content-Type: application/octet-stream" \
279 | http://10.0.2.16:9090/put/file
280 | """
281 |
282 |
283 | @request_map("/put/file", method="PUT")
284 | async def reader_test(
285 | content_type: Header = Header("Content-Type"),
286 | reader: RequestBodyReader = None):
287 | buf = 1024 * 1024
288 | folder = os.path.dirname(os.path.abspath(__file__)) + "/tmp"
289 | if not os.path.isdir(folder):
290 | os.mkdir(folder)
291 | _logger.info(f"content-type:: {content_type}")
292 | with open(f"{folder}/target_file", "wb") as outfile:
293 | while True:
294 | _logger.info("read file")
295 | data = await reader.read(buf)
296 | _logger.info(f"read data {len(data)} and write")
297 | if data == b'':
298 | break
299 | outfile.write(data)
300 | return None
301 |
302 |
303 | @route("/res/write/bytes")
304 | def res_writer(response: Response):
305 | response.status_code = 200
306 | response.add_header("Content-Type", "application/octet-stream")
307 | response.write_bytes(b'abcd')
308 | response.write_bytes(bytearray(b'efg'))
309 | response.close()
310 |
311 |
312 | @route("/param/narrowing", params="a=b")
313 | def params_narrowing():
314 | return "a=b"
315 |
316 |
317 | @route("/param/narrowing", params="a!=b")
318 | def params_narrowing2():
319 | return "a!=b"
320 |
321 |
322 | @controller
323 | @request_map(url="/page", params=("a=b", ))
324 | class IndexPage:
325 |
326 | @request_map("/index", method='GET', params="x=y", match_all_params_expressions=False)
327 | def index_page(self):
328 | return "你好你好,世界!"
329 |
330 |
331 | @route(url="/header_narrowing", method="POST", headers="Content-Type^=text/")
332 | def header_narrowing():
333 | return "a^=b"
334 |
--------------------------------------------------------------------------------
/tests/ctrls/my_controllers_model_binding.py:
--------------------------------------------------------------------------------
1 |
2 | from typing import Any
3 | from naja_atra.request_handlers.model_bindings import ModelBinding
4 | from naja_atra import model_binding, default_model_binding
5 | from naja_atra import HttpError, route
6 | from naja_atra.utils.logger import get_logger
7 |
8 | _logger = get_logger("my_ctrls_model_bindings")
9 |
10 |
11 | class Person:
12 |
13 | def __init__(self, name: str = "", sex: int = "", age: int = 0) -> None:
14 | self.name = name
15 | self.sex = sex
16 | self.age = age
17 |
18 |
19 | class Dog:
20 |
21 | def __init__(self, name="a dog") -> None:
22 | self.name = name
23 |
24 | def wang(self):
25 | name = self.name
26 | if hasattr(self, "name"):
27 | name = self.name
28 | _logger.info(f"wang, I'm {name}, with attrs {self.__dict__}")
29 | return name
30 |
31 |
32 | @model_binding(Person)
33 | class PersonModelBinding(ModelBinding):
34 |
35 | async def bind(self) -> Any:
36 | name = self.request.get_parameter("name", "no-one")
37 | sex = self.request.get_parameter("sex", "secret")
38 | try:
39 | age = int(self.request.get_parameter("age", ""))
40 | except:
41 | raise HttpError(400, "Age is required, and must be an integer")
42 | return Person(name, sex, age)
43 |
44 |
45 | @default_model_binding
46 | class SetAttrModelBinding(ModelBinding):
47 |
48 | def bind(self) -> Any:
49 | try:
50 | obj = self.arg_type()
51 | for k, v in self.request.parameter.items():
52 | setattr(obj, k, v)
53 | return obj
54 | except Exception as e:
55 | _logger.warning(
56 | f"Cannot create Object with given type {self.arg_type}. ", stack_info=True)
57 | return self.default_value
58 |
59 |
60 | @route("/model_binding/person")
61 | def test_model_binding(person: Person):
62 | return {
63 | "name": person.name,
64 | "sex": person.sex,
65 | "age": person.age,
66 | }
67 |
68 |
69 | @route("/model_binding/dog")
70 | def test_model_binding_dog(dog: Dog):
71 | return {
72 | "name": dog.wang()
73 | }
74 |
--------------------------------------------------------------------------------
/tests/ctrls/ws_controllers.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | import os
5 | import shutil
6 | from uuid import uuid4
7 | from naja_atra import WebsocketCloseReason, WebsocketHandler, WebsocketRequest, WebsocketSession, websocket_handler, websocket_message, websocket_handshake, websocket_open, websocket_close, WEBSOCKET_MESSAGE_TEXT
8 | import naja_atra.utils.logger as logger
9 |
10 | _logger = logger.get_logger("ws_test")
11 |
12 |
13 | @websocket_handler(endpoint="/ws/{path_val}")
14 | class WSHandler(WebsocketHandler):
15 |
16 | def __init__(self) -> None:
17 | self.uuid: str = uuid4().hex
18 |
19 | def on_handshake(self, request: WebsocketRequest):
20 | return 0, {}
21 |
22 | async def on_open(self, session: WebsocketSession):
23 |
24 | _logger.info(f">>{session.id}<< open! {session.request.path_values}")
25 |
26 | def on_text_message(self, session: WebsocketSession, message: str):
27 | _logger.info(
28 | f">>{session.id}::{self.uuid}<< on text message: {message}")
29 | session.send(f"{session.request.path_values['path_val']}-{message}")
30 | if message == "close":
31 | session.close()
32 |
33 | def on_close(self, session: WebsocketSession, reason: WebsocketCloseReason):
34 | _logger.info(
35 | f">>{session.id}<< close::{reason.message}-{reason.code}-{reason.reason}")
36 |
37 | def on_binary_frame(self, session: WebsocketSession = None, fin: bool = False, frame_data: bytes = b''):
38 | _logger.info(f"Fin => {fin}, Data: {frame_data}")
39 | return True
40 |
41 | def on_binary_message(self, session: WebsocketSession = None, message: bytes = b''):
42 | _logger.info(f'Binary Message:: {message}')
43 | tmp_folder = os.path.dirname(os.path.abspath(__file__)) + "/tmp"
44 | is_dir_tmp_folder = os.path.isdir(tmp_folder)
45 | if not is_dir_tmp_folder:
46 | os.mkdir(tmp_folder)
47 | session.send(
48 | 'binary-message-received, and this is some message for the long size.', chunk_size=10)
49 | tmp_file_path = f"{tmp_folder}/ws_bi_re.tmp"
50 | with open(tmp_file_path, 'wb') as out_file:
51 | out_file.write(message)
52 |
53 | session.send_file(tmp_file_path)
54 | session.send_file(tmp_file_path, chunk_size=10)
55 | if is_dir_tmp_folder:
56 | os.remove(tmp_file_path)
57 | else:
58 | shutil.rmtree(tmp_folder)
59 |
60 |
61 | @websocket_handler(regexp="^/ws-reg/([a-zA-Z0-9]+)$", singleton=False)
62 | class WSRegHander(WebsocketHandler):
63 |
64 | def __init__(self) -> None:
65 | self.uuid: str = uuid4().hex
66 |
67 | def on_text_message(self, session: WebsocketSession, message: str):
68 | _logger.info(
69 | f">>{session.id}::{self.uuid}<< on text message: {message}")
70 | _logger.info(f"{session.request.reg_groups}")
71 | session.send(f"{session.request.reg_groups[0]}-{message}")
72 |
73 |
74 | @websocket_handshake(endpoint="/ws-fun/{path_val}")
75 | def ws_handshake(request: WebsocketRequest):
76 | return 0, {}
77 |
78 |
79 | @websocket_open(endpoint="/ws-fun/{path_val}")
80 | def ws_open(session: WebsocketSession):
81 | _logger.info(f">>{session.id}<< open! {session.request.path_values}")
82 |
83 |
84 | @websocket_close(endpoint="/ws-fun/{path_val}")
85 | def ws_close(session: WebsocketSession, reason: WebsocketCloseReason):
86 | _logger.info(
87 | f">>{session.id}<< close::{reason.message}-{reason.code}-{reason.reason}")
88 |
89 |
90 | @websocket_message(endpoint="/ws-fun/{path_val}", message_type=WEBSOCKET_MESSAGE_TEXT)
91 | async def ws_text(session: WebsocketSession, message: str):
92 | _logger.info(f">>{session.id}<< on text message: {message}")
93 | session.send(f"{session.request.path_values['path_val']}-{message}")
94 | if message == "close":
95 | session.close()
96 |
--------------------------------------------------------------------------------
/tests/static/a.txt:
--------------------------------------------------------------------------------
1 | hello world!
--------------------------------------------------------------------------------
/tests/static/inner/b.txt:
--------------------------------------------------------------------------------
1 | hello, inner!
--------------------------------------------------------------------------------
/tests/static/inner/y.ini:
--------------------------------------------------------------------------------
1 | [my]
2 | hello = inner
--------------------------------------------------------------------------------
/tests/static/x.ini:
--------------------------------------------------------------------------------
1 | [conf]
2 | hello = world
--------------------------------------------------------------------------------
/tests/static/中文.txt:
--------------------------------------------------------------------------------
1 | hello world!
--------------------------------------------------------------------------------
/tests/test_all_ctrls.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | from gzip import GzipFile
4 | import os
5 | import json
6 | import websocket
7 | import unittest
8 | import urllib.request
9 | import urllib.error
10 | import http.client
11 | from typing import Dict
12 | from threading import Thread
13 | from time import sleep
14 |
15 | from naja_atra.utils.logger import get_logger, set_level
16 | import naja_atra.server as server
17 |
18 | set_level("DEBUG")
19 |
20 | _logger = get_logger("http_test")
21 |
22 |
23 | class ThreadingServerTest(unittest.TestCase):
24 |
25 | PORT = 9090
26 |
27 | WAIT_COUNT = 10
28 |
29 | COROUTINE = False
30 |
31 | @classmethod
32 | def start_server(cls):
33 | cls.tearDownClass()
34 | _logger.info("start server in background. ")
35 | root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
36 | server.scan(project_dir=root, base_dir="tests/ctrls",
37 | regx=r'.*controllers.*')
38 | server.start(
39 | port=cls.PORT,
40 | resources={"/public/*": f"{root}/tests/static"},
41 | gzip_content_types={"text/plain"},
42 | prefer_coroutine=cls.COROUTINE)
43 |
44 | @classmethod
45 | def setUpClass(cls):
46 | Thread(target=cls.start_server, daemon=True, name="t").start()
47 | retry = 0
48 | while not server.is_ready():
49 | sleep(1)
50 | retry = retry + 1
51 | _logger.info(
52 | f"server is not ready wait. {retry}/{cls.WAIT_COUNT} ")
53 | if retry >= cls.WAIT_COUNT:
54 | raise Exception("Server start wait timeout.")
55 |
56 | @classmethod
57 | def tearDownClass(cls):
58 | try:
59 | server.stop()
60 | except:
61 | pass
62 |
63 | @classmethod
64 | def visit(cls, ctx_path, headers: Dict[str, str] = {}, data=None, return_type: str = "TEXT"):
65 | req: urllib.request.Request = urllib.request.Request(
66 | f"http://127.0.0.1:{cls.PORT}/{ctx_path}")
67 | for k, v in headers.items():
68 | req.add_header(k, v)
69 | res: http.client.HTTPResponse = urllib.request.urlopen(req, data=data)
70 |
71 | if return_type == "RESPONSE":
72 | return res
73 | elif return_type == "HEADERS":
74 | headers = res.headers
75 | res.close()
76 | return headers
77 | elif return_type == "JSON":
78 | txt = res.read().decode("utf-8")
79 | res.close()
80 | return json.loads(txt)
81 | else:
82 | txt = res.read().decode("utf-8")
83 | res.close()
84 | return txt
85 |
86 | def test_header_echo(self):
87 | res: http.client.HTTPResponse = self.visit(
88 | f"header_echo", headers={"X-KJ-ABC": "my-headers"}, return_type="RESPONSE")
89 | assert "X-Kj-Abc" in res.headers
90 | assert res.headers["X-Kj-Abc"] == "my-headers"
91 |
92 | def test_static(self):
93 | txt = self.visit("public/a.txt")
94 | assert txt == "hello world!"
95 |
96 | def test_gzip(self):
97 | res: http.client.HTTPResponse = self.visit(
98 | f"public/a.txt", headers={"Accept-Encoding": "gzip ,deflate"}, return_type="RESPONSE")
99 | content_encoding = res.info().get("Content-Encoding")
100 | assert "gzip" == content_encoding
101 | f = GzipFile(fileobj=res)
102 | txt = f.read().decode()
103 | assert txt == "hello world!"
104 |
105 | def test_path_value(self):
106 | pval = "abc"
107 | path_val = "xyz"
108 | txt = self.visit(f"path_values/{pval}/{path_val}/x")
109 | assert txt == f"{pval}, {path_val}"
110 |
111 | def test_error(self):
112 | try:
113 | self.visit("error")
114 | except urllib.error.HTTPError as err:
115 | assert err.code == 400
116 | error_msg = err.read().decode("utf-8")
117 | _logger.info(error_msg)
118 | assert error_msg == "code:400, message: Parameter Error!, explain: Test Parameter Error!"
119 |
120 | def test_coroutine(self):
121 | txt = self.visit(f"%E4%B8%AD%E6%96%87/coroutine?hey=KJ2")
122 | assert txt == "Success! KJ2"
123 |
124 | def test_post_json(self):
125 | data_dict = {
126 | "code": 0,
127 | "msg": "xxx"
128 | }
129 | res: str = self.visit(f"post_json", headers={
130 | "Content-Type": "application/json"}, data=json.dumps(data_dict).encode(errors="replace"))
131 | res_dict: dict = json.loads(res)
132 | assert data_dict["code"] == res_dict["code"]
133 | assert data_dict["msg"] == res_dict["msg"]
134 |
135 | def test_filter(self):
136 | res: http.client.HTTPResponse = self.visit(
137 | f"tuple?user_name=kj&pass=wu", return_type="RESPONSE")
138 | assert "Res-Filter-Header" in res.headers
139 | assert res.headers["Res-Filter-Header"] == "from-filter"
140 |
141 | def test_exception(self):
142 | try:
143 | self.visit("exception")
144 | except urllib.error.HTTPError as err:
145 | assert err.code == 500
146 | error_msg = err.read().decode("utf-8")
147 | _logger.info(error_msg)
148 | assert error_msg == '500-Internal Server Error-some error occurs!'
149 |
150 | def test_res_write_bytes(self):
151 | body = self.visit("res/write/bytes")
152 | assert body == 'abcdefg'
153 |
154 | def test_ws(self):
155 | ws = websocket.WebSocket()
156 | path_val = "test-ws"
157 | msg = "hello websocket!"
158 | ws.connect(f"ws://127.0.0.1:{self.PORT}/ws/{path_val}")
159 | ws.send(msg)
160 | txt = ws.recv()
161 | ws.close()
162 | assert txt == f"{path_val}-{msg}"
163 |
164 | def test_ws_fun(self):
165 | ws = websocket.WebSocket()
166 | path_val = "test-ws-fun"
167 | msg = "hello websocket!"
168 | ws.connect(f"ws://127.0.0.1:{self.PORT}/ws-fun/{path_val}")
169 | ws.send(msg)
170 | txt = ws.recv()
171 | ws.close()
172 | assert txt == f"{path_val}-{msg}"
173 |
174 | def test_ws_continuation(self):
175 | ws = websocket.WebSocket()
176 | path_val = "test-ws"
177 |
178 | ws.connect(f"ws://127.0.0.1:{self.PORT}/ws/{path_val}")
179 | msg0 = "Hello "
180 | frame0 = websocket.ABNF.create_frame(
181 | msg0, websocket.ABNF.OPCODE_TEXT, 0)
182 | ws.send_frame(frame0)
183 | msg1 = "Websocket "
184 | frame1 = websocket.ABNF.create_frame(
185 | msg1, websocket.ABNF.OPCODE_CONT, 0)
186 | ws.send_frame(frame1)
187 | msg2 = "Frames!"
188 | frame2 = websocket.ABNF.create_frame(
189 | msg2, websocket.ABNF.OPCODE_CONT, 1)
190 | ws.send_frame(frame2)
191 |
192 | txt = ws.recv()
193 | ws.close()
194 | assert txt == f"{path_val}-{msg0 + msg1 + msg2}"
195 |
196 | def test_ws_bytes_continuation(self):
197 | ws = websocket.WebSocket()
198 | path_val = "test-ws"
199 |
200 | ws.connect(f"ws://127.0.0.1:{self.PORT}/ws/{path_val}")
201 | msg0 = "Hello "
202 | frame0 = websocket.ABNF.create_frame(
203 | msg0, websocket.ABNF.OPCODE_BINARY, 0)
204 | ws.send_frame(frame0)
205 | msg1 = "Websocket "
206 | frame1 = websocket.ABNF.create_frame(
207 | msg1, websocket.ABNF.OPCODE_CONT, 0)
208 | ws.send_frame(frame1)
209 | msg2 = "Frames!"
210 | frame2 = websocket.ABNF.create_frame(
211 | msg2, websocket.ABNF.OPCODE_CONT, 1)
212 | ws.send_frame(frame2)
213 |
214 | txt: str = ws.recv()
215 | bs: bytes = ws.recv()
216 | bs2: bytes = ws.recv()
217 | ws.close()
218 | assert txt == "binary-message-received, and this is some message for the long size."
219 | assert bs.decode() == bs2.decode() == msg0 + msg1 + msg2
220 |
221 | def test_ws_regexp(self):
222 | ws = websocket.WebSocket()
223 | path_val = "wstest"
224 | msg = 'hello, reg'
225 |
226 | ws.connect(f"ws://127.0.0.1:{self.PORT}/ws-reg/{path_val}")
227 | ws.send(msg)
228 |
229 | txt: str = ws.recv()
230 | print(txt)
231 | ws.close()
232 | assert txt == f"{path_val}-{msg}"
233 |
234 | def test_params_narrowing(self):
235 | body = self.visit("param/narrowing?a=b")
236 | assert body == 'a=b'
237 | body = self.visit("param/narrowing?a=c")
238 | assert body == 'a!=b'
239 |
240 | def test_model_binding(self):
241 | name = "keijack"
242 | sex = "male"
243 | age = 18
244 | res: Dict = self.visit(
245 | f"model_binding/person?name={name}&sex={sex}&age={age}", return_type="JSON")
246 | assert res["name"] == name
247 | assert res["sex"] == sex
248 | assert res["age"] == age
249 |
250 |
251 | class CoroutineServerTest(ThreadingServerTest):
252 |
253 | PORT = 9091
254 |
255 | COROUTINE = True
256 |
--------------------------------------------------------------------------------
/upload.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | rm dist/*
4 |
5 | python3 -m build
6 |
7 | python3 -m twine upload dist/*
--------------------------------------------------------------------------------