├── docs ├── README.md └── images │ └── .gitignore ├── src ├── forge_tools │ ├── builtin │ │ ├── __init__.mojo │ │ └── error.mojo │ ├── memory │ │ ├── deactivate │ │ │ └── __init__.mojo │ │ ├── arc_pointer.mojo │ │ ├── rc_pointer.mojo │ │ ├── rc.mojo │ │ ├── arena_pointer.mojo │ │ └── pointer.mojo │ ├── ffi │ │ ├── __init__.mojo │ │ ├── jvm │ │ │ ├── __init__.mojo │ │ │ ├── logging.mojo │ │ │ └── types.mojo │ │ └── c │ │ │ ├── __init__.mojo │ │ │ └── types.mojo │ ├── __init__.mojo │ ├── complex │ │ └── __init__.mojo │ ├── collections │ │ └── __init__.mojo │ ├── socket │ │ ├── deactivate │ │ │ └── __init__.mojo │ │ ├── _windows.mojo │ │ ├── _bsd.mojo │ │ ├── _apple.mojo │ │ ├── _linux.mojo │ │ ├── _wasi.mojo │ │ ├── _freertos.mojo │ │ └── README.md │ └── datetime │ │ ├── __init__.mojo │ │ ├── timezone.mojo │ │ └── _lists.mojo ├── benchmarks │ ├── benchmarks_array_list_inlinearray_numeric_ops.png │ ├── benchmarks_array_list_inlinearray_vector_ops.png │ ├── benchmarks_array_list_inlinearray_collection_ops.png │ ├── README.md │ └── collections │ │ ├── bench_array.mojo │ │ └── bench_list.mojo └── test │ ├── ffi │ ├── jvm │ │ └── test_logging.mojo │ └── c │ │ └── test_logging.mojo │ ├── socket │ ├── socket_echo_server.py │ └── test_socket.mojo │ ├── datetime │ ├── test_timezone.mojo │ ├── test_calendar.mojo │ └── test_zoneinfo.mojo │ ├── complex │ └── test_quaternion.mojo │ └── collections │ ├── test_dbuffer.mojo │ └── test_result.mojo ├── .gitattributes ├── .gitignore ├── pixi.toml └── scripts ├── package-lib.sh ├── test.sh └── benchmark.sh /docs/README.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/images/.gitignore: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/forge_tools/builtin/__init__.mojo: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/forge_tools/memory/deactivate/__init__.mojo: -------------------------------------------------------------------------------- 1 | """Memory package.""" 2 | -------------------------------------------------------------------------------- /src/forge_tools/ffi/__init__.mojo: -------------------------------------------------------------------------------- 1 | """Foreign Function Interface utils.""" 2 | -------------------------------------------------------------------------------- /src/forge_tools/ffi/jvm/__init__.mojo: -------------------------------------------------------------------------------- 1 | """Java Native Interface package.""" 2 | -------------------------------------------------------------------------------- /src/forge_tools/ffi/jvm/logging.mojo: -------------------------------------------------------------------------------- 1 | """Java Native Interface logging.""" 2 | -------------------------------------------------------------------------------- /src/forge_tools/__init__.mojo: -------------------------------------------------------------------------------- 1 | """Tools to extend the functionality of the Mojo standard library.""" 2 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # GitHub syntax highlighting 2 | pixi.lock linguist-language=YAML linguist-generated=true 3 | -------------------------------------------------------------------------------- /src/forge_tools/complex/__init__.mojo: -------------------------------------------------------------------------------- 1 | """Implements the complex package.""" 2 | 3 | from .quaternion import Quaternion, DualQuaternion 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # pixi environments 2 | .pixi 3 | *.egg-info 4 | # magic environments 5 | .magic/ 6 | magic.lock 7 | build/ 8 | # packages 9 | *.mojopkg 10 | -------------------------------------------------------------------------------- /src/benchmarks/benchmarks_array_list_inlinearray_numeric_ops.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinvuyk/forge-tools/HEAD/src/benchmarks/benchmarks_array_list_inlinearray_numeric_ops.png -------------------------------------------------------------------------------- /src/benchmarks/benchmarks_array_list_inlinearray_vector_ops.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinvuyk/forge-tools/HEAD/src/benchmarks/benchmarks_array_list_inlinearray_vector_ops.png -------------------------------------------------------------------------------- /src/forge_tools/collections/__init__.mojo: -------------------------------------------------------------------------------- 1 | """Implements the collections package.""" 2 | 3 | from .array import Array 4 | from .dbuffer import DBuffer 5 | from .result import Result 6 | -------------------------------------------------------------------------------- /src/benchmarks/benchmarks_array_list_inlinearray_collection_ops.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinvuyk/forge-tools/HEAD/src/benchmarks/benchmarks_array_list_inlinearray_collection_ops.png -------------------------------------------------------------------------------- /src/test/ffi/jvm/test_logging.mojo: -------------------------------------------------------------------------------- 1 | # RUN: %mojo %s 2 | 3 | from testing import assert_equal, assert_false, assert_raises, assert_true 4 | 5 | from memory import UnsafePointer, stack_allocation 6 | from forge_tools.ffi.jvm.logging import * 7 | from forge_tools.ffi.jvm.types import * 8 | 9 | 10 | def main(): 11 | ... 12 | -------------------------------------------------------------------------------- /src/forge_tools/ffi/c/__init__.mojo: -------------------------------------------------------------------------------- 1 | """FFI utils for the C programming language. 2 | 3 | Notes: 4 | The functions in this module follow only the Libc POSIX standard. Exceptions 5 | are made only for Windows. 6 | """ 7 | 8 | 9 | from .constants import * 10 | from .types import * 11 | from .libc import Libc, TryLibc 12 | -------------------------------------------------------------------------------- /src/test/socket/socket_echo_server.py: -------------------------------------------------------------------------------- 1 | import socket 2 | 3 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 4 | s.bind(("0.0.0.0", 8000)) 5 | s.listen() 6 | while True: 7 | conn, addr = s.accept() 8 | with conn: 9 | print(f"Connected by {addr}") 10 | while True: 11 | data = conn.recv(1024) 12 | if not data: 13 | break 14 | print(data) 15 | conn.sendall(data) 16 | -------------------------------------------------------------------------------- /pixi.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | authors = ["martinvuyk "] 3 | channels = ["conda-forge", "https://conda.modular.com/max-nightly"] 4 | description = "Add a short description here" 5 | name = "forge_tools" 6 | platforms = ["osx-arm64", "linux-aarch64", "linux-64"] 7 | version = "0.1.0" 8 | 9 | [tasks] 10 | build = { cmd = "./scripts/package-lib.sh" } 11 | format = { cmd = "mojo format ./src" } 12 | test = { cmd = "./scripts/test.sh" } 13 | benchmark = { cmd = "./scripts/benchmark.sh" } 14 | 15 | [dependencies] 16 | max = ">=25.6.0.dev2025082605,<26" 17 | -------------------------------------------------------------------------------- /src/forge_tools/socket/deactivate/__init__.mojo: -------------------------------------------------------------------------------- 1 | """Socket package. 2 | The goal is to achieve as close an interface as possible to 3 | Python's [socket implementation](https://docs.python.org/3/library/socket.html). 4 | 5 | Examples: 6 | ```mojo 7 | from forge_tools.socket import Socket 8 | 9 | 10 | async def main(): 11 | with Socket.create_server(("0.0.0.0", 8000)) as server: 12 | while True: 13 | conn, addr = await server.accept() 14 | ... # handle new connection 15 | 16 | # TODO: once we have async generators: 17 | # async for conn, addr in server: 18 | # ... # handle new connection 19 | ``` 20 | . 21 | """ 22 | # TODO: better docs and show examples. 23 | from .socket import ( 24 | Socket, 25 | SockFamily, 26 | SockType, 27 | SockProtocol, 28 | SockPlatform, 29 | ) 30 | from .address import IPv4Addr, IPv6Addr 31 | -------------------------------------------------------------------------------- /src/benchmarks/README.md: -------------------------------------------------------------------------------- 1 | # Benchmarks 2 | 3 | ## Benchmark Array against List and InlineArray 4 | 5 | - mojo `2024.6.2905` 6 | - Using: Intel® i7-7700HQ @2.80 GHz (Instruction Set Extensions 7 | Intel® SSE4.1, Intel® SSE4.2, Intel® AVX2). 4 cores 8 threads 8 | 9 | | Cache | | 10 | |----------|-------------------| 11 | | Cache L1 | 64 KB (per core) | 12 | | Cache L2 | 256 KB (per core) | 13 | | Cache L3 | 6 MB (shared) | 14 | 15 | - amount items = (3, 8, 16, 32, 64, 128, 256) 16 | - datatype = UInt64 17 | - average of 5 iterations with 100 warmup (several runs had similar results) 18 | 19 | #### Results for "standard" sequential collection operations: 20 | 21 | ![](./benchmarks_array_list_inlinearray_collection_ops.png) 22 | 23 | 24 | 25 | #### Numeric operations 26 | 27 | ![](./benchmarks_array_list_inlinearray_numeric_ops.png) 28 | 29 | 30 | #### Vector operations 31 | 32 | - 1k times reverse (Int64) 33 | - 1k times dot product (Float64) 34 | - 5k times cross product (Float64) 35 | 36 | ![](./benchmarks_array_list_inlinearray_vector_ops.png) 37 | -------------------------------------------------------------------------------- /scripts/package-lib.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | ##===----------------------------------------------------------------------===## 3 | # Copyright (c) 2024, Modular Inc. All rights reserved. 4 | # 5 | # Licensed under the Apache License v2.0 with LLVM Exceptions: 6 | # https://llvm.org/LICENSE.txt 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | ##===----------------------------------------------------------------------===## 14 | 15 | set -euo pipefail 16 | 17 | SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) 18 | REPO_ROOT=$(realpath "${SCRIPT_DIR}/..") 19 | BUILD_DIR="${REPO_ROOT}"/build 20 | mkdir -p "${BUILD_DIR}" 21 | 22 | LIB_PATH="${REPO_ROOT}/src/forge_tools" 23 | 24 | echo "Packaging up the Library." 25 | PACKAGE_NAME="forge_tools.mojopkg" 26 | FULL_PACKAGE_PATH="${BUILD_DIR}"/"${PACKAGE_NAME}" 27 | mojo package "${LIB_PATH}" -o "${FULL_PACKAGE_PATH}" 28 | 29 | echo Successfully created "${FULL_PACKAGE_PATH}" 30 | -------------------------------------------------------------------------------- /src/forge_tools/ffi/jvm/types.mojo: -------------------------------------------------------------------------------- 1 | """Java Native Interface types.""" 2 | 3 | 4 | # ===----------------------------------------------------------------------=== # 5 | # Base Types 6 | # ===----------------------------------------------------------------------=== # 7 | 8 | 9 | struct J: 10 | """Java types.""" 11 | 12 | alias boolean = UInt8 13 | """Type: `boolean`. Type signature: `Z`.""" 14 | alias byte = Int8 15 | """Type: `byte`. Type signature: `B`.""" 16 | alias char = UInt16 17 | """Type: `char`. Type signature: `C`.""" 18 | alias short = Int16 19 | """Type: `short`. Type signature: `S`.""" 20 | alias int = Int32 21 | """Type: `int`. Type signature: `I`.""" 22 | alias long = Int64 23 | """Type: `long`. Type signature: `J`.""" 24 | alias float = Float32 25 | """Type: `float`. Type signature: `F`.""" 26 | alias double = Float64 27 | """Type: `double`. Type signature: `D`.""" 28 | alias null = None 29 | """Type: `null`. Type signature: `V`.""" 30 | alias ptr_addr = Int 31 | """Type: A Pointer Address.""" 32 | 33 | 34 | # ===----------------------------------------------------------------------=== # 35 | # Utils 36 | # ===----------------------------------------------------------------------=== # 37 | -------------------------------------------------------------------------------- /scripts/test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | ##===----------------------------------------------------------------------===## 3 | # Copyright (c) 2024, Modular Inc. All rights reserved. 4 | # 5 | # Licensed under the Apache License v2.0 with LLVM Exceptions: 6 | # https://llvm.org/LICENSE.txt 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | ##===----------------------------------------------------------------------===## 14 | 15 | set -euo pipefail 16 | 17 | SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) 18 | REPO_ROOT="${SCRIPT_DIR}"/.. 19 | BUILD_DIR="${REPO_ROOT}"/build 20 | 21 | echo "Creating build directory for building the Library running the tests in." 22 | mkdir -p "${BUILD_DIR}" 23 | 24 | source "${SCRIPT_DIR}"/package-lib.sh 25 | TEST_PATH="${REPO_ROOT}/src/test" 26 | if [[ $# -gt 0 ]]; then 27 | # If an argument is provided, use it as the specific test file or directory 28 | TEST_PATH=$1 29 | fi 30 | 31 | # Run the tests 32 | mojo test -D ASSERT=all -I build/ $TEST_PATH 33 | -------------------------------------------------------------------------------- /scripts/benchmark.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | ##===----------------------------------------------------------------------===## 3 | # Copyright (c) 2024, Modular Inc. All rights reserved. 4 | # 5 | # Licensed under the Apache License v2.0 with LLVM Exceptions: 6 | # https://llvm.org/LICENSE.txt 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | ##===----------------------------------------------------------------------===## 14 | 15 | set -euo pipefail 16 | 17 | SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) 18 | REPO_ROOT="${SCRIPT_DIR}"/.. 19 | BUILD_DIR="${REPO_ROOT}"/build 20 | 21 | echo "Creating build directory for building the Library running the tests in." 22 | mkdir -p "${BUILD_DIR}" 23 | 24 | source "${SCRIPT_DIR}"/package-lib.sh 25 | BENCHMARK_PATH="${REPO_ROOT}/src/benchmarks" 26 | if [[ $# -gt 0 ]]; then 27 | # If an argument is provided, use it as the specific test file or directory 28 | BENCHMARK_PATH=$1 29 | fi 30 | 31 | # Run the benchmarks 32 | for f in $(find $BENCHMARK_PATH | grep -E ".mojo|.🔥"); do 33 | echo "-------------------- BENCHMARK START ${f} --------------------" 34 | mojo run -I build/ $f 35 | echo "-------------------- BENCHMARK END ${f} --------------------" 36 | done 37 | -------------------------------------------------------------------------------- /src/test/datetime/test_timezone.mojo: -------------------------------------------------------------------------------- 1 | # RUN: %mojo %s 2 | 3 | from testing import assert_equal, assert_false, assert_raises, assert_true 4 | 5 | from forge_tools.datetime.timezone import ( 6 | TimeZone, 7 | ZoneInfo, 8 | ZoneInfoMem32, 9 | ZoneInfoMem8, 10 | ) 11 | 12 | 13 | def test_tz_no_iana(): 14 | alias TZ = TimeZone[iana=False, pyzoneinfo=False, native=False] 15 | tz0 = TZ("Etc/UTC", 0, 0) 16 | tz_1 = TZ("Etc/UTC-1", 1, 0) 17 | tz_2 = TZ("Etc/UTC-2", 2, 30) 18 | tz_3 = TZ("Etc/UTC-3", 3, 45) 19 | tz1_ = TZ("Etc/UTC+1", 1, 0, -1) 20 | tz2_ = TZ("Etc/UTC+2", 2, 30, -1) 21 | tz3_ = TZ("Etc/UTC+3", 3, 45, -1) 22 | assert_true(tz0 == TZ()) 23 | assert_true(tz1_ != tz_1 and tz2_ != tz_2 and tz3_ != tz_3) 24 | d = (1970, 1, 1, 0, 0, 0) 25 | tz0_of = tz0.offset_at(d[0], d[1], d[2], d[3], d[4], d[5]) 26 | tz_1_of = tz_1.offset_at(d[0], d[1], d[2], d[3], d[4], d[5]) 27 | tz_2_of = tz_2.offset_at(d[0], d[1], d[2], d[3], d[4], d[5]) 28 | tz_3_of = tz_3.offset_at(d[0], d[1], d[2], d[3], d[4], d[5]) 29 | tz1__of = tz1_.offset_at(d[0], d[1], d[2], d[3], d[4], d[5]) 30 | tz2__of = tz2_.offset_at(d[0], d[1], d[2], d[3], d[4], d[5]) 31 | tz3__of = tz3_.offset_at(d[0], d[1], d[2], d[3], d[4], d[5]) 32 | assert_equal(tz0_of.hour, 0) 33 | assert_equal(tz0_of.minute, 0) 34 | assert_equal(tz0_of.sign, 1) 35 | assert_equal(tz_1_of.hour, 1) 36 | assert_equal(tz_1_of.minute, 0) 37 | assert_equal(tz_1_of.sign, 1) 38 | assert_equal(tz_2_of.hour, 2) 39 | assert_equal(tz_2_of.minute, 30) 40 | assert_equal(tz_2_of.sign, 1) 41 | assert_equal(tz_3_of.hour, 3) 42 | assert_equal(tz_3_of.minute, 45) 43 | assert_equal(tz_3_of.sign, 1) 44 | assert_equal(tz1__of.hour, 1) 45 | assert_equal(tz1__of.minute, 0) 46 | assert_equal(tz1__of.sign, -1) 47 | assert_equal(tz2__of.hour, 2) 48 | assert_equal(tz2__of.minute, 30) 49 | assert_equal(tz2__of.sign, -1) 50 | assert_equal(tz3__of.hour, 3) 51 | assert_equal(tz3__of.minute, 45) 52 | assert_equal(tz3__of.sign, -1) 53 | 54 | 55 | def test_tz_iana_dst(): 56 | # TODO: test from positive and negative UTC 57 | # TODO: test transitions to and from DST 58 | # TODO: test for Australia/Lord_Howe and Antarctica/Troll base 59 | pass 60 | 61 | 62 | def test_tz_iana_no_dst(): 63 | # TODO: test from positive and negative UTC 64 | pass 65 | 66 | 67 | def main(): 68 | test_tz_no_iana() 69 | test_tz_iana_dst() 70 | test_tz_iana_no_dst() 71 | -------------------------------------------------------------------------------- /src/forge_tools/memory/arc_pointer.mojo: -------------------------------------------------------------------------------- 1 | """Atomic Reference Counted Pointer module.""" 2 | 3 | from memory.arc import Arc 4 | from memory import UnsafePointer 5 | from .pointer import Pointer 6 | 7 | 8 | struct ArcPointer[ 9 | is_mutable: Bool, //, 10 | type: AnyType, 11 | origin: Origin[is_mutable], 12 | address_space: AddressSpace = AddressSpace.GENERIC, 13 | ](Copyable, Movable): 14 | """Atomic Reference Counted Pointer.""" 15 | 16 | alias _P = Pointer[type, origin, address_space] 17 | alias _U = UnsafePointer[type, address_space] 18 | var _ptr: ArcPointer[Self._P] 19 | 20 | @doc_private 21 | @always_inline("nodebug") 22 | fn __init__( 23 | out self, 24 | *, 25 | ptr: Self._U, 26 | is_allocated: Bool, 27 | in_registers: Bool, 28 | is_initialized: Bool, 29 | ): 30 | """Constructs an ArcPointer from an UnsafePointer. 31 | 32 | Args: 33 | ptr: The UnsafePointer. 34 | is_allocated: Whether the pointer's memory is allocated. 35 | in_registers: Whether the pointer is allocated in registers. 36 | is_initialized: Whether the memory is initialized. 37 | """ 38 | self._ptr = ArcPointer( 39 | Self._P( 40 | ptr=ptr, 41 | is_allocated=is_allocated, 42 | in_registers=in_registers, 43 | is_initialized=is_initialized, 44 | self_is_owner=True, 45 | ) 46 | ) 47 | 48 | fn __init__(out self, *, ptr: Self._P): 49 | """Constructs a Pointer from an Pointer. 50 | 51 | Args: 52 | ptr: The Pointer. 53 | """ 54 | self._ptr = ArcPointer(ptr) 55 | 56 | @staticmethod 57 | @always_inline 58 | fn alloc(count: Int) -> Self: 59 | """Allocate memory according to the pointer's logic. 60 | 61 | Args: 62 | count: The number of elements in the buffer. 63 | 64 | Returns: 65 | The pointer to the newly allocated buffer. 66 | """ 67 | return Self( 68 | ptr=Self._U.alloc(count).bitcast[address_space=address_space](), 69 | is_allocated=True, 70 | in_registers=False, 71 | is_initialized=False, 72 | ) 73 | 74 | fn __del__(deinit self): 75 | """Free the memory referenced by the pointer or ignore.""" 76 | 77 | @parameter 78 | if address_space is AddressSpace.GENERIC and is_mutable: 79 | if self._ptr.count() == 1: 80 | self._ptr.unsafe_ptr()[]._flags |= 0b0101_0000 81 | -------------------------------------------------------------------------------- /src/forge_tools/memory/rc_pointer.mojo: -------------------------------------------------------------------------------- 1 | """Reference Counted Pointer module.""" 2 | 3 | from memory import UnsafePointer 4 | from .pointer import Pointer 5 | from .rc import Rc 6 | 7 | 8 | struct RcPointer[ 9 | is_mutable: Bool, //, 10 | type: AnyType, 11 | origin: Origin[is_mutable.value].type, 12 | address_space: AddressSpace = AddressSpace.GENERIC, 13 | ](Copyable, Movable): 14 | """Reference Counted Pointer. 15 | 16 | Safety: 17 | This is not thread safe. 18 | """ 19 | 20 | alias _P = Pointer[type, origin, address_space] 21 | alias _U = UnsafePointer[type, address_space] 22 | var _ptr: Rc[Self._P] 23 | 24 | @doc_private 25 | @always_inline("nodebug") 26 | fn __init__( 27 | out self, 28 | *, 29 | ptr: Self._U, 30 | is_allocated: Bool, 31 | in_registers: Bool, 32 | is_initialized: Bool, 33 | ): 34 | """Constructs an RcPointer from an UnsafePointer. 35 | 36 | Args: 37 | ptr: The UnsafePointer. 38 | is_allocated: Whether the pointer's memory is allocated. 39 | in_registers: Whether the pointer is allocated in registers. 40 | is_initialized: Whether the memory is initialized. 41 | """ 42 | self._ptr = Rc( 43 | Self._P( 44 | ptr=ptr, 45 | is_allocated=is_allocated, 46 | in_registers=in_registers, 47 | is_initialized=is_initialized, 48 | self_is_owner=True, 49 | ) 50 | ) 51 | 52 | fn __init__(out self, *, ptr: Self._P): 53 | """Constructs a Pointer from an Pointer. 54 | 55 | Args: 56 | ptr: The Pointer. 57 | """ 58 | self._ptr = Rc(ptr) 59 | 60 | @staticmethod 61 | @always_inline 62 | fn alloc(count: Int) -> Self: 63 | """Allocate memory according to the pointer's logic. 64 | 65 | Args: 66 | count: The number of elements in the buffer. 67 | 68 | Returns: 69 | The pointer to the newly allocated buffer. 70 | """ 71 | return Self( 72 | ptr=Self._U.alloc(count).bitcast[address_space=address_space](), 73 | is_allocated=True, 74 | in_registers=False, 75 | is_initialized=False, 76 | ) 77 | 78 | fn __del__(deinit self): 79 | """Free the memory referenced by the pointer or ignore.""" 80 | 81 | @parameter 82 | if address_space is AddressSpace.GENERIC and is_mutable: 83 | alias P = Pointer[type, MutableAnyOrigin, AddressSpace.GENERIC] 84 | if self._ptr.count() == 1: 85 | p = rebind[P](self._ptr) 86 | p._flags = p._flags | 0b0101_0000 87 | self._ptr = rebind[__type_of(self)._P](p) 88 | -------------------------------------------------------------------------------- /src/forge_tools/builtin/error.mojo: -------------------------------------------------------------------------------- 1 | # ===----------------------------------------------------------------------=== # 2 | # Copyright (c) 2024, Martin Vuyk Loperena 3 | # 4 | # Licensed under the Apache License v2.0 with LLVM Exceptions: 5 | # https://llvm.org/LICENSE.txt 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # ===----------------------------------------------------------------------=== # 13 | """Implements the Error2 type.""" 14 | 15 | # ===----------------------------------------------------------------------===# 16 | # Error 17 | # ===----------------------------------------------------------------------===# 18 | 19 | 20 | @fieldwise_init 21 | struct Error2[T: StaticString = "AnyError"]( 22 | Boolable, Copyable, Movable, Stringable 23 | ): 24 | """This type represents a parametric Error.""" 25 | 26 | alias kind = T 27 | """The kind of Error.""" 28 | var message: String 29 | """The Error message.""" 30 | 31 | @always_inline("nodebug") 32 | fn __bool__(self) -> Bool: 33 | """Returns True if the Error is set and false otherwise. 34 | 35 | Returns: 36 | True if the Error object contains a message and False otherwise. 37 | """ 38 | return Bool(self.message) 39 | 40 | @always_inline("nodebug") 41 | fn __str__(self) -> String: 42 | """Converts the Error to string representation. 43 | 44 | Returns: 45 | A String of the Error kind and message. 46 | """ 47 | return self.kind + ": " + self.message 48 | 49 | @always_inline("nodebug") 50 | fn __repr__(self) -> String: 51 | """Converts the Error to printable representation. 52 | 53 | Returns: 54 | A printable representation of the Error message. 55 | """ 56 | return String(self) 57 | 58 | fn __eq__[A: StringLiteral](self, other: Error2[A]) -> Bool: 59 | """Whether the Errors have the same message. 60 | 61 | Args: 62 | other: The Error to compare to. 63 | 64 | Returns: 65 | The comparison. 66 | """ 67 | 68 | return self.message == other.message 69 | 70 | fn __eq__(self, value: StringLiteral) -> Bool: 71 | """Whether the Error message is set and self.kind is equal to the 72 | StringLiteral. Error kind "AnyError" matches with all errors. 73 | 74 | Args: 75 | value: The StringLiteral to compare to. 76 | 77 | Returns: 78 | The Result. 79 | """ 80 | 81 | return Bool(self) and (self.kind == value or value == "AnyError") 82 | -------------------------------------------------------------------------------- /src/forge_tools/memory/rc.mojo: -------------------------------------------------------------------------------- 1 | """Reference counter module.""" 2 | from memory import UnsafePointer 3 | 4 | 5 | struct _RcInner[T: Movable]: 6 | var refcount: UInt64 7 | var payload: T 8 | 9 | fn __init__(out self, var value: T): 10 | """Create an initialized instance of this with a refcount of 1.""" 11 | self.refcount = 1 12 | self.payload = value^ 13 | 14 | fn add_ref(out self): 15 | """Increment the refcount.""" 16 | self.refcount += 1 17 | 18 | fn drop_ref(out self): 19 | """Decrement the refcount and return true if the result hits zero.""" 20 | self.refcount -= 1 21 | 22 | 23 | @register_passable 24 | struct Rc[T: Movable]: 25 | """Reference counter.""" 26 | 27 | alias _inner_type = _RcInner[T] 28 | var _inner: UnsafePointer[_RcInner[T]] 29 | 30 | fn __init__(out self, var value: T): 31 | """Create an initialized instance of this with a refcount of 1.""" 32 | self._inner = UnsafePointer[Self._inner_type].alloc(1) 33 | # Cannot use init_pointee_move as _ArcInner isn't movable. 34 | __get_address_as_uninit_lvalue(self._inner.address) = Self._inner_type( 35 | value^ 36 | ) 37 | 38 | fn __init__(out self, *, other: Self): 39 | """Copy the object. 40 | 41 | Args: 42 | other: The value to copy. 43 | """ 44 | other._inner[].add_ref() 45 | self._inner = other._inner 46 | 47 | fn __copyinit__(out self, existing: Self): 48 | """Copy an existing reference. Increment the refcount to the object. 49 | 50 | Args: 51 | existing: The existing reference. 52 | """ 53 | # Order here does not matter since `existing` can't be destroyed until 54 | # sometime after we return. 55 | existing._inner[].add_ref() 56 | self._inner = existing._inner 57 | 58 | fn increment(out self): 59 | self._inner[].add_ref() 60 | 61 | fn decrement(out self): 62 | self._inner[].drop_ref() 63 | 64 | @no_inline 65 | fn __del__(deinit self): 66 | """Delete the smart pointer reference. 67 | 68 | Decrement the ref count for the reference. If there are no more 69 | references, delete the object and free its memory. 70 | """ 71 | self.decrement() 72 | 73 | if self.count() == 1: 74 | # Call inner destructor, then free the memory. 75 | self._inner.destroy_pointee() 76 | self._inner.free() 77 | 78 | fn count(self) -> UInt64: 79 | """Count the amount of current references. 80 | 81 | Returns: 82 | The current amount of references to the pointee. 83 | """ 84 | return self._inner[0].refcount[0] 85 | 86 | fn unsafe_ptr(self) -> UnsafePointer[T]: 87 | """Retrieves a pointer to the underlying memory. 88 | 89 | Returns: 90 | The UnsafePointer to the underlying memory. 91 | """ 92 | return UnsafePointer(to=self._inner[0].payload) 93 | -------------------------------------------------------------------------------- /src/test/complex/test_quaternion.mojo: -------------------------------------------------------------------------------- 1 | from testing import assert_equal, assert_false, assert_true, assert_almost_equal 2 | 3 | from forge_tools.complex import Quaternion, DualQuaternion 4 | 5 | 6 | def test_quaternion_ops(): 7 | q1 = Quaternion(2, 3, 4, 5) 8 | q2 = Quaternion(2, 3, 4, 5) 9 | q3 = Quaternion(5, 4, 3, 2) 10 | assert_almost_equal(7.348, q1.__abs__(), rtol=0.1) 11 | assert_almost_equal(Quaternion(4, 6, 8, 10).vec, (q1 + q2).vec, rtol=0.1) 12 | assert_almost_equal(Quaternion(0, 0, 0, 0).vec, (q1 - q2).vec, rtol=0.1) 13 | assert_almost_equal( 14 | Quaternion(-46, 12, 16, 20).vec, (q1 * q2).vec, rtol=0.1 15 | ) 16 | assert_almost_equal((q1 * q2).vec, (q2 * q1).vec, rtol=0.1) 17 | assert_almost_equal( 18 | Quaternion(-24, 16, 40, 22).vec, (q1 * q3).vec, rtol=0.1 19 | ) 20 | assert_almost_equal( 21 | Quaternion(-24, 30, 12, 36).vec, (q3 * q1).vec, rtol=0.1 22 | ) 23 | assert_almost_equal( 24 | Quaternion(0.0925926, -0.0740741, -0.0555556, -0.037037).vec, 25 | q3.inverse().vec, 26 | rtol=0.1, 27 | ) 28 | assert_almost_equal( 29 | Quaternion(0.037037, -0.0555556, -0.0740741, -0.0925926).vec, 30 | q1.inverse().vec, 31 | rtol=0.1, 32 | ) 33 | assert_almost_equal( 34 | Quaternion(0.8148, 0.2593, 0, 0.5185).vec, (q1 / q3).vec, rtol=0.1 35 | ) 36 | assert_almost_equal( 37 | Quaternion(0.8148, -0.2593, 0, -0.5185).vec, (q3 / q1).vec, rtol=0.1 38 | ) 39 | # TODO 40 | # assert_almost_equal(..., q1**3) 41 | # assert_almost_equal(..., q1.exp()) 42 | # assert_almost_equal(..., q1.ln()) 43 | # assert_almost_equal(..., q1.sqrt()) 44 | # assert_almost_equal(..., q1.phi()) 45 | 46 | 47 | def test_quaternion_matrix(): 48 | # TODO 49 | # q1 = Quaternion(2, 3, 4, 5) 50 | # assert_almost_equal( 51 | # List( 52 | # -14 / 27, 53 | # 2 / 27, 54 | # 23 / 27, 55 | # 22 / 27, 56 | # -7 / 27, 57 | # 14 / 27, 58 | # 7 / 27, 59 | # 26 / 27, 60 | # 2 / 27, 61 | # ), 62 | # q1.to_matrix(), 63 | # rtol=0.1, 64 | # ) 65 | pass 66 | 67 | 68 | def test_dualquaternion_ops(): 69 | q1 = DualQuaternion(2, 3, 4, 5, 6, 7, 8, 9) 70 | q2 = DualQuaternion(2, 3, 4, 5, 6, 7, 8, 9) 71 | q3 = DualQuaternion(9, 8, 7, 6, 5, 4, 3, 2) 72 | assert_almost_equal( 73 | DualQuaternion(4, 6, 8, 10, 12, 14, 16, 18).vec, (q1 + q2).vec, rtol=0.1 74 | ) 75 | assert_almost_equal( 76 | DualQuaternion(0, 0, 0, 0, 0, 0, 0, 0).vec, (q1 - q2).vec, rtol=0.1 77 | ) 78 | assert_almost_equal( 79 | DualQuaternion(-46, 12, 16, 20, -172, 64, 80, 96).vec, 80 | (q1 * q2).vec, 81 | rtol=0.1, 82 | ) 83 | assert_almost_equal((q1 * q2).vec, (q2 * q1).vec, rtol=0.1) 84 | assert_almost_equal( 85 | DualQuaternion(-64, 32, 72, 46, -136, 112, 184, 124).vec, 86 | (q1 * q3).vec, 87 | rtol=0.1, 88 | ) 89 | assert_almost_equal( 90 | DualQuaternion(-64, 54, 28, 68, -136, 156, 96, 168).vec, 91 | (q3 * q1).vec, 92 | rtol=0.1, 93 | ) 94 | # TODO 95 | # assert_almost_equal(..., q1**3) 96 | 97 | 98 | def test_dualquaternion_matrix(): 99 | # TODO 100 | pass 101 | 102 | 103 | def test_dualquaternion_screw(): 104 | # TODO 105 | pass 106 | 107 | 108 | def main(): 109 | test_quaternion_ops() 110 | test_quaternion_matrix() 111 | test_dualquaternion_ops() 112 | test_dualquaternion_matrix() 113 | test_dualquaternion_screw() 114 | -------------------------------------------------------------------------------- /src/forge_tools/datetime/__init__.mojo: -------------------------------------------------------------------------------- 1 | # ===----------------------------------------------------------------------=== # 2 | # Copyright (c) 2024, Martin Vuyk Loperena 3 | # 4 | # Licensed under the Apache License v2.0 with LLVM Exceptions: 5 | # https://llvm.org/LICENSE.txt 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # ===----------------------------------------------------------------------=== # 13 | """The datetime package. 14 | 15 | - `DateTime` 16 | - A structure aware of TimeZone, Calendar, and leap days and seconds. 17 | - Nanosecond resolution, though when using dunder methods (e.g. 18 | `dt1 == dt2`) it has only Microsecond resolution. 19 | - `Date` 20 | - A structure aware of TimeZone, Calendar, and leap days and seconds. 21 | - `DateTime64`, `DateTime32`, `DateTime16`, `DateTime8` 22 | - Fast implementations of DateTime, no leap seconds or years, 23 | and some have much lower resolutions but better performance. 24 | - `TimeZone` 25 | - By default UTC, highly customizable and options for full or partial 26 | IANA timezones support. 27 | - Notes: 28 | - The caveats of each implementation are better explained in each struct's docstrings. 29 | 30 | Examples: 31 | 32 | ```mojo 33 | from testing import assert_equal, assert_true 34 | from forge_tools.datetime import DateTime, Calendar, IsoFormat 35 | from forge_tools.datetime.calendar import PythonCalendar, UTCCalendar 36 | 37 | alias DateT = DateTime[iana=False, pyzoneinfo=False, native=False] 38 | dt = DateT(2024, 6, 18, 22, 14, 7) 39 | print(dt) # 2024-06-18T22:14:07+00:00 40 | alias fstr = IsoFormat.HH_MM_SS 41 | iso_str = dt.to_iso[fstr]() 42 | dt = ( 43 | DateT.from_iso[fstr](iso_str, calendar=Calendar(2024, 6, 18)) 44 | .value() 45 | .replace(calendar=Calendar()) # Calendar() == PythonCalendar 46 | ) 47 | print(dt) # 2024-06-18T22:14:07+00:00 48 | 49 | 50 | # TODO: current mojo limitation. Parametrized structs need to be bound to an 51 | # alias and used for interoperability 52 | # customtz = TimeZone[False, False, False]("my_str", 1, 0) 53 | tz_0 = DateT._tz("my_str", 0, 0) 54 | tz_1 = DateT._tz("my_str", 1, 0) 55 | assert_equal(DateT(2024, 6, 18, 0, tz=tz_0), DateT(2024, 6, 18, 1, tz=tz_1)) 56 | 57 | 58 | # using python and unix calendar should have no difference in results 59 | alias pycal = PythonCalendar 60 | alias unixcal = UTCCalendar 61 | tz_0_ = DateT._tz("Etc/UTC", 0, 0) 62 | tz_1 = DateT._tz("Etc/UTC-1", 1, 0) 63 | tz1_ = DateT._tz("Etc/UTC+1", 1, 0, -1) 64 | 65 | dt = DateT(2022, 6, 1, tz=tz_0_, calendar=pycal) + DateT( 66 | 2, 6, 31, tz=tz_0_, calendar=pycal 67 | ) 68 | offset_0 = DateT(2025, 1, 1, tz=tz_0_, calendar=unixcal) 69 | offset_p_1 = DateT(2025, 1, 1, hour=1, tz=tz_1, calendar=unixcal) 70 | offset_n_1 = DateT(2024, 12, 31, hour=23, tz=tz1_, calendar=unixcal) 71 | assert_equal(dt, offset_0) 72 | assert_equal(dt, offset_p_1) 73 | assert_equal(dt, offset_n_1) 74 | 75 | 76 | fstr = "mojo: %Y🔥%m🤯%d" 77 | assert_equal("mojo: 0009🔥06🤯01", DateT(9, 6, 1).strftime(fstr)) 78 | fstr = "%Y-%m-%d %H:%M:%S.%f" 79 | ref1 = DateT(2024, 9, 9, 9, 9, 9, 9, 9) 80 | assert_equal("2024-09-09 09:09:09.009009", ref1.strftime(fstr)) 81 | 82 | 83 | fstr = "mojo: %Y🔥%m🤯%d" 84 | vstr = "mojo: 0009🔥06🤯01" 85 | ref1 = DateT(9, 6, 1) 86 | parsed = DateT.strptime(vstr, fstr) 87 | assert_true(parsed) 88 | assert_equal(ref1, parsed.value()) 89 | fstr = "%Y-%m-%d %H:%M:%S.%f" 90 | vstr = "2024-09-09 09:09:09.009009" 91 | ref1 = DateT(2024, 9, 9, 9, 9, 9, 9, 9) 92 | parsed = DateT.strptime(vstr, fstr) 93 | assert_true(parsed) 94 | assert_equal(ref1, parsed.value()) 95 | ``` 96 | . 97 | """ 98 | 99 | from .calendar import Calendar, ZeroCalendar 100 | from .date import Date 101 | from .datetime import DateTime, timedelta 102 | from .dt_str import IsoFormat 103 | from .fast import DateTime64, DateTime32, DateTime16, DateTime8 104 | from .timezone import TimeZone 105 | from .zoneinfo import get_zoneinfo 106 | 107 | alias datetime = DateTime 108 | """A `DateTime` alias for Python code not to break to much.""" 109 | alias date = Date 110 | """A `Date` alias for Python code not to break too much.""" 111 | -------------------------------------------------------------------------------- /src/test/socket/test_socket.mojo: -------------------------------------------------------------------------------- 1 | from testing import assert_equal, assert_false, assert_true, assert_almost_equal 2 | from sys import is_big_endian 3 | from bit import bit_reverse 4 | from memory import UnsafePointer, stack_allocation 5 | from utils import StringSlice, Span 6 | 7 | from forge_tools.socket import Socket 8 | 9 | 10 | def test_ntohs(): 11 | value = UInt16(1 << 15) 12 | res = Socket.ntohs(value) 13 | 14 | @parameter 15 | if is_big_endian(): 16 | assert_equal(value, res) 17 | else: 18 | assert_equal(1 << 7, res) 19 | 20 | 21 | def test_ntohl(): 22 | value = UInt32(1 << 31) 23 | res = Socket.ntohl(value) 24 | 25 | @parameter 26 | if is_big_endian(): 27 | assert_equal(value, res) 28 | else: 29 | assert_equal(1 << 7, res) 30 | 31 | 32 | def test_htons(): 33 | value = UInt16(1 << 15) 34 | res = Socket.htons(value) 35 | 36 | @parameter 37 | if is_big_endian(): 38 | assert_equal(value, res) 39 | else: 40 | assert_equal(1 << 7, res) 41 | 42 | 43 | def test_htonl(): 44 | value = UInt32(1 << 31) 45 | res = Socket.htonl(value) 46 | 47 | @parameter 48 | if is_big_endian(): 49 | assert_equal(value, res) 50 | else: 51 | assert_equal(1 << 7, res) 52 | 53 | 54 | def test_inet_aton(): 55 | res = Socket.inet_aton(String("123.45.67.89")) 56 | assert_true(res) 57 | value = UInt32(0b01111011001011010100001101011001) 58 | 59 | @parameter 60 | if not is_big_endian(): 61 | b0 = value << 24 62 | b1 = (value << 8) & 0xFF_00_00 63 | b2 = (value >> 8) & 0xFF_00 64 | b3 = value >> 24 65 | value = b0 | b1 | b2 | b3 66 | assert_equal(value, res.value()) 67 | 68 | 69 | def test_inet_ntoa(): 70 | value = UInt32(0b01111011001011010100001101011001) 71 | 72 | @parameter 73 | if not is_big_endian(): 74 | b0 = value << 24 75 | b1 = (value << 8) & 0xFF_00_00 76 | b2 = (value >> 8) & 0xFF_00 77 | b3 = value >> 24 78 | value = b0 | b1 | b2 | b3 79 | res = Socket.inet_ntoa(value) 80 | assert_equal(String("123.45.67.89"), res) 81 | 82 | 83 | def test_server_sync_ipv4(): 84 | socket = Socket() 85 | socket.bind(("0.0.0.0", 8001)) 86 | socket.listen() 87 | 88 | 89 | # def test_client_sync_ipv4(): 90 | # socket = Socket() 91 | # await socket.connect(("0.0.0.0", 8000)) 92 | # client_msg = String("123456789") 93 | # bytes_sent = await socket.send(client_msg.as_bytes_span()) 94 | # _ = socket 95 | 96 | 97 | def test_create_server_sync_ipv4(): 98 | server = Socket.create_server(("0.0.0.0", 8002)) 99 | _ = server^ 100 | 101 | 102 | # def test_create_connection_sync_ipv4(): 103 | # client = Socket.create_connection(("0.0.0.0", 8000)) 104 | # _ = client 105 | 106 | # async def test_client_server_ipv4(): 107 | # server = Socket.create_server(("0.0.0.0", 8000)) 108 | # client = Socket.create_connection(("0.0.0.0", 8000)) 109 | 110 | # client_msg = String("123456789") 111 | # bytes_sent = client.send(client_msg.as_bytes_span()) 112 | # conn = (await server.accept())[0] 113 | # assert_equal(9, await bytes_sent^) 114 | 115 | # alias Life = ImmutableAnyLifetime 116 | # server_ptr = UnsafePointer[UInt8](stack_allocation[10, UInt8]()) 117 | # server_buf = Span[UInt8, Life](unsafe_ptr=server_ptr, len=10) 118 | # server_bytes_recv = await conn.recv(server_buf) 119 | # assert_equal(9, server_bytes_recv) 120 | # assert_equal(client_msg, String(ptr=server_ptr, len=10)) 121 | 122 | # server_msg = String("987654321") 123 | # server_sent = await conn.send(server_msg.as_bytes_span()) 124 | # assert_equal(9, server_sent) 125 | 126 | # client_ptr = UnsafePointer[UInt8](stack_allocation[10, UInt8]()) 127 | # client_buf = Span[UInt8, Life](unsafe_ptr=client_ptr, len=10) 128 | # client_bytes_recv = await client.recv(client_buf) 129 | # assert_equal(9, client_bytes_recv) 130 | # assert_equal(server_msg, String(ptr=client_ptr, len=10)) 131 | 132 | 133 | def main(): 134 | test_ntohs() 135 | test_ntohl() 136 | test_htons() 137 | test_htonl() 138 | test_inet_aton() 139 | test_inet_ntoa() 140 | test_server_sync_ipv4() 141 | # test_client_sync_ipv4() 142 | test_create_server_sync_ipv4() 143 | # test_create_connection_sync_ipv4() 144 | # await test_client_server_ipv4() 145 | -------------------------------------------------------------------------------- /src/forge_tools/socket/_windows.mojo: -------------------------------------------------------------------------------- 1 | from collections import Optional 2 | from memory import UnsafePointer, ArcPointer 3 | from sys.intrinsics import _type_is_eq 4 | from memory import Span 5 | from forge_tools.ffi.c.types import C 6 | from .socket import ( 7 | # SocketInterface, 8 | SockType, 9 | SockProtocol, 10 | ) 11 | from .address import SockFamily, SockAddr, IPv4Addr, IPv6Addr 12 | 13 | 14 | struct _WindowsSocket[ 15 | sock_family: SockFamily, 16 | sock_type: SockType, 17 | sock_protocol: SockProtocol, 18 | sock_address: SockAddr, 19 | ](Copyable, Movable): 20 | var fd: ArcPointer[FileDescriptor] 21 | """The Socket's `ArcPointer[FileDescriptor]`.""" 22 | 23 | fn __init__(out self) raises: 24 | """Create a new socket object.""" 25 | raise Error("Failed to create socket.") 26 | 27 | fn __init__(out self, fd: ArcPointer[FileDescriptor]): 28 | """Create a new socket object from an open `ArcPointer[FileDescriptor]`. 29 | """ 30 | self.fd = fd 31 | 32 | fn close(var self) raises: 33 | """Closes the Socket.""" 34 | ... # TODO: implement 35 | 36 | fn __del__(deinit self): 37 | """Closes the Socket if it's the last reference to its 38 | `FileDescriptor`. 39 | """ 40 | ... 41 | 42 | fn bind(self, address: sock_address) raises: 43 | """Bind the socket to address. The socket must not already be bound.""" 44 | ... 45 | 46 | fn listen(self, backlog: UInt = 0) raises: 47 | """Enable a server to accept connections. `backlog` specifies the number 48 | of unaccepted connections that the system will allow before refusing 49 | new connections. If `backlog == 0`, a default value is chosen. 50 | """ 51 | ... 52 | 53 | async fn connect(self, address: sock_address) raises: 54 | """Connect to a remote socket at address.""" 55 | ... 56 | 57 | async fn accept(self) -> Optional[(Self, sock_address)]: 58 | """Return a new socket representing the connection, and the address of 59 | the client. 60 | """ 61 | return None 62 | 63 | @staticmethod 64 | fn socketpair() raises -> (Self, Self): 65 | """Create a pair of socket objects from the sockets returned by the 66 | platform `socketpair()` function.""" 67 | raise Error("Failed to create socket.") 68 | 69 | fn get_fd(self) -> FileDescriptor: 70 | """Get the Socket's FileDescriptor.""" 71 | return 0 72 | 73 | async fn send_fds(self, fds: List[FileDescriptor]) -> Bool: 74 | """Send file descriptors to the socket.""" 75 | return False 76 | 77 | async fn recv_fds(self, maxfds: Int) -> List[FileDescriptor]: 78 | """Receive file descriptors from the socket.""" 79 | return List[FileDescriptor]() 80 | 81 | async fn send(self, buf: Span[UInt8], flags: C.int = 0) -> Int: 82 | """Send a buffer of bytes to the socket.""" 83 | return -1 84 | 85 | async fn recv[ 86 | O: MutableOrigin 87 | ](self, buf: Span[UInt8, O], flags: C.int = 0) -> Int: 88 | return -1 89 | 90 | @staticmethod 91 | fn gethostname() -> Optional[String]: 92 | """Return the current hostname.""" 93 | return None 94 | 95 | @staticmethod 96 | fn gethostbyname(name: String) -> Optional[sock_address]: 97 | """Map a hostname to its Address.""" 98 | return None 99 | 100 | @staticmethod 101 | fn gethostbyaddr(address: sock_address) -> Optional[String]: 102 | """Map an Address to DNS info.""" 103 | return None 104 | 105 | @staticmethod 106 | fn getservbyname(name: String) -> Optional[sock_address]: 107 | """Map a service name and a protocol name to a port number.""" 108 | return None 109 | 110 | @staticmethod 111 | fn getdefaulttimeout() -> Optional[Float64]: 112 | """Get the default timeout value.""" 113 | return None 114 | 115 | @staticmethod 116 | fn setdefaulttimeout(value: Optional[Float64]) -> Bool: 117 | """Set the default timeout value.""" 118 | return False 119 | 120 | fn settimeout(self, value: Optional[Float64]) -> Bool: 121 | """Set the socket timeout value.""" 122 | return False 123 | 124 | @staticmethod 125 | fn create_connection( 126 | address: IPv4Addr, 127 | timeout: Optional[Float64] = None, 128 | source_address: IPv4Addr = IPv4Addr(("", 0)), 129 | *, 130 | all_errors: Bool = False, 131 | ) raises -> Self: 132 | """Connects to an address, with an optional timeout and optional source 133 | address.""" 134 | alias s = sock_address 135 | alias cond = _type_is_eq[s, IPv4Addr]() or _type_is_eq[s, IPv6Addr]() 136 | constrained[cond, "sock_address must be IPv4Addr or IPv6Addr"]() 137 | raise Error("Failed to create socket.") 138 | 139 | @staticmethod 140 | fn create_server( 141 | address: IPv4Addr, 142 | *, 143 | backlog: Optional[Int] = None, 144 | reuse_port: Bool = False, 145 | ) raises -> Self: 146 | """Create a socket, bind it to a specified address, and listen.""" 147 | constrained[ 148 | _type_is_eq[sock_address, IPv4Addr](), 149 | "sock_address must be IPv4Addr", 150 | ]() 151 | raise Error("Failed to create socket.") 152 | 153 | @staticmethod 154 | fn create_server( 155 | address: IPv6Addr, 156 | *, 157 | backlog: Optional[Int] = None, 158 | reuse_port: Bool = False, 159 | dualstack_ipv6: Bool = False, 160 | ) raises -> Self: 161 | """Create a socket, bind it to a specified address, and listen.""" 162 | constrained[ 163 | _type_is_eq[sock_address, IPv6Addr](), 164 | "sock_address must be IPv6Addr", 165 | ]() 166 | raise Error("Failed to create socket.") 167 | 168 | fn keep_alive( 169 | self, 170 | enable: Bool = True, 171 | idle: C.int = 2 * 60 * 60, 172 | interval: C.int = 75, 173 | count: C.int = 10, 174 | ) raises: 175 | """Whether and how to keep the connection alive.""" 176 | raise Error("Failed to set socket options.") 177 | 178 | fn reuse_address( 179 | self, value: Bool = True, *, full_duplicates: Bool = True 180 | ) raises: 181 | """Whether to allow duplicated addresses.""" 182 | raise Error("Failed to set socket options.") 183 | -------------------------------------------------------------------------------- /src/test/collections/test_dbuffer.mojo: -------------------------------------------------------------------------------- 1 | from collections import InlineArray, List 2 | from testing import assert_equal, assert_true 3 | 4 | from forge_tools.collections import DBuffer 5 | 6 | 7 | def test_dbuffer_list_init_trivial(): 8 | # test taking ownership 9 | var l1 = List[Int](1, 2, 3, 4, 5, 6, 7) 10 | var l1_copy = l1.copy() 11 | var s1 = DBuffer[origin=MutableAnyOrigin].own(l1^) 12 | assert_true(s1.is_owner()) 13 | assert_equal(len(s1), len(l1_copy)) 14 | for i in range(len(s1)): 15 | assert_equal(l1_copy[i], s1[i]) 16 | # subslice 17 | var slice_1 = s1[2:] 18 | assert_true(not slice_1.is_owner()) 19 | assert_equal(slice_1[0], l1_copy[2]) 20 | assert_equal(slice_1[1], l1_copy[3]) 21 | assert_equal(slice_1[2], l1_copy[4]) 22 | assert_equal(slice_1[3], l1_copy[5]) 23 | assert_equal(s1[-1], l1_copy[-1]) 24 | 25 | # test non owning Buffer 26 | var l2 = List[Int](1, 2, 3, 4, 5, 6, 7) 27 | var s2 = DBuffer(l2) 28 | assert_true(not s2.is_owner()) 29 | assert_equal(len(s2), len(l2)) 30 | for i in range(len(s2)): 31 | assert_equal(l2[i], s2[i]) 32 | # subslice 33 | var slice_2 = s2[2:] 34 | assert_true(not slice_2.is_owner()) 35 | assert_equal(slice_2[0], l2[2]) 36 | assert_equal(slice_2[1], l2[3]) 37 | assert_equal(slice_2[2], l2[4]) 38 | assert_equal(slice_2[3], l2[5]) 39 | assert_equal(s2[-1], l2[-1]) 40 | 41 | # Test mutation 42 | s2[0] = 9 43 | assert_equal(s2[0], 9) 44 | assert_equal(l2[0], 9) 45 | 46 | s2[-1] = 0 47 | assert_equal(s2[-1], 0) 48 | assert_equal(l2[-1], 0) 49 | 50 | 51 | def test_dbuffer_list_init_memory(): 52 | # test taking ownership 53 | var l1 = List[String]("a", "b", "c", "d", "e", "f", "g") 54 | var l1_copy = l1.copy() 55 | var s1 = DBuffer[origin=MutableAnyOrigin].own(l1^) 56 | assert_true(s1.is_owner()) 57 | assert_equal(len(s1), len(l1_copy)) 58 | for i in range(len(s1)): 59 | assert_equal(l1_copy[i], s1[i]) 60 | # subslice 61 | var slice_1 = s1[2:] 62 | assert_true(not slice_1.is_owner()) 63 | assert_equal(slice_1[0], l1_copy[2]) 64 | assert_equal(slice_1[1], l1_copy[3]) 65 | assert_equal(slice_1[2], l1_copy[4]) 66 | assert_equal(slice_1[3], l1_copy[5]) 67 | 68 | # test non owning Buffer 69 | var l2 = List[String]("a", "b", "c", "d", "e", "f", "g") 70 | var s2 = DBuffer(l2) 71 | assert_true(not s2.is_owner()) 72 | assert_equal(len(s2), len(l2)) 73 | for i in range(len(s2)): 74 | assert_equal(l2[i], s2[i]) 75 | # subslice 76 | var slice_2 = s2[2:] 77 | assert_true(not slice_2.is_owner()) 78 | assert_equal(slice_2[0], l2[2]) 79 | assert_equal(slice_2[1], l2[3]) 80 | assert_equal(slice_2[2], l2[4]) 81 | assert_equal(slice_2[3], l2[5]) 82 | 83 | # Test mutation 84 | s2[0] = "h" 85 | assert_equal(s2[0], "h") 86 | assert_equal(l2[0], "h") 87 | 88 | s2[-1] = "i" 89 | assert_equal(s2[-1], "i") 90 | assert_equal(l2[-1], "i") 91 | 92 | 93 | def test_dbuffer_array_int(): 94 | var l = InlineArray[Int, 7](1, 2, 3, 4, 5, 6, 7) 95 | var s = DBuffer[Int](array=l) 96 | assert_equal(len(s), len(l)) 97 | for i in range(len(s)): 98 | assert_equal(l[i], s[i]) 99 | # subslice 100 | var s2 = s[2:] 101 | assert_equal(s2[0], l[2]) 102 | assert_equal(s2[1], l[3]) 103 | assert_equal(s2[2], l[4]) 104 | assert_equal(s2[3], l[5]) 105 | 106 | # Test mutation 107 | s[0] = 9 108 | assert_equal(s[0], 9) 109 | assert_equal(l[0], 9) 110 | 111 | s[-1] = 0 112 | assert_equal(s[-1], 0) 113 | assert_equal(l[-1], 0) 114 | 115 | 116 | def test_dbuffer_array_str(): 117 | var l = InlineArray[String, 7]("a", "b", "c", "d", "e", "f", "g") 118 | var s = DBuffer[String](array=l) 119 | assert_true(not s.is_owner()) 120 | assert_equal(len(s), len(l)) 121 | for i in range(len(s)): 122 | assert_equal(l[i], s[i]) 123 | # subslice 124 | var s2 = s[2:] 125 | assert_equal(s2[0], l[2]) 126 | assert_equal(s2[1], l[3]) 127 | assert_equal(s2[2], l[4]) 128 | assert_equal(s2[3], l[5]) 129 | 130 | # Test mutation 131 | s[0] = "h" 132 | assert_equal(s[0], "h") 133 | assert_equal(l[0], "h") 134 | 135 | s[-1] = "i" 136 | assert_equal(s[-1], "i") 137 | assert_equal(l[-1], "i") 138 | 139 | 140 | def test_indexing(): 141 | var l = InlineArray[Int, 7](1, 2, 3, 4, 5, 6, 7) 142 | var s = DBuffer[Int](array=l) 143 | assert_equal(s[True], 2) 144 | assert_equal(s[Int(0)], 1) 145 | assert_equal(s[3], 4) 146 | 147 | 148 | def test_dbuffer_slice(): 149 | def compare(s: DBuffer[Int], l: List[Int]) -> Bool: 150 | if len(s) != len(l): 151 | return False 152 | for i in range(len(s)): 153 | if s[i] != l[i]: 154 | return False 155 | return True 156 | 157 | var l = List(1, 2, 3, 4, 5) 158 | var s = DBuffer(l) 159 | var res = s[1:2] 160 | assert_equal(res[0], 2) 161 | res = s[1:-1:1] 162 | assert_equal(res[0], 2) 163 | assert_equal(res[1], 3) 164 | assert_equal(res[2], 4) 165 | # Test slicing with negative step 166 | res = s[1::-1] 167 | assert_equal(res[0], 2) 168 | assert_equal(res[1], 1) 169 | res = s[2:1:-1] 170 | assert_equal(res[0], 3) 171 | assert_equal(len(res), 1) 172 | res = s[5:1:-2] 173 | assert_equal(res[0], 5) 174 | assert_equal(res[1], 3) 175 | 176 | 177 | def test_bool(): 178 | var l = InlineArray[String, 7]("a", "b", "c", "d", "e", "f", "g") 179 | var s = DBuffer[String](l) 180 | assert_true(s) 181 | assert_true(not s[0:0]) 182 | 183 | 184 | def test_equality(): 185 | var l = InlineArray[String, 7]("a", "b", "c", "d", "e", "f", "g") 186 | var l2 = List[String]("a", "b", "c", "d", "e", "f", "g") 187 | var sp = DBuffer[String](l) 188 | var sp2 = DBuffer[String](l) 189 | var sp3 = DBuffer(l2) 190 | # same pointer 191 | assert_true(sp == sp2) 192 | # different pointer 193 | assert_true(sp == sp3) 194 | # different length 195 | assert_true(sp != sp3[:-1]) 196 | # empty 197 | assert_true(sp[0:0] == sp3[0:0]) 198 | 199 | 200 | def test_fill(): 201 | var l1 = List[Int](0, 1, 2, 3, 4, 5, 6, 7, 8) 202 | var s1 = DBuffer(l1) 203 | 204 | s1.fill(2) 205 | 206 | for i in range(len(l1)): 207 | assert_equal(l1[i], 2) 208 | assert_equal(s1[i], 2) 209 | 210 | var l2 = List[String]("a", "b", "c", "d", "e", "f", "g") 211 | var s2 = DBuffer(l2) 212 | 213 | s2.fill("hi") 214 | 215 | for i in range(len(s2)): 216 | assert_equal(l2[i], "hi") 217 | assert_equal(s2[i], "hi") 218 | 219 | 220 | def main(): 221 | test_dbuffer_list_init_trivial() 222 | test_dbuffer_list_init_memory() 223 | test_dbuffer_array_int() 224 | test_dbuffer_array_str() 225 | test_indexing() 226 | test_dbuffer_slice() 227 | test_bool() 228 | test_equality() 229 | test_fill() 230 | -------------------------------------------------------------------------------- /src/test/collections/test_result.mojo: -------------------------------------------------------------------------------- 1 | # RUN: %mojo %s 2 | 3 | from testing import assert_true, assert_false, assert_equal 4 | from collections import Dict 5 | from forge_tools.collections.result import Result, Result2, Error2 6 | 7 | 8 | def _returning_err[T: Copyable & Movable](value: T) -> Result[T]: 9 | result = Result[T](err=Error("something")) 10 | if not result: 11 | return result 12 | raise Error("shouldn't get here") 13 | 14 | 15 | def _returning_ok[T: Copyable & Movable](value: T) -> Result[T]: 16 | result = Result[T](value) 17 | if result: 18 | return result 19 | raise Error("shouldn't get here") 20 | 21 | 22 | def _returning_transferred_err[T: Copyable & Movable](value: T) -> Result[T]: 23 | # this value and err at the same time will never happen, just for testing 24 | # the value "some other string" should NOT get transferred 25 | res1 = Result(String("some other string")) 26 | res1.err = Error("some error") 27 | if res1: 28 | return res1 29 | raise Error("shouldn't get here") 30 | 31 | 32 | def _returning_none_err[T: Copyable & Movable](value: T) -> Result[T]: 33 | res1 = Result[String](err=Error("some error")) 34 | if res1.err: 35 | return None, res1.err 36 | raise Error("shouldn't get here") 37 | 38 | 39 | def test_none_err_constructor(): 40 | res1 = _returning_none_err(String("some string")) 41 | assert_true(not res1 and res1.err and String(res1.err) == "some error") 42 | res2 = _returning_none_err[String]("some string") 43 | assert_true(not res2 and res2.err and String(res2.err) == "some error") 44 | res3 = _returning_none_err[StaticString]("some string") 45 | assert_true(not res3 and res3.err and String(res3.err) == "some error") 46 | res4 = _returning_none_err("some string") 47 | assert_true(not res4 and res4.err and String(res4.err) == "some error") 48 | 49 | 50 | def test_error_transfer(): 51 | res1 = _returning_transferred_err(String("some string")) 52 | assert_true(res1 is None and String(res1.err) == "some error") 53 | res2 = _returning_transferred_err[String]("some string") 54 | assert_true(res2 is None and String(res2.err) == "some error") 55 | res3 = _returning_transferred_err[StaticString]("some string") 56 | assert_true(res3 is None and String(res3.err) == "some error") 57 | res4 = _returning_transferred_err("some string") 58 | assert_true(res4 is None and String(res4.err) == "some error") 59 | 60 | 61 | def test_returning_err(): 62 | item_s = _returning_err(String("string")) 63 | assert_true(not item_s and item_s.err and String(item_s.err) == "something") 64 | # item_ti = _returning_err(Tuple[Int]()) 65 | # assert_true(not item_ti and item_ti.err and String(item_ti.err) == "something") 66 | # item_ts = _returning_err(Tuple[String]()) 67 | # assert_true(not item_ts and item_ts.err and String(item_ts.err) == "something") 68 | item_li = _returning_err(List[Int]()) 69 | assert_true( 70 | not item_li and item_li.err and String(item_li.err) == "something" 71 | ) 72 | item_ls = _returning_err(List[String]()) 73 | assert_true( 74 | not item_ls and item_ls.err and String(item_ls.err) == "something" 75 | ) 76 | item_dii = _returning_err(Dict[Int, Int]()) 77 | assert_true( 78 | not item_dii and item_dii.err and String(item_dii.err) == "something" 79 | ) 80 | item_dss = _returning_err(Dict[String, String]()) 81 | assert_true( 82 | not item_dss and item_dss.err and String(item_dss.err) == "something" 83 | ) 84 | item_oi = _returning_err(Result[Int]()) 85 | assert_true( 86 | not item_oi and item_oi.err and String(item_oi.err) == "something" 87 | ) 88 | item_os = _returning_err(Result[String]()) 89 | assert_true( 90 | not item_os and item_os.err and String(item_os.err) == "something" 91 | ) 92 | 93 | 94 | def test_returning_ok(): 95 | # this one would fail if the String gets implicitly cast to Error(src: String) 96 | item_s = _returning_ok(String("string")) 97 | assert_true(item_s and not item_s.err and String(item_s.err) == "") 98 | # item_ti = _returning_ok(Tuple[Int]()) 99 | # assert_true(item_ti and not item_ti.err and String(item_ti.err) == "") 100 | # item_ts = _returning_ok(Tuple[String]()) 101 | # assert_true(item_ts and not item_ts.err and String(item_ts.err) == "") 102 | item_li = _returning_ok(List[Int]()) 103 | assert_true(item_li and not item_li.err and String(item_li.err) == "") 104 | item_ls = _returning_ok(List[String]()) 105 | assert_true(item_ls and not item_ls.err and String(item_ls.err) == "") 106 | item_dii = _returning_ok(Dict[Int, Int]()) 107 | assert_true(item_dii and not item_dii.err and String(item_dii.err) == "") 108 | item_dss = _returning_ok(Dict[String, String]()) 109 | assert_true(item_dss and not item_dss.err and String(item_dss.err) == "") 110 | item_oi = _returning_ok(Result[Int]()) 111 | assert_true(item_oi and not item_oi.err and String(item_oi.err) == "") 112 | item_os = _returning_ok(Result[String]()) 113 | assert_true(item_os and not item_os.err and String(item_os.err) == "") 114 | 115 | 116 | def test_basic(): 117 | a = Result(1) 118 | b = Result[Int]() 119 | 120 | assert_true(a) 121 | assert_false(b) 122 | 123 | assert_true(a and True) 124 | assert_true(True and a) 125 | assert_false(a and False) 126 | 127 | assert_false(b and True) 128 | assert_false(b and False) 129 | 130 | assert_true(a or True) 131 | assert_true(a or False) 132 | 133 | assert_true(b or True) 134 | assert_false(b or False) 135 | 136 | assert_equal(1, a.value()) 137 | 138 | # Test invert operator 139 | assert_false(~a) 140 | assert_true(~b) 141 | 142 | # TODO(27776): can't inline these, they need to be mutable lvalues 143 | a1 = a.or_else(2) 144 | b1 = b.or_else(2) 145 | 146 | assert_equal(1, a1) 147 | assert_equal(2, b1) 148 | 149 | assert_equal(1, a.value()) 150 | 151 | # TODO: this currently only checks for mutable references. 152 | # We may want to come back and add an immutable test once 153 | # there are the language features to do so. 154 | a2 = Result(1) 155 | a2.value() = 2 156 | assert_equal(a2.value(), 2) 157 | 158 | 159 | def test_result_is(): 160 | a = Result(1) 161 | assert_false(a is None) 162 | 163 | a = Result[Int]() 164 | assert_true(a is None) 165 | 166 | 167 | def test_result_isnot(): 168 | a = Result(1) 169 | assert_true(a is not None) 170 | 171 | a = Result[Int]() 172 | assert_false(a is not None) 173 | 174 | 175 | def _do_something(i: Int) -> Result2[Int, "IndexError"]: 176 | if i < 0: 177 | return None, Error2["IndexError"]("index out of bounds: " + String(i)) 178 | return 1 179 | 180 | 181 | def _do_some_other_thing() -> Result2[String, "OtherError"]: 182 | a = _do_something(-1) 183 | if a.err: 184 | print(String(a.err)) # IndexError: index out of bounds: -1 185 | return a 186 | return String("success") 187 | 188 | 189 | def test_result2(): 190 | res = _do_some_other_thing() 191 | assert_false(res) 192 | assert_equal(res.err.message, "index out of bounds: -1") 193 | 194 | 195 | def main(): 196 | test_basic() 197 | test_result_is() 198 | test_result_isnot() 199 | test_returning_ok() 200 | test_returning_err() 201 | test_error_transfer() 202 | test_none_err_constructor() 203 | test_result2() 204 | -------------------------------------------------------------------------------- /src/forge_tools/ffi/c/types.mojo: -------------------------------------------------------------------------------- 1 | """C POSIX types.""" 2 | 3 | from sys.info import is_64bit 4 | from sys.ffi import external_call 5 | from os import abort 6 | from utils import StaticTuple 7 | from memory import memcpy, UnsafePointer 8 | 9 | # ===----------------------------------------------------------------------=== # 10 | # Base Types 11 | # ===----------------------------------------------------------------------=== # 12 | 13 | 14 | struct C: 15 | """C types. This assumes that the platform is 32 or 64 bit, and char is 16 | always 8 bit (POSIX standard). 17 | """ 18 | 19 | alias char = Int8 20 | """Type: `char`. The signedness of `char` is platform specific. Most 21 | systems, including x86 GNU/Linux and Windows, use `signed char`, but those 22 | based on PowerPC and ARM processors typically use `unsigned char`.""" 23 | alias s_char = Int8 24 | """Type: `signed char`.""" 25 | alias u_char = UInt8 26 | """Type: `unsigned char`.""" 27 | alias short = Int16 28 | """Type: `short`.""" 29 | alias u_short = UInt16 30 | """Type: `unsigned short`.""" 31 | alias int = Int32 32 | """Type: `int`.""" 33 | alias u_int = UInt32 34 | """Type: `unsigned int`.""" 35 | alias long = Scalar[_c_long_dtype()] 36 | """Type: `long`.""" 37 | alias u_long = Scalar[_c_u_long_dtype()] 38 | """Type: `unsigned long`.""" 39 | alias long_long = Int64 40 | """Type: `long long`.""" 41 | alias u_long_long = UInt64 42 | """Type: `unsigned long long`.""" 43 | alias float = Float32 44 | """Type: `float`.""" 45 | alias double = Float64 46 | """Type: `double`.""" 47 | alias void = Int8 48 | """Type: `void`.""" 49 | alias NULL = UnsafePointer[Self.void]() 50 | """Constant: NULL pointer.""" 51 | alias ptr_addr = Int 52 | """Type: A Pointer Address.""" 53 | 54 | 55 | alias size_t = UInt 56 | """Type: `size_t`.""" 57 | alias ssize_t = Int 58 | """Type: `ssize_t`.""" 59 | 60 | 61 | # ===----------------------------------------------------------------------=== # 62 | # Utils 63 | # ===----------------------------------------------------------------------=== # 64 | 65 | 66 | fn _c_long_dtype() -> DType: 67 | # https://en.wikipedia.org/wiki/64-bit_computing#64-bit_data_models 68 | 69 | @parameter 70 | if is_64bit() and CompilationTarget.is_windows(): 71 | return DType.int32 # LLP64 72 | elif is_64bit(): 73 | return DType.int64 # LP64 74 | else: 75 | return DType.int32 # ILP32 76 | 77 | 78 | fn _c_u_long_dtype() -> DType: 79 | # https://en.wikipedia.org/wiki/64-bit_computing#64-bit_data_models 80 | 81 | @parameter 82 | if is_64bit() and CompilationTarget.is_windows(): 83 | return DType.uint32 # LLP64 84 | elif is_64bit(): 85 | return DType.uint64 # LP64 86 | else: 87 | return DType.uint32 # ILP32 88 | 89 | 90 | @always_inline 91 | fn char_ptr(item: String) -> UnsafePointer[C.char]: 92 | """Get the `C.char` pointer. 93 | 94 | Args: 95 | item: The item. 96 | 97 | Returns: 98 | The pointer. 99 | """ 100 | return item.unsafe_ptr().bitcast[C.char]() 101 | 102 | 103 | @always_inline 104 | fn char_ptr(item: StringSlice) -> UnsafePointer[C.char]: 105 | """Get the `C.char` pointer. 106 | 107 | Args: 108 | item: The item. 109 | 110 | Returns: 111 | The pointer. 112 | """ 113 | return item.unsafe_ptr().bitcast[C.char]() 114 | 115 | 116 | @always_inline 117 | fn char_ptr(item: StringLiteral) -> UnsafePointer[C.char]: 118 | """Get the `C.char` pointer. 119 | 120 | Args: 121 | item: The item. 122 | 123 | Returns: 124 | The pointer. 125 | """ 126 | return item.unsafe_ptr().bitcast[C.char]() 127 | 128 | 129 | @always_inline 130 | fn char_ptr[T: AnyType](ptr: UnsafePointer[T]) -> UnsafePointer[C.char]: 131 | """Get the `C.char` pointer. 132 | 133 | Parameters: 134 | T: The type. 135 | 136 | Args: 137 | ptr: The pointer. 138 | 139 | Returns: 140 | The pointer. 141 | """ 142 | return ptr.bitcast[C.char]() 143 | 144 | 145 | fn char_ptr_to_string(s: UnsafePointer[C.char]) -> String: 146 | """Create a String **copying** a **null terminated** char pointer. 147 | 148 | Args: 149 | s: A pointer to a C string. 150 | 151 | Returns: 152 | The String. 153 | """ 154 | idx = 0 155 | while s[idx] != 0: 156 | idx += 1 157 | return String( 158 | bytes=rebind[Span[Byte, MutableAnyOrigin]]( 159 | Span(ptr=s.bitcast[Byte](), length=idx) 160 | ) 161 | ) 162 | 163 | 164 | # ===----------------------------------------------------------------------=== # 165 | # Networking Types 166 | # ===----------------------------------------------------------------------=== # 167 | 168 | alias sa_family_t = C.u_short 169 | """Type: `sa_family_t`.""" 170 | alias socklen_t = C.u_int 171 | """Type: `socklen_t`.""" 172 | alias in_addr_t = C.u_int 173 | """Type: `in_addr_t`.""" 174 | alias in_port_t = C.u_short 175 | """Type: `in_port_t`.""" 176 | 177 | 178 | @fieldwise_init 179 | @register_passable("trivial") 180 | struct in_addr: 181 | """Incoming IPv4 Socket Address.""" 182 | 183 | var s_addr: in_addr_t 184 | """Source Address.""" 185 | 186 | 187 | @fieldwise_init 188 | @register_passable("trivial") 189 | struct in6_addr: 190 | """Incoming IPv6 Socket Address.""" 191 | 192 | var s6_addr: StaticTuple[C.char, 16] 193 | """Source IPv6 Address.""" 194 | 195 | 196 | @fieldwise_init 197 | @register_passable("trivial") 198 | struct sockaddr: 199 | """Socket Address.""" 200 | 201 | var sa_family: sa_family_t 202 | """Socket Address Family.""" 203 | var sa_data: StaticTuple[C.char, 14] 204 | """Socket Address.""" 205 | 206 | 207 | @fieldwise_init 208 | @register_passable("trivial") 209 | struct sockaddr_in: 210 | """Incoming Socket Address.""" 211 | 212 | var sin_family: sa_family_t 213 | """Socket Address Family.""" 214 | var sin_port: in_port_t 215 | """Socket Address Port.""" 216 | var sin_addr: in_addr 217 | """Socket Address.""" 218 | var sin_zero: StaticTuple[C.char, 8] 219 | """Socket zero padding.""" 220 | 221 | 222 | @fieldwise_init 223 | @register_passable("trivial") 224 | struct sockaddr_in6: 225 | """Incoming IPv6 Socket Address.""" 226 | 227 | var sin6_family: sa_family_t 228 | """Socket Address Family.""" 229 | var sin6_port: in_port_t 230 | """Socket Address Port.""" 231 | var sin6_flowinfo: C.u_int 232 | """Flow Information.""" 233 | var sin6_addr: in6_addr 234 | """Socket Address.""" 235 | var sin6_scope_id: C.u_int 236 | """Scope ID.""" 237 | 238 | 239 | @fieldwise_init 240 | @register_passable("trivial") 241 | struct addrinfo: 242 | """Address Information.""" 243 | 244 | var ai_flags: C.int 245 | """Address Information Flags.""" 246 | var ai_family: C.int 247 | """Address Family.""" 248 | var ai_socktype: C.int 249 | """Socket Type.""" 250 | var ai_protocol: C.int 251 | """Socket Protocol.""" 252 | var ai_addrlen: socklen_t 253 | """Socket Address Length.""" 254 | var ai_addr: UnsafePointer[sockaddr] 255 | """Address Information.""" 256 | var ai_canonname: UnsafePointer[C.char] 257 | """Canon Name.""" 258 | # FIXME: This should be UnsafePointer[addrinfo] 259 | var ai_next: UnsafePointer[C.void] 260 | """Next Address Information struct.""" 261 | 262 | fn __init__(out self): 263 | """Construct an empty addrinfo struct.""" 264 | p0 = UnsafePointer[sockaddr]() 265 | self = Self(0, 0, 0, 0, 0, p0, UnsafePointer[C.char](), C.NULL) 266 | 267 | 268 | # ===----------------------------------------------------------------------=== # 269 | # File Types 270 | # ===----------------------------------------------------------------------=== # 271 | 272 | alias off_t = Int64 273 | """Type: `off_t`.""" 274 | alias mode_t = UInt32 275 | """Type: `mode_t`.""" 276 | 277 | 278 | @register_passable("trivial") 279 | struct FILE: 280 | """Type: `FILE`.""" 281 | 282 | pass 283 | -------------------------------------------------------------------------------- /src/forge_tools/socket/_bsd.mojo: -------------------------------------------------------------------------------- 1 | from collections import Optional 2 | from memory import UnsafePointer, ArcPointer 3 | from sys.intrinsics import _type_is_eq 4 | from memory import Span 5 | from forge_tools.ffi.c.types import C 6 | from .socket import ( 7 | # SocketInterface, 8 | SockType, 9 | SockProtocol, 10 | ) 11 | from .address import SockFamily, SockAddr, IPv4Addr, IPv6Addr 12 | from ._unix import _UnixSocket 13 | 14 | 15 | struct _BSDSocket[ 16 | sock_family: SockFamily, 17 | sock_type: SockType, 18 | sock_protocol: SockProtocol, 19 | sock_address: SockAddr, 20 | ](Copyable, Movable): 21 | alias _ST = _UnixSocket[sock_family, sock_type, sock_protocol, sock_address] 22 | var _sock: Self._ST 23 | 24 | alias _ipv4 = _BSDSocket[ 25 | SockFamily.AF_INET, sock_type, sock_protocol, IPv4Addr 26 | ] 27 | 28 | fn __init__(out self) raises: 29 | """Create a new socket object.""" 30 | self._sock = Self._ST() 31 | 32 | fn __init__(out self, fd: ArcPointer[FileDescriptor]): 33 | """Create a new socket object from an open `ArcPointer[FileDescriptor]`. 34 | """ 35 | self._sock = Self._ST(fd=fd) 36 | 37 | fn close(var self) raises: 38 | """Closes the Socket.""" 39 | self._sock.close() 40 | 41 | fn __del__(deinit self): 42 | """Closes the Socket if it's the last reference to its 43 | `FileDescriptor`. 44 | """ 45 | ... 46 | 47 | fn setsockopt[ 48 | D: DType = C.int.element_type 49 | ](self, level: C.int, option_name: C.int, option_value: Scalar[D]) raises: 50 | """Set socket options.""" 51 | self._sock.setsockopt(level, option_name, option_value) 52 | 53 | fn bind(self, address: sock_address) raises: 54 | """Bind the socket to address. The socket must not already be bound.""" 55 | self._sock.bind(address) 56 | 57 | fn listen(self, backlog: UInt = 0) raises: 58 | """Enable a server to accept connections. `backlog` specifies the number 59 | of unaccepted connections that the system will allow before refusing 60 | new connections. If `backlog == 0`, a default value is chosen. 61 | """ 62 | self._sock.listen(backlog) 63 | 64 | async fn connect(self, address: sock_address) raises: 65 | """Connect to a remote socket at address.""" 66 | await self._sock.connect(address) 67 | 68 | async fn accept(self) -> Optional[(Self, sock_address)]: 69 | """Return a new socket representing the connection, and the address of 70 | the client.""" 71 | res = await self._sock.accept() 72 | if not res: 73 | return None 74 | s_a = res.value() 75 | return Self(fd=s_a[0].get_fd()), s_a[1] 76 | 77 | @staticmethod 78 | fn socketpair() raises -> (Self, Self): 79 | """Create a pair of socket objects from the sockets returned by the 80 | platform `socketpair()` function.""" 81 | s_s = Self._ST.socketpair() 82 | return Self(fd=s_s[0].get_fd()), Self(fd=s_s[1].get_fd()) 83 | 84 | fn get_fd(self) -> FileDescriptor: 85 | """Get the Socket's `FileDescriptor`.""" 86 | return self._sock.get_fd() 87 | 88 | async fn send_fds(self, fds: List[FileDescriptor]) -> Bool: 89 | """Send file descriptors to the socket.""" 90 | return await self._sock.send_fds(fds) 91 | 92 | async fn recv_fds(self, maxfds: Int) -> List[FileDescriptor]: 93 | """Receive file descriptors from the socket.""" 94 | return await self._sock.recv_fds(maxfds) 95 | 96 | async fn send(self, buf: Span[UInt8], flags: C.int = 0) -> Int: 97 | """Send a buffer of bytes to the socket.""" 98 | return await self._sock.send(buf, flags) 99 | 100 | async fn recv[ 101 | O: MutableOrigin 102 | ](self, buf: Span[UInt8, O], flags: C.int = 0) -> Int: 103 | return await self._sock.recv(buf, flags) 104 | 105 | @staticmethod 106 | fn gethostname() -> Optional[String]: 107 | """Return the current hostname.""" 108 | return Self._ST.gethostname() 109 | 110 | @staticmethod 111 | fn gethostbyname(name: String) -> Optional[sock_address]: 112 | """Map a hostname to its Address.""" 113 | return Self._ST.gethostbyname(name) 114 | 115 | @staticmethod 116 | fn gethostbyaddr(address: sock_address) -> Optional[String]: 117 | """Map an Address to DNS info.""" 118 | return Self._ST.gethostbyaddr(address) 119 | 120 | @staticmethod 121 | fn getservbyname(name: String) -> Optional[sock_address]: 122 | """Map a service name and a protocol name to a port number.""" 123 | return Self._ST.getservbyname(name) 124 | 125 | @staticmethod 126 | fn getdefaulttimeout() -> Optional[Float64]: 127 | """Get the default timeout value.""" 128 | return Self._ST.getdefaulttimeout() 129 | 130 | @staticmethod 131 | fn setdefaulttimeout(value: Optional[Float64]) -> Bool: 132 | """Set the default timeout value.""" 133 | return Self._ST.setdefaulttimeout(value) 134 | 135 | fn settimeout(self, value: Optional[Float64]) -> Bool: 136 | """Set the socket timeout value.""" 137 | return self._sock.settimeout(value) 138 | 139 | # TODO: should this return an iterator instead? 140 | @staticmethod 141 | fn getaddrinfo( 142 | address: sock_address, flags: Int = 0 143 | ) raises -> List[ 144 | (SockFamily, SockType, SockProtocol, String, sock_address) 145 | ]: 146 | """Get the available address information. 147 | 148 | Notes: 149 | [Reference](\ 150 | https://man7.org/linux/man-pages/man3/freeaddrinfo.3p.html). 151 | """ 152 | return Self._ST.getaddrinfo(address, flags) 153 | 154 | @staticmethod 155 | fn create_connection( 156 | address: IPv4Addr, 157 | timeout: Optional[Float64] = None, 158 | source_address: IPv4Addr = IPv4Addr(), 159 | *, 160 | all_errors: Bool = False, 161 | ) raises -> Self: 162 | """Connects to an address, with an optional timeout and optional source 163 | address.""" 164 | return Self._ST.create_connection( 165 | address, timeout, source_address, all_errors=all_errors 166 | ) 167 | 168 | @staticmethod 169 | fn create_connection( 170 | address: IPv6Addr, 171 | timeout: Optional[Float64] = None, 172 | source_address: IPv6Addr = IPv6Addr(), 173 | *, 174 | all_errors: Bool = False, 175 | ) raises -> Self: 176 | """Connects to an address, with an optional timeout and optional source 177 | address.""" 178 | return Self._ST.create_connection( 179 | address, timeout, source_address, all_errors=all_errors 180 | ) 181 | 182 | @staticmethod 183 | fn create_server( 184 | address: IPv4Addr, 185 | *, 186 | backlog: Optional[Int] = None, 187 | reuse_port: Bool = False, 188 | ) raises -> Self: 189 | """Create a socket, bind it to a specified address, and listen.""" 190 | return Self._ST.create_server( 191 | address, backlog=backlog, reuse_port=reuse_port 192 | ) 193 | 194 | @staticmethod 195 | fn create_server( 196 | address: IPv6Addr, 197 | *, 198 | backlog: Optional[Int] = None, 199 | reuse_port: Bool = False, 200 | ) raises -> Self: 201 | """Create a socket, bind it to a specified address, and listen. Default 202 | no dual stack IPv6.""" 203 | return Self._ST.create_server( 204 | address, backlog=backlog, reuse_port=reuse_port 205 | ) 206 | 207 | @staticmethod 208 | fn create_server( 209 | address: IPv6Addr, 210 | *, 211 | dualstack_ipv6: Bool, 212 | backlog: Optional[Int] = None, 213 | reuse_port: Bool = False, 214 | ) raises -> (Self, Self._ipv4): 215 | """Create a socket, bind it to a specified address, and listen.""" 216 | s_s = Self._ST.create_server( 217 | address, 218 | dualstack_ipv6=dualstack_ipv6, 219 | backlog=backlog, 220 | reuse_port=reuse_port, 221 | ) 222 | return Self(fd=s_s[0].get_fd()), Self._ipv4(fd=s_s[1].get_fd()) 223 | 224 | fn keep_alive( 225 | self, 226 | enable: Bool = True, 227 | idle: C.int = 2 * 60 * 60, 228 | interval: C.int = 75, 229 | count: C.int = 10, 230 | ) raises: 231 | """Whether and how to keep the connection alive.""" 232 | return self._sock.keep_alive(enable, idle, interval, count) 233 | 234 | fn reuse_address( 235 | self, value: Bool = True, *, full_duplicates: Bool = True 236 | ) raises: 237 | """Whether to allow duplicated addresses.""" 238 | return self._sock.reuse_address(value, full_duplicates=full_duplicates) 239 | 240 | fn no_delay(self, value: Bool = True) raises: 241 | """Whether to send packets ASAP without accumulating more.""" 242 | return self._sock.no_delay(value) 243 | -------------------------------------------------------------------------------- /src/forge_tools/socket/_apple.mojo: -------------------------------------------------------------------------------- 1 | from collections import Optional 2 | from memory import UnsafePointer, ArcPointer 3 | from sys.intrinsics import _type_is_eq 4 | from memory import Span 5 | from forge_tools.ffi.c.types import C 6 | from .socket import ( 7 | # SocketInterface, 8 | SockType, 9 | SockProtocol, 10 | ) 11 | from .address import SockFamily, SockAddr, IPv4Addr, IPv6Addr 12 | from ._unix import _UnixSocket 13 | 14 | 15 | struct _AppleSocket[ 16 | sock_family: SockFamily, 17 | sock_type: SockType, 18 | sock_protocol: SockProtocol, 19 | sock_address: SockAddr, 20 | ](Copyable, Movable): 21 | alias _ST = _UnixSocket[sock_family, sock_type, sock_protocol, sock_address] 22 | var _sock: Self._ST 23 | 24 | alias _ipv4 = _AppleSocket[ 25 | SockFamily.AF_INET, sock_type, sock_protocol, IPv4Addr 26 | ] 27 | 28 | fn __init__(out self) raises: 29 | """Create a new socket object.""" 30 | self._sock = Self._ST() 31 | 32 | fn __init__(out self, fd: ArcPointer[FileDescriptor]): 33 | """Create a new socket object from an open `ArcPointer[FileDescriptor]`. 34 | """ 35 | self._sock = Self._ST(fd=fd) 36 | 37 | fn close(var self) raises: 38 | """Closes the Socket.""" 39 | self._sock.close() 40 | 41 | fn __del__(deinit self): 42 | """Closes the Socket if it's the last reference to its 43 | `FileDescriptor`. 44 | """ 45 | ... 46 | 47 | fn setsockopt[ 48 | D: DType = C.int.element_type 49 | ](self, level: C.int, option_name: C.int, option_value: Scalar[D]) raises: 50 | """Set socket options.""" 51 | self._sock.setsockopt(level, option_name, option_value) 52 | 53 | fn bind(self, address: sock_address) raises: 54 | """Bind the socket to address. The socket must not already be bound.""" 55 | self._sock.bind(address) 56 | 57 | fn listen(self, backlog: UInt = 0) raises: 58 | """Enable a server to accept connections. `backlog` specifies the number 59 | of unaccepted connections that the system will allow before refusing 60 | new connections. If `backlog == 0`, a default value is chosen. 61 | """ 62 | self._sock.listen(backlog) 63 | 64 | async fn connect(self, address: sock_address) raises: 65 | """Connect to a remote socket at address.""" 66 | await self._sock.connect(address) 67 | 68 | async fn accept(self) -> Optional[(Self, sock_address)]: 69 | """Return a new socket representing the connection, and the address of 70 | the client.""" 71 | res = await self._sock.accept() 72 | if not res: 73 | return None 74 | s_a = res.value() 75 | return Self(fd=s_a[0].get_fd()), s_a[1] 76 | 77 | @staticmethod 78 | fn socketpair() raises -> (Self, Self): 79 | """Create a pair of socket objects from the sockets returned by the 80 | platform `socketpair()` function.""" 81 | s_s = Self._ST.socketpair() 82 | return Self(fd=s_s[0].get_fd()), Self(fd=s_s[1].get_fd()) 83 | 84 | fn get_fd(self) -> FileDescriptor: 85 | """Get the Socket's `FileDescriptor`.""" 86 | return self._sock.get_fd() 87 | 88 | async fn send_fds(self, fds: List[FileDescriptor]) -> Bool: 89 | """Send file descriptors to the socket.""" 90 | return await self._sock.send_fds(fds) 91 | 92 | async fn recv_fds(self, maxfds: Int) -> List[FileDescriptor]: 93 | """Receive file descriptors from the socket.""" 94 | return await self._sock.recv_fds(maxfds) 95 | 96 | async fn send(self, buf: Span[UInt8], flags: C.int = 0) -> Int: 97 | """Send a buffer of bytes to the socket.""" 98 | return await self._sock.send(buf, flags) 99 | 100 | async fn recv[ 101 | O: MutableOrigin 102 | ](self, buf: Span[UInt8, O], flags: C.int = 0) -> Int: 103 | return await self._sock.recv(buf, flags) 104 | 105 | @staticmethod 106 | fn gethostname() -> Optional[String]: 107 | """Return the current hostname.""" 108 | return Self._ST.gethostname() 109 | 110 | @staticmethod 111 | fn gethostbyname(name: String) -> Optional[sock_address]: 112 | """Map a hostname to its Address.""" 113 | return Self._ST.gethostbyname(name) 114 | 115 | @staticmethod 116 | fn gethostbyaddr(address: sock_address) -> Optional[String]: 117 | """Map an Address to DNS info.""" 118 | return Self._ST.gethostbyaddr(address) 119 | 120 | @staticmethod 121 | fn getservbyname(name: String) -> Optional[sock_address]: 122 | """Map a service name and a protocol name to a port number.""" 123 | return Self._ST.getservbyname(name) 124 | 125 | @staticmethod 126 | fn getdefaulttimeout() -> Optional[Float64]: 127 | """Get the default timeout value.""" 128 | return Self._ST.getdefaulttimeout() 129 | 130 | @staticmethod 131 | fn setdefaulttimeout(value: Optional[Float64]) -> Bool: 132 | """Set the default timeout value.""" 133 | return Self._ST.setdefaulttimeout(value) 134 | 135 | fn settimeout(self, value: Optional[Float64]) -> Bool: 136 | """Set the socket timeout value.""" 137 | return self._sock.settimeout(value) 138 | 139 | # TODO: should this return an iterator instead? 140 | @staticmethod 141 | fn getaddrinfo( 142 | address: sock_address, flags: Int = 0 143 | ) raises -> List[ 144 | (SockFamily, SockType, SockProtocol, String, sock_address) 145 | ]: 146 | """Get the available address information. 147 | 148 | Notes: 149 | [Reference](\ 150 | https://man7.org/linux/man-pages/man3/freeaddrinfo.3p.html). 151 | """ 152 | return Self._ST.getaddrinfo(address, flags) 153 | 154 | @staticmethod 155 | fn create_connection( 156 | address: IPv4Addr, 157 | timeout: Optional[Float64] = None, 158 | source_address: IPv4Addr = IPv4Addr(), 159 | *, 160 | all_errors: Bool = False, 161 | ) raises -> Self: 162 | """Connects to an address, with an optional timeout and optional source 163 | address.""" 164 | return Self._ST.create_connection( 165 | address, timeout, source_address, all_errors=all_errors 166 | ) 167 | 168 | @staticmethod 169 | fn create_connection( 170 | address: IPv6Addr, 171 | timeout: Optional[Float64] = None, 172 | source_address: IPv6Addr = IPv6Addr(), 173 | *, 174 | all_errors: Bool = False, 175 | ) raises -> Self: 176 | """Connects to an address, with an optional timeout and optional source 177 | address.""" 178 | return Self._ST.create_connection( 179 | address, timeout, source_address, all_errors=all_errors 180 | ) 181 | 182 | @staticmethod 183 | fn create_server( 184 | address: IPv4Addr, 185 | *, 186 | backlog: Optional[Int] = None, 187 | reuse_port: Bool = False, 188 | ) raises -> Self: 189 | """Create a socket, bind it to a specified address, and listen.""" 190 | return Self._ST.create_server( 191 | address, backlog=backlog, reuse_port=reuse_port 192 | ) 193 | 194 | @staticmethod 195 | fn create_server( 196 | address: IPv6Addr, 197 | *, 198 | backlog: Optional[Int] = None, 199 | reuse_port: Bool = False, 200 | ) raises -> Self: 201 | """Create a socket, bind it to a specified address, and listen. Default 202 | no dual stack IPv6.""" 203 | return Self._ST.create_server( 204 | address, backlog=backlog, reuse_port=reuse_port 205 | ) 206 | 207 | @staticmethod 208 | fn create_server( 209 | address: IPv6Addr, 210 | *, 211 | dualstack_ipv6: Bool, 212 | backlog: Optional[Int] = None, 213 | reuse_port: Bool = False, 214 | ) raises -> (Self, Self._ipv4): 215 | """Create a socket, bind it to a specified address, and listen.""" 216 | s_s = Self._ST.create_server( 217 | address, 218 | dualstack_ipv6=dualstack_ipv6, 219 | backlog=backlog, 220 | reuse_port=reuse_port, 221 | ) 222 | return Self(fd=s_s[0].get_fd()), Self._ipv4(fd=s_s[1].get_fd()) 223 | 224 | fn keep_alive( 225 | self, 226 | enable: Bool = True, 227 | idle: C.int = 2 * 60 * 60, 228 | interval: C.int = 75, 229 | count: C.int = 10, 230 | ) raises: 231 | """Whether and how to keep the connection alive.""" 232 | return self._sock.keep_alive(enable, idle, interval, count) 233 | 234 | fn reuse_address( 235 | self, value: Bool = True, *, full_duplicates: Bool = True 236 | ) raises: 237 | """Whether to allow duplicated addresses.""" 238 | return self._sock.reuse_address(value, full_duplicates=full_duplicates) 239 | 240 | fn no_delay(self, value: Bool = True) raises: 241 | """Whether to send packets ASAP without accumulating more.""" 242 | return self._sock.no_delay(value) 243 | -------------------------------------------------------------------------------- /src/forge_tools/socket/_linux.mojo: -------------------------------------------------------------------------------- 1 | from collections import Optional 2 | from memory import UnsafePointer, ArcPointer 3 | from sys.intrinsics import _type_is_eq 4 | from memory import Span 5 | from forge_tools.ffi.c.types import C 6 | from .socket import ( 7 | # SocketInterface, 8 | SockType, 9 | SockProtocol, 10 | ) 11 | from .address import SockFamily, SockAddr, IPv4Addr, IPv6Addr 12 | from ._unix import _UnixSocket 13 | 14 | 15 | struct _LinuxSocket[ 16 | sock_family: SockFamily, 17 | sock_type: SockType, 18 | sock_protocol: SockProtocol, 19 | sock_address: SockAddr, 20 | ](Copyable, Movable): 21 | alias _ST = _UnixSocket[sock_family, sock_type, sock_protocol, sock_address] 22 | var _sock: Self._ST 23 | 24 | alias _ipv4 = _LinuxSocket[ 25 | SockFamily.AF_INET, sock_type, sock_protocol, IPv4Addr 26 | ] 27 | 28 | fn __init__(out self) raises: 29 | """Create a new socket object.""" 30 | self._sock = Self._ST() 31 | 32 | fn __init__(out self, fd: ArcPointer[FileDescriptor]): 33 | """Create a new socket object from an open `ArcPointer[FileDescriptor]`. 34 | """ 35 | self._sock = Self._ST(fd=fd) 36 | 37 | fn close(var self) raises: 38 | """Closes the Socket.""" 39 | self._sock.close() 40 | 41 | fn __del__(deinit self): 42 | """Closes the Socket if it's the last reference to its 43 | `FileDescriptor`. 44 | """ 45 | ... 46 | 47 | fn setsockopt[ 48 | D: DType = C.int.element_type 49 | ](self, level: C.int, option_name: C.int, option_value: Scalar[D]) raises: 50 | """Set socket options.""" 51 | self._sock.setsockopt(level, option_name, option_value) 52 | 53 | fn bind(self, address: sock_address) raises: 54 | """Bind the socket to address. The socket must not already be bound.""" 55 | self._sock.bind(address) 56 | 57 | fn listen(self, backlog: UInt = 0) raises: 58 | """Enable a server to accept connections. `backlog` specifies the number 59 | of unaccepted connections that the system will allow before refusing 60 | new connections. If `backlog == 0`, a default value is chosen. 61 | """ 62 | self._sock.listen(backlog) 63 | 64 | async fn connect(self, address: sock_address) raises: 65 | """Connect to a remote socket at address.""" 66 | await self._sock.connect(address) 67 | 68 | async fn accept(self) -> Optional[(Self, sock_address)]: 69 | """Return a new socket representing the connection, and the address of 70 | the client.""" 71 | res = await self._sock.accept() 72 | if not res: 73 | return None 74 | s_a = res.value() 75 | return Self(fd=s_a[0].get_fd()), s_a[1] 76 | 77 | @staticmethod 78 | fn socketpair() raises -> (Self, Self): 79 | """Create a pair of socket objects from the sockets returned by the 80 | platform `socketpair()` function.""" 81 | s_s = Self._ST.socketpair() 82 | return Self(fd=s_s[0].get_fd()), Self(fd=s_s[1].get_fd()) 83 | 84 | fn get_fd(self) -> FileDescriptor: 85 | """Get the Socket's `FileDescriptor`.""" 86 | return self._sock.get_fd() 87 | 88 | async fn send_fds(self, fds: List[FileDescriptor]) -> Bool: 89 | """Send file descriptors to the socket.""" 90 | return await self._sock.send_fds(fds) 91 | 92 | async fn recv_fds(self, maxfds: Int) -> List[FileDescriptor]: 93 | """Receive file descriptors from the socket.""" 94 | return await self._sock.recv_fds(maxfds) 95 | 96 | async fn send(self, buf: Span[UInt8], flags: C.int = 0) -> Int: 97 | """Send a buffer of bytes to the socket.""" 98 | return await self._sock.send(buf, flags) 99 | 100 | async fn recv[ 101 | O: MutableOrigin 102 | ](self, buf: Span[UInt8, O], flags: C.int = 0) -> Int: 103 | return await self._sock.recv(buf, flags) 104 | 105 | @staticmethod 106 | fn gethostname() -> Optional[String]: 107 | """Return the current hostname.""" 108 | return Self._ST.gethostname() 109 | 110 | @staticmethod 111 | fn gethostbyname(name: String) -> Optional[sock_address]: 112 | """Map a hostname to its Address.""" 113 | return Self._ST.gethostbyname(name) 114 | 115 | @staticmethod 116 | fn gethostbyaddr(address: sock_address) -> Optional[String]: 117 | """Map an Address to DNS info.""" 118 | return Self._ST.gethostbyaddr(address) 119 | 120 | @staticmethod 121 | fn getservbyname(name: String) -> Optional[sock_address]: 122 | """Map a service name and a protocol name to a port number.""" 123 | return Self._ST.getservbyname(name) 124 | 125 | @staticmethod 126 | fn getdefaulttimeout() -> Optional[Float64]: 127 | """Get the default timeout value.""" 128 | return Self._ST.getdefaulttimeout() 129 | 130 | @staticmethod 131 | fn setdefaulttimeout(value: Optional[Float64]) -> Bool: 132 | """Set the default timeout value.""" 133 | return Self._ST.setdefaulttimeout(value) 134 | 135 | fn settimeout(self, value: Optional[Float64]) -> Bool: 136 | """Set the socket timeout value.""" 137 | return self._sock.settimeout(value) 138 | 139 | # TODO: should this return an iterator instead? 140 | @staticmethod 141 | fn getaddrinfo( 142 | address: sock_address, flags: Int = 0 143 | ) raises -> List[ 144 | (SockFamily, SockType, SockProtocol, String, sock_address) 145 | ]: 146 | """Get the available address information. 147 | 148 | Notes: 149 | [Reference](\ 150 | https://man7.org/linux/man-pages/man3/freeaddrinfo.3p.html). 151 | """ 152 | return Self._ST.getaddrinfo(address, flags) 153 | 154 | @staticmethod 155 | fn create_connection( 156 | address: IPv4Addr, 157 | timeout: Optional[Float64] = None, 158 | source_address: IPv4Addr = IPv4Addr(), 159 | *, 160 | all_errors: Bool = False, 161 | ) raises -> Self: 162 | """Connects to an address, with an optional timeout and optional source 163 | address.""" 164 | return Self._ST.create_connection( 165 | address, timeout, source_address, all_errors=all_errors 166 | ) 167 | 168 | @staticmethod 169 | fn create_connection( 170 | address: IPv6Addr, 171 | timeout: Optional[Float64] = None, 172 | source_address: IPv6Addr = IPv6Addr(), 173 | *, 174 | all_errors: Bool = False, 175 | ) raises -> Self: 176 | """Connects to an address, with an optional timeout and optional source 177 | address.""" 178 | return Self._ST.create_connection( 179 | address, timeout, source_address, all_errors=all_errors 180 | ) 181 | 182 | @staticmethod 183 | fn create_server( 184 | address: IPv4Addr, 185 | *, 186 | backlog: Optional[Int] = None, 187 | reuse_port: Bool = False, 188 | ) raises -> Self: 189 | """Create a socket, bind it to a specified address, and listen.""" 190 | return Self._ST.create_server( 191 | address, backlog=backlog, reuse_port=reuse_port 192 | ) 193 | 194 | @staticmethod 195 | fn create_server( 196 | address: IPv6Addr, 197 | *, 198 | backlog: Optional[Int] = None, 199 | reuse_port: Bool = False, 200 | ) raises -> Self: 201 | """Create a socket, bind it to a specified address, and listen. Default 202 | no dual stack IPv6.""" 203 | return Self._ST.create_server( 204 | address, backlog=backlog, reuse_port=reuse_port 205 | ) 206 | 207 | @staticmethod 208 | fn create_server( 209 | address: IPv6Addr, 210 | *, 211 | dualstack_ipv6: Bool, 212 | backlog: Optional[Int] = None, 213 | reuse_port: Bool = False, 214 | ) raises -> (Self, Self._ipv4): 215 | """Create a socket, bind it to a specified address, and listen.""" 216 | s_s = Self._ST.create_server( 217 | address, 218 | dualstack_ipv6=dualstack_ipv6, 219 | backlog=backlog, 220 | reuse_port=reuse_port, 221 | ) 222 | return Self(fd=s_s[0].get_fd()), Self._ipv4(fd=s_s[1].get_fd()) 223 | 224 | fn keep_alive( 225 | self, 226 | enable: Bool = True, 227 | idle: C.int = 2 * 60 * 60, 228 | interval: C.int = 75, 229 | count: C.int = 10, 230 | ) raises: 231 | """Whether and how to keep the connection alive.""" 232 | return self._sock.keep_alive(enable, idle, interval, count) 233 | 234 | fn reuse_address( 235 | self, value: Bool = True, *, full_duplicates: Bool = True 236 | ) raises: 237 | """Whether to allow duplicated addresses.""" 238 | return self._sock.reuse_address(value, full_duplicates=full_duplicates) 239 | 240 | fn no_delay(self, value: Bool = True) raises: 241 | """Whether to send packets ASAP without accumulating more.""" 242 | return self._sock.no_delay(value) 243 | -------------------------------------------------------------------------------- /src/forge_tools/socket/_wasi.mojo: -------------------------------------------------------------------------------- 1 | from collections import Optional 2 | from memory import UnsafePointer, ArcPointer 3 | from sys.intrinsics import _type_is_eq 4 | from memory import Span 5 | from forge_tools.ffi.c.types import C 6 | from .socket import ( 7 | # SocketInterface, 8 | SockType, 9 | SockProtocol, 10 | ) 11 | from .address import SockFamily, SockAddr, IPv4Addr, IPv6Addr 12 | from ._unix import _UnixSocket 13 | 14 | 15 | struct _WASISocket[ 16 | sock_family: SockFamily, 17 | sock_type: SockType, 18 | sock_protocol: SockProtocol, 19 | sock_address: SockAddr, 20 | ](Copyable, Movable): 21 | alias _ST = _UnixSocket[sock_family, sock_type, sock_protocol, sock_address] 22 | var _sock: Self._ST 23 | 24 | alias _ipv4 = _WASISocket[ 25 | SockFamily.AF_INET, sock_type, sock_protocol, IPv4Addr 26 | ] 27 | 28 | fn __init__(out self) raises: 29 | """Create a new socket object.""" 30 | self._sock = Self._ST() 31 | 32 | fn __init__(out self, fd: ArcPointer[FileDescriptor]): 33 | """Create a new socket object from an open `ArcPointer[FileDescriptor]`. 34 | """ 35 | self._sock = Self._ST(fd=fd) 36 | 37 | fn close(var self) raises: 38 | """Closes the Socket.""" 39 | self._sock.close() 40 | 41 | fn __del__(deinit self): 42 | """Closes the Socket if it's the last reference to its 43 | `FileDescriptor`. 44 | """ 45 | ... 46 | 47 | fn setsockopt[ 48 | D: DType = C.int.element_type 49 | ](self, level: C.int, option_name: C.int, option_value: Scalar[D]) raises: 50 | """Set socket options.""" 51 | self._sock.setsockopt(level, option_name, option_value) 52 | 53 | fn bind(self, address: sock_address) raises: 54 | """Bind the socket to address. The socket must not already be bound.""" 55 | self._sock.bind(address) 56 | 57 | fn listen(self, backlog: UInt = 0) raises: 58 | """Enable a server to accept connections. `backlog` specifies the number 59 | of unaccepted connections that the system will allow before refusing 60 | new connections. If `backlog == 0`, a default value is chosen. 61 | """ 62 | self._sock.listen(backlog) 63 | 64 | async fn connect(self, address: sock_address) raises: 65 | """Connect to a remote socket at address.""" 66 | await self._sock.connect(address) 67 | 68 | async fn accept(self) -> Optional[(Self, sock_address)]: 69 | """Return a new socket representing the connection, and the address of 70 | the client.""" 71 | res = await self._sock.accept() 72 | if not res: 73 | return None 74 | s_a = res.value() 75 | return Self(fd=s_a[0].get_fd()), s_a[1] 76 | 77 | @staticmethod 78 | fn socketpair() raises -> (Self, Self): 79 | """Create a pair of socket objects from the sockets returned by the 80 | platform `socketpair()` function.""" 81 | s_s = Self._ST.socketpair() 82 | return Self(fd=s_s[0].get_fd()), Self(fd=s_s[1].get_fd()) 83 | 84 | fn get_fd(self) -> FileDescriptor: 85 | """Get the Socket's `FileDescriptor`.""" 86 | return self._sock.get_fd() 87 | 88 | async fn send_fds(self, fds: List[FileDescriptor]) -> Bool: 89 | """Send file descriptors to the socket.""" 90 | return await self._sock.send_fds(fds) 91 | 92 | async fn recv_fds(self, maxfds: Int) -> List[FileDescriptor]: 93 | """Receive file descriptors from the socket.""" 94 | return await self._sock.recv_fds(maxfds) 95 | 96 | async fn send(self, buf: Span[UInt8], flags: C.int = 0) -> Int: 97 | """Send a buffer of bytes to the socket.""" 98 | return await self._sock.send(buf, flags) 99 | 100 | async fn recv[ 101 | O: MutableOrigin 102 | ](self, buf: Span[UInt8, O], flags: C.int = 0) -> Int: 103 | return await self._sock.recv(buf, flags) 104 | 105 | @staticmethod 106 | fn gethostname() -> Optional[String]: 107 | """Return the current hostname.""" 108 | return Self._ST.gethostname() 109 | 110 | @staticmethod 111 | fn gethostbyname(name: String) -> Optional[sock_address]: 112 | """Map a hostname to its Address.""" 113 | return Self._ST.gethostbyname(name) 114 | 115 | @staticmethod 116 | fn gethostbyaddr(address: sock_address) -> Optional[String]: 117 | """Map an Address to DNS info.""" 118 | return Self._ST.gethostbyaddr(address) 119 | 120 | @staticmethod 121 | fn getservbyname(name: String) -> Optional[sock_address]: 122 | """Map a service name and a protocol name to a port number.""" 123 | return Self._ST.getservbyname(name) 124 | 125 | @staticmethod 126 | fn getdefaulttimeout() -> Optional[Float64]: 127 | """Get the default timeout value.""" 128 | return Self._ST.getdefaulttimeout() 129 | 130 | @staticmethod 131 | fn setdefaulttimeout(value: Optional[Float64]) -> Bool: 132 | """Set the default timeout value.""" 133 | return Self._ST.setdefaulttimeout(value) 134 | 135 | fn settimeout(self, value: Optional[Float64]) -> Bool: 136 | """Set the socket timeout value.""" 137 | return self._sock.settimeout(value) 138 | 139 | # TODO: should this return an iterator instead? 140 | @staticmethod 141 | fn getaddrinfo( 142 | address: sock_address, flags: Int = 0 143 | ) raises -> List[ 144 | (SockFamily, SockType, SockProtocol, String, sock_address) 145 | ]: 146 | """Get the available address information. 147 | 148 | Notes: 149 | [Reference](\ 150 | https://man7.org/linux/man-pages/man3/freeaddrinfo.3p.html). 151 | """ 152 | return Self._ST.getaddrinfo(address, flags) 153 | 154 | @staticmethod 155 | fn create_connection( 156 | address: IPv4Addr, 157 | timeout: Optional[Float64] = None, 158 | source_address: IPv4Addr = IPv4Addr(), 159 | *, 160 | all_errors: Bool = False, 161 | ) raises -> Self: 162 | """Connects to an address, with an optional timeout and optional source 163 | address.""" 164 | return Self._ST.create_connection( 165 | address, timeout, source_address, all_errors=all_errors 166 | ) 167 | 168 | @staticmethod 169 | fn create_connection( 170 | address: IPv6Addr, 171 | timeout: Optional[Float64] = None, 172 | source_address: IPv6Addr = IPv6Addr(), 173 | *, 174 | all_errors: Bool = False, 175 | ) raises -> Self: 176 | """Connects to an address, with an optional timeout and optional source 177 | address.""" 178 | return Self._ST.create_connection( 179 | address, timeout, source_address, all_errors=all_errors 180 | ) 181 | 182 | @staticmethod 183 | fn create_server( 184 | address: IPv4Addr, 185 | *, 186 | backlog: Optional[Int] = None, 187 | reuse_port: Bool = False, 188 | ) raises -> Self: 189 | """Create a socket, bind it to a specified address, and listen.""" 190 | return Self._ST.create_server( 191 | address, backlog=backlog, reuse_port=reuse_port 192 | ) 193 | 194 | @staticmethod 195 | fn create_server( 196 | address: IPv6Addr, 197 | *, 198 | backlog: Optional[Int] = None, 199 | reuse_port: Bool = False, 200 | ) raises -> Self: 201 | """Create a socket, bind it to a specified address, and listen. Default 202 | no dual stack IPv6.""" 203 | return Self._ST.create_server( 204 | address, backlog=backlog, reuse_port=reuse_port 205 | ) 206 | 207 | @staticmethod 208 | fn create_server( 209 | address: IPv6Addr, 210 | *, 211 | dualstack_ipv6: Bool, 212 | backlog: Optional[Int] = None, 213 | reuse_port: Bool = False, 214 | ) raises -> (Self, Self._ipv4): 215 | """Create a socket, bind it to a specified address, and listen.""" 216 | s_s = Self._ST.create_server( 217 | address, 218 | dualstack_ipv6=dualstack_ipv6, 219 | backlog=backlog, 220 | reuse_port=reuse_port, 221 | ) 222 | return Self(fd=s_s[0].get_fd()), Self._ipv4(fd=s_s[1].get_fd()) 223 | 224 | fn keep_alive( 225 | self, 226 | enable: Bool = True, 227 | idle: C.int = 2 * 60 * 60, 228 | interval: C.int = 75, 229 | count: C.int = 10, 230 | ) raises: 231 | """Whether and how to keep the connection alive.""" 232 | return self._sock.keep_alive(enable, idle, interval, count) 233 | 234 | fn reuse_address( 235 | self, value: Bool = True, *, full_duplicates: Bool = True 236 | ) raises: 237 | """Whether to allow duplicated addresses.""" 238 | return self._sock.reuse_address(value, full_duplicates=full_duplicates) 239 | 240 | fn no_delay(self, value: Bool = True) raises: 241 | """Whether to send packets ASAP without accumulating more.""" 242 | return self._sock.no_delay(value) 243 | -------------------------------------------------------------------------------- /src/forge_tools/socket/_freertos.mojo: -------------------------------------------------------------------------------- 1 | from collections import Optional 2 | from memory import UnsafePointer, ArcPointer 3 | from sys.intrinsics import _type_is_eq 4 | from memory import Span 5 | from forge_tools.ffi.c.types import C 6 | from .socket import ( 7 | # SocketInterface, 8 | SockType, 9 | SockProtocol, 10 | ) 11 | from .address import SockFamily, SockAddr, IPv4Addr, IPv6Addr 12 | from ._unix import _UnixSocket 13 | 14 | 15 | struct _FreeRTOSSocket[ 16 | sock_family: SockFamily, 17 | sock_type: SockType, 18 | sock_protocol: SockProtocol, 19 | sock_address: SockAddr, 20 | ](Copyable, Movable): 21 | alias _ST = _UnixSocket[sock_family, sock_type, sock_protocol, sock_address] 22 | var _sock: Self._ST 23 | 24 | alias _ipv4 = _FreeRTOSSocket[ 25 | SockFamily.AF_INET, sock_type, sock_protocol, IPv4Addr 26 | ] 27 | 28 | fn __init__(out self) raises: 29 | """Create a new socket object.""" 30 | self._sock = Self._ST() 31 | 32 | fn __init__(out self, fd: ArcPointer[FileDescriptor]): 33 | """Create a new socket object from an open `ArcPointer[FileDescriptor]`. 34 | """ 35 | self._sock = Self._ST(fd=fd) 36 | 37 | fn close(var self) raises: 38 | """Closes the Socket.""" 39 | self._sock.close() 40 | 41 | fn __del__(deinit self): 42 | """Closes the Socket if it's the last reference to its 43 | `FileDescriptor`. 44 | """ 45 | ... 46 | 47 | fn setsockopt[ 48 | D: DType = C.int.element_type 49 | ](self, level: C.int, option_name: C.int, option_value: Scalar[D]) raises: 50 | """Set socket options.""" 51 | self._sock.setsockopt(level, option_name, option_value) 52 | 53 | fn bind(self, address: sock_address) raises: 54 | """Bind the socket to address. The socket must not already be bound.""" 55 | self._sock.bind(address) 56 | 57 | fn listen(self, backlog: UInt = 0) raises: 58 | """Enable a server to accept connections. `backlog` specifies the number 59 | of unaccepted connections that the system will allow before refusing 60 | new connections. If `backlog == 0`, a default value is chosen. 61 | """ 62 | self._sock.listen(backlog) 63 | 64 | async fn connect(self, address: sock_address) raises: 65 | """Connect to a remote socket at address.""" 66 | await self._sock.connect(address) 67 | 68 | async fn accept(self) -> Optional[(Self, sock_address)]: 69 | """Return a new socket representing the connection, and the address of 70 | the client.""" 71 | res = await self._sock.accept() 72 | if not res: 73 | return None 74 | s_a = res.value() 75 | return Self(fd=s_a[0].get_fd()), s_a[1] 76 | 77 | @staticmethod 78 | fn socketpair() raises -> (Self, Self): 79 | """Create a pair of socket objects from the sockets returned by the 80 | platform `socketpair()` function.""" 81 | s_s = Self._ST.socketpair() 82 | return Self(fd=s_s[0].get_fd()), Self(fd=s_s[1].get_fd()) 83 | 84 | fn get_fd(self) -> FileDescriptor: 85 | """Get the Socket's `FileDescriptor`.""" 86 | return self._sock.get_fd() 87 | 88 | async fn send_fds(self, fds: List[FileDescriptor]) -> Bool: 89 | """Send file descriptors to the socket.""" 90 | return await self._sock.send_fds(fds) 91 | 92 | async fn recv_fds(self, maxfds: Int) -> List[FileDescriptor]: 93 | """Receive file descriptors from the socket.""" 94 | return await self._sock.recv_fds(maxfds) 95 | 96 | async fn send(self, buf: Span[UInt8], flags: C.int = 0) -> Int: 97 | """Send a buffer of bytes to the socket.""" 98 | return await self._sock.send(buf, flags) 99 | 100 | async fn recv[ 101 | O: MutableOrigin 102 | ](self, buf: Span[UInt8, O], flags: C.int = 0) -> Int: 103 | return await self._sock.recv(buf, flags) 104 | 105 | @staticmethod 106 | fn gethostname() -> Optional[String]: 107 | """Return the current hostname.""" 108 | return Self._ST.gethostname() 109 | 110 | @staticmethod 111 | fn gethostbyname(name: String) -> Optional[sock_address]: 112 | """Map a hostname to its Address.""" 113 | return Self._ST.gethostbyname(name) 114 | 115 | @staticmethod 116 | fn gethostbyaddr(address: sock_address) -> Optional[String]: 117 | """Map an Address to DNS info.""" 118 | return Self._ST.gethostbyaddr(address) 119 | 120 | @staticmethod 121 | fn getservbyname(name: String) -> Optional[sock_address]: 122 | """Map a service name and a protocol name to a port number.""" 123 | return Self._ST.getservbyname(name) 124 | 125 | @staticmethod 126 | fn getdefaulttimeout() -> Optional[Float64]: 127 | """Get the default timeout value.""" 128 | return Self._ST.getdefaulttimeout() 129 | 130 | @staticmethod 131 | fn setdefaulttimeout(value: Optional[Float64]) -> Bool: 132 | """Set the default timeout value.""" 133 | return Self._ST.setdefaulttimeout(value) 134 | 135 | fn settimeout(self, value: Optional[Float64]) -> Bool: 136 | """Set the socket timeout value.""" 137 | return self._sock.settimeout(value) 138 | 139 | # TODO: should this return an iterator instead? 140 | @staticmethod 141 | fn getaddrinfo( 142 | address: sock_address, flags: Int = 0 143 | ) raises -> List[ 144 | (SockFamily, SockType, SockProtocol, String, sock_address) 145 | ]: 146 | """Get the available address information. 147 | 148 | Notes: 149 | [Reference](\ 150 | https://man7.org/linux/man-pages/man3/freeaddrinfo.3p.html). 151 | """ 152 | return Self._ST.getaddrinfo(address, flags) 153 | 154 | @staticmethod 155 | fn create_connection( 156 | address: IPv4Addr, 157 | timeout: Optional[Float64] = None, 158 | source_address: IPv4Addr = IPv4Addr(), 159 | *, 160 | all_errors: Bool = False, 161 | ) raises -> Self: 162 | """Connects to an address, with an optional timeout and optional source 163 | address.""" 164 | return Self._ST.create_connection( 165 | address, timeout, source_address, all_errors=all_errors 166 | ) 167 | 168 | @staticmethod 169 | fn create_connection( 170 | address: IPv6Addr, 171 | timeout: Optional[Float64] = None, 172 | source_address: IPv6Addr = IPv6Addr(), 173 | *, 174 | all_errors: Bool = False, 175 | ) raises -> Self: 176 | """Connects to an address, with an optional timeout and optional source 177 | address.""" 178 | return Self._ST.create_connection( 179 | address, timeout, source_address, all_errors=all_errors 180 | ) 181 | 182 | @staticmethod 183 | fn create_server( 184 | address: IPv4Addr, 185 | *, 186 | backlog: Optional[Int] = None, 187 | reuse_port: Bool = False, 188 | ) raises -> Self: 189 | """Create a socket, bind it to a specified address, and listen.""" 190 | return Self._ST.create_server( 191 | address, backlog=backlog, reuse_port=reuse_port 192 | ) 193 | 194 | @staticmethod 195 | fn create_server( 196 | address: IPv6Addr, 197 | *, 198 | backlog: Optional[Int] = None, 199 | reuse_port: Bool = False, 200 | ) raises -> Self: 201 | """Create a socket, bind it to a specified address, and listen. Default 202 | no dual stack IPv6.""" 203 | return Self._ST.create_server( 204 | address, backlog=backlog, reuse_port=reuse_port 205 | ) 206 | 207 | @staticmethod 208 | fn create_server( 209 | address: IPv6Addr, 210 | *, 211 | dualstack_ipv6: Bool, 212 | backlog: Optional[Int] = None, 213 | reuse_port: Bool = False, 214 | ) raises -> (Self, Self._ipv4): 215 | """Create a socket, bind it to a specified address, and listen.""" 216 | s_s = Self._ST.create_server( 217 | address, 218 | dualstack_ipv6=dualstack_ipv6, 219 | backlog=backlog, 220 | reuse_port=reuse_port, 221 | ) 222 | return Self(fd=s_s[0].get_fd()), Self._ipv4(fd=s_s[1].get_fd()) 223 | 224 | fn keep_alive( 225 | self, 226 | enable: Bool = True, 227 | idle: C.int = 2 * 60 * 60, 228 | interval: C.int = 75, 229 | count: C.int = 10, 230 | ) raises: 231 | """Whether and how to keep the connection alive.""" 232 | return self._sock.keep_alive(enable, idle, interval, count) 233 | 234 | fn reuse_address( 235 | self, value: Bool = True, *, full_duplicates: Bool = True 236 | ) raises: 237 | """Whether to allow duplicated addresses.""" 238 | return self._sock.reuse_address(value, full_duplicates=full_duplicates) 239 | 240 | fn no_delay(self, value: Bool = True) raises: 241 | """Whether to send packets ASAP without accumulating more.""" 242 | return self._sock.no_delay(value) 243 | -------------------------------------------------------------------------------- /src/test/ffi/c/test_logging.mojo: -------------------------------------------------------------------------------- 1 | # RUN: %mojo %s 2 | 3 | from testing import assert_equal, assert_false, assert_raises, assert_true 4 | 5 | from memory import UnsafePointer, stack_allocation, memcmp 6 | 7 | from forge_tools.ffi.c.libc import Libc, TryLibc 8 | from forge_tools.ffi.c.types import C, char_ptr, char_ptr_to_string 9 | from forge_tools.ffi.c.constants import * 10 | 11 | 12 | alias error_message = ( 13 | (SUCCESS, "Success"), 14 | (EPERM, "Operation not permitted"), 15 | (ENOENT, "No such file or directory"), 16 | (ESRCH, "No such process"), 17 | (EINTR, "Interrupted system call"), 18 | (EIO, "Input/output error"), 19 | (ENXIO, "No such device or address"), 20 | (E2BIG, "Argument list too long"), 21 | (ENOEXEC, "Exec format error"), 22 | (EBADF, "Bad file descriptor"), 23 | (ECHILD, "No child processes"), 24 | (EAGAIN, "Resource temporarily unavailable"), 25 | (EWOULDBLOCK, "Resource temporarily unavailable"), 26 | (ENOMEM, "Cannot allocate memory"), 27 | (EACCES, "Permission denied"), 28 | (EFAULT, "Bad address"), 29 | (ENOTBLK, "Block device required"), 30 | (EBUSY, "Device or resource busy"), 31 | (EEXIST, "File exists"), 32 | (EXDEV, "Invalid cross-device link"), 33 | (ENODEV, "No such device"), 34 | (ENOTDIR, "Not a directory"), 35 | (EISDIR, "Is a directory"), 36 | (EINVAL, "Invalid argument"), 37 | (ENFILE, "Too many open files in system"), 38 | (EMFILE, "Too many open files"), 39 | (ENOTTY, "Inappropriate ioctl for device"), 40 | (ETXTBSY, "Text file busy"), 41 | (EFBIG, "File too large"), 42 | (ENOSPC, "No space left on device"), 43 | (ESPIPE, "Illegal seek"), 44 | (EROFS, "Read-only file system"), 45 | (EMLINK, "Too many links"), 46 | (EPIPE, "Broken pipe"), 47 | (EDOM, "Numerical argument out of domain"), 48 | (ERANGE, "Numerical result out of range"), 49 | (EDEADLK, "Resource deadlock avoided"), 50 | (ENAMETOOLONG, "File name too long"), 51 | (ENOLCK, "No locks available"), 52 | (ENOSYS, "Function not implemented"), 53 | (ENOTEMPTY, "Directory not empty"), 54 | (ELOOP, "Too many levels of symbolic links"), 55 | (ENOMSG, "No message of desired type"), 56 | (EIDRM, "Identifier removed"), 57 | (ECHRNG, "Channel number out of range"), 58 | (EL2NSYNC, "Level 2 not synchronized"), 59 | (EL3HLT, "Level 3 halted"), 60 | (EL3RST, "Level 3 reset"), 61 | (ELNRNG, "Link number out of range"), 62 | (EUNATCH, "Protocol driver not attached"), 63 | (ENOCSI, "No CSI structure available"), 64 | (EL2HLT, "Level 2 halted"), 65 | (EBADE, "Invalid exchange"), 66 | (EBADR, "Invalid request descriptor"), 67 | (EXFULL, "Exchange full"), 68 | (ENOANO, "No anode"), 69 | (EBADRQC, "Invalid request code"), 70 | (EBADSLT, "Invalid slot"), 71 | (EBFONT, "Bad font file format"), 72 | (ENOSTR, "Device not a stream"), 73 | (ENODATA, "No data available"), 74 | (ETIME, "Timer expired"), 75 | (ENOSR, "Out of streams resources"), 76 | (ENONET, "Machine is not on the network"), 77 | (ENOPKG, "Package not installed"), 78 | (EREMOTE, "Object is remote"), 79 | (ENOLINK, "Link has been severed"), 80 | (EADV, "Advertise error"), 81 | (ESRMNT, "Srmount error"), 82 | (ECOMM, "Communication error on send"), 83 | (EPROTO, "Protocol error"), 84 | (EMULTIHOP, "Multihop attempted"), 85 | (EDOTDOT, "RFS specific error"), 86 | (EBADMSG, "Bad message"), 87 | (EOVERFLOW, "Value too large for defined data type"), 88 | (ENOTUNIQ, "Name not unique on network"), 89 | (EBADFD, "File descriptor in bad state"), 90 | (EREMCHG, "Remote address changed"), 91 | (ELIBACC, "Can not access a needed shared library"), 92 | (ELIBBAD, "Accessing a corrupted shared library"), 93 | (ELIBSCN, ".lib section in a.out corrupted"), 94 | (ELIBMAX, "Attempting to link in too many shared libraries"), 95 | (ELIBEXEC, "Cannot exec a shared library directly"), 96 | (EILSEQ, "Invalid or incomplete multibyte or wide character"), 97 | (ERESTART, "Interrupted system call should be restarted"), 98 | (ESTRPIPE, "Streams pipe error"), 99 | (EUSERS, "Too many users"), 100 | (ENOTSOCK, "Socket operation on non-socket"), 101 | (EDESTADDRREQ, "Destination address required"), 102 | (EMSGSIZE, "Message too long"), 103 | (EPROTOTYPE, "Protocol wrong type for socket"), 104 | (ENOPROTOOPT, "Protocol not available"), 105 | (EPROTONOSUPPORT, "Protocol not supported"), 106 | (ESOCKTNOSUPPORT, "Socket type not supported"), 107 | (EOPNOTSUPP, "Operation not supported"), 108 | (EPFNOSUPPORT, "Protocol family not supported"), 109 | (EAFNOSUPPORT, "Address family not supported by protocol"), 110 | (EADDRINUSE, "Address already in use"), 111 | (EADDRNOTAVAIL, "Cannot assign requested address"), 112 | (ENETDOWN, "Network is down"), 113 | (ENETUNREACH, "Network is unreachable"), 114 | (ENETRESET, "Network dropped connection on reset"), 115 | (ECONNABORTED, "Software caused connection abort"), 116 | (ECONNRESET, "Connection reset by peer"), 117 | (ENOBUFS, "No buffer space available"), 118 | (EISCONN, "Transport endpoint is already connected"), 119 | (ENOTCONN, "Transport endpoint is not connected"), 120 | (ESHUTDOWN, "Cannot send after transport endpoint shutdown"), 121 | (ETOOMANYREFS, "Too many references: cannot splice"), 122 | (ETIMEDOUT, "Connection timed out"), 123 | (ECONNREFUSED, "Connection refused"), 124 | (EHOSTDOWN, "Host is down"), 125 | (EHOSTUNREACH, "No route to host"), 126 | (EALREADY, "Operation already in progress"), 127 | (EINPROGRESS, "Operation now in progress"), 128 | (ESTALE, "Stale file handle"), 129 | (EUCLEAN, "Structure needs cleaning"), 130 | (ENOTNAM, "Not a XENIX named type file"), 131 | (ENAVAIL, "No XENIX semaphores available"), 132 | (EISNAM, "Is a named type file"), 133 | (EREMOTEIO, "Remote I/O error"), 134 | (EDQUOT, "Disk quota exceeded"), 135 | (ENOMEDIUM, "No medium found"), 136 | (EMEDIUMTYPE, "Wrong medium type"), 137 | (ECANCELED, "Operation canceled"), 138 | (ENOKEY, "Required key not available"), 139 | (EKEYEXPIRED, "Key has expired"), 140 | (EKEYREVOKED, "Key has been revoked"), 141 | (EKEYREJECTED, "Key was rejected by service"), 142 | (EOWNERDEAD, "Owner died"), 143 | (ENOTRECOVERABLE, "State not recoverable"), 144 | (ERFKILL, "Operation not possible due to RF-kill"), 145 | (EHWPOISON, "Memory page has hardware error"), 146 | ) 147 | 148 | 149 | def _test_errno(libc: Libc): 150 | @parameter 151 | for i in range(len(error_message)): 152 | errno_msg = error_message[i] 153 | errno = errno_msg[0] 154 | libc.set_errno(i) 155 | assert_equal(libc.get_errno(), i) 156 | libc.set_errno(0) 157 | 158 | 159 | def test_dynamic_errno(): 160 | _test_errno(Libc[static=False]()) 161 | 162 | 163 | def test_static_errno(): 164 | _test_errno(Libc[static=True]()) 165 | 166 | 167 | def _test_strerror(libc: Libc): 168 | @parameter 169 | for i in range(len(error_message)): 170 | errno_msg = error_message[i] 171 | errno = errno_msg[0] 172 | msg = errno_msg[1] 173 | res = char_ptr_to_string(libc.strerror(errno)) 174 | 175 | @parameter 176 | if CompilationTarget.is_linux(): 177 | assert_equal(res, msg) 178 | 179 | 180 | def test_dynamic_strerror(): 181 | _test_strerror(Libc[static=False]()) 182 | 183 | 184 | def test_static_strerror(): 185 | _test_strerror(Libc[static=True]()) 186 | 187 | 188 | def _test_perror(libc: Libc): 189 | @parameter 190 | for i in range(len(error_message)): 191 | errno_msg = error_message[i] 192 | errno = errno_msg[0] 193 | libc.set_errno(errno) 194 | libc.perror() 195 | libc.set_errno(0) 196 | 197 | 198 | def test_dynamic_perror(): 199 | _test_perror(Libc[static=False]()) 200 | 201 | 202 | def test_static_perror(): 203 | _test_perror(Libc[static=True]()) 204 | 205 | 206 | alias log_levels = ( 207 | LOG_EMERG, 208 | LOG_ALERT, 209 | LOG_CRIT, 210 | LOG_ERR, 211 | LOG_WARNING, 212 | LOG_NOTICE, 213 | LOG_INFO, 214 | LOG_DEBUG, 215 | ) 216 | alias log_options = ( 217 | LOG_PID, 218 | LOG_CONS, 219 | LOG_ODELAY, 220 | LOG_NDELAY, 221 | LOG_NOWAIT, 222 | LOG_PERROR, 223 | ) 224 | alias log_facilities = ( 225 | LOG_KERN, 226 | LOG_USER, 227 | LOG_MAIL, 228 | LOG_DAEMON, 229 | LOG_AUTH, 230 | LOG_SYSLOG, 231 | LOG_LPR, 232 | LOG_NEWS, 233 | LOG_UUCP, 234 | LOG_CRON, 235 | LOG_AUTHPRIV, 236 | LOG_FTP, 237 | ) 238 | 239 | 240 | def _test_log(libc: Libc): 241 | with TryLibc(libc): 242 | name = "log_tester" 243 | identity = char_ptr(name) 244 | 245 | @parameter 246 | for i in range(len(log_levels)): 247 | alias level = log_levels[i] 248 | 249 | @parameter 250 | for j in range(len(log_options)): 251 | alias option = log_options[j] 252 | 253 | @parameter 254 | for k in range(len(log_facilities)): 255 | alias facility = log_facilities[k] 256 | libc.openlog(identity, option, facility) 257 | _ = libc.setlogmask(level) 258 | libc.syslog( 259 | level, char_ptr("test i:%d, j:%d, k:%d"), i, j, k 260 | ) 261 | libc.closelog() 262 | _ = name 263 | 264 | 265 | def test_dynamic_log(): 266 | _test_log(Libc[static=False]()) 267 | 268 | 269 | def test_static_log(): 270 | _test_log(Libc[static=True]()) 271 | 272 | 273 | def main(): 274 | test_dynamic_errno() 275 | test_static_errno() 276 | test_dynamic_strerror() 277 | test_static_strerror() 278 | test_dynamic_perror() 279 | test_static_perror() 280 | test_dynamic_log() 281 | test_static_log() 282 | -------------------------------------------------------------------------------- /src/forge_tools/socket/README.md: -------------------------------------------------------------------------------- 1 | # Notes on Socket 2 | 3 | This is an attempt to design something that will be usable for any use case 4 | where one would like to connect two machines. The base abstraction layer for all 5 | communication protocols should be the simple BSD socket API. With some minor 6 | additions of async where IO has no reason to block the main thread, this 7 | implementation follows that philosophy. 8 | 9 | #### Current plan 10 | 1. Build scaffolding for most important platforms in an extensible manner. 11 | 2. Setup a Unified socket interface that all platforms adhere to but constraint 12 | on what is currently supported for each. 13 | 3. Make sync APIs first with async wrappers, progresively develop async infra. 14 | 1. Make sync TCP work for Linux as a starting point. 15 | 2. Develop sync TCP for other platforms. 16 | 3. Start making things really async under the hood. 17 | 4. Develop other protocols. 18 | 19 | #### Current outlook 20 | 21 | Current blocker: no Mojo async, no parametrizable traits. 22 | 23 | The idea is for the Socket struct to be the overarching API for any one platform 24 | specific socket implementation 25 | ```mojo 26 | struct Socket[ 27 | sock_family: SockFamily = SockFamily.AF_INET, 28 | sock_type: SockType = SockType.SOCK_STREAM, 29 | sock_protocol: SockProtocol = SockProtocol.TCP, 30 | sock_address: SockAddr = IPv4Addr, 31 | sock_platform: SockPlatform = current_sock_platform(), 32 | ](Copyable, Movable): 33 | """Struct for using Sockets. In the future this struct should be able to 34 | use any implementation that conforms to the `SocketInterface` trait, once 35 | traits can be parametrized. This will allow the user to implement the 36 | interface for whatever functionality is missing and inject the type. 37 | 38 | Parameters: 39 | sock_family: The socket family e.g. `SockFamily.AF_INET`. 40 | sock_type: The socket type e.g. `SockType.SOCK_STREAM`. 41 | sock_protocol: The socket protocol e.g. `SockProtocol.TCP`. 42 | sock_address: The address type for the socket. 43 | sock_platform: The socket platform e.g. `SockPlatform.LINUX`. 44 | """ 45 | ... 46 | ``` 47 | 48 | The idea is for the interface to be generic and let each implementation 49 | constraint at compile time what it supports and what it doesn't. 50 | 51 | The Socket struct should be parametrizable with the implementation of the 52 | socket interface 53 | ```mojo 54 | socket_impl: SocketInterface = _LinuxSocket[ 55 | sock_family, sock_type, sock_protocol, sock_address 56 | ] 57 | ``` 58 | 59 | Where the interface for any socket implementation looks like this: 60 | (many features are not part of the Mojo language, take it as pseudocode) 61 | ```mojo 62 | trait SocketInterface[ 63 | sock_family: SockFamily, 64 | sock_type: SockType, 65 | sock_protocol: SockProtocol, 66 | sock_address: SockAddr, 67 | sock_platform: SockPlatform, 68 | ](Copyable, Movable): 69 | """Interface for Sockets.""" 70 | 71 | fn __init__(out self) raises: 72 | """Create a new socket object.""" 73 | ... 74 | 75 | fn __init__(out self, fd: ArcPointer[FileDescriptor]): 76 | """Create a new socket object from an open `ArcPointer[FileDescriptor]`.""" 77 | ... 78 | 79 | fn __init__(out self, fd: FileDescriptor): 80 | """Create a new socket object from an open `FileDescriptor`.""" 81 | ... 82 | 83 | fn close(var self) raises: 84 | """Closes the Socket.""" 85 | ... 86 | 87 | fn __del__(deinit self): 88 | """Closes the Socket if it's the last reference to its 89 | `FileDescriptor`. 90 | """ 91 | ... 92 | 93 | fn setsockopt[ 94 | D: DType = C.int.element_type 95 | ](self, level: C.int, option_name: C.int, option_value: Scalar[D]) raises: 96 | """Set socket options.""" 97 | ... 98 | 99 | fn bind(self, address: sock_address) raises: 100 | """Bind the socket to address. The socket must not already be bound.""" 101 | ... 102 | 103 | fn listen(self, backlog: UInt = 0) raises: 104 | """Enable a server to accept connections. `backlog` specifies the number 105 | of unaccepted connections that the system will allow before refusing 106 | new connections. If `backlog == 0`, a default value is chosen. 107 | """ 108 | ... 109 | 110 | async fn connect(self, address: sock_address) raises: 111 | """Connect to a remote socket at address.""" 112 | ... 113 | 114 | async fn accept(self) -> Optional[(Self, sock_address)]: 115 | """Return a new socket representing the connection, and the address of 116 | the client.""" 117 | ... 118 | 119 | # TODO: once we have async generators 120 | fn __iter__(self) -> _SocketIter: 121 | """Iterate asynchronously over the incoming connections.""" 122 | ... 123 | 124 | @staticmethod 125 | fn socketpair() raises -> (Self, Self): 126 | """Create a pair of socket objects from the sockets returned by the 127 | platform `socketpair()` function.""" 128 | ... 129 | 130 | fn get_fd(self) -> FileDescriptor: 131 | """Get the Socket's `FileDescriptor`.""" 132 | ... 133 | 134 | async fn send_fds(self, fds: List[FileDescriptor]) -> Bool: 135 | """Send file descriptor to the socket.""" 136 | ... 137 | 138 | async fn recv_fds(self, maxfds: Int) -> List[FileDescriptor]: 139 | """Receive file descriptors from the socket.""" 140 | ... 141 | 142 | async fn send(self, buf: Span[UInt8], flags: C.int = 0) -> Int: 143 | """Send a buffer of bytes to the socket.""" 144 | ... 145 | 146 | async fn recv[O: MutableOrigin]( 147 | self, buf: Span[UInt8, O], flags: C.int = 0 148 | ) -> Int: 149 | """Receive up to `len(buf)` bytes into the buffer.""" 150 | ... 151 | 152 | @staticmethod 153 | fn gethostname() -> Optional[String]: 154 | """Return the current hostname.""" 155 | ... 156 | 157 | @staticmethod 158 | fn gethostbyname(name: String) -> Optional[sock_address]: 159 | """Map a hostname to its Address.""" 160 | ... 161 | 162 | @staticmethod 163 | fn gethostbyaddr(address: sock_address) -> Optional[String]: 164 | """Map an Address to DNS info.""" 165 | ... 166 | 167 | @staticmethod 168 | fn getservbyname( 169 | name: String, proto: SockProtocol = SockProtocol.TCP 170 | ) -> Optional[sock_address]: 171 | """Map a service name and a protocol name to a port number.""" 172 | ... 173 | 174 | @staticmethod 175 | fn getdefaulttimeout() -> Optional[Float64]: 176 | """Get the default timeout value.""" 177 | ... 178 | 179 | @staticmethod 180 | fn setdefaulttimeout(value: Optional[Float64]) -> Bool: 181 | """Set the default timeout value.""" 182 | ... 183 | 184 | fn settimeout(self, value: Optional[Float64]) -> Bool: 185 | """Set the socket timeout value.""" 186 | ... 187 | 188 | # TODO: This should return an iterator instead 189 | @staticmethod 190 | fn getaddrinfo( 191 | address: sock_address, flags: Int = 0 192 | ) raises -> List[ 193 | (SockFamily, SockType, SockProtocol, String, sock_address) 194 | ]: 195 | """Get the available address information.""" 196 | ... 197 | 198 | fn keep_alive( 199 | self, 200 | enable: Bool = True, 201 | idle: C.int = 2 * 60 * 60, 202 | interval: C.int = 75, 203 | count: C.int = 10, 204 | ) raises: 205 | """Whether and how to keep the connection alive.""" 206 | ... 207 | 208 | fn reuse_address( 209 | self, value: Bool = True, *, full_duplicates: Bool = True 210 | ) raises: 211 | """Whether to allow duplicated addresses.""" 212 | ... 213 | 214 | fn no_delay(self, value: Bool = True) raises: 215 | """Whether to send packets ASAP without accumulating more.""" 216 | ... 217 | ``` 218 | 219 | 220 | What this all will allow is to build higher level pythonic syntax to do servers 221 | for any protocol and inject whatever implementation for any platform specific 222 | use case that the user does not find in the stdlib but exists in an external 223 | library. 224 | 225 | Examples: 226 | 227 | ```mojo 228 | from forge_tools.socket import Socket 229 | 230 | 231 | async def main(): 232 | # TODO: once we have async generators: 233 | # async for conn_attempt in Socket.create_server(("0.0.0.0", 8000)): 234 | # conn_attempt = await server.accept() 235 | # if not conn_attempt: 236 | # continue 237 | # conn, addr = conn_attempt.value() 238 | # ... # handle new connection 239 | 240 | with Socket.create_server(("0.0.0.0", 8000)) as server: 241 | while True: 242 | conn_attempt = await server.accept() 243 | if not conn_attempt: 244 | continue 245 | conn, addr = conn_attempt.value() 246 | ... # handle new connection 247 | ``` 248 | 249 | In the future something like this should be possible: 250 | ```mojo 251 | from collections import Optional 252 | from multiprocessing import Pool 253 | from forge_tools.socket import Socket, IPv4Addr 254 | 255 | 256 | async fn handler(conn_attempt: Optional[Socket, IPv4Addr]): 257 | if not conn_attempt: 258 | return 259 | conn, addr = conn_attempt.value() 260 | ... 261 | 262 | async def main(): 263 | server = Socket.create_server(("0.0.0.0", 8000)) 264 | with Pool() as pool: 265 | _ = await pool.map(handler, iter(server)) 266 | ``` 267 | 268 | #### On future implementation of kernel async IO protocols 269 | 270 | - Is it worth it using [io_uring](https://kernel.dk/io_uring.pdf) (Linux), 271 | [kqueue](https://man.freebsd.org/cgi/man.cgi?query=kqueue&sektion=2) (Unix), 272 | [IOCP]( 273 | https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports) 274 | (Windows) ? 275 | - How much would we need to deviate from Python's APIs ? 276 | - How do we deal with external C library dependencies like [liburing]( 277 | https://github.com/axboe/liburing) if we decide to use it ? 278 | - Do we wait for everything to be implemented in Mojo ? 279 | ([io_uring project](https://github.com/dmitry-salin/io_uring)) 280 | - Could we just leave the implementation to the community and setup a solid 281 | interface? 282 | -------------------------------------------------------------------------------- /src/forge_tools/memory/arena_pointer.mojo: -------------------------------------------------------------------------------- 1 | """Arena Pointer module.""" 2 | 3 | from memory import UnsafePointer, memset, stack_allocation 4 | from collections import Optional 5 | from memory import Span 6 | from sys.info import bitwidthof, simdwidthof 7 | from sys.ffi import OpaquePointer 8 | from os import abort 9 | from .arc_pointer import ArcPointer 10 | 11 | 12 | struct GladiatorPointer[ 13 | is_mutable: Bool, //, 14 | type: AnyType, 15 | origin: Origin[is_mutable], 16 | address_space: AddressSpace = AddressSpace.GENERIC, 17 | ]: 18 | """Gladiator Pointer (Weak Arena Pointer) that resides in an Arena.""" 19 | 20 | alias _U = UnsafePointer[type, address_space] 21 | alias _C = ColosseumPointer[type, origin, address_space] 22 | alias _A = ArcPointer[UnsafePointer[OpaquePointer], origin, address_space] 23 | var _colosseum: Self._A 24 | """A pointer to the collosseum.""" 25 | var _start: Int 26 | """The absolute starting offset from the colosseum pointer.""" 27 | var _len: Int 28 | """The length of the pointer.""" 29 | 30 | fn __init__(out self): 31 | self._colosseum = Self._A( 32 | ptr=stack_allocation[1, Self._C]().bitcast[OpaquePointer](), 33 | is_allocated=True, 34 | in_registers=True, 35 | is_initialized=True, 36 | ) 37 | self._start = 0 38 | self._len = 0 39 | 40 | fn __init__(out self, *, colosseum: Self._A, start: Int, length: Int): 41 | """Constructs a GladiatorPointer from a Pointer. 42 | 43 | Args: 44 | colosseum: A pointer to the colosseum. 45 | start: The absolute starting offset from the colosseum pointer. 46 | length: The length of the pointer. 47 | """ 48 | 49 | self._colosseum = colosseum 50 | self._start = start 51 | self._len = length 52 | 53 | @staticmethod 54 | @always_inline 55 | fn alloc(count: Int) -> Self: 56 | """Allocate memory according to the pointer's logic. 57 | 58 | Args: 59 | count: The number of elements in the buffer. 60 | 61 | Returns: 62 | The pointer to the newly allocated buffer. 63 | """ 64 | return ( 65 | self._colosseum._ptr.unsafe_ptr()[][] 66 | .bitcast[Self._C]() 67 | .alloc(count) 68 | ) 69 | 70 | fn __del__(deinit self): 71 | """Free the memory referenced by the pointer or ignore.""" 72 | self._colosseum._ptr.unsafe_ptr()[].bitcast[Self._C]()[]._free(self^) 73 | 74 | fn __int__(self) -> Int: 75 | return Int(self._colosseum._ptr.unsafe_ptr()[]) 76 | 77 | fn __bool__(self) -> Bool: 78 | return Bool(self._colosseum._ptr.unsafe_ptr()[]) 79 | 80 | 81 | struct ColosseumPointer[ 82 | is_mutable: Bool, //, 83 | type: AnyType, 84 | origin: Origin[is_mutable], 85 | address_space: AddressSpace = AddressSpace.GENERIC, 86 | ](Copyable, Movable): 87 | """Colosseum Pointer (Arena Owner Pointer) that deallocates the arena when 88 | deleted.""" 89 | 90 | var _free_slots: UnsafePointer[Byte] 91 | """Bits indicating whether the slot is free.""" 92 | var _len: Int 93 | """The amount of bits set in the _free_slots pointer.""" 94 | alias _P = UnsafePointer[type, address_space] 95 | var _ptr: Self._P 96 | """The data.""" 97 | alias _S = ArcPointer[UnsafePointer[OpaquePointer], origin, address_space] 98 | var _self_ptr: Self._S 99 | """A self pointer.""" 100 | alias _G = GladiatorPointer[type, origin, address_space] 101 | 102 | fn __init__(out self): 103 | self._free_slots = UnsafePointer[Byte]() 104 | self._len = 0 105 | self._ptr = Self._P() 106 | self._self_ptr = Self._S.alloc(1) 107 | 108 | fn __init__(out self, *, var other: Self): 109 | self._free_slots = other._free_slots 110 | self._len = other._len 111 | self._ptr = other._ptr 112 | self._self_ptr = other._self_ptr 113 | 114 | @doc_private 115 | @always_inline("nodebug") 116 | fn __init__(out self, *, ptr: Self._P, length: Int): 117 | """Constructs an ArenaPointer from an UnsafePointer. 118 | 119 | Args: 120 | ptr: The UnsafePointer. 121 | length: The length of the pointer. 122 | """ 123 | s = Self() 124 | s._ptr = ptr 125 | amnt = length // 8 + Int(length < 8) 126 | p = UnsafePointer[Byte].alloc(amnt) 127 | memset(p, 0xFF, amnt) 128 | s._free_slots = p 129 | s._len = length // 8 + length % 8 130 | s._self_ptr = Self._S(UnsafePointer(to=s).bitcast[OpaquePointer]()) 131 | self = Self(other=s) 132 | 133 | @staticmethod 134 | @always_inline 135 | fn alloc(count: Int) -> Self: 136 | """Allocate memory according to the pointer's logic. 137 | 138 | Args: 139 | count: The number of elements in the buffer. 140 | 141 | Returns: 142 | The pointer to the newly allocated buffer. 143 | """ 144 | return Self(ptr=Self._P.alloc(count), length=count) 145 | 146 | @always_inline 147 | fn alloc(out self, count: Int) -> Self._G: 148 | """Allocate an array with specified or default alignment. 149 | 150 | Args: 151 | count: The number of elements in the array. 152 | 153 | Returns: 154 | The pointer to the newly allocated array. 155 | """ 156 | 157 | alias int_bitwidth = bitwidthof[Int]() 158 | alias int_simdwidth = simdwidthof[Int]() 159 | mask = Scalar[DType.index](0) 160 | for i in range(min(count, int_bitwidth)): 161 | mask |= 1 << (int_bitwidth - i) 162 | 163 | alias widths = (128, 64, 32, 16, 8) 164 | ptr = self._free_slots.bitcast[DType.index]() 165 | num_bytes = self._len // 8 + Int(self._len < 8) 166 | amnt = UInt8(0) 167 | start = 0 168 | 169 | @parameter 170 | for i in range(len(widths)): 171 | alias w = widths[i] 172 | 173 | @parameter 174 | if simdwidthof[Int]() >= w: 175 | rest = num_bytes - start 176 | for _ in range(rest // w): 177 | vec = (ptr + start).load[width=w]() 178 | res = vec == mask 179 | if res.reduce_or(): 180 | amnt = res.cast[DType.uint8]().reduce_add() 181 | break 182 | start += w 183 | 184 | if amnt * int_bitwidth < count: 185 | # TODO: realloc parent 186 | # TODO: return call alloc 187 | abort("support for realloc is still in development") 188 | 189 | if start == num_bytes: 190 | return GladiatorPointer[type, origin, address_space]() 191 | 192 | p = self._free_slots.offset(start // 8) 193 | mask = ~mask 194 | remaining = count - int_bitwidth 195 | if remaining > int_bitwidth: 196 | while remaining > int_bitwidth: 197 | p = p + remaining - int_bitwidth 198 | new_value = ( 199 | p.bitcast[Scalar[DType.index]]().load[width=int_simdwidth]() 200 | & mask 201 | ) 202 | p.store(0, new_value.cast[DType.uint8]()) 203 | remaining -= int_bitwidth 204 | mask = 0 205 | for i in range(remaining): 206 | mask |= 1 << (int_bitwidth - i) 207 | mask = ~mask 208 | new_value = ( 209 | p.bitcast[Scalar[DType.index]]().load[width=int_simdwidth]() & mask 210 | ) 211 | p.store(0, new_value.cast[DType.uint8]()) 212 | 213 | return __type_of(self)._G( 214 | ptr=self._ptr, 215 | start=start, 216 | length=count, 217 | owner_is_alive=self._owner_is_alive, 218 | free_slots=self._free_slots, 219 | ) 220 | 221 | fn _free(out self, var gladiator: Self._G): 222 | p0 = self._free_slots - gladiator._start 223 | full_byte_start = gladiator._start // 8 224 | full_byte_end = gladiator._len // 8 225 | memset(p0 + full_byte_start, 0xFF, full_byte_end) 226 | mask = 0 227 | for i in range(full_byte_end, full_byte_end + gladiator._start % 8): 228 | mask |= 1 << (bitwidthof[Int]() - i) 229 | p0[full_byte_end] = p0[full_byte_end] | mask 230 | 231 | fn unsafe_ptr(self) -> UnsafePointer[type, address_space]: 232 | alias P = Pointer[type, MutableAnyOrigin, address_space] 233 | return self._ptr.bitcast[address_space=address_space]() 234 | 235 | fn __int__(self) -> Int: 236 | return Int(self._ptr) 237 | 238 | fn __bool__(self) -> Bool: 239 | return Bool(self._ptr) 240 | 241 | fn __del__(deinit self): 242 | """Free the memory referenced by the pointer or ignore.""" 243 | self._free_slots.free() 244 | self._owner_is_alive[0] = False # mark as deleted first 245 | self._ptr.free() 246 | 247 | 248 | struct SpartacusPointer[ 249 | is_mutable: Bool, //, 250 | type: AnyType, 251 | origin: Origin[is_mutable], 252 | address_space: AddressSpace = AddressSpace.GENERIC, 253 | ](Copyable, Movable): 254 | """Reference Counted Arena Pointer that deallocates the arena when it's the 255 | last one. 256 | 257 | Safety: 258 | This is not thread safe. 259 | 260 | Notes: 261 | Spartacus is arguably the most famous Roman gladiator, a tough fighter 262 | who led a massive slave rebellion. After being enslaved and put through 263 | gladiator training school, an incredibly brutal place, he and 78 others 264 | revolted against their master Batiatus using only kitchen knives. 265 | [Source]( 266 | https://www.historyextra.com/period/roman/who-were-roman-gladiators-famous-spartacus-crixus/ 267 | ). 268 | """ 269 | 270 | ... 271 | 272 | 273 | struct FlammaPointer[ 274 | is_mutable: Bool, //, 275 | type: AnyType, 276 | origin: Origin[is_mutable], 277 | address_space: AddressSpace = AddressSpace.GENERIC, 278 | ](Copyable, Movable): 279 | """Atomic Reference Counted Arena Pointer that deallocates the arena when 280 | it's the last one. 281 | 282 | Notes: 283 | Gladiators were usually slaves, and Flamma came from the faraway 284 | province of Syria. However, the fighting lifestyle seemed to suit him 285 | well - he was offered his freedom four times, after winning 21 battles, 286 | but refused it and continued to entertain the crowds of the Colosseum 287 | until he died aged 30. His face was even used on coins. [Source]( 288 | https://www.historyextra.com/period/roman/who-were-roman-gladiators-famous-spartacus-crixus/ 289 | ). 290 | """ 291 | 292 | ... 293 | -------------------------------------------------------------------------------- /src/forge_tools/memory/pointer.mojo: -------------------------------------------------------------------------------- 1 | from documentation import doc_private 2 | from collections import Optional 3 | from memory.unsafe_pointer import _default_alignment, UnsafePointer 4 | from memory import stack_allocation 5 | from os import abort 6 | 7 | 8 | trait SafePointer: 9 | """Trait for generic safe pointers.""" 10 | 11 | # TODO: this needs parametrized __getitem__, unsafe_ptr() etc. 12 | 13 | @staticmethod 14 | @always_inline 15 | fn alloc(count: Int) -> Self: 16 | """Allocate memory according to the pointer's logic. 17 | 18 | Args: 19 | count: The number of elements in the buffer. 20 | 21 | Returns: 22 | The pointer to the newly allocated buffer. 23 | """ 24 | ... 25 | 26 | fn __del__(deinit self): 27 | """Free the memory referenced by the pointer or ignore.""" 28 | ... 29 | 30 | 31 | struct Pointer[ 32 | is_mutable: Bool, //, 33 | type: AnyType, 34 | origin: Origin[is_mutable], 35 | address_space: AddressSpace = AddressSpace.GENERIC, 36 | ](Copyable, Movable): 37 | """Defines a base pointer. 38 | 39 | Safety: 40 | This is not thread safe. This is not reference counted. When doing an 41 | explicit copy from another pointer, the self_is_owner flag is set to 42 | False. 43 | """ 44 | 45 | alias _mlir_type = __mlir_type[ 46 | `!lit.ref<`, 47 | type, 48 | `, `, 49 | origin, 50 | `, `, 51 | address_space._value.value, 52 | `>`, 53 | ] 54 | 55 | var _mlir_value: Self._mlir_type 56 | """The underlying MLIR representation.""" 57 | var _flags: UInt8 58 | """Bitwise flags for the pointer. 59 | 60 | #### Bits: 61 | 62 | - 0: in_registers: Whether the pointer is allocated in registers. 63 | - 1: is_allocated: Whether the pointer's memory is allocated. 64 | - 2: is_initialized: Whether the memory is initialized. 65 | - 3: self_is_owner: Whether the pointer owns the memory. 66 | - 4: unset. 67 | - 5: unset. 68 | - 6: unset. 69 | - 7: unset. 70 | """ 71 | 72 | # ===------------------------------------------------------------------===# 73 | # Initializers 74 | # ===------------------------------------------------------------------===# 75 | 76 | fn __init__(out self): 77 | self = Self( 78 | ptr=UnsafePointer[type, address_space](), 79 | is_allocated=False, 80 | in_registers=False, 81 | is_initialized=False, 82 | self_is_owner=True, 83 | ) 84 | 85 | @doc_private 86 | @always_inline("nodebug") 87 | fn __init__( 88 | out self, 89 | *, 90 | _mlir_value: Self._mlir_type, 91 | is_allocated: Bool, 92 | in_registers: Bool = False, 93 | is_initialized: Bool = True, 94 | self_is_owner: Bool = True, 95 | ): 96 | """Constructs a Pointer from its MLIR prepresentation. 97 | 98 | Args: 99 | _mlir_value: The MLIR representation of the pointer. 100 | is_allocated: Whether the pointer's memory is allocated. 101 | in_registers: Whether the pointer is allocated in registers. 102 | is_initialized: Whether the memory is initialized. 103 | self_is_owner: Whether the pointer owns the memory. 104 | """ 105 | self._mlir_value = _mlir_value 106 | self._flags = ( 107 | (UInt8(in_registers) << 7) 108 | | (UInt8(is_allocated) << 6) 109 | | (UInt8(is_initialized) << 5) 110 | | (UInt8(self_is_owner) << 4) 111 | ) 112 | 113 | @staticmethod 114 | @always_inline("nodebug") 115 | fn address_of(ref [origin, address_space._value.value]value: type) -> Self: 116 | """Constructs a Pointer from a reference to a value. 117 | 118 | Args: 119 | value: The value to get the address of. 120 | 121 | Returns: 122 | The result Pointer. 123 | """ 124 | return Pointer( 125 | _mlir_value=__get_mvalue_as_litref(value), 126 | is_allocated=True, 127 | in_registers=True, 128 | is_initialized=True, 129 | self_is_owner=False, 130 | ) 131 | 132 | fn __init__(out self, *, other: Self): 133 | """Constructs a copy from another Pointer **(not the data)**. 134 | 135 | Args: 136 | other: The `Pointer` to copy. 137 | """ 138 | self._mlir_value = other._mlir_value 139 | self._flags = other._flags & 0b1110_1111 140 | 141 | @doc_private 142 | @always_inline("nodebug") 143 | fn __init__( 144 | out self, 145 | *, 146 | ptr: UnsafePointer[type, address_space], 147 | is_allocated: Bool, 148 | in_registers: Bool, 149 | is_initialized: Bool, 150 | self_is_owner: Bool, 151 | ): 152 | """Constructs a Pointer from an UnsafePointer. 153 | 154 | Args: 155 | ptr: The UnsafePointer. 156 | is_allocated: Whether the pointer's memory is allocated. 157 | in_registers: Whether the pointer is allocated in registers. 158 | is_initialized: Whether the memory is initialized. 159 | self_is_owner: Whether the pointer owns the memory. 160 | """ 161 | self = __type_of(self)( 162 | _mlir_value=__mlir_op.`lit.ref.from_pointer`[ 163 | _type = __type_of(self)._mlir_type 164 | ](ptr.address), 165 | is_allocated=is_allocated, 166 | in_registers=in_registers, 167 | is_initialized=is_initialized, 168 | self_is_owner=self_is_owner, 169 | ) 170 | 171 | @staticmethod 172 | @always_inline 173 | fn alloc(count: Int) -> Self: 174 | """Allocate memory according to the pointer's logic. 175 | 176 | Args: 177 | count: The number of elements in the buffer. 178 | 179 | Returns: 180 | The pointer to the newly allocated buffer. 181 | """ 182 | return Self( 183 | ptr=UnsafePointer[type, address_space] 184 | .alloc(count) 185 | .bitcast[address_space=address_space](), 186 | is_allocated=True, 187 | in_registers=False, 188 | is_initialized=False, 189 | self_is_owner=True, 190 | ) 191 | 192 | @staticmethod 193 | @always_inline 194 | fn alloc[ 195 | count: Int, 196 | /, 197 | O: MutableOrigin, 198 | *, 199 | stack_alloc_limit: Int = 1 * 2**20, 200 | name: Optional[StringLiteral] = None, 201 | ]() -> Pointer[type, O, address_space]: 202 | """Allocate an array on the stack with specified or default alignment. 203 | 204 | Parameters: 205 | count: The number of elements in the array. 206 | O: The origin of the Pointer. 207 | stack_alloc_limit: The limit of bytes to allocate on the stack 208 | (default 1 MiB). 209 | name: The name of the global variable (only honored in certain 210 | cases). 211 | 212 | Returns: 213 | The pointer to the newly allocated array. 214 | """ 215 | return Pointer[type, O, address_space]( 216 | ptr=stack_allocation[count, type, address_space=address_space](), 217 | is_allocated=True, 218 | in_registers=True, 219 | is_initialized=True, 220 | self_is_owner=True, 221 | ) 222 | 223 | fn bitcast[ 224 | T: AnyType = Self.type 225 | ](self, out output: Pointer[T, origin, address_space]): 226 | """Bitcasts a `Pointer` to a different type. 227 | 228 | Parameters: 229 | T: The target type. 230 | 231 | Returns: 232 | A new `Pointer` object with the specified type and the same address, 233 | as the original `Pointer`. 234 | """ 235 | alias P = Pointer[T, MutableAnyOrigin, address_space] 236 | s = rebind[Pointer[T, MutableAnyOrigin, address_space]](self) 237 | output = rebind[__type_of(output)]( 238 | P( 239 | ptr=s.unsafe_ptr().bitcast[T](), 240 | is_allocated=s.is_allocated, 241 | in_registers=s.in_registers, 242 | is_initialized=s.is_initialized, 243 | self_is_owner=s.self_is_owner, 244 | ) 245 | ) 246 | 247 | fn unsafe_ptr(self, out output: UnsafePointer[type, address_space]): 248 | """Get a raw pointer to the underlying data. 249 | 250 | Returns: 251 | The raw pointer to the data. 252 | """ 253 | p = __mlir_op.`lit.ref.to_pointer`(self._mlir_value) 254 | output = __type_of(output)(rebind[__type_of(output)._mlir_type](p)) 255 | 256 | @always_inline 257 | fn __getattr__[name: StringLiteral](self) -> Bool: 258 | """Get the attribute. 259 | 260 | Parameters: 261 | name: The name of the attribute. 262 | 263 | Returns: 264 | The attribute value. 265 | """ 266 | 267 | @parameter 268 | if name == "in_registers": 269 | return Bool((self._flags >> 7) & 0b1) 270 | elif name == "is_allocated": 271 | return Bool((self._flags >> 6) & 0b1) 272 | elif name == "is_initialized": 273 | return Bool((self._flags >> 5) & 0b1) 274 | elif name == "self_is_owner": 275 | return Bool((self._flags >> 4) & 0b1) 276 | else: 277 | constrained[False, "unknown attribute"]() 278 | return abort[Bool]() 279 | 280 | @always_inline 281 | fn __setattr__[name: StringLiteral](out self, value: Bool): 282 | """Set the attribute. 283 | 284 | Parameters: 285 | name: The name of the attribute. 286 | 287 | Args: 288 | value: The value to set the attribute to. 289 | """ 290 | 291 | @parameter 292 | if name == "in_registers": 293 | self._flags &= (UInt8(value) << 7) | 0b0111_1111 294 | elif name == "is_allocated": 295 | self._flags &= (UInt8(value) << 6) | 0b1011_1111 296 | elif name == "is_initialized": 297 | self._flags &= (UInt8(value) << 5) | 0b1101_1111 298 | elif name == "self_is_owner": 299 | self._flags &= (UInt8(value) << 4) | 0b1110_1111 300 | else: 301 | constrained[False, "unknown attribute"]() 302 | 303 | fn __bool__(self) -> Bool: 304 | return (self._flags & 0b0110_0000) == 0b0110_0000 305 | 306 | fn __int__(self) -> Int: 307 | return Int( 308 | rebind[Pointer[type, MutableAnyOrigin, address_space]]( 309 | self 310 | ).unsafe_ptr() 311 | ) 312 | 313 | fn __del__(deinit self): 314 | @parameter 315 | if address_space is AddressSpace.GENERIC and is_mutable: 316 | if self._flags & 0b1101_0000 == 0b0101_0000: 317 | p = __mlir_op.`lit.ref.to_pointer`(self._mlir_value) 318 | alias UP = UnsafePointer[ 319 | type, AddressSpace.GENERIC, _default_alignment[type]() 320 | ] 321 | UP(rebind[UP._mlir_type](p)).free() 322 | self._flags &= 0b0001_1111 323 | -------------------------------------------------------------------------------- /src/forge_tools/datetime/timezone.mojo: -------------------------------------------------------------------------------- 1 | # ===----------------------------------------------------------------------=== # 2 | # Copyright (c) 2024, Martin Vuyk Loperena 3 | # 4 | # Licensed under the Apache License v2.0 with LLVM Exceptions: 5 | # https://llvm.org/LICENSE.txt 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # ===----------------------------------------------------------------------=== # 13 | """`TimeZone` module. 14 | 15 | - Notes: 16 | - IANA is supported: [`TimeZone` and DST data sources]( 17 | http://www.iana.org/time-zones/repository/tz-link.html). 18 | [List of TZ identifiers (`tz_str`)]( 19 | https://en.wikipedia.org/wiki/List_of_tz_database_time_zones). 20 | """ 21 | 22 | from collections import Optional 23 | 24 | from .zoneinfo import ( 25 | Offset, 26 | ZoneDST, 27 | ZoneInfo, 28 | ZoneInfoMem32, 29 | ZoneInfoMem8, 30 | ZoneStorageDST, 31 | ZoneStorageNoDST, 32 | offset_at, 33 | get_zoneinfo, 34 | ) 35 | 36 | 37 | struct TimeZone[ 38 | dst_storage: ZoneStorageDST = ZoneInfoMem32, 39 | no_dst_storage: ZoneStorageNoDST = ZoneInfoMem8, 40 | iana: Bool = True, 41 | pyzoneinfo: Bool = True, 42 | native: Bool = False, 43 | ](Copyable, Movable, Writable): 44 | """`TimeZone` struct. Because of a POSIX standard, if you set 45 | the tz_str e.g. Etc/UTC-4 it means 4 hours east of UTC 46 | which is UTC + 4 in numbers. That is: 47 | `TimeZone("Etc/UTC-4", offset_h=4, offset_m=0, sign=1)`. If 48 | `TimeZone[iana=True]("Etc/UTC-4")`, the correct offsets are 49 | returned for the calculations, but the attributes offset_h, 50 | offset_m and sign will remain the default 0, 0, 1 respectively. 51 | 52 | Parameters: 53 | dst_storage: The type of storage to use for ZoneInfo 54 | for zones with Dailight Saving Time. Default Memory. 55 | no_dst_storage: The type of storage to use for ZoneInfo 56 | for zones with no Dailight Saving Time. Default Memory. 57 | iana: Whether timezones from the [IANA database]( 58 | http://www.iana.org/time-zones/repository/tz-link.html) 59 | are used. It defaults to using all available timezones, 60 | if getting them fails at compile time, it tries using 61 | python's zoneinfo if pyzoneinfo is set to True, otherwise 62 | it uses the offsets as is, no daylight saving or 63 | special exceptions. [List of TZ identifiers]( 64 | https://en.wikipedia.org/wiki/List_of_tz_database_time_zones). 65 | pyzoneinfo: Whether to use python's zoneinfo and 66 | datetime to get full IANA support. 67 | native: (fast, partial IANA support) Whether to use a native Dict 68 | with the current timezones from the [List of TZ identifiers]( 69 | https://en.wikipedia.org/wiki/List_of_tz_database_time_zones) 70 | at the time of compilation (for now they're hardcoded 71 | at stdlib release time, in the future it should get them 72 | from the OS). If it fails at compile time, it defaults to 73 | using the given offsets when the timezone was constructed. 74 | """ 75 | 76 | var tz_str: StaticString 77 | """[`TZ identifier`]( 78 | https://en.wikipedia.org/wiki/List_of_tz_database_time_zones).""" 79 | var has_dst: Bool 80 | """Whether the `TimeZone` has Daylight Saving Time.""" 81 | var _dst: dst_storage 82 | var _no_dst: no_dst_storage 83 | 84 | fn __init__( 85 | out self, 86 | tz_str: StaticString = "Etc/UTC", 87 | offset_h: UInt8 = 0, 88 | offset_m: UInt8 = 0, 89 | sign: UInt8 = 1, 90 | has_dst: Bool = False, 91 | zoneinfo: Optional[ZoneInfo[dst_storage, no_dst_storage]] = None, 92 | ): 93 | """Construct a `TimeZone`. 94 | 95 | Args: 96 | tz_str: The [`TZ identifier`]( 97 | https://en.wikipedia.org/wiki/List_of_tz_database_time_zones). 98 | offset_h: Offset for the hour. 99 | offset_m: Offset for the minute. 100 | sign: Sign: {1, -1}. 101 | has_dst: Whether the `TimeZone` has Daylight Saving Time. 102 | zoneinfo: The ZoneInfo for the `TimeZone` to instantiate. 103 | defaults to looking for info on all available timezones. 104 | """ 105 | debug_assert( 106 | offset_h < 100 107 | and offset_h >= 0 108 | and offset_m < 100 109 | and offset_m >= 0 110 | and (sign == 1 or sign == -1), 111 | ( 112 | "utc offsets can't have a member bigger than 100, and sign must" 113 | " be either 1 or -1" 114 | ), 115 | ) 116 | 117 | self.tz_str = tz_str 118 | self.has_dst = has_dst 119 | self._dst = dst_storage() 120 | self._no_dst = no_dst_storage() 121 | if not has_dst: 122 | s = ( 123 | -1 if sign == -1 124 | and not (offset_h == 0 and offset_m == 0) else 1 125 | ) 126 | self._no_dst.add(tz_str, Offset(offset_h, offset_m, s)) 127 | 128 | z = zoneinfo 129 | 130 | @parameter 131 | if native: 132 | if not zoneinfo: 133 | z = get_zoneinfo[dst_storage, no_dst_storage]() 134 | if not z: 135 | return 136 | 137 | @parameter 138 | if iana: 139 | zi = z.value() 140 | if has_dst: 141 | dst = zi.with_dst.get(tz_str) 142 | if not dst: 143 | return 144 | self._dst.add(tz_str, dst.value()) 145 | return 146 | tz = zi.with_no_dst.get(tz_str) 147 | if not tz: 148 | return 149 | self._no_dst.add(tz_str, tz.value()) 150 | 151 | fn __getattr__(self, name: StaticString) raises -> Int8: 152 | """Get the attribute. 153 | 154 | Args: 155 | name: The name of the attribute. 156 | 157 | Returns: 158 | The attribute. 159 | 160 | Raises: 161 | "ZoneInfo not found". 162 | """ 163 | 164 | if name not in ["offset_h", "offset_m", "sign"]: 165 | constrained[False, "there is no such attribute"]() 166 | return 0 167 | 168 | var offset: Offset 169 | if self.has_dst: 170 | var data = self._dst.get(self.tz_str) 171 | if not data: 172 | raise Error("ZoneInfo not found") 173 | offset = data.value().from_hash()[2] 174 | else: 175 | var data = self._no_dst.get(self.tz_str) 176 | if not data: 177 | raise Error("ZoneInfo not found") 178 | offset = data.value() 179 | 180 | if name == "offset_h": 181 | return offset.hour.cast[DType.int8]() 182 | elif name == "offset_m": 183 | return offset.minute.cast[DType.int8]() 184 | elif name == "sign": 185 | return offset.sign 186 | constrained[False, "there is no such attribute"]() 187 | return 0 188 | 189 | fn offset_at( 190 | self, 191 | year: UInt16, 192 | month: UInt8, 193 | day: UInt8, 194 | hour: UInt8 = 0, 195 | minute: UInt8 = 0, 196 | second: UInt8 = 0, 197 | ) -> Offset: 198 | """Return the UTC offset for the `TimeZone` at the given date. 199 | 200 | Args: 201 | year: Year. 202 | month: Month. 203 | day: Day. 204 | hour: Hour. 205 | minute: Minute. 206 | second: Second. 207 | 208 | Returns: 209 | The Offset. 210 | """ 211 | 212 | @parameter 213 | if iana and native: 214 | tz = self._dst.get(self.tz_str) 215 | var offset = offset_at(tz, year, month, day, hour, minute, second) 216 | if offset: 217 | return offset.value() 218 | elif iana and pyzoneinfo: 219 | try: 220 | from python import Python 221 | 222 | zoneinfo = Python.import_module("zoneinfo") 223 | dt = Python.import_module("datetime") 224 | zone = zoneinfo.ZoneInfo(self.tz_str) 225 | local = dt.datetime(year, month, day, hour, tzinfo=zone) 226 | offset = local.utcoffset() 227 | sign = 1 if offset.days == -1 else -1 228 | hours = Int(offset.seconds) // (60 * 60) - Int(hour) 229 | minutes = Int(offset.seconds) % 60 230 | return Offset(hours, minutes, sign) 231 | except: 232 | pass 233 | 234 | data = self._no_dst.get(self.tz_str) 235 | if data: 236 | return data.value() 237 | return Offset(0, 0, 1) 238 | 239 | @always_inline 240 | fn write_to[W: Writer](self, mut writer: W): 241 | """Write the `TimeZone` to a writer. 242 | 243 | Parameters: 244 | W: The writer type. 245 | 246 | Args: 247 | writer: The writer to write to. 248 | """ 249 | writer.write(self.tz_str) 250 | 251 | @always_inline("nodebug") 252 | fn __str__(self) -> String: 253 | """Str. 254 | 255 | Returns: 256 | String. 257 | """ 258 | return self.tz_str 259 | 260 | @always_inline("nodebug") 261 | fn __repr__(self) -> String: 262 | """Repr. 263 | 264 | Returns: 265 | String. 266 | """ 267 | return self.__str__() 268 | 269 | @always_inline("nodebug") 270 | fn __eq__(self, other: Self) -> Bool: 271 | """Whether the tz_str from both TimeZones 272 | are the same. 273 | 274 | Args: 275 | other: Other. 276 | 277 | Returns: 278 | Bool. 279 | """ 280 | return self.tz_str == other.tz_str 281 | 282 | @always_inline("nodebug") 283 | fn __ne__(self, other: Self) -> Bool: 284 | """Whether the tz_str from both TimeZones 285 | are different. 286 | 287 | Args: 288 | other: Other. 289 | 290 | Returns: 291 | Bool. 292 | """ 293 | return self.tz_str != other.tz_str 294 | 295 | @staticmethod 296 | fn from_offset( 297 | year: UInt16, 298 | month: UInt8, 299 | day: UInt8, 300 | offset_h: UInt8, 301 | offset_m: UInt8, 302 | sign: UInt8, 303 | ) -> Self: 304 | """Build a UTC TZ string from the offset. 305 | 306 | Args: 307 | year: Year. 308 | month: Month. 309 | day: Day. 310 | offset_h: Offset for the hour. 311 | offset_m: Offset for the minute. 312 | sign: Sign: {1, -1}. 313 | 314 | Returns: 315 | Self. 316 | """ 317 | _ = year, month, day, offset_h, offset_m, sign 318 | # TODO: it should create an Etc/UTC-X TimeZone 319 | return Self() 320 | -------------------------------------------------------------------------------- /src/test/datetime/test_calendar.mojo: -------------------------------------------------------------------------------- 1 | # RUN: %mojo %s 2 | 3 | from testing import assert_equal, assert_false, assert_raises, assert_true 4 | 5 | from forge_tools.datetime.calendar import ( 6 | CalendarHashes, 7 | Calendar, 8 | Gregorian, 9 | UTCFast, 10 | PythonCalendar, 11 | UTCCalendar, 12 | UTCFastCal, 13 | _date, 14 | ) 15 | 16 | 17 | def _get_dates_as_lists(t1: _date, t2: _date) -> (List[Int], List[Int]): 18 | l1 = List[Int]( 19 | Int(t1[0]), 20 | Int(t1[1]), 21 | Int(t1[2]), 22 | Int(t1[3]), 23 | Int(t1[4]), 24 | Int(t1[5]), 25 | Int(t1[6]), 26 | Int(t1[7]), 27 | ) 28 | l2 = List[Int]( 29 | Int(t2[0]), 30 | Int(t2[1]), 31 | Int(t2[2]), 32 | Int(t2[3]), 33 | Int(t2[4]), 34 | Int(t2[5]), 35 | Int(t2[6]), 36 | Int(t2[7]), 37 | ) 38 | return l1^, l2^ 39 | 40 | 41 | def test_calendar_hashes(): 42 | alias calh64 = CalendarHashes(CalendarHashes.UINT64) 43 | alias calh32 = CalendarHashes(CalendarHashes.UINT32) 44 | alias calh16 = CalendarHashes(CalendarHashes.UINT16) 45 | alias calh8 = CalendarHashes(CalendarHashes.UINT8) 46 | 47 | greg = Gregorian() 48 | d = _date(9999, 12, 31, 23, 59, 59, 999, 999) 49 | h = greg.hash[calh64](d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7]) 50 | result = _get_dates_as_lists(d, greg.from_hash[calh64](h)) 51 | assert_equal(result[0].__str__(), result[1].__str__()) 52 | d = _date(4095, 12, 31, 0, 0, 0, 0, 0) 53 | h = greg.hash[calh32](d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7]) 54 | result = _get_dates_as_lists(d, greg.from_hash[calh32](h)) 55 | assert_equal(result[0].__str__(), result[1].__str__()) 56 | 57 | utcfast = UTCFast() 58 | d = _date(9999, 12, 31, 23, 59, 59, 999, 0) 59 | h = utcfast.hash[calh64](d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7]) 60 | result = _get_dates_as_lists(d, utcfast.from_hash[calh64](h)) 61 | assert_equal(result[0].__str__(), result[1].__str__()) 62 | d = _date(4095, 12, 31, 23, 59, 0, 0, 0) 63 | h = utcfast.hash[calh32](d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7]) 64 | result = _get_dates_as_lists(d, utcfast.from_hash[calh32](h)) 65 | assert_equal(result[0].__str__(), result[1].__str__()) 66 | d = _date(3, 12, 31, 23, 0, 0, 0, 0) 67 | h = utcfast.hash[calh16](d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7]) 68 | result = _get_dates_as_lists(d, utcfast.from_hash[calh16](h)) 69 | assert_equal(result[0].__str__(), result[1].__str__()) 70 | d = _date(0, 0, 6, 23, 0, 0, 0, 0) 71 | h = utcfast.hash[calh8](d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7]) 72 | result = _get_dates_as_lists(d, utcfast.from_hash[calh8](h)) 73 | assert_equal(result[0].__str__(), result[1].__str__()) 74 | 75 | 76 | def test_python_calendar(): 77 | alias cal = PythonCalendar 78 | assert_equal(3, Int(cal.day_of_week(2023, 6, 15))) 79 | assert_equal(5, Int(cal.day_of_week(2024, 6, 15))) 80 | assert_equal(166, Int(cal.day_of_year(2023, 6, 15))) 81 | assert_equal(167, Int(cal.day_of_year(2024, 6, 15))) 82 | assert_equal(365, Int(cal.day_of_year(2023, 12, 31))) 83 | assert_equal(366, Int(cal.day_of_year(2024, 12, 31))) 84 | 85 | for i in range(1, 3_000): 86 | if i % 4 == 0 and (i % 100 != 0 or i % 400 == 0): 87 | assert_true(cal.is_leapyear(i)) 88 | assert_equal(29, Int(cal.max_days_in_month(i, 2))) 89 | else: 90 | assert_false(cal.is_leapyear(i)) 91 | assert_equal(28, Int(cal.max_days_in_month(i, 2))) 92 | 93 | assert_equal(27, Int(cal.leapsecs_since_epoch(2017, 1, 2))) 94 | res = cal.monthrange(2023, 2) 95 | assert_equal(2, Int(res[0])) 96 | assert_equal(28, Int(res[1])) 97 | res = cal.monthrange(2024, 2) 98 | assert_equal(3, Int(res[0])) 99 | assert_equal(29, Int(res[1])) 100 | assert_equal(60, Int(cal.max_second(1972, 6, 30, 23, 59))) 101 | assert_equal(60, Int(cal.max_second(1972, 12, 31, 23, 59))) 102 | assert_equal(60, Int(cal.max_second(1973, 12, 31, 23, 59))) 103 | assert_equal(60, Int(cal.max_second(1974, 12, 31, 23, 59))) 104 | assert_equal(60, Int(cal.max_second(1975, 12, 31, 23, 59))) 105 | assert_equal(60, Int(cal.max_second(1976, 12, 31, 23, 59))) 106 | assert_equal(60, Int(cal.max_second(1977, 12, 31, 23, 59))) 107 | assert_equal(60, Int(cal.max_second(1978, 12, 31, 23, 59))) 108 | assert_equal(60, Int(cal.max_second(1979, 12, 31, 23, 59))) 109 | assert_equal(60, Int(cal.max_second(1981, 6, 30, 23, 59))) 110 | assert_equal(60, Int(cal.max_second(1982, 6, 30, 23, 59))) 111 | assert_equal(60, Int(cal.max_second(1983, 6, 30, 23, 59))) 112 | assert_equal(60, Int(cal.max_second(1985, 6, 30, 23, 59))) 113 | assert_equal(60, Int(cal.max_second(1987, 12, 31, 23, 59))) 114 | assert_equal(60, Int(cal.max_second(1989, 12, 31, 23, 59))) 115 | assert_equal(60, Int(cal.max_second(1990, 12, 31, 23, 59))) 116 | assert_equal(60, Int(cal.max_second(1992, 6, 30, 23, 59))) 117 | assert_equal(60, Int(cal.max_second(1993, 6, 30, 23, 59))) 118 | assert_equal(60, Int(cal.max_second(1994, 6, 30, 23, 59))) 119 | assert_equal(60, Int(cal.max_second(1995, 12, 31, 23, 59))) 120 | assert_equal(60, Int(cal.max_second(1997, 6, 30, 23, 59))) 121 | assert_equal(60, Int(cal.max_second(1998, 12, 31, 23, 59))) 122 | assert_equal(60, Int(cal.max_second(2005, 12, 31, 23, 59))) 123 | assert_equal(60, Int(cal.max_second(2008, 12, 31, 23, 59))) 124 | assert_equal(60, Int(cal.max_second(2012, 6, 30, 23, 59))) 125 | assert_equal(60, Int(cal.max_second(2015, 6, 30, 23, 59))) 126 | assert_equal(60, Int(cal.max_second(2016, 12, 31, 23, 59))) 127 | assert_equal(120, Int(cal.seconds_since_epoch(1, 1, 1, 0, 2, 0))) 128 | assert_equal( 129 | 120 * 1_000, Int(cal.m_seconds_since_epoch(1, 1, 1, 0, 2, 0, 0)) 130 | ) 131 | assert_equal( 132 | Int(120 * 1e9), 133 | Int(cal.n_seconds_since_epoch(1, 1, 1, 0, 2, 0, 0, 0, 0)), 134 | ) 135 | d1 = cal.seconds_since_epoch(2024, 1, 1, 0, 2, 0) 136 | d2 = cal.seconds_since_epoch(2024, 1, 1, 0, 0, 0) 137 | assert_equal(120, Int(d1 - d2)) 138 | d1 = cal.m_seconds_since_epoch(2024, 1, 1, 0, 2, 0, 0) 139 | d2 = cal.m_seconds_since_epoch(2024, 1, 1, 0, 0, 0, 0) 140 | assert_equal(120 * 1_000, Int(d1 - d2)) 141 | d1 = cal.n_seconds_since_epoch(500, 1, 1, 0, 2, 0, 0, 0, 0) 142 | d2 = cal.n_seconds_since_epoch(500, 1, 1, 0, 0, 0, 0, 0, 0) 143 | assert_equal(Int(120 * 1e9), Int(d1 - d2)) 144 | 145 | alias day_to_sec: UInt64 = 24 * 60 * 60 146 | alias sec_to_nano: UInt64 = 1_000_000_000 147 | d1 = cal.seconds_since_epoch(2024, 12, 31, 3, 4, 5) 148 | d2 = cal.seconds_since_epoch(2025, 1, 1, 3, 4, 5) 149 | assert_equal(1 * day_to_sec, d2 - d1) 150 | d1 = cal.m_seconds_since_epoch(2024, 12, 31, 3, 4, 5, 6) 151 | d2 = cal.m_seconds_since_epoch(2025, 1, 1, 3, 4, 5, 6) 152 | assert_equal(1 * day_to_sec * 1_000, d2 - d1) 153 | d1 = cal.n_seconds_since_epoch(500, 12, 31, 3, 4, 5, 6, 7, 8) 154 | d2 = cal.n_seconds_since_epoch(501, 1, 1, 3, 4, 5, 6, 7, 8) 155 | assert_equal(1 * day_to_sec * sec_to_nano, d2 - d1) 156 | 157 | 158 | def test_gregorian_utc_calendar(): 159 | alias cal = UTCCalendar 160 | assert_equal(3, Int(cal.day_of_week(2023, 6, 15))) 161 | assert_equal(5, Int(cal.day_of_week(2024, 6, 15))) 162 | assert_equal(166, Int(cal.day_of_year(2023, 6, 15))) 163 | assert_equal(167, Int(cal.day_of_year(2024, 6, 15))) 164 | assert_equal(27, Int(cal.leapsecs_since_epoch(2017, 1, 2))) 165 | res = cal.monthrange(2023, 2) 166 | assert_equal(2, Int(res[0])) 167 | assert_equal(28, Int(res[1])) 168 | res = cal.monthrange(2024, 2) 169 | assert_equal(3, Int(res[0])) 170 | assert_equal(29, Int(res[1])) 171 | assert_equal(120, Int(cal.seconds_since_epoch(1970, 1, 1, 0, 2, 0))) 172 | assert_equal( 173 | 120 * 1_000, Int(cal.m_seconds_since_epoch(1970, 1, 1, 0, 2, 0, 0)) 174 | ) 175 | assert_equal( 176 | Int(120 * 1e9), 177 | Int(cal.n_seconds_since_epoch(1970, 1, 1, 0, 2, 0, 0, 0, 0)), 178 | ) 179 | d1 = cal.seconds_since_epoch(2024, 1, 1, 0, 2, 0) 180 | d2 = cal.seconds_since_epoch(2024, 1, 1, 0, 0, 0) 181 | assert_equal(UInt64(120), d1 - d2) 182 | d1 = cal.m_seconds_since_epoch(2024, 1, 1, 0, 2, 0, 0) 183 | d2 = cal.m_seconds_since_epoch(2024, 1, 1, 0, 0, 0, 0) 184 | assert_equal(UInt64(120 * 1_000), d1 - d2) 185 | d1 = cal.n_seconds_since_epoch(2024, 1, 1, 0, 2, 0, 0, 0, 0) 186 | d2 = cal.n_seconds_since_epoch(2024, 1, 1, 0, 0, 0, 0, 0, 0) 187 | assert_equal(UInt64(Int(120 * 1e9)), d1 - d2) 188 | 189 | alias day_to_sec: UInt64 = 24 * 60 * 60 190 | alias sec_to_nano: UInt64 = 1_000_000_000 191 | d1 = cal.seconds_since_epoch(2024, 12, 31, 3, 4, 5) 192 | d2 = cal.seconds_since_epoch(2025, 1, 1, 3, 4, 5) 193 | assert_equal(1 * day_to_sec, d2 - d1) 194 | d1 = cal.m_seconds_since_epoch(2024, 12, 31, 3, 4, 5, 6) 195 | d2 = cal.m_seconds_since_epoch(2025, 1, 1, 3, 4, 5, 6) 196 | assert_equal(1 * day_to_sec * 1_000, d2 - d1) 197 | d1 = cal.n_seconds_since_epoch(2024, 12, 31, 3, 4, 5, 6, 7, 8) 198 | d2 = cal.n_seconds_since_epoch(2025, 1, 1, 3, 4, 5, 6, 7, 8) 199 | assert_equal(1 * day_to_sec * sec_to_nano, d2 - d1) 200 | 201 | 202 | def test_utcfast_calendar(): 203 | alias cal = UTCFastCal 204 | assert_equal(3, Int(cal.day_of_week(2023, 6, 15))) 205 | assert_equal(5, Int(cal.day_of_week(2024, 6, 15))) 206 | assert_equal(166, Int(cal.day_of_year(2023, 6, 15))) 207 | assert_equal(167, Int(cal.day_of_year(2024, 6, 15))) 208 | assert_equal(365, Int(cal.day_of_year(2023, 12, 31))) 209 | assert_equal(366, Int(cal.day_of_year(2024, 12, 31))) 210 | 211 | assert_equal(0, Int(cal.leapsecs_since_epoch(2017, 1, 2))) 212 | res = cal.monthrange(2023, 2) 213 | assert_equal(2, Int(res[0])) 214 | assert_equal(28, Int(res[1])) 215 | res = cal.monthrange(2024, 2) 216 | assert_equal(3, Int(res[0])) 217 | assert_equal(29, Int(res[1])) 218 | assert_equal(120, Int(cal.seconds_since_epoch(1970, 1, 1, 0, 2, 0))) 219 | assert_equal( 220 | 120 * 1_000, Int(cal.m_seconds_since_epoch(1970, 1, 1, 0, 2, 0, 0)) 221 | ) 222 | assert_equal( 223 | Int(120 * 1e9), 224 | Int(cal.n_seconds_since_epoch(1970, 1, 1, 0, 2, 0, 0, 0, 0)), 225 | ) 226 | d1 = cal.seconds_since_epoch(2024, 1, 1, 0, 2, 0) 227 | d2 = cal.seconds_since_epoch(2024, 1, 1, 0, 0, 0) 228 | assert_equal(120, Int(d1 - d2)) 229 | d1 = cal.m_seconds_since_epoch(2024, 1, 1, 0, 2, 0, 0) 230 | d2 = cal.m_seconds_since_epoch(2024, 1, 1, 0, 0, 0, 0) 231 | assert_equal(120 * 1_000, Int(d1 - d2)) 232 | d1 = cal.n_seconds_since_epoch(2024, 1, 1, 0, 2, 0, 0, 0, 0) 233 | d2 = cal.n_seconds_since_epoch(2024, 1, 1, 0, 0, 0, 0, 0, 0) 234 | assert_equal(Int(120 * 1e9), Int(d1 - d2)) 235 | 236 | alias day_to_sec: UInt64 = 24 * 60 * 60 237 | alias sec_to_nano: UInt64 = 1_000_000_000 238 | d1 = cal.seconds_since_epoch(2024, 12, 31, 3, 4, 5) 239 | d2 = cal.seconds_since_epoch(2025, 1, 1, 3, 4, 5) 240 | assert_equal(1 * day_to_sec, d2 - d1) 241 | d1 = cal.m_seconds_since_epoch(2024, 12, 31, 3, 4, 5, 6) 242 | d2 = cal.m_seconds_since_epoch(2025, 1, 1, 3, 4, 5, 6) 243 | assert_equal(1 * day_to_sec * 1_000, d2 - d1) 244 | d1 = cal.n_seconds_since_epoch(2024, 12, 31, 3, 4, 5, 6, 7, 8) 245 | d2 = cal.n_seconds_since_epoch(2025, 1, 1, 3, 4, 5, 6, 7, 8) 246 | assert_equal(1 * day_to_sec * sec_to_nano, d2 - d1) 247 | 248 | 249 | def main(): 250 | test_calendar_hashes() 251 | test_python_calendar() 252 | test_gregorian_utc_calendar() 253 | test_utcfast_calendar() 254 | -------------------------------------------------------------------------------- /src/forge_tools/datetime/_lists.mojo: -------------------------------------------------------------------------------- 1 | # ===----------------------------------------------------------------------=== # 2 | # Copyright (c) 2024, Martin Vuyk Loperena 3 | # 4 | # Licensed under the Apache License v2.0 with LLVM Exceptions: 5 | # https://llvm.org/LICENSE.txt 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # ===----------------------------------------------------------------------=== # 13 | """Hardcoded lists.""" 14 | 15 | from .zoneinfo import Leapsecs 16 | 17 | alias tz_list = List[StaticString]( 18 | "Africa/Abidjan", 19 | "Africa/Algiers", 20 | "Africa/Bissau", 21 | "Africa/Cairo", 22 | "Africa/Casablanca", 23 | "Africa/Ceuta", 24 | "Africa/El_Aaiun", 25 | "Africa/Johannesburg", 26 | "Africa/Juba", 27 | "Africa/Khartoum", 28 | "Africa/Lagos", 29 | "Africa/Maputo", 30 | "Africa/Monrovia", 31 | "Africa/Nairobi", 32 | "Africa/Ndjamena", 33 | "Africa/Sao_Tome", 34 | "Africa/Tripoli", 35 | "Africa/Tunis", 36 | "Africa/Windhoek", 37 | "America/Adak", 38 | "America/Anchorage", 39 | "America/Araguaina", 40 | "America/Argentina/Buenos_Aires", 41 | "America/Argentina/Catamarca", 42 | "America/Argentina/Cordoba", 43 | "America/Argentina/Jujuy", 44 | "America/Argentina/La_Rioja", 45 | "America/Argentina/Mendoza", 46 | "America/Argentina/Rio_Gallegos", 47 | "America/Argentina/Salta", 48 | "America/Argentina/San_Juan", 49 | "America/Argentina/San_Luis", 50 | "America/Argentina/Tucuman", 51 | "America/Argentina/Ushuaia", 52 | "America/Asuncion", 53 | "America/Bahia", 54 | "America/Bahia_Banderas", 55 | "America/Barbados", 56 | "America/Belem", 57 | "America/Belize", 58 | "America/Boa_Vista", 59 | "America/Bogota", 60 | "America/Boise", 61 | "America/Cambridge_Bay", 62 | "America/Campo_Grande", 63 | "America/Cancun", 64 | "America/Caracas", 65 | "America/Cayenne", 66 | "America/Chicago", 67 | "America/Chihuahua", 68 | "America/Ciudad_Juarez", 69 | "America/Costa_Rica", 70 | "America/Cuiaba", 71 | "America/Danmarkshavn", 72 | "America/Dawson", 73 | "America/Dawson_Creek", 74 | "America/Denver", 75 | "America/Detroit", 76 | "America/Edmonton", 77 | "America/Eirunepe", 78 | "America/El_Salvador", 79 | "America/Fort_Nelson", 80 | "America/Fortaleza", 81 | "America/Glace_Bay", 82 | "America/Goose_Bay", 83 | "America/Grand_Turk", 84 | "America/Guatemala", 85 | "America/Guayaquil", 86 | "America/Guyana", 87 | "America/Halifax", 88 | "America/Havana", 89 | "America/Hermosillo", 90 | "America/Indiana/Indianapolis", 91 | "America/Indiana/Knox", 92 | "America/Indiana/Marengo", 93 | "America/Indiana/Petersburg", 94 | "America/Indiana/Tell_City", 95 | "America/Indiana/Vevay", 96 | "America/Indiana/Vincennes", 97 | "America/Indiana/Winamac", 98 | "America/Inuvik", 99 | "America/Iqaluit", 100 | "America/Jamaica", 101 | "America/Juneau", 102 | "America/Kentucky/Louisville", 103 | "America/Kentucky/Monticello", 104 | "America/La_Paz", 105 | "America/Lima", 106 | "America/Los_Angeles", 107 | "America/Maceio", 108 | "America/Managua", 109 | "America/Manaus", 110 | "America/Martinique", 111 | "America/Matamoros", 112 | "America/Mazatlan", 113 | "America/Menominee", 114 | "America/Merida", 115 | "America/Metlakatla", 116 | "America/Mexico_City", 117 | "America/Miquelon", 118 | "America/Moncton", 119 | "America/Monterrey", 120 | "America/Montevideo", 121 | "America/New_York", 122 | "America/Nome", 123 | "America/Noronha", 124 | "America/North_Dakota/Beulah", 125 | "America/North_Dakota/Center", 126 | "America/North_Dakota/New_Salem", 127 | "America/Nuuk", 128 | "America/Ojinaga", 129 | "America/Panama", 130 | "America/Paramaribo", 131 | "America/Phoenix", 132 | "America/Port-au-Prince", 133 | "America/Porto_Velho", 134 | "America/Puerto_Rico", 135 | "America/Punta_Arenas", 136 | "America/Rankin_Inlet", 137 | "America/Recife", 138 | "America/Regina", 139 | "America/Resolute", 140 | "America/Rio_Branco", 141 | "America/Santarem", 142 | "America/Santiago", 143 | "America/Santo_Domingo", 144 | "America/Sao_Paulo", 145 | "America/Scoresbysund", 146 | "America/Sitka", 147 | "America/St_Johns", 148 | "America/Swift_Current", 149 | "America/Tegucigalpa", 150 | "America/Thule", 151 | "America/Tijuana", 152 | "America/Toronto", 153 | "America/Vancouver", 154 | "America/Whitehorse", 155 | "America/Winnipeg", 156 | "America/Yakutat", 157 | "Antarctica/Casey", 158 | "Antarctica/Davis", 159 | "Antarctica/Macquarie", 160 | "Antarctica/Mawson", 161 | "Antarctica/Palmer", 162 | "Antarctica/Rothera", 163 | "Antarctica/Troll", 164 | "Asia/Almaty", 165 | "Asia/Amman", 166 | "Asia/Anadyr", 167 | "Asia/Aqtau", 168 | "Asia/Aqtobe", 169 | "Asia/Ashgabat", 170 | "Asia/Atyrau", 171 | "Asia/Baghdad", 172 | "Asia/Baku", 173 | "Asia/Bangkok", 174 | "Asia/Barnaul", 175 | "Asia/Beirut", 176 | "Asia/Bishkek", 177 | "Asia/Chita", 178 | "Asia/Choibalsan", 179 | "Asia/Colombo", 180 | "Asia/Damascus", 181 | "Asia/Dhaka", 182 | "Asia/Dili", 183 | "Asia/Dubai", 184 | "Asia/Dushanbe", 185 | "Asia/Famagusta", 186 | "Asia/Gaza", 187 | "Asia/Hebron", 188 | "Asia/Ho_Chi_Minh", 189 | "Asia/Hong_Kong", 190 | "Asia/Hovd", 191 | "Asia/Irkutsk", 192 | "Asia/Jakarta", 193 | "Asia/Jayapura", 194 | "Asia/Jerusalem", 195 | "Asia/Kabul", 196 | "Asia/Kamchatka", 197 | "Asia/Karachi", 198 | "Asia/Kathmandu", 199 | "Asia/Khandyga", 200 | "Asia/Kolkata", 201 | "Asia/Krasnoyarsk", 202 | "Asia/Kuching", 203 | "Asia/Macau", 204 | "Asia/Magadan", 205 | "Asia/Makassar", 206 | "Asia/Manila", 207 | "Asia/Nicosia", 208 | "Asia/Novokuznetsk", 209 | "Asia/Novosibirsk", 210 | "Asia/Omsk", 211 | "Asia/Oral", 212 | "Asia/Pontianak", 213 | "Asia/Pyongyang", 214 | "Asia/Qatar", 215 | "Asia/Qostanay", 216 | "Asia/Qyzylorda", 217 | "Asia/Riyadh", 218 | "Asia/Sakhalin", 219 | "Asia/Samarkand", 220 | "Asia/Seoul", 221 | "Asia/Shanghai", 222 | "Asia/Singapore", 223 | "Asia/Srednekolymsk", 224 | "Asia/Taipei", 225 | "Asia/Tashkent", 226 | "Asia/Tbilisi", 227 | "Asia/Tehran", 228 | "Asia/Thimphu", 229 | "Asia/Tokyo", 230 | "Asia/Tomsk", 231 | "Asia/Ulaanbaatar", 232 | "Asia/Urumqi", 233 | "Asia/Ust-Nera", 234 | "Asia/Vladivostok", 235 | "Asia/Yakutsk", 236 | "Asia/Yangon", 237 | "Asia/Yekaterinburg", 238 | "Asia/Yerevan", 239 | "Atlantic/Azores", 240 | "Atlantic/Bermuda", 241 | "Atlantic/Canary", 242 | "Atlantic/Cape_Verde", 243 | "Atlantic/Faroe", 244 | "Atlantic/Madeira", 245 | "Atlantic/South_Georgia", 246 | "Atlantic/Stanley", 247 | "Australia/Adelaide", 248 | "Australia/Brisbane", 249 | "Australia/Broken_Hill", 250 | "Australia/Darwin", 251 | "Australia/Eucla", 252 | "Australia/Hobart", 253 | "Australia/Lindeman", 254 | "Australia/Lord_Howe", 255 | "Australia/Melbourne", 256 | "Australia/Perth", 257 | "Australia/Sydney", 258 | "CET", 259 | "CST6CDT", 260 | "EET", 261 | "EST", 262 | "EST5EDT", 263 | "Etc/GMT", 264 | "Etc/GMT+1", 265 | "Etc/GMT+10", 266 | "Etc/GMT+11", 267 | "Etc/GMT+12", 268 | "Etc/GMT+2", 269 | "Etc/GMT+3", 270 | "Etc/GMT+4", 271 | "Etc/GMT+5", 272 | "Etc/GMT+6", 273 | "Etc/GMT+7", 274 | "Etc/GMT+8", 275 | "Etc/GMT+9", 276 | "Etc/GMT-1", 277 | "Etc/GMT-10", 278 | "Etc/GMT-11", 279 | "Etc/GMT-12", 280 | "Etc/GMT-13", 281 | "Etc/GMT-14", 282 | "Etc/GMT-2", 283 | "Etc/GMT-3", 284 | "Etc/GMT-4", 285 | "Etc/GMT-5", 286 | "Etc/GMT-6", 287 | "Etc/GMT-7", 288 | "Etc/GMT-8", 289 | "Etc/GMT-9", 290 | "Etc/UTC", 291 | "Europe/Andorra", 292 | "Europe/Astrakhan", 293 | "Europe/Athens", 294 | "Europe/Belgrade", 295 | "Europe/Berlin", 296 | "Europe/Brussels", 297 | "Europe/Bucharest", 298 | "Europe/Budapest", 299 | "Europe/Chisinau", 300 | "Europe/Dublin", 301 | "Europe/Gibraltar", 302 | "Europe/Helsinki", 303 | "Europe/Istanbul", 304 | "Europe/Kaliningrad", 305 | "Europe/Kirov", 306 | "Europe/Kyiv", 307 | "Europe/Lisbon", 308 | "Europe/London", 309 | "Europe/Madrid", 310 | "Europe/Malta", 311 | "Europe/Minsk", 312 | "Europe/Moscow", 313 | "Europe/Paris", 314 | "Europe/Prague", 315 | "Europe/Riga", 316 | "Europe/Rome", 317 | "Europe/Samara", 318 | "Europe/Saratov", 319 | "Europe/Simferopol", 320 | "Europe/Sofia", 321 | "Europe/Tallinn", 322 | "Europe/Tirane", 323 | "Europe/Ulyanovsk", 324 | "Europe/Vienna", 325 | "Europe/Vilnius", 326 | "Europe/Volgograd", 327 | "Europe/Warsaw", 328 | "Europe/Zurich", 329 | "HST", 330 | "Indian/Chagos", 331 | "Indian/Maldives", 332 | "Indian/Mauritius", 333 | "MET", 334 | "MST", 335 | "MST7MDT", 336 | "PST8PDT", 337 | "Pacific/Apia", 338 | "Pacific/Auckland", 339 | "Pacific/Bougainville", 340 | "Pacific/Chatham", 341 | "Pacific/Easter", 342 | "Pacific/Efate", 343 | "Pacific/Fakaofo", 344 | "Pacific/Fiji", 345 | "Pacific/Galapagos", 346 | "Pacific/Gambier", 347 | "Pacific/Guadalcanal", 348 | "Pacific/Guam", 349 | "Pacific/Honolulu", 350 | "Pacific/Kanton", 351 | "Pacific/Kiritimati", 352 | "Pacific/Kosrae", 353 | "Pacific/Kwajalein", 354 | "Pacific/Marquesas", 355 | "Pacific/Nauru", 356 | "Pacific/Niue", 357 | "Pacific/Norfolk", 358 | "Pacific/Noumea", 359 | "Pacific/Pago_Pago", 360 | "Pacific/Palau", 361 | "Pacific/Pitcairn", 362 | "Pacific/Port_Moresby", 363 | "Pacific/Rarotonga", 364 | "Pacific/Tahiti", 365 | "Pacific/Tarawa", 366 | "Pacific/Tongatapu", 367 | "WET", 368 | ) 369 | """List of tz_str.""" 370 | 371 | from .calendar import PythonCalendar, CalendarHashes 372 | 373 | alias cal = PythonCalendar 374 | alias calh32 = CalendarHashes(CalendarHashes.UINT32) 375 | alias leapsecs = List[UInt32]( 376 | cal.hash[calh32](1972, 6, 30), 377 | cal.hash[calh32](1972, 12, 31), 378 | cal.hash[calh32](1973, 12, 31), 379 | cal.hash[calh32](1974, 12, 31), 380 | cal.hash[calh32](1975, 12, 31), 381 | cal.hash[calh32](1976, 12, 31), 382 | cal.hash[calh32](1977, 12, 31), 383 | cal.hash[calh32](1978, 12, 31), 384 | cal.hash[calh32](1979, 12, 31), 385 | cal.hash[calh32](1981, 6, 30), 386 | cal.hash[calh32](1982, 6, 30), 387 | cal.hash[calh32](1983, 6, 30), 388 | cal.hash[calh32](1985, 6, 30), 389 | cal.hash[calh32](1987, 12, 31), 390 | cal.hash[calh32](1989, 12, 31), 391 | cal.hash[calh32](1990, 12, 31), 392 | cal.hash[calh32](1992, 6, 30), 393 | cal.hash[calh32](1993, 6, 30), 394 | cal.hash[calh32](1994, 6, 30), 395 | cal.hash[calh32](1995, 12, 31), 396 | cal.hash[calh32](1997, 6, 30), 397 | cal.hash[calh32](1998, 12, 31), 398 | cal.hash[calh32](2005, 12, 31), 399 | cal.hash[calh32](2008, 12, 31), 400 | cal.hash[calh32](2012, 6, 30), 401 | cal.hash[calh32](2015, 6, 30), 402 | cal.hash[calh32](2016, 12, 31), 403 | ) 404 | """List of leap seconds: cal.hash[calh32](year, month, day). 405 | They MUST be on either June 30th at 23:59 or Dec. 31st at 23:59.""" 406 | -------------------------------------------------------------------------------- /src/benchmarks/collections/bench_array.mojo: -------------------------------------------------------------------------------- 1 | from benchmark import ( 2 | Bench, 3 | BenchConfig, 4 | Bencher, 5 | BenchId, 6 | Unit, 7 | keep, 8 | run, 9 | ) 10 | from random import seed, random_float64 11 | 12 | from forge_tools.collections import Array 13 | 14 | 15 | # ===----------------------------------------------------------------------===# 16 | # Benchmark Data 17 | # ===----------------------------------------------------------------------===# 18 | fn make_array[ 19 | capacity: Int, static: Bool, T: DType = DType.int64 20 | ]() -> Array[T, capacity, static]: 21 | a = Array[T, capacity, static](fill=0) 22 | for i in range(0, capacity): 23 | 24 | @parameter 25 | if T == DType.int64: 26 | a.vec[i] = rebind[Scalar[T]](random.random_si64(0, capacity)) 27 | elif T == DType.float64: 28 | a.vec[i] = rebind[Scalar[T]](random.random_float64(0, capacity)) 29 | a.capacity_left = 0 30 | return a 31 | 32 | 33 | # ===----------------------------------------------------------------------===# 34 | # Benchmark Array init 35 | # ===----------------------------------------------------------------------===# 36 | 37 | 38 | @parameter 39 | fn bench_array_init[capacity: Int, static: Bool](mut b: Bencher) raises: 40 | @always_inline 41 | @parameter 42 | fn call_fn(): 43 | res = Array[DType.int64, capacity, static](fill=0) 44 | keep(res) 45 | 46 | b.iter[call_fn]() 47 | 48 | 49 | # ===----------------------------------------------------------------------===# 50 | # Benchmark Array Insert 51 | # ===----------------------------------------------------------------------===# 52 | @parameter 53 | fn bench_array_insert[capacity: Int, static: Bool](mut b: Bencher) raises: 54 | arr = make_array[capacity, static]() 55 | 56 | @always_inline 57 | @parameter 58 | fn call_fn() raises: 59 | for i in range(0, capacity): 60 | arr.insert(i, random.random_si64(0, capacity).value) 61 | 62 | b.iter[call_fn]() 63 | keep(arr.vec.value) 64 | 65 | 66 | # ===----------------------------------------------------------------------===# 67 | # Benchmark Array Lookup 68 | # ===----------------------------------------------------------------------===# 69 | @parameter 70 | fn bench_array_lookup[capacity: Int, static: Bool](mut b: Bencher) raises: 71 | arr = make_array[capacity, static]() 72 | 73 | @always_inline 74 | @parameter 75 | fn call_fn() raises: 76 | for i in range(0, capacity): 77 | res = arr.index(i) 78 | keep(res._value._impl) 79 | 80 | b.iter[call_fn]() 81 | keep(arr.vec.value) 82 | 83 | 84 | # ===----------------------------------------------------------------------===# 85 | # Benchmark Array contains 86 | # ===----------------------------------------------------------------------===# 87 | @parameter 88 | fn bench_array_contains[capacity: Int, static: Bool](mut b: Bencher) raises: 89 | arr = make_array[capacity, static]() 90 | 91 | @always_inline 92 | @parameter 93 | fn call_fn() raises: 94 | for i in range(0, capacity): 95 | res = i in arr 96 | keep(res) 97 | 98 | b.iter[call_fn]() 99 | keep(arr.vec.value) 100 | 101 | 102 | # ===----------------------------------------------------------------------===# 103 | # Benchmark Array count 104 | # ===----------------------------------------------------------------------===# 105 | @parameter 106 | fn bench_array_count[capacity: Int, static: Bool](mut b: Bencher) raises: 107 | arr = make_array[capacity, static]() 108 | 109 | @always_inline 110 | @parameter 111 | fn call_fn() raises: 112 | for i in range(0, capacity): 113 | res = arr.count(i) 114 | keep(res) 115 | 116 | b.iter[call_fn]() 117 | keep(arr.vec.value) 118 | 119 | 120 | # ===----------------------------------------------------------------------===# 121 | # Benchmark Array sum 122 | # ===----------------------------------------------------------------------===# 123 | @parameter 124 | fn bench_array_sum[capacity: Int](mut b: Bencher) raises: 125 | arr = make_array[capacity, False]() 126 | 127 | @always_inline 128 | @parameter 129 | fn call_fn() raises: 130 | res = arr.sum() 131 | keep(res) 132 | 133 | b.iter[call_fn]() 134 | keep(arr.vec.value) 135 | 136 | 137 | # ===----------------------------------------------------------------------===# 138 | # Benchmark Array filter 139 | # ===----------------------------------------------------------------------===# 140 | @parameter 141 | fn bench_array_filter[capacity: Int, static: Bool](mut b: Bencher) raises: 142 | arr = make_array[capacity, static]() 143 | 144 | fn filterfn(a: Int64) -> Scalar[DType.bool]: 145 | return a < (capacity // 2) 146 | 147 | @always_inline 148 | @parameter 149 | fn call_fn() raises: 150 | res = arr.filter(filterfn) 151 | keep(res) 152 | 153 | b.iter[call_fn]() 154 | keep(arr.vec.value) 155 | 156 | 157 | # ===----------------------------------------------------------------------===# 158 | # Benchmark Array apply 159 | # ===----------------------------------------------------------------------===# 160 | @parameter 161 | fn bench_array_apply[capacity: Int, static: Bool](mut b: Bencher) raises: 162 | arr = make_array[capacity, static]() 163 | 164 | fn applyfn(a: Int64) -> Scalar[DType.int64]: 165 | if a < Int64.MAX_FINITE // 2: 166 | return a * 2 167 | return a 168 | 169 | @always_inline 170 | @parameter 171 | fn call_fn() raises: 172 | arr.apply(applyfn) 173 | 174 | b.iter[call_fn]() 175 | keep(arr.vec.value) 176 | 177 | 178 | # ===----------------------------------------------------------------------===# 179 | # Benchmark Array multiply 180 | # ===----------------------------------------------------------------------===# 181 | @parameter 182 | fn bench_array_multiply[capacity: Int](mut b: Bencher) raises: 183 | arr = make_array[capacity, False]() 184 | 185 | @always_inline 186 | @parameter 187 | fn call_fn() raises: 188 | arr *= 2 189 | 190 | b.iter[call_fn]() 191 | keep(arr.vec.value) 192 | 193 | 194 | # ===----------------------------------------------------------------------===# 195 | # Benchmark Array reverse 196 | # ===----------------------------------------------------------------------===# 197 | @parameter 198 | fn bench_array_reverse[capacity: Int](mut b: Bencher) raises: 199 | arr = make_array[capacity, False, DType.int64]() 200 | 201 | @always_inline 202 | @parameter 203 | fn call_fn() raises: 204 | for _ in range(1_000): 205 | arr.reverse() 206 | 207 | b.iter[call_fn]() 208 | keep(arr.vec.value) 209 | 210 | 211 | # ===----------------------------------------------------------------------===# 212 | # Benchmark Array dot 213 | # ===----------------------------------------------------------------------===# 214 | @parameter 215 | fn bench_array_dot[capacity: Int](mut b: Bencher) raises: 216 | arr1 = make_array[capacity, True, DType.float64]() 217 | arr2 = make_array[capacity, True, DType.float64]() 218 | 219 | @always_inline 220 | @parameter 221 | fn call_fn() raises: 222 | for _ in range(1_000): 223 | res = arr1.dot(arr2) 224 | keep(res) 225 | 226 | b.iter[call_fn]() 227 | keep(arr1) 228 | keep(arr2) 229 | 230 | 231 | # ===----------------------------------------------------------------------===# 232 | # Benchmark Array cross 233 | # ===----------------------------------------------------------------------===# 234 | @parameter 235 | fn bench_array_cross(mut b: Bencher) raises: 236 | arr1 = Array[DType.float64, 3, True]( 237 | random_float64(0, 500), random_float64(0, 500), random_float64(0, 500) 238 | ) 239 | arr2 = Array[DType.float64, 3, True]( 240 | random_float64(0, 500), random_float64(0, 500), random_float64(0, 500) 241 | ) 242 | 243 | @always_inline 244 | @parameter 245 | fn call_fn() raises: 246 | for _ in range(1_000): 247 | res = arr1.cross(arr2) 248 | keep(res) 249 | 250 | b.iter[call_fn]() 251 | keep(arr1) 252 | keep(arr2) 253 | 254 | 255 | # ===----------------------------------------------------------------------===# 256 | # Benchmark Main 257 | # ===----------------------------------------------------------------------===# 258 | def main(): 259 | seed() 260 | m = Bench(BenchConfig(num_repetitions=5, warmup_iters=100)) 261 | alias sizes = Tuple(3, 8, 16, 32, 64, 128, 256) 262 | 263 | @parameter 264 | for i in range(7): 265 | alias size = sizes[i] 266 | # m.bench_function[bench_array_init[size, False]]( 267 | # BenchId("bench_array_init[" + String(size) + "]") 268 | # ) 269 | # FIXME: for some reason, static does not appear faster in these benchmarks 270 | # m.bench_function[bench_array_init[size, True]]( 271 | # BenchId("bench_array_init_static[" + String(size) + "]") 272 | # ) 273 | # m.bench_function[bench_array_insert[size, False]]( 274 | # BenchId("bench_array_insert[" + String(size) + "]") 275 | # ) 276 | # m.bench_function[bench_array_insert[size, True]]( 277 | # BenchId("bench_array_insert_static[" + String(size) + "]") 278 | # ) 279 | # m.bench_function[bench_array_lookup[size, False]]( 280 | # BenchId("bench_array_lookup[" + String(size) + "]") 281 | # ) 282 | # m.bench_function[bench_array_lookup[size, True]]( 283 | # BenchId("bench_array_lookup_static[" + String(size) + "]") 284 | # ) 285 | # m.bench_function[bench_array_contains[size, False]]( 286 | # BenchId("bench_array_contains[" + String(size) + "]") 287 | # ) 288 | # m.bench_function[bench_array_contains[size, True]]( 289 | # BenchId("bench_array_contains_static[" + String(size) + "]") 290 | # ) 291 | # m.bench_function[bench_array_count[size, False]]( 292 | # BenchId("bench_array_count[" + String(size) + "]") 293 | # ) 294 | # m.bench_function[bench_array_count[size, True]]( 295 | # BenchId("bench_array_count_static[" + String(size) + "]") 296 | # ) 297 | m.bench_function[bench_array_sum[size]]( 298 | BenchId("bench_array_sum[" + String(size) + "]") 299 | ) 300 | m.bench_function[bench_array_filter[size, False]]( 301 | BenchId("bench_array_filter[" + String(size) + "]") 302 | ) 303 | # m.bench_function[bench_array_filter[size, True]]( 304 | # BenchId("bench_array_filter_static[" + String(size) + "]") 305 | # ) 306 | m.bench_function[bench_array_apply[size, True]]( 307 | BenchId("bench_array_apply[" + String(size) + "]") 308 | ) 309 | # m.bench_function[bench_array_apply[size, True]]( 310 | # BenchId("bench_array_apply_static[" + String(size) + "]") 311 | # ) 312 | m.bench_function[bench_array_multiply[size]]( 313 | BenchId("bench_array_multiply[" + String(size) + "]") 314 | ) 315 | # m.bench_function[bench_array_reverse[size]]( 316 | # BenchId("bench_array_reverse[" + String(size) + "]") 317 | # ) 318 | # m.bench_function[bench_array_dot[size]]( 319 | # BenchId("bench_array_dot[" + String(size) + "]") 320 | # ) 321 | # m.bench_function[bench_array_cross](BenchId("bench_array_cross")) 322 | 323 | print("") 324 | values = Dict[String, List[Float64]]() 325 | for i in m.info_vec: 326 | res = i[].result.mean() 327 | val = values.get(i[].name, List[Float64](0, 0)) 328 | values[i[].name] = List[Float64](res + val[0], val[1] + 1) 329 | for i in values.items(): 330 | print(i[].key, ":", i[].value[0] / i[].value[1]) 331 | -------------------------------------------------------------------------------- /src/benchmarks/collections/bench_list.mojo: -------------------------------------------------------------------------------- 1 | from benchmark import ( 2 | Bench, 3 | BenchConfig, 4 | Bencher, 5 | BenchId, 6 | Unit, 7 | keep, 8 | run, 9 | clobber_memory, 10 | ) 11 | from random import seed, random_float64 12 | 13 | 14 | # ===----------------------------------------------------------------------===# 15 | # Benchmark Data 16 | # ===----------------------------------------------------------------------===# 17 | fn make_list[capacity: Int, T: DType = DType.int64]() -> List[Scalar[T]]: 18 | a = List[Scalar[T]](capacity=capacity) 19 | for i in range(0, capacity): 20 | 21 | @parameter 22 | if T == DType.int64: 23 | a[i] = rebind[Scalar[T]](random.random_si64(0, capacity)) 24 | elif T == DType.float64: 25 | a[i] = rebind[Scalar[T]](random.random_float64(0, capacity)) 26 | else: 27 | a[i] = 0 28 | a.size = capacity 29 | return a^ 30 | 31 | 32 | # ===----------------------------------------------------------------------===# 33 | # Benchmark list init 34 | # ===----------------------------------------------------------------------===# 35 | 36 | 37 | @parameter 38 | fn bench_list_init[capacity: Int](mut b: Bencher) raises: 39 | @always_inline 40 | @parameter 41 | fn call_fn(): 42 | p = DTypePointer[DType.int64].alloc(capacity) 43 | p.scatter(Int64(1), Int64(0)) 44 | res = List[Int64]( 45 | unsafe_pointer=UnsafePointer[Int64]._from_dtype_ptr(p), 46 | size=capacity, 47 | capacity=capacity, 48 | ) 49 | clobber_memory() 50 | keep(res.data.address) 51 | 52 | b.iter[call_fn]() 53 | 54 | 55 | # ===----------------------------------------------------------------------===# 56 | # Benchmark list Insert 57 | # ===----------------------------------------------------------------------===# 58 | @parameter 59 | fn bench_list_insert[capacity: Int](mut b: Bencher) raises: 60 | items = make_list[capacity]() 61 | 62 | @always_inline 63 | @parameter 64 | fn call_fn() raises: 65 | for i in range(0, capacity): 66 | items.insert(i, random.random_si64(0, capacity).value) 67 | clobber_memory() 68 | 69 | b.iter[call_fn]() 70 | keep(items.data.address) 71 | 72 | 73 | # ===----------------------------------------------------------------------===# 74 | # Benchmark list Lookup 75 | # ===----------------------------------------------------------------------===# 76 | @parameter 77 | fn bench_list_lookup[capacity: Int](mut b: Bencher) raises: 78 | items = make_list[capacity]() 79 | 80 | @always_inline 81 | @parameter 82 | fn call_fn() raises: 83 | for i in range(0, capacity): 84 | res = 0 85 | for idx in range(capacity): 86 | if items.unsafe_get(idx) == i: 87 | res = idx 88 | break 89 | keep(res) 90 | 91 | b.iter[call_fn]() 92 | keep(items.data.address) 93 | 94 | 95 | # ===----------------------------------------------------------------------===# 96 | # Benchmark list contains 97 | # ===----------------------------------------------------------------------===# 98 | @parameter 99 | fn bench_list_contains[capacity: Int](mut b: Bencher) raises: 100 | items = make_list[capacity]() 101 | 102 | @always_inline 103 | @parameter 104 | fn call_fn() raises: 105 | for i in range(0, capacity): 106 | res = False 107 | for idx in range(capacity): 108 | if items.unsafe_get(idx) == i: 109 | res = True 110 | break 111 | keep(res) 112 | 113 | b.iter[call_fn]() 114 | keep(items.data.address) 115 | 116 | 117 | # ===----------------------------------------------------------------------===# 118 | # Benchmark list count 119 | # ===----------------------------------------------------------------------===# 120 | @parameter 121 | fn bench_list_count[capacity: Int](mut b: Bencher) raises: 122 | items = make_list[capacity]() 123 | 124 | @always_inline 125 | @parameter 126 | fn call_fn() raises: 127 | for i in range(0, capacity): 128 | res = 0 129 | for idx in range(capacity): 130 | if items.unsafe_get(idx) == i: 131 | res += 1 132 | keep(res) 133 | 134 | b.iter[call_fn]() 135 | keep(items.data.address) 136 | 137 | 138 | # ===----------------------------------------------------------------------===# 139 | # Benchmark list sum 140 | # ===----------------------------------------------------------------------===# 141 | @parameter 142 | fn bench_list_sum[capacity: Int](mut b: Bencher) raises: 143 | items = make_list[capacity]() 144 | 145 | @always_inline 146 | @parameter 147 | fn call_fn() raises: 148 | res: Int64 = 0 149 | for i in range(capacity): 150 | res += items.unsafe_get(i) 151 | clobber_memory() 152 | keep(res) 153 | 154 | b.iter[call_fn]() 155 | keep(items.data.address) 156 | 157 | 158 | # ===----------------------------------------------------------------------===# 159 | # Benchmark list filter 160 | # ===----------------------------------------------------------------------===# 161 | @parameter 162 | fn bench_list_filter[capacity: Int](mut b: Bencher) raises: 163 | items = make_list[capacity]() 164 | 165 | fn filterfn(a: Int64) -> Scalar[DType.bool]: 166 | return a < (capacity // 2) 167 | 168 | @always_inline 169 | @parameter 170 | fn call_fn() raises: 171 | res = List[Int64](capacity=capacity) 172 | amnt = 0 173 | for i in range(capacity): 174 | if filterfn(items.unsafe_get(i)): 175 | res.unsafe_set(amnt, items.unsafe_get(i)) 176 | amnt += 1 177 | clobber_memory() 178 | keep(res.data.address) 179 | 180 | b.iter[call_fn]() 181 | keep(items.data.address) 182 | 183 | 184 | # ===----------------------------------------------------------------------===# 185 | # Benchmark list apply 186 | # ===----------------------------------------------------------------------===# 187 | @parameter 188 | fn bench_list_apply[capacity: Int](mut b: Bencher) raises: 189 | items = make_list[capacity]() 190 | 191 | fn applyfn(a: Int64) -> Scalar[DType.int64]: 192 | if a < Int64.MAX_FINITE // 2: 193 | return a * 2 194 | return a 195 | 196 | @always_inline 197 | @parameter 198 | fn call_fn() raises: 199 | for i in range(capacity): 200 | items.unsafe_set(i, applyfn(items.unsafe_get(i))) 201 | clobber_memory() 202 | 203 | b.iter[call_fn]() 204 | keep(items.data.address) 205 | 206 | 207 | # ===----------------------------------------------------------------------===# 208 | # Benchmark list multiply 209 | # ===----------------------------------------------------------------------===# 210 | @parameter 211 | fn bench_list_multiply[capacity: Int](mut b: Bencher) raises: 212 | items = make_list[capacity]() 213 | 214 | @always_inline 215 | @parameter 216 | fn call_fn() raises: 217 | for i in range(capacity): 218 | items.unsafe_set(i, items.unsafe_get(i) * 2) 219 | clobber_memory() 220 | 221 | b.iter[call_fn]() 222 | keep(items.data.address) 223 | 224 | 225 | # ===----------------------------------------------------------------------===# 226 | # Benchmark list reverse 227 | # ===----------------------------------------------------------------------===# 228 | @parameter 229 | fn bench_list_reverse[capacity: Int](mut b: Bencher) raises: 230 | items = make_list[capacity, DType.uint8]() 231 | 232 | @always_inline 233 | @parameter 234 | fn call_fn() raises: 235 | for _ in range(1_000): 236 | items.reverse() 237 | clobber_memory() 238 | 239 | b.iter[call_fn]() 240 | keep(items.data.address) 241 | 242 | 243 | # ===----------------------------------------------------------------------===# 244 | # Benchmark list dot 245 | # ===----------------------------------------------------------------------===# 246 | @parameter 247 | fn bench_list_dot[capacity: Int](mut b: Bencher) raises: 248 | arr1 = make_list[capacity, DType.float64]() 249 | arr2 = make_list[capacity, DType.float64]() 250 | 251 | @always_inline 252 | @parameter 253 | fn call_fn() raises: 254 | for _ in range(1_000): 255 | res: Float64 = 0 256 | for i in range(len(arr1)): 257 | res += arr1.unsafe_get(i) * arr2.unsafe_get(i) 258 | clobber_memory() 259 | keep(res) 260 | 261 | b.iter[call_fn]() 262 | keep(arr1.data) 263 | keep(arr2.data) 264 | 265 | 266 | # ===----------------------------------------------------------------------===# 267 | # Benchmark list cross 268 | # ===----------------------------------------------------------------------===# 269 | @parameter 270 | fn bench_list_cross(mut b: Bencher) raises: 271 | arr1 = List[Float64](capacity=3) 272 | arr1[0] = random_float64(0, 500) 273 | arr1[1] = random_float64(0, 500) 274 | arr1[2] = random_float64(0, 500) 275 | arr2 = List[Float64](capacity=3) 276 | arr2[0] = random_float64(0, 500) 277 | arr2[1] = random_float64(0, 500) 278 | arr2[2] = random_float64(0, 500) 279 | 280 | @always_inline 281 | @parameter 282 | fn call_fn() raises: 283 | for _ in range(1_000): 284 | res = List[Float64]( 285 | arr1.unsafe_get(1) * arr2.unsafe_get(2) 286 | - arr1.unsafe_get(2) * arr2.unsafe_get(1), 287 | arr1.unsafe_get(2) * arr2.unsafe_get(0) 288 | - arr1.unsafe_get(0) * arr2.unsafe_get(2), 289 | arr1.unsafe_get(0) * arr2.unsafe_get(1) 290 | - arr1.unsafe_get(1) * arr2.unsafe_get(0), 291 | ) 292 | keep(res.data.address) 293 | 294 | b.iter[call_fn]() 295 | keep(arr1.data) 296 | keep(arr2.data) 297 | 298 | 299 | # ===----------------------------------------------------------------------===# 300 | # Benchmark Main 301 | # ===----------------------------------------------------------------------===# 302 | def main(): 303 | seed() 304 | m = Bench(BenchConfig(num_repetitions=5, warmup_iters=100)) 305 | alias sizes = Tuple(3, 8, 16, 32, 64, 128, 256) 306 | 307 | @parameter 308 | for i in range(7): 309 | alias size = sizes[i] 310 | # m.bench_function[bench_list_init[size]]( 311 | # BenchId("bench_list_init[" + String(size) + "]") 312 | # ) 313 | # m.bench_function[bench_list_insert[size]]( 314 | # BenchId("bench_list_insert[" + String(size) + "]") 315 | # ) 316 | # m.bench_function[bench_list_lookup[size]]( 317 | # BenchId("bench_list_lookup[" + String(size) + "]") 318 | # ) 319 | # m.bench_function[bench_list_contains[size]]( 320 | # BenchId("bench_list_contains[" + String(size) + "]") 321 | # ) 322 | # m.bench_function[bench_list_count[size]]( 323 | # BenchId("bench_list_count[" + String(size) + "]") 324 | # ) 325 | m.bench_function[bench_list_sum[size]]( 326 | BenchId("bench_list_sum[" + String(size) + "]") 327 | ) 328 | m.bench_function[bench_list_filter[size]]( 329 | BenchId("bench_list_filter[" + String(size) + "]") 330 | ) 331 | m.bench_function[bench_list_apply[size]]( 332 | BenchId("bench_list_apply[" + String(size) + "]") 333 | ) 334 | m.bench_function[bench_list_multiply[size]]( 335 | BenchId("bench_list_multiply[" + String(size) + "]") 336 | ) 337 | # m.bench_function[bench_list_reverse[size]]( 338 | # BenchId("bench_list_reverse[" + String(size) + "]") 339 | # ) 340 | # m.bench_function[bench_list_dot[size]]( 341 | # BenchId("bench_list_dot[" + String(size) + "]") 342 | # ) 343 | # m.bench_function[bench_list_cross](BenchId("bench_list_cross")) 344 | print("") 345 | values = Dict[String, List[Float64]]() 346 | for i in m.info_vec: 347 | res = i[].result.mean() 348 | val = values.get(i[].name, List[Float64](0, 0)) 349 | values[i[].name] = List[Float64](res + val[0], val[1] + 1) 350 | for i in values.items(): 351 | print(i[].key, ":", i[].value[0] / i[].value[1]) 352 | -------------------------------------------------------------------------------- /src/test/datetime/test_zoneinfo.mojo: -------------------------------------------------------------------------------- 1 | # RUN: %mojo %s 2 | 3 | from testing import assert_equal, assert_false, assert_raises, assert_true 4 | 5 | from forge_tools.datetime.zoneinfo import ( 6 | Offset, 7 | TzDT, 8 | ZoneDST, 9 | ZoneInfoFile32, 10 | ZoneInfoFile8, 11 | ZoneInfoMem32, 12 | ZoneInfoMem8, 13 | get_zoneinfo, 14 | get_leapsecs, 15 | # _parse_iana_leapsecs, 16 | # _parse_iana_zonenow, 17 | # _parse_iana_dst_transitions, 18 | ) 19 | 20 | 21 | def test_offset(): 22 | alias minutes = SIMD[DType.uint8, 4](0, 30, 45, 0) 23 | for k in range(2): 24 | sign = 1 if k == 0 else -1 25 | for j in range(3): 26 | for i in range(16): 27 | of = Offset(i, minutes[j], sign) 28 | assert_equal(of.hour, i) 29 | assert_equal(of.minute, minutes[j]) 30 | assert_equal(of.sign, sign) 31 | of = Offset(buf=of.buf) 32 | assert_equal(of.hour, i) 33 | assert_equal(of.minute, minutes[j]) 34 | assert_equal(of.sign, sign) 35 | 36 | 37 | def test_tzdst(): 38 | alias hours = SIMD[DType.uint8, 8](20, 21, 22, 23, 0, 1, 2, 3) 39 | for month in range(1, 13): 40 | for dow in range(2): 41 | for eomon in range(2): 42 | for week in range(2): 43 | for hour in range(8): 44 | tzdt = TzDT(month, dow, eomon, week, hours[hour]) 45 | assert_equal(tzdt.month, month) 46 | assert_equal(tzdt.dow, dow) 47 | assert_equal(tzdt.eomon, eomon) 48 | assert_equal(tzdt.week, week) 49 | assert_equal(tzdt.hour, hours[hour]) 50 | tzdt = TzDT(buf=tzdt.buf) 51 | assert_equal(tzdt.month, month) 52 | assert_equal(tzdt.dow, dow) 53 | assert_equal(tzdt.eomon, eomon) 54 | assert_equal(tzdt.week, week) 55 | assert_equal(tzdt.hour, hours[hour]) 56 | 57 | 58 | def test_zonedst(): 59 | alias hours = SIMD[DType.uint8, 8](20, 21, 22, 23, 0, 1, 2, 3) 60 | alias minutes = SIMD[DType.uint8, 4](0, 30, 45, 0) 61 | for month in range(1, 13): 62 | for dow in range(2): 63 | for eomon in range(2): 64 | for week in range(2): 65 | for hour in range(8): 66 | for k in range(2): 67 | sign = 1 if k == 0 else -1 68 | for j in range(3): 69 | for i in range(16): 70 | tzdt = TzDT( 71 | month, dow, eomon, week, hours[hour] 72 | ) 73 | of = Offset(i, minutes[j], sign) 74 | parsed = ZoneDST(tzdt, tzdt, of).from_hash() 75 | assert_equal(tzdt.buf, parsed[0].buf) 76 | assert_equal(tzdt.buf, parsed[1].buf) 77 | assert_equal(of.buf, parsed[2].buf) 78 | 79 | 80 | def test_zoneinfomem32(): 81 | storage = ZoneInfoMem32() 82 | tz0 = "tz0" 83 | tz1 = "tz1" 84 | tz2 = "tz2" 85 | tz30 = "tz30" 86 | tz45 = "tz45" 87 | tz0_of = Offset(0, 0, 1) 88 | tz1_of = Offset(1, 0, 1) 89 | tz2_of = Offset(2, 0, 1) 90 | tz30_of = Offset(0, 30, 1) 91 | tz45_of = Offset(0, 45, 1) 92 | storage.add(tz0, ZoneDST(TzDT(), TzDT(), tz0_of)) 93 | storage.add(tz1, ZoneDST(TzDT(), TzDT(), tz1_of)) 94 | storage.add(tz2, ZoneDST(TzDT(), TzDT(), tz2_of)) 95 | storage.add(tz30, ZoneDST(TzDT(), TzDT(), tz30_of)) 96 | storage.add(tz45, ZoneDST(TzDT(), TzDT(), tz45_of)) 97 | tz0_read = storage.get(tz0).value().from_hash()[2] 98 | tz1_read = storage.get(tz1).value().from_hash()[2] 99 | tz2_read = storage.get(tz2).value().from_hash()[2] 100 | tz30_read = storage.get(tz30).value().from_hash()[2] 101 | tz45_read = storage.get(tz45).value().from_hash()[2] 102 | assert_equal(tz0_read.hour, tz0_of.hour) 103 | assert_equal(tz1_read.hour, tz1_of.hour) 104 | assert_equal(tz2_read.hour, tz2_of.hour) 105 | assert_equal(tz30_read.hour, tz30_of.hour) 106 | assert_equal(tz45_read.hour, tz45_of.hour) 107 | assert_equal(tz0_read.minute, tz0_of.minute) 108 | assert_equal(tz1_read.minute, tz1_of.minute) 109 | assert_equal(tz2_read.minute, tz2_of.minute) 110 | assert_equal(tz30_read.minute, tz30_of.minute) 111 | assert_equal(tz45_read.minute, tz45_of.minute) 112 | assert_equal(tz0_read.sign, tz0_of.sign) 113 | assert_equal(tz1_read.sign, tz1_of.sign) 114 | assert_equal(tz2_read.sign, tz2_of.sign) 115 | assert_equal(tz30_read.sign, tz30_of.sign) 116 | assert_equal(tz45_read.sign, tz45_of.sign) 117 | assert_equal(tz0_read.buf, tz0_of.buf) 118 | assert_equal(tz1_read.buf, tz1_of.buf) 119 | assert_equal(tz2_read.buf, tz2_of.buf) 120 | assert_equal(tz30_read.buf, tz30_of.buf) 121 | assert_equal(tz45_read.buf, tz45_of.buf) 122 | 123 | 124 | def test_zoneinfomem8(): 125 | storage = ZoneInfoMem8() 126 | tz0 = "tz0" 127 | tz1 = "tz1" 128 | tz2 = "tz2" 129 | tz30 = "tz30" 130 | tz45 = "tz45" 131 | tz0_of = Offset(0, 0, 1) 132 | tz1_of = Offset(1, 0, 1) 133 | tz2_of = Offset(2, 0, 1) 134 | tz30_of = Offset(0, 30, 1) 135 | tz45_of = Offset(0, 45, 1) 136 | storage.add(tz0, tz0_of) 137 | storage.add(tz1, tz1_of) 138 | storage.add(tz2, tz2_of) 139 | storage.add(tz30, tz30_of) 140 | storage.add(tz45, tz45_of) 141 | tz0_read = storage.get(tz0).value() 142 | tz1_read = storage.get(tz1).value() 143 | tz2_read = storage.get(tz2).value() 144 | tz30_read = storage.get(tz30).value() 145 | tz45_read = storage.get(tz45).value() 146 | assert_equal(tz0_read.hour, tz0_of.hour) 147 | assert_equal(tz1_read.hour, tz1_of.hour) 148 | assert_equal(tz2_read.hour, tz2_of.hour) 149 | assert_equal(tz30_read.hour, tz30_of.hour) 150 | assert_equal(tz45_read.hour, tz45_of.hour) 151 | assert_equal(tz0_read.minute, tz0_of.minute) 152 | assert_equal(tz1_read.minute, tz1_of.minute) 153 | assert_equal(tz2_read.minute, tz2_of.minute) 154 | assert_equal(tz30_read.minute, tz30_of.minute) 155 | assert_equal(tz45_read.minute, tz45_of.minute) 156 | assert_equal(tz0_read.sign, tz0_of.sign) 157 | assert_equal(tz1_read.sign, tz1_of.sign) 158 | assert_equal(tz2_read.sign, tz2_of.sign) 159 | assert_equal(tz30_read.sign, tz30_of.sign) 160 | assert_equal(tz45_read.sign, tz45_of.sign) 161 | assert_equal(tz0_read.buf, tz0_of.buf) 162 | assert_equal(tz1_read.buf, tz1_of.buf) 163 | assert_equal(tz2_read.buf, tz2_of.buf) 164 | assert_equal(tz30_read.buf, tz30_of.buf) 165 | assert_equal(tz45_read.buf, tz45_of.buf) 166 | 167 | 168 | # FIXME 169 | # def test_zoneinfofile32(): 170 | # storage = ZoneInfoFile32() 171 | # tz0 = "tz0" 172 | # tz1 = "tz1" 173 | # tz2 = "tz2" 174 | # tz30 = "tz30" 175 | # tz45 = "tz45" 176 | # tz0_of = Offset(0, 0, 1) 177 | # tz1_of = Offset(1, 0, 1) 178 | # tz2_of = Offset(2, 0, 1) 179 | # tz30_of = Offset(0, 30, 1) 180 | # tz45_of = Offset(0, 45, 1) 181 | # storage.add(tz0, ZoneDST(TzDT(), TzDT(), tz0_of)) 182 | # storage.add(tz1, ZoneDST(TzDT(), TzDT(), tz1_of)) 183 | # storage.add(tz2, ZoneDST(TzDT(), TzDT(), tz2_of)) 184 | # storage.add(tz30, ZoneDST(TzDT(), TzDT(), tz30_of)) 185 | # storage.add(tz45, ZoneDST(TzDT(), TzDT(), tz45_of)) 186 | # tz0_read = storage.get(tz0).value().from_hash()[2] 187 | # tz1_read = storage.get(tz1).value().from_hash()[2] 188 | # tz2_read = storage.get(tz2).value().from_hash()[2] 189 | # tz30_read = storage.get(tz30).value().from_hash()[2] 190 | # tz45_read = storage.get(tz45).value().from_hash()[2] 191 | # assert_equal(tz0_read.hour, tz0_of.hour) 192 | # assert_equal(tz1_read.hour, tz1_of.hour) 193 | # assert_equal(tz2_read.hour, tz2_of.hour) 194 | # assert_equal(tz30_read.hour, tz30_of.hour) 195 | # assert_equal(tz45_read.hour, tz45_of.hour) 196 | # assert_equal(tz0_read.minute, tz0_of.minute) 197 | # assert_equal(tz1_read.minute, tz1_of.minute) 198 | # assert_equal(tz2_read.minute, tz2_of.minute) 199 | # assert_equal(tz30_read.minute, tz30_of.minute) 200 | # assert_equal(tz45_read.minute, tz45_of.minute) 201 | # assert_equal(tz0_read.sign, tz0_of.sign) 202 | # assert_equal(tz1_read.sign, tz1_of.sign) 203 | # assert_equal(tz2_read.sign, tz2_of.sign) 204 | # assert_equal(tz30_read.sign, tz30_of.sign) 205 | # assert_equal(tz45_read.sign, tz45_of.sign) 206 | # assert_equal(tz0_read.buf, tz0_of.buf) 207 | # assert_equal(tz1_read.buf, tz1_of.buf) 208 | # assert_equal(tz2_read.buf, tz2_of.buf) 209 | # assert_equal(tz30_read.buf, tz30_of.buf) 210 | # assert_equal(tz45_read.buf, tz45_of.buf) 211 | 212 | 213 | # FIXME 214 | # def test_zoneinfofile8(): 215 | # storage = ZoneInfoFile8() 216 | # tz0 = "tz0" 217 | # tz1 = "tz1" 218 | # tz2 = "tz2" 219 | # tz30 = "tz30" 220 | # tz45 = "tz45" 221 | # tz0_of = Offset(0, 0, 1) 222 | # tz1_of = Offset(1, 0, 1) 223 | # tz2_of = Offset(2, 0, 1) 224 | # tz30_of = Offset(0, 30, 1) 225 | # tz45_of = Offset(0, 45, 1) 226 | # storage.add(tz0, tz0_of) 227 | # storage.add(tz1, tz1_of) 228 | # storage.add(tz2, tz2_of) 229 | # storage.add(tz30, tz30_of) 230 | # storage.add(tz45, tz45_of) 231 | # tz0_read = storage.get(tz0).value() 232 | # tz1_read = storage.get(tz1).value() 233 | # tz2_read = storage.get(tz2).value() 234 | # tz30_read = storage.get(tz30).value() 235 | # tz45_read = storage.get(tz45).value() 236 | # # print("tz0_of: ", tz0_of.hour, tz0_of.minute, tz0_of.sign) 237 | # # print("tz0_read: ", tz0_read.hour, tz0_read.minute, tz0_read.sign) 238 | # # print("tz1_of: ", tz1_of.hour, tz1_of.minute, tz1_of.sign) 239 | # # print("tz1_read: ", tz1_read.hour, tz1_read.minute, tz1_read.sign) 240 | # # print("tz2_of: ", tz2_of.hour, tz2_of.minute, tz2_of.sign) 241 | # # print("tz2_read: ", tz2_read.hour, tz2_read.minute, tz2_read.sign) 242 | # # print("tz30_of: ", tz30_of.hour, tz30_of.minute, tz30_of.sign) 243 | # # print("tz30_read: ", tz30_read.hour, tz30_read.minute, tz30_read.sign) 244 | # # print("tz45_of: ", tz45_of.hour, tz45_of.minute, tz45_of.sign) 245 | # # print("tz45_read: ", tz45_read.hour, tz45_read.minute, tz45_read.sign) 246 | # assert_equal(tz0_read.hour, tz0_of.hour) 247 | # assert_equal(tz1_read.hour, tz1_of.hour) 248 | # assert_equal(tz2_read.hour, tz2_of.hour) 249 | # assert_equal(tz30_read.hour, tz30_of.hour) 250 | # assert_equal(tz45_read.hour, tz45_of.hour) 251 | # assert_equal(tz0_read.minute, tz0_of.minute) 252 | # assert_equal(tz1_read.minute, tz1_of.minute) 253 | # assert_equal(tz2_read.minute, tz2_of.minute) 254 | # assert_equal(tz30_read.minute, tz30_of.minute) 255 | # assert_equal(tz45_read.minute, tz45_of.minute) 256 | # assert_equal(tz0_read.sign, tz0_of.sign) 257 | # assert_equal(tz1_read.sign, tz1_of.sign) 258 | # assert_equal(tz2_read.sign, tz2_of.sign) 259 | # assert_equal(tz30_read.sign, tz30_of.sign) 260 | # assert_equal(tz45_read.sign, tz45_of.sign) 261 | # assert_equal(tz0_read.buf, tz0_of.buf) 262 | # assert_equal(tz1_read.buf, tz1_of.buf) 263 | # assert_equal(tz2_read.buf, tz2_of.buf) 264 | # assert_equal(tz30_read.buf, tz30_of.buf) 265 | # assert_equal(tz45_read.buf, tz45_of.buf) 266 | 267 | 268 | def test_get_zoneinfo(): 269 | # TODO 270 | pass 271 | 272 | 273 | def test_get_leapsecs(): 274 | # TODO 275 | pass 276 | 277 | 278 | def test_parse_iana_leapsecs(): 279 | # TODO 280 | pass 281 | 282 | 283 | def test_parse_iana_zonenow(): 284 | # TODO 285 | pass 286 | 287 | 288 | def test_parse_iana_dst_transitions(): 289 | # TODO 290 | pass 291 | 292 | 293 | def main(): 294 | test_offset() 295 | test_tzdst() 296 | test_zonedst() 297 | test_zoneinfomem32() 298 | test_zoneinfomem8() 299 | # test_zoneinfofile32() 300 | # test_zoneinfofile8() 301 | test_get_zoneinfo() 302 | test_get_leapsecs() 303 | test_parse_iana_leapsecs() 304 | test_parse_iana_zonenow() 305 | test_parse_iana_dst_transitions() 306 | --------------------------------------------------------------------------------