32 """Test for preservation of unknown fields in the pure Python implementation."""
34 __author__ =
'bohdank@google.com (Bohdan Koval)'
60 @testing_refleaks.TestCase
61 class UnknownFieldsTest(unittest.TestCase):
64 self.
descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
80 message = unittest_proto3_arena_pb2.TestEmptyMessage()
94 raw = unittest_mset_pb2.RawMessageSet()
98 item.type_id = 98218603
99 message1 = message_set_extensions_pb2.TestMessageSetExtension1()
101 item.message = message1.SerializeToString()
103 serialized = raw.SerializeToString()
106 proto = message_set_extensions_pb2.TestMessageSet()
107 proto.MergeFromString(serialized)
109 unknown_fields = proto.UnknownFields()
110 self.assertEqual(
len(unknown_fields), 1)
113 self.assertEqual(unknown_fields[0].field_number, item.type_id)
114 self.assertEqual(unknown_fields[0].wire_type,
115 wire_format.WIRETYPE_LENGTH_DELIMITED)
116 d = unknown_fields[0].data
117 message_new = message_set_extensions_pb2.TestMessageSetExtension1()
118 message_new.ParseFromString(d)
119 self.assertEqual(message1, message_new)
122 reserialized = proto.SerializeToString()
123 new_raw = unittest_mset_pb2.RawMessageSet()
124 new_raw.MergeFromString(reserialized)
125 self.assertEqual(raw, new_raw)
128 message = unittest_pb2.TestEmptyMessage()
140 message = unittest_pb2.TestAllTypes()
141 other_message = unittest_pb2.TestAllTypes()
142 other_message.optional_string =
'discard'
143 message.optional_nested_message.ParseFromString(
144 other_message.SerializeToString())
146 other_message.SerializeToString())
148 b
'', message.optional_nested_message.SerializeToString())
151 message.DiscardUnknownFields()
152 self.assertEqual(b
'', message.optional_nested_message.SerializeToString())
156 msg = map_unittest_pb2.TestMap()
157 msg.map_int32_all_types[1].optional_nested_message.ParseFromString(
158 other_message.SerializeToString())
159 msg.map_string_string[
'1'] =
'test'
162 msg.map_int32_all_types[1].optional_nested_message.SerializeToString())
163 msg.DiscardUnknownFields()
166 msg.map_int32_all_types[1].optional_nested_message.SerializeToString())
169 @testing_refleaks.TestCase
173 self.
descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
187 if api_implementation.Type() ==
'cpp':
189 field_descriptor = self.
descriptor.fields_by_name[name]
190 wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type]
191 field_tag = encoder.TagBytes(field_descriptor.number, wire_type)
194 if tag_bytes == field_tag:
195 decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0]
196 decoder(memoryview(value), 0,
len(value), self.
all_fields, result_dict)
197 self.assertEqual(expected_value, result_dict[field_descriptor])
200 field_descriptor = self.
descriptor.fields_by_name[name]
201 expected_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[
202 field_descriptor.type]
203 for unknown_field
in unknown_fields:
204 if unknown_field.field_number == field_descriptor.number:
205 self.assertEqual(expected_type, unknown_field.wire_type)
206 if expected_type == 3:
208 self.assertEqual(expected_value[0],
209 unknown_field.data[0].field_number)
210 self.assertEqual(expected_value[1], unknown_field.data[0].wire_type)
211 self.assertEqual(expected_value[2], unknown_field.data[0].data)
213 if expected_type == wire_format.WIRETYPE_LENGTH_DELIMITED:
214 self.assertIn(
type(unknown_field.data), (str, bytes))
215 if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED:
216 self.assertIn(unknown_field.data, expected_value)
218 self.assertEqual(expected_value, unknown_field.data)
260 self.
all_fields.optional_string.encode(
'utf-8'))
271 self.assertEqual(97,
len(unknown_fields))
274 message = unittest_pb2.TestEmptyMessage()
279 message = unittest_pb2.TestAllTypes()
280 message.optional_int32 = 1
281 message.optional_uint32 = 2
282 source = unittest_pb2.TestEmptyMessage()
283 source.ParseFromString(message.SerializeToString())
285 message.ClearField(
'optional_int32')
286 message.optional_int64 = 3
287 message.optional_uint32 = 4
288 destination = unittest_pb2.TestEmptyMessage()
289 unknown_fields = destination.UnknownFields()
290 self.assertEqual(0,
len(unknown_fields))
291 destination.ParseFromString(message.SerializeToString())
293 with self.assertRaises(ValueError)
as context:
295 self.assertIn(
'UnknownFields does not exist.',
296 str(context.exception))
297 unknown_fields = destination.UnknownFields()
298 self.assertEqual(2,
len(unknown_fields))
299 destination.MergeFrom(source)
300 self.assertEqual(4,
len(unknown_fields))
303 message.ParseFromString(destination.SerializeToString())
304 self.assertEqual(message.optional_int32, 1)
305 self.assertEqual(message.optional_uint32, 2)
306 self.assertEqual(message.optional_int64, 3)
313 with self.assertRaises(ValueError)
as context:
315 self.assertIn(
'UnknownFields does not exist.',
316 str(context.exception))
318 @unittest.skipIf((sys.version_info.major, sys.version_info.minor) < (3, 4),
319 'tracemalloc requires python 3.4+')
324 def leaking_function():
325 for _
in range(nb_leaks):
329 snapshot1 = tracemalloc.take_snapshot()
331 snapshot2 = tracemalloc.take_snapshot()
332 top_stats = snapshot2.compare_to(snapshot1,
'lineno')
336 self.assertEqual([], [x
for x
in top_stats
if x.count_diff == nb_leaks])
339 message = unittest_pb2.TestAllTypes()
340 message.optionalgroup.a = 123
341 destination = unittest_pb2.TestEmptyMessage()
342 destination.ParseFromString(message.SerializeToString())
343 sub_unknown_fields = destination.UnknownFields()[0].data
344 self.assertEqual(1,
len(sub_unknown_fields))
345 self.assertEqual(sub_unknown_fields[0].data, 123)
347 with self.assertRaises(ValueError)
as context:
348 len(sub_unknown_fields)
349 self.assertIn(
'UnknownFields does not exist.',
350 str(context.exception))
351 with self.assertRaises(ValueError)
as context:
353 sub_unknown_fields[0]
354 self.assertIn(
'UnknownFields does not exist.',
355 str(context.exception))
357 message.optional_uint32 = 456
358 nested_message = unittest_pb2.NestedTestAllTypes()
359 nested_message.payload.optional_nested_message.ParseFromString(
360 message.SerializeToString())
362 nested_message.payload.optional_nested_message.UnknownFields())
363 self.assertEqual(unknown_fields[0].data, 456)
364 nested_message.ClearField(
'payload')
365 self.assertEqual(unknown_fields[0].data, 456)
367 nested_message.payload.optional_nested_message.UnknownFields())
368 self.assertEqual(0,
len(unknown_fields))
371 message = unittest_pb2.TestAllTypes()
372 message.optional_int32 = 123
373 destination = unittest_pb2.TestEmptyMessage()
374 destination.ParseFromString(message.SerializeToString())
375 unknown_field = destination.UnknownFields()[0]
377 with self.assertRaises(ValueError)
as context:
379 self.assertIn(
'The parent message might be cleared.',
380 str(context.exception))
383 message = unittest_pb2.TestEmptyMessageWithExtensions()
385 self.assertEqual(
len(message.UnknownFields()), 97)
389 @testing_refleaks.TestCase
393 self.
descriptor = missing_enum_values_pb2.TestEnumValues.DESCRIPTOR
395 self.
message = missing_enum_values_pb2.TestEnumValues()
397 self.
message.optional_nested_enum = (
398 missing_enum_values_pb2.TestEnumValues.ZERO)
399 self.
message.repeated_nested_enum.extend([
400 missing_enum_values_pb2.TestEnumValues.ZERO,
401 missing_enum_values_pb2.TestEnumValues.ONE,
403 self.
message.packed_nested_enum.extend([
404 missing_enum_values_pb2.TestEnumValues.ZERO,
405 missing_enum_values_pb2.TestEnumValues.ONE,
418 field_descriptor = self.
descriptor.fields_by_name[name]
421 for field
in unknown_fields:
422 if field.field_number == field_descriptor.number:
424 if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED:
425 self.assertIn(field.data, expected_value)
427 self.assertEqual(expected_value, field.data)
428 if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED:
429 self.assertEqual(count,
len(expected_value))
431 self.assertEqual(count, 1)
434 just_string = missing_enum_values_pb2.JustString()
435 just_string.dummy =
'blah'
437 missing = missing_enum_values_pb2.TestEnumValues()
440 missing.ParseFromString(just_string.SerializeToString())
444 self.assertEqual(missing.optional_nested_enum, 0)
462 self.assertEqual(
len(unknown_fields), 5)
464 self.
message.optional_nested_enum)
466 self.
message.repeated_nested_enum)
468 self.
message.packed_nested_enum)
471 new_message = missing_enum_values_pb2.TestEnumValues()
473 self.assertEqual(self.
message, new_message)
476 if __name__ ==
'__main__':