diff --git a/CHANGES.rst b/CHANGES.rst index 7299f660..945772be 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -5,6 +5,14 @@ Changelog 0.14.0 (????-??-??) =================== +New features: + +* Simplify flexible versions in schema. + Defining an API request or response schemas that should support + flexible versions (KIP-482) is now achieved by setting `FLEXIBLE_VERSION` to True. + Tagged fields could be expressed with ("name", tag) instead of just a name. + (pr #1139 by @vmaurin) + Bugfixes: * Fix type annotation for `AIOKafkaAdminClient` (issue #1148) diff --git a/aiokafka/conn.py b/aiokafka/conn.py index fdf5d878..da0a9b73 100644 --- a/aiokafka/conn.py +++ b/aiokafka/conn.py @@ -428,10 +428,11 @@ def send(self, request, expect_response=True): ) from err log.debug( - "Request to %s:%d %d: %s", + "Request to %s:%d %d: %s, %s", self._host, self._port, correlation_id, + header, request_struct, ) @@ -565,10 +566,11 @@ def _handle_frame(self, resp): if not fut.done(): response = resp_type.decode(resp) log.debug( - "Response from %s:%d %d: %s", + "Response from %s:%d %d: %s, %s", self._host, self._port, correlation_id, + response_header, response, ) fut.set_result(response) diff --git a/aiokafka/protocol/abstract.py b/aiokafka/protocol/abstract.py index c466357e..4f963710 100644 --- a/aiokafka/protocol/abstract.py +++ b/aiokafka/protocol/abstract.py @@ -8,11 +8,11 @@ class AbstractType(Generic[T], metaclass=abc.ABCMeta): @classmethod @abc.abstractmethod - def encode(cls, value: T) -> bytes: ... + def encode(cls, value: T, flexible: bool) -> bytes: ... @classmethod @abc.abstractmethod - def decode(cls, data: BytesIO) -> T: ... + def decode(cls, data: BytesIO, flexible: bool) -> T: ... @classmethod def repr(cls, value: T) -> str: diff --git a/aiokafka/protocol/admin.py b/aiokafka/protocol/admin.py index 8a2d6e89..e4801294 100644 --- a/aiokafka/protocol/admin.py +++ b/aiokafka/protocol/admin.py @@ -8,8 +8,6 @@ Array, Boolean, Bytes, - CompactArray, - CompactString, Float64, Int8, Int16, @@ -17,7 +15,6 @@ Int64, Schema, String, - TaggedFields, ) @@ -1453,53 +1450,48 @@ def build( class AlterPartitionReassignmentsResponse_v0(Response): API_KEY = 45 API_VERSION = 0 + FLEXIBLE_VERSION = True SCHEMA = Schema( ("throttle_time_ms", Int32), ("error_code", Int16), - ("error_message", CompactString("utf-8")), + ("error_message", String("utf-8")), ( "responses", - CompactArray( - ("name", CompactString("utf-8")), + Array( + ("name", String("utf-8")), ( "partitions", - CompactArray( + Array( ("partition_index", Int32), ("error_code", Int16), - ("error_message", CompactString("utf-8")), - ("tags", TaggedFields), + ("error_message", String("utf-8")), ), ), - ("tags", TaggedFields), ), ), - ("tags", TaggedFields), ) class AlterPartitionReassignmentsRequest_v0(RequestStruct): - FLEXIBLE_VERSION = True API_KEY = 45 API_VERSION = 0 + FLEXIBLE_VERSION = True RESPONSE_TYPE = AlterPartitionReassignmentsResponse_v0 SCHEMA = Schema( ("timeout_ms", Int32), ( "topics", - CompactArray( - ("name", CompactString("utf-8")), + Array( + ("name", String("utf-8")), ( "partitions", - CompactArray( + Array( ("partition_index", Int32), - ("replicas", CompactArray(Int32)), - ("tags", TaggedFields), + ("replicas", Array(Int32)), ), ), - ("tags", TaggedFields), ), ), - ("tags", TaggedFields), ) @@ -1516,44 +1508,40 @@ class AlterPartitionReassignmentsRequest( def __init__( self, timeout_ms: int, - topics: list[tuple[str, tuple[int, list[int], TaggedFields], TaggedFields]], - tags: TaggedFields, + topics: list[tuple[str, tuple[int, list[int]]]], ): self._timeout_ms = timeout_ms self._topics = topics - self._tags = tags def build( self, request_struct_class: type[AlterPartitionReassignmentsRequestStruct] ) -> AlterPartitionReassignmentsRequestStruct: - return request_struct_class(self._timeout_ms, self._topics, self._tags) + return request_struct_class(self._timeout_ms, self._topics) class ListPartitionReassignmentsResponse_v0(Response): API_KEY = 46 API_VERSION = 0 + FLEXIBLE_VERSION = True SCHEMA = Schema( ("throttle_time_ms", Int32), ("error_code", Int16), - ("error_message", CompactString("utf-8")), + ("error_message", String("utf-8")), ( "topics", - CompactArray( - ("name", CompactString("utf-8")), + Array( + ("name", String("utf-8")), ( "partitions", - CompactArray( + Array( ("partition_index", Int32), - ("replicas", CompactArray(Int32)), - ("adding_replicas", CompactArray(Int32)), - ("removing_replicas", CompactArray(Int32)), - ("tags", TaggedFields), + ("replicas", Array(Int32)), + ("adding_replicas", Array(Int32)), + ("removing_replicas", Array(Int32)), ), ), - ("tags", TaggedFields), ), ), - ("tags", TaggedFields), ) @@ -1566,13 +1554,11 @@ class ListPartitionReassignmentsRequest_v0(RequestStruct): ("timeout_ms", Int32), ( "topics", - CompactArray( - ("name", CompactString("utf-8")), - ("partition_index", CompactArray(Int32)), - ("tags", TaggedFields), + Array( + ("name", String("utf-8")), + ("partition_index", Array(Int32)), ), ), - ("tags", TaggedFields), ) @@ -1589,17 +1575,15 @@ class ListPartitionReassignmentsRequest( def __init__( self, timeout_ms: int, - topics: list[tuple[str, tuple[int, list[int], TaggedFields], TaggedFields]], - tags: TaggedFields, + topics: list[tuple[str, tuple[int, list[int]]]], ): self._timeout_ms = timeout_ms self._topics = topics - self._tags = tags def build( self, request_struct_class: type[ListPartitionReassignmentsRequestStruct] ) -> ListPartitionReassignmentsRequestStruct: - return request_struct_class(self._timeout_ms, self._topics, self._tags) + return request_struct_class(self._timeout_ms, self._topics) class DeleteRecordsResponse_v0(Response): @@ -1633,26 +1617,8 @@ class DeleteRecordsResponse_v1(Response): class DeleteRecordsResponse_v2(Response): API_KEY = 21 API_VERSION = 2 - SCHEMA = Schema( - ("throttle_time_ms", Int32), - ( - "topics", - CompactArray( - ("name", CompactString("utf-8")), - ( - "partitions", - CompactArray( - ("partition_index", Int32), - ("low_watermark", Int64), - ("error_code", Int16), - ("tags", TaggedFields), - ), - ), - ("tags", TaggedFields), - ), - ), - ("tags", TaggedFields), - ) + FLEXIBLE_VERSION = True + SCHEMA = DeleteRecordsResponse_v0.SCHEMA class DeleteRecordsRequest_v0(RequestStruct): @@ -1689,25 +1655,7 @@ class DeleteRecordsRequest_v2(RequestStruct): API_VERSION = 2 FLEXIBLE_VERSION = True RESPONSE_TYPE = DeleteRecordsResponse_v2 - SCHEMA = Schema( - ( - "topics", - CompactArray( - ("name", CompactString("utf-8")), - ( - "partitions", - CompactArray( - ("partition_index", Int32), - ("offset", Int64), - ("tags", TaggedFields), - ), - ), - ("tags", TaggedFields), - ), - ), - ("timeout_ms", Int32), - ("tags", TaggedFields), - ) + SCHEMA = DeleteRecordsRequest_v0.SCHEMA DeleteRecordsRequestStruct: TypeAlias = ( @@ -1722,43 +1670,20 @@ def __init__( self, topics: Iterable[tuple[str, Iterable[tuple[int, int]]]], timeout_ms: int, - tags: dict[int, bytes] | None = None, ) -> None: self._topics = topics self._timeout_ms = timeout_ms - self._tags = tags def build( self, request_struct_class: type[DeleteRecordsRequestStruct] ) -> DeleteRecordsRequestStruct: - if request_struct_class.API_VERSION < 2: - if self._tags is not None: - raise IncompatibleBrokerVersion( - "tags requires DeleteRecordsRequest >= v2" - ) - - return request_struct_class( - [ - ( - topic, - list(partitions), - ) - for (topic, partitions) in self._topics - ], - self._timeout_ms, - ) return request_struct_class( [ ( topic, - [ - (partition, before_offset, {}) - for partition, before_offset in partitions - ], - {}, + list(partitions), ) for (topic, partitions) in self._topics ], self._timeout_ms, - self._tags or {}, ) diff --git a/aiokafka/protocol/api.py b/aiokafka/protocol/api.py index b7ae72f9..97cfcefd 100644 --- a/aiokafka/protocol/api.py +++ b/aiokafka/protocol/api.py @@ -8,7 +8,7 @@ from aiokafka.errors import IncompatibleBrokerVersion from .struct import Struct -from .types import Array, Int16, Int32, Schema, String, TaggedFields +from .types import Array, Int16, Int32, Schema, String class RequestHeader_v1(Struct): @@ -36,19 +36,18 @@ class RequestHeader_v2(Struct): ("api_key", Int16), ("api_version", Int16), ("correlation_id", Int32), - ("client_id", String("utf-8")), - ("tags", TaggedFields), + ("client_id", String("utf-8", allow_flexible=False)), ) + FLEXIBLE_VERSION = True def __init__( self, request: RequestStruct, correlation_id: int = 0, client_id: str = "aiokafka", - tags: dict[int, bytes] | None = None, ): super().__init__( - request.API_KEY, request.API_VERSION, correlation_id, client_id, tags or {} + request.API_KEY, request.API_VERSION, correlation_id, client_id ) @@ -61,8 +60,8 @@ class ResponseHeader_v0(Struct): class ResponseHeader_v1(Struct): SCHEMA = Schema( ("correlation_id", Int32), - ("tags", TaggedFields), ) + FLEXIBLE_VERSION = True T = TypeVar("T", bound="RequestStruct") @@ -150,7 +149,7 @@ class RequestStruct(Struct, metaclass=abc.ABCMeta): Attributes ---------- FLEXIBLE_VERSION : bool - Use request header with flexible tags + Support flexible versions/compact format API_KEY : int The unique API key identifying the request. API_VERSION : int @@ -161,11 +160,9 @@ class RequestStruct(Struct, metaclass=abc.ABCMeta): An instance of Schema() representing the request structure. """ - FLEXIBLE_VERSION: ClassVar[bool] = False API_KEY: ClassVar[int] API_VERSION: ClassVar[int] RESPONSE_TYPE: ClassVar[type[Response]] - SCHEMA: ClassVar[Schema] def __init_subclass__(cls) -> None: super().__init_subclass__() @@ -203,15 +200,23 @@ def parse_response_header( class Response(Struct, metaclass=abc.ABCMeta): - @property - @abc.abstractmethod - def API_KEY(self) -> int: - """Integer identifier for api request/response""" + """ + Base structure for API responses. - @property - @abc.abstractmethod - def API_VERSION(self) -> int: - """Integer of api request/response version""" + Attributes + ---------- + FLEXIBLE_VERSION : bool + Support flexible versions/compact format + API_KEY : int + The unique API key identifying the response. + API_VERSION : int + Which API version the Response class is. + SCHEMA : Schema + An instance of Schema() representing the response structure. + """ + + API_KEY: ClassVar[int] + API_VERSION: ClassVar[int] def to_object(self) -> dict[str, Any]: return _to_object(self.SCHEMA, self) diff --git a/aiokafka/protocol/message.py b/aiokafka/protocol/message.py index 67f9d4ed..f4d7803a 100644 --- a/aiokafka/protocol/message.py +++ b/aiokafka/protocol/message.py @@ -132,11 +132,13 @@ def encode(self, recalc_crc: bool = True) -> bytes: self.timestamp, self.key, self.value, - ) + ), + flexible=False, ) elif version == 0: message = Message.SCHEMAS[version].encode( - (self.crc, self.magic, self.attributes, self.key, self.value) + (self.crc, self.magic, self.attributes, self.key, self.value), + flexible=False, ) else: raise ValueError(f"Unrecognized message version: {version}") @@ -144,7 +146,7 @@ def encode(self, recalc_crc: bool = True) -> bytes: return message self.crc = crc32(message[4:]) crc_field = self.BASE_FIELDS[0][1] - return crc_field.encode(self.crc) + message[4:] + return crc_field.encode(self.crc, flexible=False) + message[4:] @classmethod def decode(cls, data: io.BytesIO | bytes) -> Self: @@ -154,16 +156,16 @@ def decode(cls, data: io.BytesIO | bytes) -> Self: data = io.BytesIO(data) # Partial decode required to determine message version crc, magic, attributes = ( - cls.BASE_FIELDS[0][1].decode(data), - cls.BASE_FIELDS[1][1].decode(data), - cls.BASE_FIELDS[2][1].decode(data), + cls.BASE_FIELDS[0][1].decode(data, flexible=False), + cls.BASE_FIELDS[1][1].decode(data, flexible=False), + cls.BASE_FIELDS[2][1].decode(data, flexible=False), ) if magic == 1: magic = cast(Literal[1], magic) timestamp, key, value = ( - cls.MAGIC1_FIELDS[0][1].decode(data), - cls.MAGIC1_FIELDS[1][1].decode(data), - cls.MAGIC1_FIELDS[2][1].decode(data), + cls.MAGIC1_FIELDS[0][1].decode(data, flexible=False), + cls.MAGIC1_FIELDS[1][1].decode(data, flexible=False), + cls.MAGIC1_FIELDS[2][1].decode(data, flexible=False), ) msg = cls( value=value, @@ -176,8 +178,8 @@ def decode(cls, data: io.BytesIO | bytes) -> Self: elif magic == 0: magic = cast(Literal[0], magic) key, value = ( - cls.MAGIC0_FIELDS[0][1].decode(data), - cls.MAGIC0_FIELDS[1][1].decode(data), + cls.MAGIC0_FIELDS[0][1].decode(data, flexible=False), + cls.MAGIC0_FIELDS[1][1].decode(data, flexible=False), ) msg = cls( value=value, @@ -247,7 +249,7 @@ def encode( ) -> bytes: # RecordAccumulator encodes messagesets internally if isinstance(items, io.BytesIO): - size = Int32.decode(items) + size = Int32.decode(items, flexible=False) if prepend_size: # rewind and return all the bytes items.seek(items.tell() - 4) @@ -256,11 +258,11 @@ def encode( encoded_values: list[bytes] = [] for offset, message in items: - encoded_values.append(Int64.encode(offset)) - encoded_values.append(Bytes.encode(message)) + encoded_values.append(Int64.encode(offset, flexible=False)) + encoded_values.append(Bytes.encode(message, flexible=False)) encoded = b"".join(encoded_values) if prepend_size: - return Bytes.encode(encoded) + return Bytes.encode(encoded, flexible=False) else: return encoded @@ -274,7 +276,7 @@ def decode( if isinstance(data, bytes): data = io.BytesIO(data) if bytes_to_read is None: - bytes_to_read = Int32.decode(data) + bytes_to_read = Int32.decode(data, flexible=False) # if FetchRequest max_bytes is smaller than the available message set # the server returns partial data for the final message @@ -284,8 +286,8 @@ def decode( items: list[tuple[int, int, Message] | tuple[None, None, PartialMessage]] = [] try: while bytes_to_read: - offset = Int64.decode(raw) - msg_bytes = Bytes.decode(raw) + offset = Int64.decode(raw, flexible=False) + msg_bytes = Bytes.decode(raw, flexible=False) assert msg_bytes is not None bytes_to_read -= 8 + 4 + len(msg_bytes) items.append( diff --git a/aiokafka/protocol/struct.py b/aiokafka/protocol/struct.py index 38649d09..5cb6d09b 100644 --- a/aiokafka/protocol/struct.py +++ b/aiokafka/protocol/struct.py @@ -8,6 +8,7 @@ class Struct: SCHEMA: ClassVar = Schema() + FLEXIBLE_VERSION: ClassVar[bool] = False def __init__(self, *args: Any, **kwargs: Any) -> None: if len(args) == len(self.SCHEMA.fields): @@ -26,13 +27,15 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: ) def encode(self) -> bytes: - return self.SCHEMA.encode([self.__dict__[name] for name in self.SCHEMA.names]) + return self.SCHEMA.encode( + [self.__dict__[name] for name in self.SCHEMA.names], self.FLEXIBLE_VERSION + ) @classmethod def decode(cls, data: BytesIO | bytes) -> Self: if isinstance(data, bytes): data = BytesIO(data) - return cls(*[field.decode(data) for field in cls.SCHEMA.fields]) + return cls(*cls.SCHEMA.decode(data, cls.FLEXIBLE_VERSION)) def get_item(self, name: str) -> Any: if name not in self.SCHEMA.names: diff --git a/aiokafka/protocol/types.py b/aiokafka/protocol/types.py index 5a315dd7..48a1a779 100644 --- a/aiokafka/protocol/types.py +++ b/aiokafka/protocol/types.py @@ -19,6 +19,10 @@ ValueT: TypeAlias = Union[type[AbstractType[Any]], "String", "Array", "Schema"] +TaggedFieldId: TypeAlias = tuple[str, int] + +FieldId: TypeAlias = TaggedFieldId | str + def _pack(f: Callable[[T], bytes], value: T) -> bytes: try: @@ -47,11 +51,11 @@ class Int8(AbstractType[int]): _unpack = struct.Struct(">b").unpack @classmethod - def encode(cls, value: int) -> bytes: + def encode(cls, value: int, flexible: bool = False) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data: BytesIO) -> int: + def decode(cls, data: BytesIO, flexible: bool = False) -> int: return _unpack(cls._unpack, data.read(1)) @@ -60,11 +64,11 @@ class Int16(AbstractType[int]): _unpack = struct.Struct(">h").unpack @classmethod - def encode(cls, value: int) -> bytes: + def encode(cls, value: int, flexible: bool = False) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data: BytesIO) -> int: + def decode(cls, data: BytesIO, flexible: bool = False) -> int: return _unpack(cls._unpack, data.read(2)) @@ -73,11 +77,11 @@ class Int32(AbstractType[int]): _unpack = struct.Struct(">i").unpack @classmethod - def encode(cls, value: int) -> bytes: + def encode(cls, value: int, flexible: bool = False) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data: BytesIO) -> int: + def decode(cls, data: BytesIO, flexible: bool = False) -> int: return _unpack(cls._unpack, data.read(4)) @@ -86,11 +90,11 @@ class UInt32(AbstractType[int]): _unpack = struct.Struct(">I").unpack @classmethod - def encode(cls, value: int) -> bytes: + def encode(cls, value: int, flexible: bool = False) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data: BytesIO) -> int: + def decode(cls, data: BytesIO, flexible: bool = False) -> int: return _unpack(cls._unpack, data.read(4)) @@ -99,11 +103,11 @@ class Int64(AbstractType[int]): _unpack = struct.Struct(">q").unpack @classmethod - def encode(cls, value: int) -> bytes: + def encode(cls, value: int, flexible: bool = False) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data: BytesIO) -> int: + def decode(cls, data: BytesIO, flexible: bool = False) -> int: return _unpack(cls._unpack, data.read(8)) @@ -112,26 +116,39 @@ class Float64(AbstractType[float]): _unpack = struct.Struct(">d").unpack @classmethod - def encode(cls, value: float) -> bytes: + def encode(cls, value: float, flexible: bool = False) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data: BytesIO) -> float: + def decode(cls, data: BytesIO, flexible: bool = False) -> float: return _unpack(cls._unpack, data.read(8)) class String: - def __init__(self, encoding: str = "utf-8"): + def __init__(self, encoding: str = "utf-8", allow_flexible: bool = True): self.encoding = encoding + self.allow_flexible = allow_flexible - def encode(self, value: str | None) -> bytes: + def encode(self, value: str | None, flexible: bool) -> bytes: if value is None: - return Int16.encode(-1) + return ( + UnsignedVarInt32.encode(0) + if flexible and self.allow_flexible + else Int16.encode(-1, flexible) + ) encoded_value = str(value).encode(self.encoding) - return Int16.encode(len(encoded_value)) + encoded_value + return ( + UnsignedVarInt32.encode(len(encoded_value) + 1) + encoded_value + if flexible and self.allow_flexible + else Int16.encode(len(encoded_value), flexible) + encoded_value + ) - def decode(self, data: BytesIO) -> str | None: - length = Int16.decode(data) + def decode(self, data: BytesIO, flexible: bool) -> str | None: + length = ( + UnsignedVarInt32.decode(data) - 1 + if flexible and self.allow_flexible + else Int16.decode(data, flexible) + ) if length < 0: return None value = data.read(length) @@ -146,15 +163,25 @@ def repr(cls, value: str) -> str: class Bytes(AbstractType[bytes | None]): @classmethod - def encode(cls, value: bytes | None) -> bytes: + def encode(cls, value: bytes | None, flexible: bool) -> bytes: if value is None: - return Int32.encode(-1) + return ( + UnsignedVarInt32.encode(0) if flexible else Int32.encode(-1, flexible) + ) else: - return Int32.encode(len(value)) + value + return ( + UnsignedVarInt32.encode(len(value) + 1) + value + if flexible + else Int32.encode(len(value), flexible) + value + ) @classmethod - def decode(cls, data: BytesIO) -> bytes | None: - length = Int32.decode(data) + def decode(cls, data: BytesIO, flexible: bool) -> bytes | None: + length = ( + UnsignedVarInt32.decode(data) - 1 + if flexible + else Int32.decode(data, flexible) + ) if length < 0: return None value = data.read(length) @@ -174,33 +201,94 @@ class Boolean(AbstractType[bool]): _unpack = struct.Struct(">?").unpack @classmethod - def encode(cls, value: bool) -> bytes: + def encode(cls, value: bool, flexible: bool) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data: BytesIO) -> bool: + def decode(cls, data: BytesIO, flexible: bool) -> bool: return _unpack(cls._unpack, data.read(1)) class Schema: names: tuple[str, ...] + tags: tuple[int, ...] fields: tuple[ValueT, ...] - def __init__(self, *fields: tuple[str, ValueT]): + def __init__(self, *fields: tuple[FieldId, ValueT]): if fields: - self.names, self.fields = zip(*fields, strict=False) + tagged_names, values = zip( + *( + (key, value) if isinstance(key, tuple) else ((key, None), value) + for key, value in fields + ), + strict=False, + ) + self.names = tuple(name for name, _ in tagged_names) + self.tags = tuple(tag for _, tag in tagged_names) + self.fields = tuple(values) else: - self.names, self.fields = (), () + self.names, self.tags, self.fields = (), (), () - def encode(self, item: Sequence[Any]) -> bytes: + def encode(self, item: Sequence[Any], flexible: bool) -> bytes: if len(item) != len(self.fields): raise ValueError("Item field count does not match Schema") - return b"".join(field.encode(item[i]) for i, field in enumerate(self.fields)) + return b"".join( + field.encode(item[i], flexible) + for i, field in enumerate(self.fields) + if self.tags[i] is None + ) + ( + self._encode_tagged_fields( + { + self.tags[i]: field.encode(item[i], flexible) + for i, field in enumerate(self.fields) + if self.tags[i] is not None + } + ) + if flexible + else b"" + ) def decode( - self, data: BytesIO + self, data: BytesIO, flexible: bool ) -> tuple[Any | str | None | list[Any | tuple[Any, ...]], ...]: - return tuple(field.decode(data) for field in self.fields) + result = [ + field.decode(data, flexible) if self.tags[i] is None else None + for i, field in enumerate(self.fields) + ] + if flexible: + tagged_fields = self._decode_tagged_fields(data) + for i, tag in enumerate(self.tags): + if tag is not None: + encoded_value = tagged_fields.get(tag) + if encoded_value is not None: + result[i] = self.fields[i].decode( + BytesIO(encoded_value), flexible + ) + + return tuple(result) + + @staticmethod + def _encode_tagged_fields(value: dict[int, bytes]) -> bytes: + ret = UnsignedVarInt32.encode(len(value)) + for k, v in value.items(): + assert isinstance(k, int) and k >= 0, f"Key {k} is not a positive integer" + ret += UnsignedVarInt32.encode(k) + ret += UnsignedVarInt32.encode(len(v)) + ret += v + return ret + + @staticmethod + def _decode_tagged_fields(data: BytesIO) -> dict[int, bytes]: + num_fields = UnsignedVarInt32.decode(data) + ret: dict[int, bytes] = {} + if not num_fields: + return ret + for _ in range(num_fields): + tag = UnsignedVarInt32.decode(data) + size = UnsignedVarInt32.decode(data) + val = data.read(size) + ret[tag] = val + return ret def __len__(self) -> int: return len(self.fields) @@ -227,16 +315,18 @@ def __init__(self, array_of_0: ValueT): ... @overload def __init__( - self, array_of_0: tuple[str, ValueT], *array_of: tuple[str, ValueT] + self, + array_of_0: tuple[FieldId, ValueT], + *array_of: tuple[FieldId, ValueT], ): ... def __init__( self, - array_of_0: ValueT | tuple[str, ValueT], - *array_of: tuple[str, ValueT], + array_of_0: ValueT | tuple[FieldId, ValueT], + *array_of: tuple[FieldId, ValueT], ) -> None: if array_of: - array_of_0 = cast(tuple[str, ValueT], array_of_0) + array_of_0 = cast(tuple[FieldId, ValueT], array_of_0) self.array_of = Schema(array_of_0, *array_of) else: array_of_0 = cast(ValueT, array_of_0) @@ -247,19 +337,33 @@ def __init__( else: raise ValueError("Array instantiated with no array_of type") - def encode(self, items: Sequence[Any] | None) -> bytes: + def encode(self, items: Sequence[Any] | None, flexible: bool) -> bytes: if items is None: - return Int32.encode(-1) - encoded_items = (self.array_of.encode(item) for item in items) - return b"".join( - (Int32.encode(len(items)), *encoded_items), + return ( + UnsignedVarInt32.encode(0) if flexible else Int32.encode(-1, flexible) + ) + encoded_items = (self.array_of.encode(item, flexible) for item in items) + return ( + b"".join( + (UnsignedVarInt32.encode(len(items) + 1), *encoded_items), + ) + if flexible + else b"".join( + (Int32.encode(len(items), flexible), *encoded_items), + ) ) - def decode(self, data: BytesIO) -> list[Any | tuple[Any, ...]] | None: - length = Int32.decode(data) + def decode( + self, data: BytesIO, flexible: bool + ) -> list[Any | tuple[Any, ...]] | None: + length = ( + UnsignedVarInt32.decode(data) - 1 + if flexible + else Int32.decode(data, flexible) + ) if length == -1: return None - return [self.array_of.decode(data) for _ in range(length)] + return [self.array_of.decode(data, flexible) for _ in range(length)] def repr(self, list_of_items: Sequence[Any] | None) -> str: if list_of_items is None: @@ -267,7 +371,7 @@ def repr(self, list_of_items: Sequence[Any] | None) -> str: return "[" + ", ".join(self.array_of.repr(item) for item in list_of_items) + "]" -class UnsignedVarInt32(AbstractType[int]): +class UnsignedVarInt32: @classmethod def decode(cls, data: BytesIO) -> int: value, i = 0, 0 @@ -293,128 +397,3 @@ def encode(cls, value: int) -> bytes: value >>= 7 ret += struct.pack("B", value) return ret - - -class VarInt32(AbstractType[int]): - @classmethod - def decode(cls, data: BytesIO) -> int: - value = UnsignedVarInt32.decode(data) - return (value >> 1) ^ -(value & 1) - - @classmethod - def encode(cls, value: int) -> bytes: - # bring it in line with the java binary repr - value &= 0xFFFFFFFF - return UnsignedVarInt32.encode((value << 1) ^ (value >> 31)) - - -class VarInt64(AbstractType[int]): - @classmethod - def decode(cls, data: BytesIO) -> int: - value, i = 0, 0 - b: int - while True: - (b,) = struct.unpack("B", data.read(1)) - if not (b & 0x80): - break - value |= (b & 0x7F) << i - i += 7 - if i > 63: - raise ValueError(f"Invalid value {value}") - value |= b << i - return (value >> 1) ^ -(value & 1) - - @classmethod - def encode(cls, value: int) -> bytes: - # bring it in line with the java binary repr - value &= 0xFFFFFFFFFFFFFFFF - v = (value << 1) ^ (value >> 63) - ret = b"" - while (v & 0xFFFFFFFFFFFFFF80) != 0: - b = (value & 0x7F) | 0x80 - ret += struct.pack("B", b) - v >>= 7 - ret += struct.pack("B", v) - return ret - - -class CompactString(String): - def decode(self, data: BytesIO) -> str | None: - length = UnsignedVarInt32.decode(data) - 1 - if length < 0: - return None - value = data.read(length) - if len(value) != length: - raise ValueError("Buffer underrun decoding string") - return value.decode(self.encoding) - - def encode(self, value: str | None) -> bytes: - if value is None: - return UnsignedVarInt32.encode(0) - encoded_value = str(value).encode(self.encoding) - return UnsignedVarInt32.encode(len(encoded_value) + 1) + encoded_value - - -class TaggedFields(AbstractType[dict[int, bytes]]): - @classmethod - def decode(cls, data: BytesIO) -> dict[int, bytes]: - num_fields = UnsignedVarInt32.decode(data) - ret: dict[int, bytes] = {} - if not num_fields: - return ret - prev_tag = -1 - for _ in range(num_fields): - tag = UnsignedVarInt32.decode(data) - if tag <= prev_tag: - raise ValueError(f"Invalid or out-of-order tag {tag}") - prev_tag = tag - size = UnsignedVarInt32.decode(data) - val = data.read(size) - ret[tag] = val - return ret - - @classmethod - def encode(cls, value: dict[int, bytes]) -> bytes: - ret = UnsignedVarInt32.encode(len(value)) - for k, v in value.items(): - # do we allow for other data types ?? It could get complicated really fast - assert isinstance(v, bytes), f"Value {v!r} is not a byte array" - assert isinstance(k, int) and k > 0, f"Key {k} is not a positive integer" - ret += UnsignedVarInt32.encode(k) - ret += v - return ret - - -class CompactBytes(AbstractType[bytes | None]): - @classmethod - def decode(cls, data: BytesIO) -> bytes | None: - length = UnsignedVarInt32.decode(data) - 1 - if length < 0: - return None - value = data.read(length) - if len(value) != length: - raise ValueError("Buffer underrun decoding Bytes") - return value - - @classmethod - def encode(cls, value: bytes | None) -> bytes: - if value is None: - return UnsignedVarInt32.encode(0) - else: - return UnsignedVarInt32.encode(len(value) + 1) + value - - -class CompactArray(Array): - def encode(self, items: Sequence[Any] | None) -> bytes: - if items is None: - return UnsignedVarInt32.encode(0) - encoded_items = (self.array_of.encode(item) for item in items) - return b"".join( - (UnsignedVarInt32.encode(len(items) + 1), *encoded_items), - ) - - def decode(self, data: BytesIO) -> list[Any | tuple[Any, ...]] | None: - length = UnsignedVarInt32.decode(data) - 1 - if length == -1: - return None - return [self.array_of.decode(data) for _ in range(length)] diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 8b2d5d03..fb48abbb 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -9,9 +9,8 @@ from aiokafka.protocol.message import Message, MessageSet, PartialMessage from aiokafka.protocol.metadata import MetadataRequest_v0 from aiokafka.protocol.types import ( - CompactArray, - CompactBytes, - CompactString, + Array, + Bytes, Int16, Int32, Int64, @@ -239,7 +238,7 @@ def test_decode_fetch_response_partial() -> None: encoded = b"".join( [ Int32.encode(1), # Num Topics (Array) - String("utf-8").encode("foobar"), + String("utf-8").encode("foobar", flexible=False), Int32.encode(2), # Num Partitions (Array) Int32.encode(0), # Partition id Int16.encode(0), # Error Code @@ -328,33 +327,38 @@ def test_unsigned_varint_serde() -> None: def test_compact_data_structs() -> None: - cs = CompactString() - encoded = cs.encode(None) + cs = String() + encoded = cs.encode(None, flexible=True) assert encoded == struct.pack("B", 0) - decoded = cs.decode(io.BytesIO(encoded)) + decoded = cs.decode(io.BytesIO(encoded), flexible=True) assert decoded is None - assert cs.encode("") == b"\x01" - assert cs.decode(io.BytesIO(b"\x01")) == "" - encoded = cs.encode("foobarbaz") - assert cs.decode(io.BytesIO(encoded)) == "foobarbaz" - - arr = CompactArray(CompactString()) - assert arr.encode(None) == b"\x00" - assert arr.decode(io.BytesIO(b"\x00")) is None - enc = arr.encode([]) + assert cs.encode("", flexible=True) == b"\x01" + assert cs.decode(io.BytesIO(b"\x01"), flexible=True) == "" + encoded = cs.encode("foobarbaz", flexible=True) + assert cs.decode(io.BytesIO(encoded), flexible=True) == "foobarbaz" + + arr = Array(String()) + assert arr.encode(None, flexible=True) == b"\x00" + assert arr.decode(io.BytesIO(b"\x00"), flexible=True) is None + enc = arr.encode([], flexible=True) assert enc == b"\x01" - assert arr.decode(io.BytesIO(enc)) == [] - encoded = arr.encode(["foo", "bar", "baz", "quux"]) - assert arr.decode(io.BytesIO(encoded)) == ["foo", "bar", "baz", "quux"] + assert arr.decode(io.BytesIO(enc), flexible=True) == [] + encoded = arr.encode(["foo", "bar", "baz", "quux"], flexible=True) + assert arr.decode(io.BytesIO(encoded), flexible=True) == [ + "foo", + "bar", + "baz", + "quux", + ] - enc = CompactBytes.encode(None) + enc = Bytes.encode(None, flexible=True) assert enc == b"\x00" - assert CompactBytes.decode(io.BytesIO(b"\x00")) is None - enc = CompactBytes.encode(b"") + assert Bytes.decode(io.BytesIO(b"\x00"), flexible=True) is None + enc = Bytes.encode(b"", flexible=True) assert enc == b"\x01" - assert CompactBytes.decode(io.BytesIO(b"\x01")) == b"" - enc = CompactBytes.encode(b"foo") - assert CompactBytes.decode(io.BytesIO(enc)) == b"foo" + assert Bytes.decode(io.BytesIO(b"\x01"), flexible=True) == b"" + enc = Bytes.encode(b"foo", flexible=True) + assert Bytes.decode(io.BytesIO(enc), flexible=True) == b"foo" attr_names = [ diff --git a/tests/test_protocol_object_conversion.py b/tests/test_protocol_object_conversion.py index cdfb9705..1d00e25f 100644 --- a/tests/test_protocol_object_conversion.py +++ b/tests/test_protocol_object_conversion.py @@ -10,7 +10,7 @@ def _make_test_class( - klass: type[RequestStruct | Response], schema: Schema + klass: type[RequestStruct | Response], schema: Schema, flexible: bool = False ) -> type[RequestStruct | Response]: if klass is RequestStruct: @@ -19,6 +19,7 @@ class RequestTestClass(RequestStruct): API_VERSION = 0 RESPONSE_TYPE = Response SCHEMA = schema + FLEXIBLE_VERSION = flexible return RequestTestClass else: @@ -27,6 +28,7 @@ class ResponseTestClass(Response): API_KEY = 0 API_VERSION = 0 SCHEMA = schema + FLEXIBLE_VERSION = flexible return ResponseTestClass @@ -188,6 +190,53 @@ def test_with_complex_nested_array( assert myarray[1]["subarray"][0]["innertest"] == "hello" assert myarray[1]["subarray"][0]["otherinnertest"] == "hello again" + def test_flexible_version(self, superclass: type[RequestStruct | Response]) -> None: + TestClass = _make_test_class( + superclass, + Schema( + ("name", String("utf-8")), + ("myarray", Array(Int16)), + (("tagged_field1", 0), String("utf-8")), + (("tagged_field2", 42), Int16), + ( + ("tagged_field3", 53), + Array( + ("name", String("utf-8")), + (("tag1", 0), Int16), + (("tag2", 1), Int16), + ), + ), + ), + flexible=True, + ) + + tc = TestClass( + name="foo", + myarray=[1, 2, 3], + tagged_field1="bar", + tagged_field2=23, + tagged_field3=[("hello", 1, 2), ("world", 3, 4)], + ) + encoded = tc.encode() + assert tc.to_object()["name"] == "foo" + assert tc.to_object()["myarray"] == [1, 2, 3] + assert tc.to_object()["tagged_field1"] == "bar" + assert tc.to_object()["tagged_field2"] == 23 + assert tc.to_object()["tagged_field3"] == [ + {"name": "hello", "tag1": 1, "tag2": 2}, + {"name": "world", "tag1": 3, "tag2": 4}, + ] + + tc = TestClass.decode(encoded) + assert tc.to_object()["name"] == "foo" + assert tc.to_object()["myarray"] == [1, 2, 3] + assert tc.to_object()["tagged_field1"] == "bar" + assert tc.to_object()["tagged_field2"] == 23 + assert tc.to_object()["tagged_field3"] == [ + {"name": "hello", "tag1": 1, "tag2": 2}, + {"name": "world", "tag1": 3, "tag2": 4}, + ] + def test_with_metadata_response() -> None: tc = MetadataResponse_v5( diff --git a/tests/test_requests.py b/tests/test_requests.py index 0b1e7178..64300b1f 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -357,22 +357,22 @@ def __init__(self, expected, min_version=None, max_version=None): ], ), ( - AlterPartitionReassignmentsRequest(timeout_ms=100, topics=[], tags={}), + AlterPartitionReassignmentsRequest(timeout_ms=100, topics=[]), [ Versions(expected=IncompatibleBrokerVersion), Versions( max_version=0, - expected=AlterPartitionReassignmentsRequest_v0(100, [], {}), + expected=AlterPartitionReassignmentsRequest_v0(100, []), ), ], ), ( - ListPartitionReassignmentsRequest(timeout_ms=200, topics=[], tags={}), + ListPartitionReassignmentsRequest(timeout_ms=200, topics=[]), [ Versions(expected=IncompatibleBrokerVersion), Versions( max_version=0, - expected=ListPartitionReassignmentsRequest_v0(200, [], {}), + expected=ListPartitionReassignmentsRequest_v0(200, []), ), ], ), @@ -388,16 +388,9 @@ def __init__(self, expected, min_version=None, max_version=None): max_version=1, expected=DeleteRecordsRequest_v1([("t1", [(0, 123)])], 50), ), - ], - ), - ( - DeleteRecordsRequest(topics=[("t1", [(0, 123)])], timeout_ms=50, tags={}), - [ - Versions(expected=IncompatibleBrokerVersion), - Versions(max_version=1, expected=IncompatibleBrokerVersion), Versions( max_version=2, - expected=DeleteRecordsRequest_v2([("t1", [(0, 123, {})], {})], 50, {}), + expected=DeleteRecordsRequest_v2([("t1", [(0, 123)])], 50), ), ], ),