├── bitstream ├── __init__.py └── tests │ └── test_bitstream.py └── setup.py /bitstream/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module for sequential reading (ReadStream) and writing (WriteStream) from/to bytes. 3 | Also includes objects for converting datatypes from/to bytes, similar to the standard library struct module. 4 | """ 5 | 6 | import math 7 | import struct 8 | from abc import ABC, abstractmethod 9 | from typing import AnyStr, ByteString, cast, Generic, overload, SupportsBytes, Type, TypeVar 10 | 11 | T = TypeVar('T') 12 | 13 | class _Struct(Generic[T]): 14 | _struct: struct.Struct 15 | 16 | def __new__(cls, value: T) -> bytes: 17 | return cls._struct.pack(value) 18 | 19 | def __str__(self) -> str: 20 | return "" % _Struct._struct.format 21 | 22 | @classmethod 23 | def deserialize(cls, stream: "ReadStream") -> T: 24 | return cast(T, cls._struct.unpack(stream.read(bytes, length=cls._struct.size))[0]) 25 | 26 | class IntStruct(_Struct[int]): 27 | pass 28 | 29 | class UnsignedIntStruct(IntStruct): 30 | @classmethod 31 | def deserialize_compressed(cls, stream: "ReadStream") -> int: 32 | number_of_bytes = cls._struct.size 33 | current_byte = number_of_bytes - 1 34 | 35 | while current_byte > 0: 36 | if stream.read(c_bit): 37 | current_byte -= 1 38 | else: 39 | # Read the rest of the bytes 40 | read = stream.read(bytes, length=current_byte + 1) + bytes(number_of_bytes - current_byte - 1) 41 | return cast(int, cls._struct.unpack(read)[0]) 42 | 43 | # All but the first bytes are 0. If the upper half of the last byte is a 0 (positive) or 16 (negative) then what we read will be a 1 and the remaining 4 bits. 44 | # Otherwise we read a 0 and the 8 bits 45 | if stream.read(c_bit): 46 | start = bytes([stream.read_bits(4)]) 47 | else: 48 | start = stream.read(bytes, length=1) 49 | read = start + bytes(number_of_bytes - current_byte - 1) 50 | return cast(int, cls._struct.unpack(read)[0]) 51 | 52 | class SignedIntStruct(IntStruct): 53 | pass 54 | 55 | class c_bool(_Struct[bool]): 56 | _struct = struct.Struct(" None: 112 | """Write this object to the bitstream.""" 113 | pass 114 | 115 | @classmethod 116 | @abstractmethod 117 | def deserialize(cls, stream: "ReadStream") -> "Serializable": 118 | """Create a new object from the bitstream.""" 119 | pass 120 | 121 | S = TypeVar('S', bound=Serializable) 122 | 123 | class ReadStream: 124 | """Allows simple sequential reading from bytes.""" 125 | _data: bytes 126 | 127 | def __init__(self, data: bytes, unlocked: bool=False): 128 | self._data = data 129 | self._unlocked = unlocked 130 | self._read_offset = 0 131 | 132 | @property 133 | def read_offset(self) -> int: 134 | if not self._unlocked: 135 | raise RuntimeError("access to read offset on locked stream") 136 | return self._read_offset 137 | 138 | @read_offset.setter 139 | def read_offset(self, value: int) -> None: 140 | if not self._unlocked: 141 | raise RuntimeError("access to read offset on locked stream") 142 | self._read_offset = value 143 | 144 | def skip_read(self, byte_length: int) -> None: 145 | """Skips reading byte_length number of bytes.""" 146 | self._read_offset += byte_length * 8 147 | 148 | @overload 149 | def read(self, arg_type: Type[_Struct[T]]) -> T: 150 | pass 151 | 152 | @overload 153 | def read(self, arg_type: Type[c_bit]) -> bool: 154 | pass 155 | 156 | @overload 157 | def read(self, arg_type: Type[S]) -> S: 158 | pass 159 | 160 | @overload 161 | def read(self, arg_type: Type[bytes], length: int) -> bytes: 162 | pass 163 | 164 | @overload 165 | def read(self, arg_type: Type[bytes], allocated_length: int=None, length_type: Type[UnsignedIntStruct]=None) -> bytes: 166 | pass 167 | 168 | @overload 169 | def read(self, arg_type: Type[str], allocated_length: int=None, length_type: Type[UnsignedIntStruct]=None) -> str: 170 | pass 171 | 172 | def read(self, arg_type, length=None, allocated_length=None, length_type=None): 173 | """ 174 | Read a value of type arg_type from the bitstream. 175 | allocated_length is for fixed-length strings. 176 | length_type is for variable-length strings. 177 | """ 178 | if issubclass(arg_type, _Struct): 179 | return arg_type.deserialize(self) 180 | if issubclass(arg_type, c_bit): 181 | return self._read_bit() 182 | if issubclass(arg_type, Serializable): 183 | return arg_type.deserialize(self) 184 | if allocated_length is not None or length_type is not None: 185 | return self._read_str(arg_type, allocated_length, length_type) 186 | if issubclass(arg_type, bytes): 187 | return self._read_bytes(length) 188 | raise TypeError(arg_type) 189 | 190 | def _read_str(self, arg_type: Type[AnyStr], allocated_length: int=None, length_type: Type[UnsignedIntStruct]=None) -> AnyStr: 191 | if issubclass(arg_type, str): 192 | char_size = 2 193 | else: 194 | char_size = 1 195 | 196 | if length_type is not None: 197 | # Variable-length string 198 | length = self.read(length_type) 199 | value = self._read_bytes(length*char_size) 200 | elif allocated_length is not None: 201 | # Fixed-length string 202 | value = self._read_bytes(allocated_length*char_size) 203 | # find null terminator 204 | for i in range(len(value)): 205 | char = value[i*char_size:(i+1)*char_size] 206 | if char == bytes(char_size): 207 | value = value[:i*char_size] 208 | break 209 | else: 210 | raise RuntimeError("String doesn't have null terminator") 211 | else: 212 | raise ValueError 213 | 214 | if issubclass(arg_type, str): 215 | return value.decode("utf-16-le") 216 | return value 217 | 218 | def _read_bit(self) -> bool: 219 | bit = self._data[self._read_offset // 8] & 0x80 >> self._read_offset % 8 != 0 220 | self._read_offset += 1 221 | return bit 222 | 223 | def read_bits(self, number_of_bits: int) -> int: 224 | assert 0 < number_of_bits < 8 225 | 226 | output = (self._data[self._read_offset // 8] << self._read_offset % 8) & 0xff # First half 227 | if self._read_offset % 8 != 0 and number_of_bits > 8 - self._read_offset % 8: # If we have a second half, we didn't read enough bytes in the first half 228 | output |= self._data[self._read_offset // 8 + 1] >> 8 - self._read_offset % 8 # Second half (overlaps byte boundary) 229 | output >>= 8 - number_of_bits 230 | self._read_offset += number_of_bits 231 | return output 232 | 233 | def _read_bytes(self, length: int) -> bytes: 234 | if self._read_offset % 8 == 0: 235 | num_bytes_read = length 236 | else: 237 | num_bytes_read = length+1 238 | 239 | # check whether there is enough left to read 240 | if len(self._data) - self._read_offset//8 < num_bytes_read: 241 | raise EOFError("Trying to read %i bytes but only %i remain" % (num_bytes_read, len(self._data) - self._read_offset // 8)) 242 | 243 | if self._read_offset % 8 == 0: 244 | output = self._data[self._read_offset // 8:self._read_offset // 8 + num_bytes_read] 245 | else: 246 | # data is shifted 247 | # clear the part before the struct 248 | 249 | firstbyte = self._data[self._read_offset // 8] & ((1 << 8 - self._read_offset % 8) - 1) 250 | output = firstbyte.to_bytes(1, "big") + self._data[self._read_offset // 8 + 1:self._read_offset // 8 + num_bytes_read] 251 | # shift back 252 | output = (int.from_bytes(output, "big") >> (8 - self._read_offset % 8)).to_bytes(length, "big") 253 | self._read_offset += length * 8 254 | return output 255 | 256 | def read_compressed(self, arg_type: Type[UnsignedIntStruct]) -> int: 257 | return arg_type.deserialize_compressed(self) 258 | 259 | def read_remaining(self) -> bytes: 260 | return self._read_bytes(len(self._data) - int(math.ceil(self._read_offset / 8))) 261 | 262 | def align_read(self) -> None: 263 | if self._read_offset % 8 != 0: 264 | self._read_offset += 8 - self._read_offset % 8 265 | 266 | def all_read(self) -> bool: 267 | # This is not accurate to the bit, just to the byte 268 | return math.ceil(self._read_offset / 8) == len(self._data) 269 | 270 | # Note: a ton of the logic here assumes that the write offset is never moved back, that is, that you never overwrite things 271 | # Doing so may break everything 272 | class WriteStream(SupportsBytes): 273 | """Allows simple sequential writing to bytes.""" 274 | _data: bytearray 275 | 276 | def __init__(self) -> None: 277 | self._data = bytearray() 278 | self._write_offset = 0 279 | self._was_cast_to_bytes = False 280 | 281 | def __bytes__(self) -> bytes: 282 | if self._was_cast_to_bytes: 283 | raise RuntimeError("WriteStream can only be cast to bytes once") 284 | self._was_cast_to_bytes = True 285 | return bytes(self._data) 286 | 287 | @overload 288 | def write(self, arg: ByteString) -> None: 289 | pass 290 | 291 | @overload 292 | def write(self, arg: _Struct) -> None: 293 | pass 294 | 295 | @overload 296 | def write(self, arg: c_bit) -> None: 297 | pass 298 | 299 | @overload 300 | def write(self, arg: Serializable) -> None: 301 | pass 302 | 303 | @overload 304 | def write(self, arg: AnyStr, allocated_length: int=None, length_type: Type[UnsignedIntStruct]=None) -> None: 305 | pass 306 | 307 | def write(self, arg, allocated_length=None, length_type=None): 308 | """ 309 | Write a value to the bitstream. 310 | allocated_length is for fixed-length strings. 311 | length_type is for variable-length strings. 312 | """ 313 | if isinstance(arg, c_bit): 314 | self._write_bit(arg.value) 315 | return 316 | if isinstance(arg, Serializable): 317 | arg.serialize(self) 318 | return 319 | if allocated_length is not None or length_type is not None: 320 | self._write_str(arg, allocated_length, length_type) 321 | return 322 | if isinstance(arg, (bytes, bytearray)): 323 | self._write_bytes(arg) 324 | return 325 | 326 | raise TypeError(arg) 327 | 328 | def _write_str(self, str_: AnyStr, allocated_length: int=None, length_type: Type[UnsignedIntStruct]=None) -> None: 329 | # possibly include default encoded length for non-variable-length strings (seems to be 33) 330 | if isinstance(str_, str): 331 | encoded_str = str_.encode("utf-16-le") 332 | else: 333 | encoded_str = str_ 334 | 335 | if length_type is not None: 336 | # Variable-length string 337 | self.write(length_type(len(str_))) # note: there's also a version that uses the length of the encoded string, should that be used? 338 | elif allocated_length is not None: 339 | # Fixed-length string 340 | # null terminator 341 | if isinstance(str_, str): 342 | char_size = 2 343 | else: 344 | char_size = 1 345 | 346 | if len(str_)+1 > allocated_length: 347 | raise ValueError("String too long!") 348 | encoded_str += bytes(allocated_length*char_size-len(encoded_str)) 349 | self._write_bytes(encoded_str) 350 | 351 | def _write_bit(self, bit: bool) -> None: 352 | self._alloc_bits(1) 353 | if bit: # we don't actually have to do anything if the bit is 0 354 | self._data[self._write_offset//8] |= 0x80 >> self._write_offset % 8 355 | 356 | self._write_offset += 1 357 | 358 | def write_bits(self, value: int, number_of_bits: int) -> None: 359 | assert 0 < number_of_bits < 8 360 | self._alloc_bits(number_of_bits) 361 | 362 | if number_of_bits < 8: # In the case of a partial byte, the bits are aligned from the right (bit 0) rather than the left (as in the normal internal representation) 363 | value = value << (8 - number_of_bits) & 0xff # Shift left to get the bits on the left, as in our internal representation 364 | if self._write_offset % 8 == 0: 365 | self._data[self._write_offset//8] = value 366 | else: 367 | self._data[self._write_offset//8] |= value >> self._write_offset % 8 # First half 368 | if 8 - self._write_offset % 8 < number_of_bits: # If we didn't write it all out in the first half (8 - self._write_offset % 8 is the number we wrote in the first half) 369 | self._data[self._write_offset//8 + 1] = (value << 8 - self._write_offset % 8) & 0xff # Second half (overlaps byte boundary) 370 | 371 | self._write_offset += number_of_bits 372 | 373 | def _write_bytes(self, byte_arg: bytes) -> None: 374 | if self._write_offset % 8 == 0: 375 | self._data[self._write_offset//8:self._write_offset//8+len(byte_arg)] = byte_arg 376 | else: 377 | # shift new input to current shift 378 | new = (int.from_bytes(byte_arg, "big") << (8 - self._write_offset % 8)).to_bytes(len(byte_arg)+1, "big") 379 | # update current byte 380 | self._data[self._write_offset//8] |= new[0] 381 | # add rest 382 | self._data[self._write_offset//8+1:self._write_offset//8+1+len(byte_arg)] = new[1:] 383 | self._write_offset += len(byte_arg)*8 384 | 385 | @overload 386 | def write_compressed(self, byte_arg: UnsignedIntStruct) -> None: 387 | pass 388 | 389 | @overload 390 | def write_compressed(self, byte_arg: bytes) -> None: 391 | pass 392 | 393 | def write_compressed(self, byte_arg) -> None: 394 | current_byte = len(byte_arg) - 1 395 | 396 | # Write upper bytes with a single 1 397 | # From high byte to low byte, if high byte is 0 then write 1. Otherwise write 0 and the remaining bytes 398 | while current_byte > 0: 399 | is_zero = byte_arg[current_byte] == 0 400 | self._write_bit(is_zero) 401 | if not is_zero: 402 | # Write the remainder of the data 403 | self._write_bytes(byte_arg[:current_byte + 1]) 404 | return 405 | current_byte -= 1 406 | 407 | # If the upper half of the last byte is 0 then write 1 and the remaining 4 bits. Otherwise write 0 and the 8 bits. 408 | 409 | is_zero = byte_arg[0] & 0xF0 == 0x00 410 | self._write_bit(is_zero) 411 | if is_zero: 412 | self.write_bits(byte_arg[0], 4) 413 | else: 414 | self._write_bytes(byte_arg[:1]) 415 | 416 | def align_write(self) -> None: 417 | """Align the write offset to the byte boundary.""" 418 | if self._write_offset % 8 != 0: 419 | self._alloc_bits(8 - self._write_offset % 8) 420 | self._write_offset += 8 - self._write_offset % 8 421 | 422 | def _alloc_bits(self, number_of_bits: int) -> None: 423 | bytes_to_allocate: int = math.ceil((self._write_offset + number_of_bits) / 8) - len(self._data) 424 | if bytes_to_allocate > 0: 425 | self._data += bytes(bytes_to_allocate) 426 | -------------------------------------------------------------------------------- /bitstream/tests/test_bitstream.py: -------------------------------------------------------------------------------- 1 | import random 2 | import unittest 3 | 4 | from bitstream import c_uint, ReadStream, WriteStream 5 | 6 | class ReadStreamTest(unittest.TestCase): 7 | def setUp(self): 8 | self.locked = ReadStream(b"test") 9 | self.unlocked = ReadStream(b"test", unlocked=True) 10 | 11 | def test_read_offset(self): 12 | with self.assertRaises(RuntimeError): 13 | self.locked.read_offset 14 | self.assertEqual(self.unlocked.read_offset, 0) 15 | 16 | with self.assertRaises(RuntimeError): 17 | self.locked.read_offset = 1 18 | self.unlocked.read_offset = 1 19 | self.assertEqual(self.unlocked.read_offset, 1) 20 | 21 | def test_align_read(self): 22 | self.unlocked.read_offset = 5 23 | self.unlocked.align_read() 24 | self.assertEqual(self.unlocked.read_offset, 8) 25 | 26 | def test_skip_read(self): 27 | self.unlocked.skip_read(4) 28 | self.assertEqual(self.unlocked.read_offset, 32) 29 | 30 | def test_all_read(self): 31 | self.assertFalse(self.locked.all_read()) 32 | self.locked.skip_read(4) 33 | self.assertTrue(self.locked.all_read()) 34 | 35 | def test_read_remaining(self): 36 | self.assertEqual(self.locked.read_remaining(), b"test") 37 | 38 | 39 | class WriteStreamTest(unittest.TestCase): 40 | def setUp(self): 41 | self.stream = WriteStream() 42 | 43 | def test_cast_multiple(self): 44 | bytes(self.stream) 45 | with self.assertRaises(RuntimeError): 46 | bytes(self.stream) 47 | 48 | def test_align_write(self): 49 | self.stream.write_bits(255, 5) 50 | self.stream.align_write() 51 | self.assertEqual(bytes(self.stream), b"\xf8") 52 | 53 | class _BitStream(WriteStream, ReadStream): 54 | def __init__(self): 55 | super().__init__() 56 | self._unlocked = False 57 | self._read_offset = 0 58 | 59 | class BitStreamTest(unittest.TestCase): 60 | def setUp(self): 61 | self.stream = _BitStream() 62 | shift = random.randrange(0, 8) 63 | if shift > 0: 64 | self.stream.write_bits(0xff, shift) 65 | self.stream.read_bits(shift) 66 | 67 | class GeneralTest(BitStreamTest): 68 | def test_compressed(self): 69 | value = 42 70 | self.stream.write_compressed(c_uint(value)) 71 | self.assertEqual(self.stream.read_compressed(c_uint), value) 72 | value = 1 << 16 73 | self.stream.write_compressed(c_uint(value)) 74 | self.assertEqual(self.stream.read_compressed(c_uint), value) 75 | 76 | def test_read_bytes_too_much(self): 77 | with self.assertRaises(EOFError): 78 | self.stream.read(bytes, length=2) 79 | 80 | def test_read_bytes_too_much_shifted(self): 81 | self.stream.write_bits(0xff, 1) 82 | self.stream.read_bits(1) 83 | with self.assertRaises(EOFError): 84 | self.stream.read(bytes, length=1) 85 | 86 | def test_unaligned_bits(self): 87 | self.stream.align_write() 88 | self.stream.align_read() 89 | self.stream.write_bits(0xff, 7) 90 | self.stream.write_bits(0xff, 7) 91 | self.stream.read_bits(7) 92 | self.assertEqual(self.stream.read_bits(4), 0x0f) 93 | 94 | class StringTest: 95 | STRING = None 96 | 97 | @classmethod 98 | def setUpClass(cls): 99 | if isinstance(cls.STRING, str): 100 | cls.CHAR_SIZE = 2 101 | else: 102 | cls.CHAR_SIZE = 1 103 | 104 | def test_write_allocated_long(self): 105 | with self.assertRaises(ValueError): 106 | self.stream.write(self.STRING, allocated_length=len(self.STRING)-2) 107 | 108 | def test_allocated(self): 109 | self.stream.write(self.STRING, allocated_length=len(self.STRING) + 10) 110 | value = self.stream.read(type(self.STRING), allocated_length=len(self.STRING)+10) 111 | self.assertEqual(value, self.STRING) 112 | 113 | def test_read_allocated_buffergarbage(self): 114 | self.stream.write(self.STRING, allocated_length=len(self.STRING)+1) 115 | self.stream.write(b"\xdf"*10*self.CHAR_SIZE) 116 | value = self.stream.read(type(self.STRING), allocated_length=len(self.STRING)+1+10) 117 | self.assertEqual(value, self.STRING) 118 | 119 | def test_read_allocated_no_terminator(self): 120 | self.stream.write(b"\xff"*33*self.CHAR_SIZE) 121 | with self.assertRaises(RuntimeError): 122 | self.stream.read(type(self.STRING), allocated_length=33) 123 | 124 | def test_variable_length(self): 125 | self.stream.write(self.STRING, length_type=c_uint) 126 | value = self.stream.read(type(self.STRING), length_type=c_uint) 127 | self.assertEqual(value, self.STRING) 128 | 129 | class UnicodeStringTest(StringTest, BitStreamTest): 130 | STRING = "Hello world" 131 | 132 | class ByteStringTest(StringTest, BitStreamTest): 133 | STRING = UnicodeStringTest.STRING.encode("latin1") 134 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | from setuptools import setup 4 | 5 | setup( 6 | name="bitstream", 7 | version="0.1.0", 8 | description="Sequential bit-level reading and writing.", 9 | author="lcdr", 10 | url="https://github.com/lcdr/bitstream/", 11 | license="GPL v3", 12 | packages=["bitstream"], 13 | python_requires=">=3.6", 14 | ) 15 | --------------------------------------------------------------------------------