Skip to content

Commit

Permalink
fix: deserialize json
Browse files Browse the repository at this point in the history
  • Loading branch information
sauljabin committed Jan 17, 2025
1 parent 1997799 commit ba9695a
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 31 deletions.
84 changes: 58 additions & 26 deletions kaskade/deserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

from confluent_kafka.schema_registry import SchemaRegistryClient
from confluent_kafka.schema_registry.avro import AvroDeserializer as ConfluentAvroDeserializer
from confluent_kafka.schema_registry.json_schema import (
JSONDeserializer as ConfluentJsonDeserializer,
)
from confluent_kafka.schema_registry.protobuf import (
ProtobufDeserializer as ConfluentProtobufDeserializer,
)
Expand All @@ -18,7 +21,6 @@
from google.protobuf.message import Message
from google.protobuf.message_factory import GetMessages

from kaskade import logger
from kaskade.configs import SCHEMA_REGISTRY_MAGIC_BYTE
from kaskade.utils import unpack_bytes, file_to_bytes

Expand Down Expand Up @@ -112,18 +114,23 @@ class JsonDeserializer(Deserializer):
def deserialize(
self, data: bytes, topic: str | None = None, context: MessageField = MessageField.NONE
) -> Any:
try:
return json.loads(data)
except Exception:
# in case that the json has a confluent schema registry magic byte
# https://docs.confluent.io/platform/current/schema-registry/fundamentals/serdes-develop/index.html#wire-format
return json.loads(data[5:])
if len(data) > 5:
magic, schema_id = unpack(">bI", data[:5])
if magic == SCHEMA_REGISTRY_MAGIC_BYTE:
# in case that the json has a confluent schema registry magic byte
# https://docs.confluent.io/platform/current/schema-registry/fundamentals/serdes-develop/index.html#wire-format
return json.loads(data[5:])

return json.loads(data)


class RegistryDeserializer(Deserializer):
def __init__(self, registry_config: dict[str, str]):
registry_client = SchemaRegistryClient(registry_config)
self.confluent_deserializer = ConfluentAvroDeserializer(registry_client)
self.registry_client = SchemaRegistryClient(registry_config)
self.avro_deserializer = ConfluentAvroDeserializer(self.registry_client)
self.json_deserializer = ConfluentJsonDeserializer(
None, schema_registry_client=self.registry_client
)

def deserialize(
self, data: bytes, topic: str | None = None, context: MessageField = MessageField.NONE
Expand All @@ -134,7 +141,26 @@ def deserialize(
if context == MessageField.NONE:
raise Exception("Context is needed: KEY or VALUE")

return self.confluent_deserializer(data, SerializationContext(topic, context))
if len(data) <= 5:
raise Exception(
f"Expecting data framing of length 6 bytes or more but total data size is {len(data)} bytes. This message was not produced with a Confluent Schema Registry serializer"
)

magic, schema_id = unpack(">bI", data[:5])
if magic != SCHEMA_REGISTRY_MAGIC_BYTE:
raise Exception(
f"Unexpected magic byte {magic}. This message was not produced with a Confluent Schema Registry serializer"
)

schema = self.registry_client.get_schema(schema_id)

match schema.schema_type:
case "JSON":
return self.json_deserializer(data, SerializationContext(topic, context))
case "AVRO":
return self.avro_deserializer(data, SerializationContext(topic, context))
case _:
raise Exception("Schema type not supported")


class AvroDeserializer(Deserializer):
Expand Down Expand Up @@ -164,9 +190,13 @@ def deserialize(
if schema_path is None:
raise Exception("Avro schema file not found")

magic, schema_id = unpack(">bI", data[:5])
if magic == SCHEMA_REGISTRY_MAGIC_BYTE:
return schemaless_reader(BytesIO(data[5:]), load_schema(schema_path), None)
if len(data) > 5:
magic, schema_id = unpack(">bI", data[:5])
if magic == SCHEMA_REGISTRY_MAGIC_BYTE:
# in case that the avro has a confluent schema registry magic byte
# https://docs.confluent.io/platform/current/schema-registry/fundamentals/serdes-develop/index.html#wire-format
return schemaless_reader(BytesIO(data[5:]), load_schema(schema_path), None)

return schemaless_reader(BytesIO(data), load_schema(schema_path), None)


Expand Down Expand Up @@ -208,19 +238,21 @@ def deserialize(
if deserialization_class is None:
raise Exception("Deserialization class not found")

try:
new_message = deserialization_class()
new_message.ParseFromString(data)
return MessageToDict(new_message, always_print_fields_with_no_presence=True)
except Exception as e:
logger.warning("Error deserializing protobuf: %s", e)
# in case that the protobuf has a confluent schema registry magic byte
# https://docs.confluent.io/platform/current/schema-registry/fundamentals/serdes-develop/index.html#wire-format
protobuf_deserializer = ConfluentProtobufDeserializer(
deserialization_class, {"use.deprecated.format": False}
)
new_message = protobuf_deserializer(data, SerializationContext(topic, context))
return MessageToDict(new_message, always_print_fields_with_no_presence=True)
if len(data) > 5:
magic, schema_id = unpack(">bI", data[:5])
if magic == SCHEMA_REGISTRY_MAGIC_BYTE:
# in case that the protobuf has a confluent schema registry magic byte
# https://docs.confluent.io/platform/current/schema-registry/fundamentals/serdes-develop/index.html#wire-format
deserializer_config = {"use.deprecated.format": False}
protobuf_deserializer = ConfluentProtobufDeserializer(
deserialization_class, deserializer_config
)
new_message = protobuf_deserializer(data, SerializationContext(topic, context))
return MessageToDict(new_message, always_print_fields_with_no_presence=True)

new_message = deserialization_class()
new_message.ParseFromString(data)
return MessageToDict(new_message, always_print_fields_with_no_presence=True)


class DeserializerPool:
Expand Down
9 changes: 5 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 18 additions & 1 deletion tests/tests_deserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,13 @@ def test_json_deserialization_with_magic_byte(self):
self.assertEqual(expected_value, result)

@patch("kaskade.deserializers.SchemaRegistryClient")
def test_registry_deserialization_with_magic_byte(self, mock_sr_client_class):
def test_registry_deserialization_avro(self, mock_sr_client_class):
expected_value = {"name": "Pedro Pascal"}

mock_sr_client_class.return_value.get_schema.return_value.schema_str = file_to_str(
AVRO_PATH
)
mock_sr_client_class.return_value.get_schema.return_value.schema_type = "AVRO"

encoded = py_to_avro(expected_value)

Expand All @@ -145,6 +146,22 @@ def test_registry_deserialization_with_magic_byte(self, mock_sr_client_class):

self.assertEqual(expected_value, result)

@patch("kaskade.deserializers.SchemaRegistryClient")
def test_registry_deserialization_json(self, mock_sr_client_class):
expected_value = {"name": "Pedro Pascal"}
expected_json = json.dumps(expected_value)

mock_sr_client_class.return_value.get_schema.return_value.schema_str = expected_json
mock_sr_client_class.return_value.get_schema.return_value.schema_type = "JSON"

deserializer = RegistryDeserializer({})

result = deserializer.deserialize(
b"\x00\x00\x00\x00\x00" + expected_json.encode(), "", MessageField.VALUE
)

self.assertEqual(expected_value, result)

def test_protobuf_deserialization(self):
deserializer = ProtobufDeserializer({"descriptor": DESCRIPTOR_PATH, "value": "User"})

Expand Down

0 comments on commit ba9695a

Please sign in to comment.