14 """Tests server and client side compression."""
16 from concurrent
import futures
25 from grpc
import _grpcio_metadata
30 _UNARY_UNARY =
'/test/UnaryUnary'
31 _UNARY_STREAM =
'/test/UnaryStream'
32 _STREAM_UNARY =
'/test/StreamUnary'
33 _STREAM_STREAM =
'/test/StreamStream'
36 _STREAM_LENGTH = test_constants.STREAM_LENGTH // 16
40 _REQUEST = b
'\x00' * 100
41 _COMPRESSION_RATIO_THRESHOLD = 0.05
42 _COMPRESSION_METHODS = (
47 grpc.Compression.Gzip,
49 _COMPRESSION_NAMES = {
51 grpc.Compression.NoCompression:
'NoCompression',
52 grpc.Compression.Deflate:
'DeflateCompression',
53 grpc.Compression.Gzip:
'GzipCompression',
57 'client_streaming': (
True,
False),
58 'server_streaming': (
True,
False),
59 'channel_compression': _COMPRESSION_METHODS,
60 'multicallable_compression': _COMPRESSION_METHODS,
61 'server_compression': _COMPRESSION_METHODS,
62 'server_call_compression': _COMPRESSION_METHODS,
68 def _handle_unary(request, servicer_context):
69 if pre_response_callback:
70 pre_response_callback(request, servicer_context)
79 if pre_response_callback:
80 pre_response_callback(request, servicer_context)
81 for _
in range(_STREAM_LENGTH):
84 return _handle_unary_stream
90 if pre_response_callback:
91 pre_response_callback(request_iterator, servicer_context)
93 for request
in request_iterator:
98 return _handle_stream_unary
103 def _handle_stream(request_iterator, servicer_context):
106 for request
in request_iterator:
107 if pre_response_callback:
108 pre_response_callback(request, servicer_context)
111 return _handle_stream
116 del request_or_iterator
117 servicer_context.set_compression(compression_method)
122 servicer_context.disable_next_message_compression()
126 if int(request.decode(
'ascii')) == 0:
127 servicer_context.disable_next_message_compression()
132 def __init__(self, request_streaming, response_streaming,
133 pre_response_callback):
145 pre_response_callback)
160 if handler_call_details.method == _UNARY_UNARY:
162 elif handler_call_details.method == _UNARY_STREAM:
164 elif handler_call_details.method == _STREAM_UNARY:
166 elif handler_call_details.method == _STREAM_STREAM:
172 @contextlib.contextmanager
175 server =
grpc.server(futures.ThreadPoolExecutor(), **server_kwargs)
176 server.add_generic_rpc_handlers((server_handler,))
177 server_port = server.add_insecure_port(
'{}:0'.
format(_HOST))
180 proxy_port = proxy.get_port()
182 **channel_kwargs)
as client_channel:
184 yield client_channel, proxy, server
190 server_kwargs, server_handler, message):
192 server_handler)
as pipeline:
193 client_channel, proxy, server = pipeline
194 client_function(client_channel, multicallable_kwargs, message)
195 return proxy.get_byte_count()
199 first_multicallable_kwargs, first_server_kwargs,
200 first_server_handler, second_channel_kwargs,
201 second_multicallable_kwargs, second_server_kwargs,
202 second_server_handler, message):
204 first_channel_kwargs, first_multicallable_kwargs, client_function,
205 first_server_kwargs, first_server_handler, message)
207 second_channel_kwargs, second_multicallable_kwargs, client_function,
208 second_server_kwargs, second_server_handler, message)
209 return ((second_bytes_sent - first_bytes_sent) / float(first_bytes_sent),
210 (second_bytes_received - first_bytes_received) /
211 float(first_bytes_received))
215 multi_callable = channel.unary_unary(_UNARY_UNARY)
217 if response != message:
218 raise RuntimeError(
"Request '{}' != Response '{}'".
format(
223 multi_callable = channel.unary_stream(_UNARY_STREAM)
224 response_iterator =
multi_callable(message, **multicallable_kwargs)
225 for response
in response_iterator:
226 if response != message:
227 raise RuntimeError(
"Request '{}' != Response '{}'".
format(
232 multi_callable = channel.stream_unary(_STREAM_UNARY)
233 requests = (_REQUEST
for _
in range(_STREAM_LENGTH))
235 if response != message:
236 raise RuntimeError(
"Request '{}' != Response '{}'".
format(
241 multi_callable = channel.stream_stream(_STREAM_STREAM)
242 request_prefix =
str(0).
encode(
'ascii') * 100
244 request_prefix +
str(i).
encode(
'ascii')
for i
in range(_STREAM_LENGTH))
245 response_iterator =
multi_callable(requests, **multicallable_kwargs)
246 for i, response
in enumerate(response_iterator):
247 if int(response.decode(
'ascii')) != i:
248 raise RuntimeError(
"Request '{}' != Response '{}'".
format(
257 -1.0 * _COMPRESSION_RATIO_THRESHOLD,
258 msg=
'Actual compression ratio: {}'.
format(compression_ratio))
261 self.assertGreaterEqual(
263 -1.0 * _COMPRESSION_RATIO_THRESHOLD,
264 msg=
'Actual compession ratio: {}'.
format(compression_ratio))
268 multicallable_compression,
270 server_call_compression):
271 client_side_compressed = channel_compression
or multicallable_compression
272 server_side_compressed = server_compression
or server_call_compression
274 'compression': channel_compression,
275 }
if channel_compression
else {}
276 multicallable_kwargs = {
277 'compression': multicallable_compression,
278 }
if multicallable_compression
else {}
280 client_function =
None
281 if not client_streaming
and not server_streaming:
282 client_function = _unary_unary_client
283 elif not client_streaming
and server_streaming:
284 client_function = _unary_stream_client
285 elif client_streaming
and not server_streaming:
286 client_function = _stream_unary_client
288 client_function = _stream_stream_client
291 'compression': server_compression,
292 }
if server_compression
else {}
294 functools.partial(set_call_compression, grpc.Compression.Gzip)
298 multicallable_kwargs, server_kwargs,
299 server_handler, _REQUEST)
303 'compression': grpc.Compression.Deflate,
312 'compression': grpc.Compression.Deflate,
321 return '{}{}'.
format(name, _COMPRESSION_NAMES[value])
325 channel_compression, multicallable_compression,
326 server_compression, server_call_compression):
327 client_arity =
'Stream' if client_streaming
else 'Unary'
328 server_arity =
'Stream' if server_streaming
else 'Unary'
329 arity =
'{}{}'.
format(client_arity, server_arity)
333 'Multicallable', multicallable_compression)
336 server_call_compression)
337 return 'test{}{}{}{}{}'.
format(arity, channel_compression_str,
338 multicallable_compression_str,
339 server_compression_str,
340 server_call_compression_str)
344 for test_parameters
in itertools.product(*_TEST_OPTIONS.values()):
345 yield dict(zip(_TEST_OPTIONS.keys(), test_parameters))
352 def _test_compression(self):
353 self.assertConfigurationCompressed(**kwargs)
355 return _test_compression
360 if __name__ ==
'__main__':
361 logging.basicConfig()
362 unittest.main(verbosity=2)