31 """Code for decoding protocol buffer primitives.
33 This code is very similar to encoder.py -- read the docs for that module first.
35 A "decoder" is a function with the signature:
36 Decode(buffer, pos, end, message, field_dict)
38 buffer: The string containing the encoded message.
39 pos: The current position in the string.
40 end: The position in the string where the current message ends. May be
41 less than len(buffer) if we're reading a sub-message.
42 message: The message object into which we're parsing.
43 field_dict: message._fields (avoids a hashtable lookup).
44 The decoder reads the field and stores it into field_dict, returning the new
45 buffer position. A decoder for a repeated field may proactively decode all of
46 the elements of that field, if they appear consecutively.
48 Note that decoders may throw any of the following:
49 IndexError: Indicates a truncated message.
50 struct.error: Unpacking of a fixed-width field failed.
51 message.DecodeError: Other errors.
53 Decoders are expected to raise an exception if they are called with pos > end.
54 This allows callers to be lax about bounds checking: it's fineto read past
55 "end" as long as you are sure that someone else will notice and throw an
58 Something up the call stack is expected to catch IndexError and struct.error
59 and convert them to message.DecodeError.
61 Decoders are constructed using decoder constructors with the signature:
62 MakeDecoder(field_number, is_repeated, is_packed, key, new_default)
64 field_number: The field number of the field we want to decode.
65 is_repeated: Is the field a repeated field? (bool)
66 is_packed: Is the field a packed field? (bool)
67 key: The key to use when looking up the field within field_dict.
68 (This is actually the FieldDescriptor but nothing in this
69 file should depend on that.)
70 new_default: A function which takes a message object as a parameter and
71 returns a new instance of the default value for this field.
72 (This is called for repeated fields and sub-messages, when an
73 instance does not already exist.)
75 As with encoders, we define a decoder constructor for every type of field.
76 Then, for every field of every message class we construct an actual decoder.
77 That decoder goes into a dict indexed by tag, so when we decode a message
78 we repeatedly read a tag, look up the corresponding decoder, and invoke it.
81 __author__ =
'kenton@google.com (Kenton Varda)'
87 _UCS2_MAXUNICODE = 65535
92 _SURROGATE_PATTERN = re.compile(six.u(
r'[\ud800-\udfff]'))
109 _DecodeError = message.DecodeError
113 """Return an encoder for a basic varint value (does not include tag).
115 Decoded values will be bitwise-anded with the given mask before being
116 returned, e.g. to limit them to 32 bits. The returned decoder does not
117 take the usual "end" parameter -- the caller is expected to do bounds checking
118 after the fact (often the caller can defer such checking until later). The
119 decoder returns a (value, new_pos) pair.
122 def DecodeVarint(buffer, pos):
126 b = six.indexbytes(buffer, pos)
127 result |= ((b & 0x7f) << shift)
135 raise _DecodeError(
'Too many bytes when decoding varint.')
140 """Like _VarintDecoder() but decodes signed values."""
142 signbit = 1 << (bits - 1)
143 mask = (1 << bits) - 1
145 def DecodeVarint(buffer, pos):
149 b = six.indexbytes(buffer, pos)
150 result |= ((b & 0x7f) << shift)
154 result = (result ^ signbit) - signbit
159 raise _DecodeError(
'Too many bytes when decoding varint.')
175 """Read a tag from the memoryview, and return a (tag_bytes, new_pos) tuple.
177 We return the raw bytes of the tag rather than decoding them. The raw
178 bytes can then be used to look up the proper decoder. This effectively allows
179 us to trade some work that would be done in pure-python (decoding a varint)
180 for work that is done in C (searching for a byte string in a hash table).
181 In a low-level language it would be much cheaper to decode the varint and
182 use that, but not in Python.
185 buffer: memoryview object of the encoded bytes
186 pos: int of the current position to start from
189 Tuple[bytes, int] of the tag data and new position.
192 while six.indexbytes(buffer, pos) & 0x80:
196 tag_bytes = buffer[start:pos].tobytes()
197 return tag_bytes, pos
204 """Return a constructor for a decoder for fields of a particular type.
207 wire_type: The field's wire type.
208 decode_value: A function which decodes an individual value, e.g.
212 def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default):
214 local_DecodeVarint = _DecodeVarint
215 def DecodePackedField(buffer, pos, end, message, field_dict):
216 value = field_dict.get(key)
218 value = field_dict.setdefault(key, new_default(message))
219 (endpoint, pos) = local_DecodeVarint(buffer, pos)
223 while pos < endpoint:
224 (element, pos) = decode_value(buffer, pos)
225 value.append(element)
230 return DecodePackedField
232 tag_bytes = encoder.TagBytes(field_number, wire_type)
233 tag_len =
len(tag_bytes)
234 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
235 value = field_dict.get(key)
237 value = field_dict.setdefault(key, new_default(message))
239 (element, new_pos) = decode_value(buffer, pos)
240 value.append(element)
243 pos = new_pos + tag_len
244 if buffer[new_pos:pos] != tag_bytes
or new_pos >= end:
249 return DecodeRepeatedField
251 def DecodeField(buffer, pos, end, message, field_dict):
252 (field_dict[key], pos) = decode_value(buffer, pos)
259 return SpecificDecoder
263 """Like SimpleDecoder but additionally invokes modify_value on every value
264 before storing it. Usually modify_value is ZigZagDecode.
270 def InnerDecode(buffer, pos):
271 (result, new_pos) = decode_value(buffer, pos)
272 return (modify_value(result), new_pos)
277 """Return a constructor for a decoder for a fixed-width field.
280 wire_type: The field's wire type.
281 format: The format string to pass to struct.unpack().
284 value_size = struct.calcsize(format)
285 local_unpack = struct.unpack
294 def InnerDecode(buffer, pos):
295 new_pos = pos + value_size
296 result = local_unpack(format, buffer[pos:new_pos])[0]
297 return (result, new_pos)
302 """Returns a decoder for a float field.
304 This code works around a bug in struct.unpack for non-finite 32-bit
305 floating-point values.
308 local_unpack = struct.unpack
310 def InnerDecode(buffer, pos):
311 """Decode serialized float to a float and new position.
314 buffer: memoryview of the serialized bytes
315 pos: int, position in the memory view to start at.
318 Tuple[float, int] of the deserialized float value and new position
319 in the serialized data.
324 float_bytes = buffer[pos:new_pos].tobytes()
329 if (float_bytes[3:4]
in b
'\x7F\xFF' and float_bytes[2:3] >= b
'\x80'):
331 if float_bytes[0:3] != b
'\x00\x00\x80':
332 return (_NAN, new_pos)
334 if float_bytes[3:4] == b
'\xFF':
335 return (_NEG_INF, new_pos)
336 return (_POS_INF, new_pos)
341 result = local_unpack(
'<f', float_bytes)[0]
342 return (result, new_pos)
347 """Returns a decoder for a double field.
349 This code works around a bug in struct.unpack for not-a-number.
352 local_unpack = struct.unpack
354 def InnerDecode(buffer, pos):
355 """Decode serialized double to a double and new position.
358 buffer: memoryview of the serialized bytes.
359 pos: int, position in the memory view to start at.
362 Tuple[float, int] of the decoded double value and new position
363 in the serialized data.
368 double_bytes = buffer[pos:new_pos].tobytes()
373 if ((double_bytes[7:8]
in b
'\x7F\xFF')
374 and (double_bytes[6:7] >= b
'\xF0')
375 and (double_bytes[0:7] != b
'\x00\x00\x00\x00\x00\x00\xF0')):
376 return (_NAN, new_pos)
381 result = local_unpack(
'<d', double_bytes)[0]
382 return (result, new_pos)
386 def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
387 enum_type = key.enum_type
389 local_DecodeVarint = _DecodeVarint
390 def DecodePackedField(buffer, pos, end, message, field_dict):
391 """Decode serialized packed enum to its value and a new position.
394 buffer: memoryview of the serialized bytes.
395 pos: int, position in the memory view to start at.
396 end: int, end position of serialized data
397 message: Message object to store unknown fields in
398 field_dict: Map[Descriptor, Any] to store decoded values in.
401 int, new position in serialized data.
403 value = field_dict.get(key)
405 value = field_dict.setdefault(key, new_default(message))
406 (endpoint, pos) = local_DecodeVarint(buffer, pos)
410 while pos < endpoint:
411 value_start_pos = pos
414 if element
in enum_type.values_by_number:
415 value.append(element)
417 if not message._unknown_fields:
418 message._unknown_fields = []
419 tag_bytes = encoder.TagBytes(field_number,
420 wire_format.WIRETYPE_VARINT)
422 message._unknown_fields.append(
423 (tag_bytes, buffer[value_start_pos:pos].tobytes()))
426 if element
in enum_type.values_by_number:
429 del message._unknown_fields[-1]
432 return DecodePackedField
434 tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT)
435 tag_len =
len(tag_bytes)
436 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
437 """Decode serialized repeated enum to its value and a new position.
440 buffer: memoryview of the serialized bytes.
441 pos: int, position in the memory view to start at.
442 end: int, end position of serialized data
443 message: Message object to store unknown fields in
444 field_dict: Map[Descriptor, Any] to store decoded values in.
447 int, new position in serialized data.
449 value = field_dict.get(key)
451 value = field_dict.setdefault(key, new_default(message))
455 if element
in enum_type.values_by_number:
456 value.append(element)
458 if not message._unknown_fields:
459 message._unknown_fields = []
460 message._unknown_fields.append(
461 (tag_bytes, buffer[pos:new_pos].tobytes()))
465 pos = new_pos + tag_len
466 if buffer[new_pos:pos] != tag_bytes
or new_pos >= end:
471 return DecodeRepeatedField
473 def DecodeField(buffer, pos, end, message, field_dict):
474 """Decode serialized repeated enum to its value and a new position.
477 buffer: memoryview of the serialized bytes.
478 pos: int, position in the memory view to start at.
479 end: int, end position of serialized data
480 message: Message object to store unknown fields in
481 field_dict: Map[Descriptor, Any] to store decoded values in.
484 int, new position in serialized data.
486 value_start_pos = pos
491 if enum_value
in enum_type.values_by_number:
492 field_dict[key] = enum_value
494 if not message._unknown_fields:
495 message._unknown_fields = []
496 tag_bytes = encoder.TagBytes(field_number,
497 wire_format.WIRETYPE_VARINT)
498 message._unknown_fields.append(
499 (tag_bytes, buffer[value_start_pos:pos].tobytes()))
509 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32)
512 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint)
518 wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode)
520 wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode)
534 wire_format.WIRETYPE_VARINT, _DecodeVarint, bool)
538 is_strict_utf8=False):
539 """Returns a decoder for a string field."""
541 local_DecodeVarint = _DecodeVarint
542 local_unicode = six.text_type
544 def _ConvertToUnicode(memview):
545 """Convert byte to unicode."""
546 byte_str = memview.tobytes()
548 value = local_unicode(byte_str,
'utf-8')
549 except UnicodeDecodeError
as e:
551 e.reason =
'%s in field: %s' % (e, key.full_name)
554 if is_strict_utf8
and six.PY2
and sys.maxunicode > _UCS2_MAXUNICODE:
556 if _SURROGATE_PATTERN.search(value):
557 reason = (
'String field %s contains invalid UTF-8 data when parsing'
558 'a protocol buffer: surrogates not allowed. Use'
559 'the bytes type if you intend to send raw bytes.') % (
561 raise message.DecodeError(reason)
567 tag_bytes = encoder.TagBytes(field_number,
568 wire_format.WIRETYPE_LENGTH_DELIMITED)
569 tag_len =
len(tag_bytes)
570 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
571 value = field_dict.get(key)
573 value = field_dict.setdefault(key, new_default(message))
575 (size, pos) = local_DecodeVarint(buffer, pos)
579 value.append(_ConvertToUnicode(buffer[pos:new_pos]))
581 pos = new_pos + tag_len
582 if buffer[new_pos:pos] != tag_bytes
or new_pos == end:
585 return DecodeRepeatedField
587 def DecodeField(buffer, pos, end, message, field_dict):
588 (size, pos) = local_DecodeVarint(buffer, pos)
592 field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos])
597 def BytesDecoder(field_number, is_repeated, is_packed, key, new_default):
598 """Returns a decoder for a bytes field."""
600 local_DecodeVarint = _DecodeVarint
604 tag_bytes = encoder.TagBytes(field_number,
605 wire_format.WIRETYPE_LENGTH_DELIMITED)
606 tag_len =
len(tag_bytes)
607 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
608 value = field_dict.get(key)
610 value = field_dict.setdefault(key, new_default(message))
612 (size, pos) = local_DecodeVarint(buffer, pos)
616 value.append(buffer[pos:new_pos].tobytes())
618 pos = new_pos + tag_len
619 if buffer[new_pos:pos] != tag_bytes
or new_pos == end:
622 return DecodeRepeatedField
624 def DecodeField(buffer, pos, end, message, field_dict):
625 (size, pos) = local_DecodeVarint(buffer, pos)
629 field_dict[key] = buffer[pos:new_pos].tobytes()
634 def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
635 """Returns a decoder for a group field."""
637 end_tag_bytes = encoder.TagBytes(field_number,
638 wire_format.WIRETYPE_END_GROUP)
639 end_tag_len =
len(end_tag_bytes)
643 tag_bytes = encoder.TagBytes(field_number,
644 wire_format.WIRETYPE_START_GROUP)
645 tag_len =
len(tag_bytes)
646 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
647 value = field_dict.get(key)
649 value = field_dict.setdefault(key, new_default(message))
651 value = field_dict.get(key)
653 value = field_dict.setdefault(key, new_default(message))
657 new_pos = pos+end_tag_len
658 if buffer[pos:new_pos] != end_tag_bytes
or new_pos > end:
661 pos = new_pos + tag_len
662 if buffer[new_pos:pos] != tag_bytes
or new_pos == end:
665 return DecodeRepeatedField
667 def DecodeField(buffer, pos, end, message, field_dict):
668 value = field_dict.get(key)
670 value = field_dict.setdefault(key, new_default(message))
672 pos = value._InternalParse(buffer, pos, end)
674 new_pos = pos+end_tag_len
675 if buffer[pos:new_pos] != end_tag_bytes
or new_pos > end:
682 """Returns a decoder for a message field."""
684 local_DecodeVarint = _DecodeVarint
688 tag_bytes = encoder.TagBytes(field_number,
689 wire_format.WIRETYPE_LENGTH_DELIMITED)
690 tag_len =
len(tag_bytes)
691 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
692 value = field_dict.get(key)
694 value = field_dict.setdefault(key, new_default(message))
697 (size, pos) = local_DecodeVarint(buffer, pos)
707 pos = new_pos + tag_len
708 if buffer[new_pos:pos] != tag_bytes
or new_pos == end:
711 return DecodeRepeatedField
713 def DecodeField(buffer, pos, end, message, field_dict):
714 value = field_dict.get(key)
716 value = field_dict.setdefault(key, new_default(message))
718 (size, pos) = local_DecodeVarint(buffer, pos)
723 if value._InternalParse(buffer, pos, new_pos) != new_pos:
733 MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP)
736 """Returns a decoder for a MessageSet item.
738 The parameter is the message Descriptor.
740 The message set message looks like this:
742 repeated group Item = 1 {
743 required int32 type_id = 2;
744 required string message = 3;
749 type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT)
750 message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)
751 item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP)
753 local_ReadTag = ReadTag
754 local_DecodeVarint = _DecodeVarint
755 local_SkipField = SkipField
757 def DecodeItem(buffer, pos, end, message, field_dict):
758 """Decode serialized message set to its value and new position.
761 buffer: memoryview of the serialized bytes.
762 pos: int, position in the memory view to start at.
763 end: int, end position of serialized data
764 message: Message object to store unknown fields in
765 field_dict: Map[Descriptor, Any] to store decoded values in.
768 int, new position in serialized data.
770 message_set_item_start = pos
778 (tag_bytes, pos) = local_ReadTag(buffer, pos)
779 if tag_bytes == type_id_tag_bytes:
780 (type_id, pos) = local_DecodeVarint(buffer, pos)
781 elif tag_bytes == message_tag_bytes:
782 (size, message_start) = local_DecodeVarint(buffer, pos)
783 pos = message_end = message_start + size
784 elif tag_bytes == item_end_tag_bytes:
787 pos =
SkipField(buffer, pos, end, tag_bytes)
796 if message_start == -1:
799 extension = message.Extensions._FindExtensionByNumber(type_id)
801 if extension
is not None:
802 value = field_dict.get(extension)
804 value = field_dict.setdefault(
805 extension, extension.message_type._concrete_class())
806 if value._InternalParse(buffer, message_start,message_end) != message_end:
811 if not message._unknown_fields:
812 message._unknown_fields = []
813 message._unknown_fields.append(
814 (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos].tobytes()))
823 def MapDecoder(field_descriptor, new_default, is_message_map):
824 """Returns a decoder for a map field."""
826 key = field_descriptor
827 tag_bytes = encoder.TagBytes(field_descriptor.number,
828 wire_format.WIRETYPE_LENGTH_DELIMITED)
829 tag_len =
len(tag_bytes)
830 local_DecodeVarint = _DecodeVarint
832 message_type = field_descriptor.message_type
834 def DecodeMap(buffer, pos, end, message, field_dict):
835 submsg = message_type._concrete_class()
836 value = field_dict.get(key)
838 value = field_dict.setdefault(key, new_default(message))
841 (size, pos) = local_DecodeVarint(buffer, pos)
847 if submsg._InternalParse(buffer, pos, new_pos) != new_pos:
853 value[submsg.key].
MergeFrom(submsg.value)
855 value[submsg.key] = submsg.value
858 pos = new_pos + tag_len
859 if buffer[new_pos:pos] != tag_bytes
or new_pos == end:
870 """Skip a varint value. Returns the new position."""
874 while ord(buffer[pos:pos+1].tobytes()) & 0x80:
882 """Skip a fixed64 value. Returns the new position."""
891 """Decode a fixed64."""
893 return (struct.unpack(
'<Q', buffer[pos:new_pos])[0], new_pos)
897 """Skip a length-delimited value. Returns the new position."""
907 """Skip sub-group. Returns the new position."""
910 (tag_bytes, pos) =
ReadTag(buffer, pos)
911 new_pos =
SkipField(buffer, pos, end, tag_bytes)
918 """Decode UnknownFieldSet. Returns the UnknownFieldSet and new position."""
921 while end_pos
is None or pos < end_pos:
922 (tag_bytes, pos) =
ReadTag(buffer, pos)
924 field_number, wire_type = wire_format.UnpackTag(tag)
925 if wire_type == wire_format.WIRETYPE_END_GROUP:
929 unknown_field_set._add(field_number, wire_type, data)
931 return (unknown_field_set, pos)
935 """Decode a unknown field. Returns the UnknownField and new position."""
937 if wire_type == wire_format.WIRETYPE_VARINT:
939 elif wire_type == wire_format.WIRETYPE_FIXED64:
941 elif wire_type == wire_format.WIRETYPE_FIXED32:
943 elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED:
945 data = buffer[pos:pos+size]
947 elif wire_type == wire_format.WIRETYPE_START_GROUP:
949 elif wire_type == wire_format.WIRETYPE_END_GROUP:
958 """Skipping an END_GROUP tag returns -1 to tell the parent loop to break."""
964 """Skip a fixed32 value. Returns the new position."""
973 """Decode a fixed32."""
976 return (struct.unpack(
'<I', buffer[pos:new_pos])[0], new_pos)
980 """Skip function for unknown wire types. Raises an exception."""
985 """Constructs the SkipField function."""
987 WIRETYPE_TO_SKIPPER = [
990 _SkipLengthDelimited,
994 _RaiseInvalidWireType,
995 _RaiseInvalidWireType,
998 wiretype_mask = wire_format.TAG_TYPE_MASK
1000 def SkipField(buffer, pos, end, tag_bytes):
1001 """Skips a field with the specified tag.
1003 |pos| should point to the byte immediately after the tag.
1006 The new position (after the tag value), or -1 if the tag is an end-group
1007 tag (in which case the calling loop should break).
1011 wire_type = ord(tag_bytes[0:1]) & wiretype_mask
1012 return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)