31 """Code for decoding protocol buffer primitives.
33 This code is very similar to encoder.py -- read the docs for that module first.
35 A "decoder" is a function with the signature:
36 Decode(buffer, pos, end, message, field_dict)
38 buffer: The string containing the encoded message.
39 pos: The current position in the string.
40 end: The position in the string where the current message ends. May be
41 less than len(buffer) if we're reading a sub-message.
42 message: The message object into which we're parsing.
43 field_dict: message._fields (avoids a hashtable lookup).
44 The decoder reads the field and stores it into field_dict, returning the new
45 buffer position. A decoder for a repeated field may proactively decode all of
46 the elements of that field, if they appear consecutively.
48 Note that decoders may throw any of the following:
49 IndexError: Indicates a truncated message.
50 struct.error: Unpacking of a fixed-width field failed.
51 message.DecodeError: Other errors.
53 Decoders are expected to raise an exception if they are called with pos > end.
54 This allows callers to be lax about bounds checking: it's fineto read past
55 "end" as long as you are sure that someone else will notice and throw an
58 Something up the call stack is expected to catch IndexError and struct.error
59 and convert them to message.DecodeError.
61 Decoders are constructed using decoder constructors with the signature:
62 MakeDecoder(field_number, is_repeated, is_packed, key, new_default)
64 field_number: The field number of the field we want to decode.
65 is_repeated: Is the field a repeated field? (bool)
66 is_packed: Is the field a packed field? (bool)
67 key: The key to use when looking up the field within field_dict.
68 (This is actually the FieldDescriptor but nothing in this
69 file should depend on that.)
70 new_default: A function which takes a message object as a parameter and
71 returns a new instance of the default value for this field.
72 (This is called for repeated fields and sub-messages, when an
73 instance does not already exist.)
75 As with encoders, we define a decoder constructor for every type of field.
76 Then, for every field of every message class we construct an actual decoder.
77 That decoder goes into a dict indexed by tag, so when we decode a message
78 we repeatedly read a tag, look up the corresponding decoder, and invoke it.
81 __author__ =
'kenton@google.com (Kenton Varda)'
94 _DecodeError = message.DecodeError
98 """Return an encoder for a basic varint value (does not include tag).
100 Decoded values will be bitwise-anded with the given mask before being
101 returned, e.g. to limit them to 32 bits. The returned decoder does not
102 take the usual "end" parameter -- the caller is expected to do bounds checking
103 after the fact (often the caller can defer such checking until later). The
104 decoder returns a (value, new_pos) pair.
107 def DecodeVarint(buffer, pos):
112 result |= ((b & 0x7f) << shift)
120 raise _DecodeError(
'Too many bytes when decoding varint.')
125 """Like _VarintDecoder() but decodes signed values."""
127 signbit = 1 << (bits - 1)
128 mask = (1 << bits) - 1
130 def DecodeVarint(buffer, pos):
135 result |= ((b & 0x7f) << shift)
139 result = (result ^ signbit) - signbit
144 raise _DecodeError(
'Too many bytes when decoding varint.')
157 """Read a tag from the memoryview, and return a (tag_bytes, new_pos) tuple.
159 We return the raw bytes of the tag rather than decoding them. The raw
160 bytes can then be used to look up the proper decoder. This effectively allows
161 us to trade some work that would be done in pure-python (decoding a varint)
162 for work that is done in C (searching for a byte string in a hash table).
163 In a low-level language it would be much cheaper to decode the varint and
164 use that, but not in Python.
167 buffer: memoryview object of the encoded bytes
168 pos: int of the current position to start from
171 Tuple[bytes, int] of the tag data and new position.
174 while buffer[pos] & 0x80:
178 tag_bytes = buffer[start:pos].tobytes()
179 return tag_bytes, pos
186 """Return a constructor for a decoder for fields of a particular type.
189 wire_type: The field's wire type.
190 decode_value: A function which decodes an individual value, e.g.
194 def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default,
195 clear_if_default=False):
197 local_DecodeVarint = _DecodeVarint
198 def DecodePackedField(buffer, pos, end, message, field_dict):
199 value = field_dict.get(key)
201 value = field_dict.setdefault(key, new_default(message))
202 (endpoint, pos) = local_DecodeVarint(buffer, pos)
206 while pos < endpoint:
207 (element, pos) = decode_value(buffer, pos)
208 value.append(element)
213 return DecodePackedField
215 tag_bytes = encoder.TagBytes(field_number, wire_type)
216 tag_len =
len(tag_bytes)
217 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
218 value = field_dict.get(key)
220 value = field_dict.setdefault(key, new_default(message))
222 (element, new_pos) = decode_value(buffer, pos)
223 value.append(element)
226 pos = new_pos + tag_len
227 if buffer[new_pos:pos] != tag_bytes
or new_pos >= end:
232 return DecodeRepeatedField
234 def DecodeField(buffer, pos, end, message, field_dict):
235 (new_value, pos) = decode_value(buffer, pos)
238 if clear_if_default
and not new_value:
239 field_dict.pop(key,
None)
241 field_dict[key] = new_value
245 return SpecificDecoder
249 """Like SimpleDecoder but additionally invokes modify_value on every value
250 before storing it. Usually modify_value is ZigZagDecode.
256 def InnerDecode(buffer, pos):
257 (result, new_pos) = decode_value(buffer, pos)
258 return (modify_value(result), new_pos)
263 """Return a constructor for a decoder for a fixed-width field.
266 wire_type: The field's wire type.
267 format: The format string to pass to struct.unpack().
270 value_size = struct.calcsize(format)
271 local_unpack = struct.unpack
280 def InnerDecode(buffer, pos):
281 new_pos = pos + value_size
282 result = local_unpack(format, buffer[pos:new_pos])[0]
283 return (result, new_pos)
288 """Returns a decoder for a float field.
290 This code works around a bug in struct.unpack for non-finite 32-bit
291 floating-point values.
294 local_unpack = struct.unpack
296 def InnerDecode(buffer, pos):
297 """Decode serialized float to a float and new position.
300 buffer: memoryview of the serialized bytes
301 pos: int, position in the memory view to start at.
304 Tuple[float, int] of the deserialized float value and new position
305 in the serialized data.
310 float_bytes = buffer[pos:new_pos].tobytes()
315 if (float_bytes[3:4]
in b
'\x7F\xFF' and float_bytes[2:3] >= b
'\x80'):
317 if float_bytes[0:3] != b
'\x00\x00\x80':
318 return (math.nan, new_pos)
320 if float_bytes[3:4] == b
'\xFF':
321 return (-math.inf, new_pos)
322 return (math.inf, new_pos)
327 result = local_unpack(
'<f', float_bytes)[0]
328 return (result, new_pos)
333 """Returns a decoder for a double field.
335 This code works around a bug in struct.unpack for not-a-number.
338 local_unpack = struct.unpack
340 def InnerDecode(buffer, pos):
341 """Decode serialized double to a double and new position.
344 buffer: memoryview of the serialized bytes.
345 pos: int, position in the memory view to start at.
348 Tuple[float, int] of the decoded double value and new position
349 in the serialized data.
354 double_bytes = buffer[pos:new_pos].tobytes()
359 if ((double_bytes[7:8]
in b
'\x7F\xFF')
360 and (double_bytes[6:7] >= b
'\xF0')
361 and (double_bytes[0:7] != b
'\x00\x00\x00\x00\x00\x00\xF0')):
362 return (math.nan, new_pos)
367 result = local_unpack(
'<d', double_bytes)[0]
368 return (result, new_pos)
372 def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
373 clear_if_default=False):
374 """Returns a decoder for enum field."""
375 enum_type = key.enum_type
377 local_DecodeVarint = _DecodeVarint
378 def DecodePackedField(buffer, pos, end, message, field_dict):
379 """Decode serialized packed enum to its value and a new position.
382 buffer: memoryview of the serialized bytes.
383 pos: int, position in the memory view to start at.
384 end: int, end position of serialized data
385 message: Message object to store unknown fields in
386 field_dict: Map[Descriptor, Any] to store decoded values in.
389 int, new position in serialized data.
391 value = field_dict.get(key)
393 value = field_dict.setdefault(key, new_default(message))
394 (endpoint, pos) = local_DecodeVarint(buffer, pos)
398 while pos < endpoint:
399 value_start_pos = pos
402 if element
in enum_type.values_by_number:
403 value.append(element)
405 if not message._unknown_fields:
406 message._unknown_fields = []
407 tag_bytes = encoder.TagBytes(field_number,
408 wire_format.WIRETYPE_VARINT)
410 message._unknown_fields.append(
411 (tag_bytes, buffer[value_start_pos:pos].tobytes()))
412 if message._unknown_field_set
is None:
414 message._unknown_field_set._add(
415 field_number, wire_format.WIRETYPE_VARINT, element)
418 if element
in enum_type.values_by_number:
421 del message._unknown_fields[-1]
423 del message._unknown_field_set._values[-1]
427 return DecodePackedField
429 tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT)
430 tag_len =
len(tag_bytes)
431 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
432 """Decode serialized repeated enum to its value and a new position.
435 buffer: memoryview of the serialized bytes.
436 pos: int, position in the memory view to start at.
437 end: int, end position of serialized data
438 message: Message object to store unknown fields in
439 field_dict: Map[Descriptor, Any] to store decoded values in.
442 int, new position in serialized data.
444 value = field_dict.get(key)
446 value = field_dict.setdefault(key, new_default(message))
450 if element
in enum_type.values_by_number:
451 value.append(element)
453 if not message._unknown_fields:
454 message._unknown_fields = []
455 message._unknown_fields.append(
456 (tag_bytes, buffer[pos:new_pos].tobytes()))
457 if message._unknown_field_set
is None:
459 message._unknown_field_set._add(
460 field_number, wire_format.WIRETYPE_VARINT, element)
464 pos = new_pos + tag_len
465 if buffer[new_pos:pos] != tag_bytes
or new_pos >= end:
470 return DecodeRepeatedField
472 def DecodeField(buffer, pos, end, message, field_dict):
473 """Decode serialized repeated enum to its value and a new position.
476 buffer: memoryview of the serialized bytes.
477 pos: int, position in the memory view to start at.
478 end: int, end position of serialized data
479 message: Message object to store unknown fields in
480 field_dict: Map[Descriptor, Any] to store decoded values in.
483 int, new position in serialized data.
485 value_start_pos = pos
489 if clear_if_default
and not enum_value:
490 field_dict.pop(key,
None)
493 if enum_value
in enum_type.values_by_number:
494 field_dict[key] = enum_value
496 if not message._unknown_fields:
497 message._unknown_fields = []
498 tag_bytes = encoder.TagBytes(field_number,
499 wire_format.WIRETYPE_VARINT)
500 message._unknown_fields.append(
501 (tag_bytes, buffer[value_start_pos:pos].tobytes()))
502 if message._unknown_field_set
is None:
504 message._unknown_field_set._add(
505 field_number, wire_format.WIRETYPE_VARINT, enum_value)
515 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32)
518 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint)
520 UInt32Decoder =
_SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32)
521 UInt64Decoder =
_SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint)
524 wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode)
526 wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode)
540 wire_format.WIRETYPE_VARINT, _DecodeVarint, bool)
543 def StringDecoder(field_number, is_repeated, is_packed, key, new_default,
544 clear_if_default=False):
545 """Returns a decoder for a string field."""
547 local_DecodeVarint = _DecodeVarint
549 def _ConvertToUnicode(memview):
550 """Convert byte to unicode."""
551 byte_str = memview.tobytes()
553 value =
str(byte_str,
'utf-8')
554 except UnicodeDecodeError
as e:
556 e.reason =
'%s in field: %s' % (e, key.full_name)
563 tag_bytes = encoder.TagBytes(field_number,
564 wire_format.WIRETYPE_LENGTH_DELIMITED)
565 tag_len =
len(tag_bytes)
566 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
567 value = field_dict.get(key)
569 value = field_dict.setdefault(key, new_default(message))
571 (size, pos) = local_DecodeVarint(buffer, pos)
575 value.append(_ConvertToUnicode(buffer[pos:new_pos]))
577 pos = new_pos + tag_len
578 if buffer[new_pos:pos] != tag_bytes
or new_pos == end:
581 return DecodeRepeatedField
583 def DecodeField(buffer, pos, end, message, field_dict):
584 (size, pos) = local_DecodeVarint(buffer, pos)
588 if clear_if_default
and not size:
589 field_dict.pop(key,
None)
591 field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos])
596 def BytesDecoder(field_number, is_repeated, is_packed, key, new_default,
597 clear_if_default=False):
598 """Returns a decoder for a bytes field."""
600 local_DecodeVarint = _DecodeVarint
604 tag_bytes = encoder.TagBytes(field_number,
605 wire_format.WIRETYPE_LENGTH_DELIMITED)
606 tag_len =
len(tag_bytes)
607 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
608 value = field_dict.get(key)
610 value = field_dict.setdefault(key, new_default(message))
612 (size, pos) = local_DecodeVarint(buffer, pos)
616 value.append(buffer[pos:new_pos].tobytes())
618 pos = new_pos + tag_len
619 if buffer[new_pos:pos] != tag_bytes
or new_pos == end:
622 return DecodeRepeatedField
624 def DecodeField(buffer, pos, end, message, field_dict):
625 (size, pos) = local_DecodeVarint(buffer, pos)
629 if clear_if_default
and not size:
630 field_dict.pop(key,
None)
632 field_dict[key] = buffer[pos:new_pos].tobytes()
637 def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
638 """Returns a decoder for a group field."""
640 end_tag_bytes = encoder.TagBytes(field_number,
641 wire_format.WIRETYPE_END_GROUP)
642 end_tag_len =
len(end_tag_bytes)
646 tag_bytes = encoder.TagBytes(field_number,
647 wire_format.WIRETYPE_START_GROUP)
648 tag_len =
len(tag_bytes)
649 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
650 value = field_dict.get(key)
652 value = field_dict.setdefault(key, new_default(message))
654 value = field_dict.get(key)
656 value = field_dict.setdefault(key, new_default(message))
660 new_pos = pos+end_tag_len
661 if buffer[pos:new_pos] != end_tag_bytes
or new_pos > end:
664 pos = new_pos + tag_len
665 if buffer[new_pos:pos] != tag_bytes
or new_pos == end:
668 return DecodeRepeatedField
670 def DecodeField(buffer, pos, end, message, field_dict):
671 value = field_dict.get(key)
673 value = field_dict.setdefault(key, new_default(message))
675 pos = value._InternalParse(buffer, pos, end)
677 new_pos = pos+end_tag_len
678 if buffer[pos:new_pos] != end_tag_bytes
or new_pos > end:
684 def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
685 """Returns a decoder for a message field."""
687 local_DecodeVarint = _DecodeVarint
691 tag_bytes = encoder.TagBytes(field_number,
692 wire_format.WIRETYPE_LENGTH_DELIMITED)
693 tag_len =
len(tag_bytes)
694 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
695 value = field_dict.get(key)
697 value = field_dict.setdefault(key, new_default(message))
700 (size, pos) = local_DecodeVarint(buffer, pos)
710 pos = new_pos + tag_len
711 if buffer[new_pos:pos] != tag_bytes
or new_pos == end:
714 return DecodeRepeatedField
716 def DecodeField(buffer, pos, end, message, field_dict):
717 value = field_dict.get(key)
719 value = field_dict.setdefault(key, new_default(message))
721 (size, pos) = local_DecodeVarint(buffer, pos)
726 if value._InternalParse(buffer, pos, new_pos) != new_pos:
736 MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP)
739 """Returns a decoder for a MessageSet item.
741 The parameter is the message Descriptor.
743 The message set message looks like this:
745 repeated group Item = 1 {
746 required int32 type_id = 2;
747 required string message = 3;
752 type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT)
753 message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)
754 item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP)
756 local_ReadTag = ReadTag
757 local_DecodeVarint = _DecodeVarint
758 local_SkipField = SkipField
760 def DecodeItem(buffer, pos, end, message, field_dict):
761 """Decode serialized message set to its value and new position.
764 buffer: memoryview of the serialized bytes.
765 pos: int, position in the memory view to start at.
766 end: int, end position of serialized data
767 message: Message object to store unknown fields in
768 field_dict: Map[Descriptor, Any] to store decoded values in.
771 int, new position in serialized data.
773 message_set_item_start = pos
781 (tag_bytes, pos) = local_ReadTag(buffer, pos)
782 if tag_bytes == type_id_tag_bytes:
783 (type_id, pos) = local_DecodeVarint(buffer, pos)
784 elif tag_bytes == message_tag_bytes:
785 (size, message_start) = local_DecodeVarint(buffer, pos)
786 pos = message_end = message_start + size
787 elif tag_bytes == item_end_tag_bytes:
790 pos =
SkipField(buffer, pos, end, tag_bytes)
799 if message_start == -1:
802 extension = message.Extensions._FindExtensionByNumber(type_id)
804 if extension
is not None:
805 value = field_dict.get(extension)
807 message_type = extension.message_type
808 if not hasattr(message_type,
'_concrete_class'):
810 message._FACTORY.GetPrototype(message_type)
811 value = field_dict.setdefault(
812 extension, message_type._concrete_class())
813 if value._InternalParse(buffer, message_start,message_end) != message_end:
818 if not message._unknown_fields:
819 message._unknown_fields = []
820 message._unknown_fields.append(
821 (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos].tobytes()))
822 if message._unknown_field_set
is None:
824 message._unknown_field_set._add(
826 wire_format.WIRETYPE_LENGTH_DELIMITED,
827 buffer[message_start:message_end].tobytes())
836 def MapDecoder(field_descriptor, new_default, is_message_map):
837 """Returns a decoder for a map field."""
839 key = field_descriptor
840 tag_bytes = encoder.TagBytes(field_descriptor.number,
841 wire_format.WIRETYPE_LENGTH_DELIMITED)
842 tag_len =
len(tag_bytes)
843 local_DecodeVarint = _DecodeVarint
845 message_type = field_descriptor.message_type
847 def DecodeMap(buffer, pos, end, message, field_dict):
848 submsg = message_type._concrete_class()
849 value = field_dict.get(key)
851 value = field_dict.setdefault(key, new_default(message))
854 (size, pos) = local_DecodeVarint(buffer, pos)
860 if submsg._InternalParse(buffer, pos, new_pos) != new_pos:
866 value[submsg.key].
CopyFrom(submsg.value)
868 value[submsg.key] = submsg.value
871 pos = new_pos + tag_len
872 if buffer[new_pos:pos] != tag_bytes
or new_pos == end:
883 """Skip a varint value. Returns the new position."""
887 while ord(buffer[pos:pos+1].tobytes()) & 0x80:
895 """Skip a fixed64 value. Returns the new position."""
904 """Decode a fixed64."""
906 return (struct.unpack(
'<Q', buffer[pos:new_pos])[0], new_pos)
910 """Skip a length-delimited value. Returns the new position."""
920 """Skip sub-group. Returns the new position."""
923 (tag_bytes, pos) =
ReadTag(buffer, pos)
924 new_pos =
SkipField(buffer, pos, end, tag_bytes)
931 """Decode UnknownFieldSet. Returns the UnknownFieldSet and new position."""
934 while end_pos
is None or pos < end_pos:
935 (tag_bytes, pos) =
ReadTag(buffer, pos)
937 field_number, wire_type = wire_format.UnpackTag(tag)
938 if wire_type == wire_format.WIRETYPE_END_GROUP:
942 unknown_field_set._add(field_number, wire_type, data)
944 return (unknown_field_set, pos)
948 """Decode a unknown field. Returns the UnknownField and new position."""
950 if wire_type == wire_format.WIRETYPE_VARINT:
952 elif wire_type == wire_format.WIRETYPE_FIXED64:
954 elif wire_type == wire_format.WIRETYPE_FIXED32:
956 elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED:
958 data = buffer[pos:pos+size].tobytes()
960 elif wire_type == wire_format.WIRETYPE_START_GROUP:
962 elif wire_type == wire_format.WIRETYPE_END_GROUP:
971 """Skipping an END_GROUP tag returns -1 to tell the parent loop to break."""
977 """Skip a fixed32 value. Returns the new position."""
986 """Decode a fixed32."""
989 return (struct.unpack(
'<I', buffer[pos:new_pos])[0], new_pos)
993 """Skip function for unknown wire types. Raises an exception."""
998 """Constructs the SkipField function."""
1000 WIRETYPE_TO_SKIPPER = [
1003 _SkipLengthDelimited,
1007 _RaiseInvalidWireType,
1008 _RaiseInvalidWireType,
1011 wiretype_mask = wire_format.TAG_TYPE_MASK
1013 def SkipField(buffer, pos, end, tag_bytes):
1014 """Skips a field with the specified tag.
1016 |pos| should point to the byte immediately after the tag.
1019 The new position (after the tag value), or -1 if the tag is an end-group
1020 tag (in which case the calling loop should break).
1024 wire_type =
ord(tag_bytes[0:1]) & wiretype_mask
1025 return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)