|
""" |
|
MIT License |
|
|
|
Copyright (c) 2024-present Simon Sawicki <contact@grub4k.xyz> |
|
|
|
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated |
|
documentation files (the "Software"), to deal in the Software without restriction, including without limitation the |
|
rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, |
|
and to permit persons to whom the Software is furnished to do so, subject to the following conditions: |
|
|
|
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the |
|
Software. |
|
|
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE |
|
WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR |
|
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, |
|
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. |
|
|
|
https://github.com/Grub4K/qpb |
|
""" |
|
|
|
import ast |
|
import base64 |
|
import contextlib |
|
import enum |
|
import io |
|
import struct |
|
from collections import defaultdict |
|
|
|
|
|
def decode_protobuf(value): |
|
data = base64.b64decode(value) |
|
decoded = _decode(data) |
|
return decoded |
|
|
|
|
|
def encode_protobuf(value): |
|
try: |
|
data = ast.literal_eval(value.strip()) |
|
|
|
result = _encode(data) |
|
encoded = base64.b64encode(result).decode() |
|
return encoded |
|
except SyntaxError: |
|
raise SyntaxError(f"invalid input: {value}") |
|
|
|
|
|
class WireType(enum.IntEnum): |
|
VARINT = 0 |
|
I64 = 1 |
|
LEN = 2 |
|
SGROUP = 3 |
|
EGROUP = 4 |
|
I32 = 5 |
|
|
|
|
|
_float_struct = struct.Struct(b"<f") |
|
_double_struct = struct.Struct(b"<d") |
|
|
|
|
|
def _encode(data) -> bytes: |
|
if not isinstance(data, dict): |
|
message = "type to encode has to be a dict" |
|
raise TypeError(message) |
|
|
|
return b"".join(_encode_record(value, wire_id) for wire_id, value in data.items()) |
|
|
|
|
|
def _decode(data): |
|
reader = data if isinstance(data, io.BufferedIOBase) else io.BytesIO(data) |
|
result = defaultdict(list) |
|
|
|
record = _read_record(reader) |
|
while record: |
|
key, value = record |
|
result[key].append(value) |
|
record = _read_record(reader) |
|
|
|
for key, values in result.items(): |
|
for index, value in enumerate(values): |
|
if not isinstance(value, bytes): |
|
continue |
|
with contextlib.suppress(Exception): |
|
values[index] = _decode(value) |
|
if len(values) == 1: |
|
result[key] = values[0] |
|
|
|
return dict(result) |
|
|
|
|
|
def _read_record(reader: io.BufferedIOBase): |
|
tag = _read_tag(reader) |
|
if tag is None: |
|
return None |
|
wire_id, wire_type = tag |
|
if wire_type == WireType.VARINT: |
|
value = _read_varint(reader) |
|
elif wire_type == WireType.I64: |
|
value = reader.read(8) |
|
elif wire_type == WireType.I32: |
|
value = reader.read(4) |
|
elif wire_type == WireType.LEN: |
|
value = reader.read(_read_varint(reader)) |
|
else: |
|
message = "Unknown wire type" |
|
raise TypeError(message) |
|
|
|
return wire_id, value |
|
|
|
|
|
def _encode_record(data, wire_id) -> bytes: |
|
if isinstance(data, int): |
|
if data < 0: |
|
data = _signed_to_zigzag(data) |
|
return _encode_tag(wire_id, WireType.VARINT) + _encode_varint(data) |
|
|
|
if isinstance(data, list): |
|
encoded = b"".join(map(_encode_record, data)) |
|
elif isinstance(data, dict): |
|
encoded = _encode(data) |
|
elif isinstance(data, str): |
|
encoded = data.encode() |
|
elif isinstance(data, bytes): |
|
encoded = data |
|
else: |
|
message = f"Unencodable type: {type(data)}" |
|
raise TypeError(message) |
|
|
|
return _encode_tag(wire_id, WireType.LEN) + _encode_varint(len(encoded)) + encoded |
|
|
|
|
|
def _read_varint(reader: io.BufferedIOBase): |
|
shift = 0 |
|
data = 0 |
|
|
|
byte = 0b1000_0000 |
|
while byte & 0b1000_0000: |
|
result = reader.read(1) |
|
if not result: |
|
return None |
|
(byte,) = result |
|
data |= (byte & 0b0111_1111) << shift |
|
shift += 7 |
|
|
|
return data |
|
|
|
|
|
def _encode_varint(value: int) -> bytes: |
|
data_length = (value.bit_length() + 6) // 7 or 1 |
|
data = bytearray(data_length) |
|
for index in range(data_length - 1): |
|
data[index] = value & 0b0111_1111 | 0b1000_0000 |
|
value >>= 7 |
|
|
|
data[-1] = value |
|
return bytes(data) |
|
|
|
|
|
def _read_tag(reader: io.BufferedIOBase): |
|
value = _read_varint(reader) |
|
if value is None: |
|
return None |
|
return value >> 3, WireType(value & 0b111) |
|
|
|
|
|
def _encode_tag(wire_id, wire_type: WireType) -> bytes: |
|
if wire_id is None: |
|
return b"" |
|
return _encode_varint((wire_id << 3) | wire_type) |
|
|
|
|
|
def _signed_to_zigzag(value: int): |
|
result = value << 1 |
|
if value < 0: |
|
result = -result - 1 |
|
return result |
|
|