Package redis :: Module connection
[frames] | no frames]

Source Code for Module redis.connection

  1  from itertools import chain 
  2  import os 
  3  import socket 
  4  import sys 
  5   
  6  from redis._compat import (b, xrange, imap, byte_to_chr, unicode, bytes, long, 
  7                             BytesIO, nativestr) 
  8  from redis.exceptions import ( 
  9      RedisError, 
 10      ConnectionError, 
 11      ResponseError, 
 12      InvalidResponse, 
 13      AuthenticationError, 
 14      NoScriptError, 
 15  ) 
 16   
 17  try: 
 18      import hiredis 
 19      hiredis_available = True 
 20  except ImportError: 
 21      hiredis_available = False 
 22   
 23   
 24  SYM_STAR = b('*') 
 25  SYM_DOLLAR = b('$') 
 26  SYM_CRLF = b('\r\n') 
 27  SYM_LF = b('\n') 
 28   
 29   
30 -class PythonParser(object):
31 "Plain Python parsing class" 32 MAX_READ_LENGTH = 1000000 33 encoding = None 34 35 EXCEPTION_CLASSES = { 36 'ERR': ResponseError, 37 'NOSCRIPT': NoScriptError, 38 } 39
40 - def __init__(self):
41 self._fp = None
42
43 - def __del__(self):
44 try: 45 self.on_disconnect() 46 except: 47 pass
48
49 - def on_connect(self, connection):
50 "Called when the socket connects" 51 self._fp = connection._sock.makefile('rb') 52 if connection.decode_responses: 53 self.encoding = connection.encoding
54
55 - def on_disconnect(self):
56 "Called when the socket disconnects" 57 if self._fp is not None: 58 self._fp.close() 59 self._fp = None
60
61 - def read(self, length=None):
62 """ 63 Read a line from the socket is no length is specified, 64 otherwise read ``length`` bytes. Always strip away the newlines. 65 """ 66 try: 67 if length is not None: 68 bytes_left = length + 2 # read the line ending 69 if length > self.MAX_READ_LENGTH: 70 # apparently reading more than 1MB or so from a windows 71 # socket can cause MemoryErrors. See: 72 # https://github.com/andymccurdy/redis-py/issues/205 73 # read smaller chunks at a time to work around this 74 try: 75 buf = BytesIO() 76 while bytes_left > 0: 77 read_len = min(bytes_left, self.MAX_READ_LENGTH) 78 buf.write(self._fp.read(read_len)) 79 bytes_left -= read_len 80 buf.seek(0) 81 return buf.read(length) 82 finally: 83 buf.close() 84 return self._fp.read(bytes_left)[:-2] 85 86 # no length, read a full line 87 return self._fp.readline()[:-2] 88 except (socket.error, socket.timeout): 89 e = sys.exc_info()[1] 90 raise ConnectionError("Error while reading from socket: %s" % 91 (e.args,))
92
93 - def parse_error(self, response):
94 "Parse an error response" 95 error_code = response.split(' ')[0] 96 if error_code in self.EXCEPTION_CLASSES: 97 response = response[len(error_code) + 1:] 98 return self.EXCEPTION_CLASSES[error_code](response) 99 return ResponseError(response)
100
101 - def read_response(self):
102 response = self.read() 103 if not response: 104 raise ConnectionError("Socket closed on remote end") 105 106 byte, response = byte_to_chr(response[0]), response[1:] 107 108 if byte not in ('-', '+', ':', '$', '*'): 109 raise InvalidResponse("Protocol Error") 110 111 # server returned an error 112 if byte == '-': 113 response = nativestr(response) 114 if response.startswith('LOADING '): 115 # if we're loading the dataset into memory, kill the socket 116 # so we re-initialize (and re-SELECT) next time. 117 raise ConnectionError("Redis is loading data into memory") 118 # *return*, not raise the exception class. if it is meant to be 119 # raised, it will be at a higher level. 120 return self.parse_error(response) 121 # single value 122 elif byte == '+': 123 pass 124 # int value 125 elif byte == ':': 126 response = long(response) 127 # bulk response 128 elif byte == '$': 129 length = int(response) 130 if length == -1: 131 return None 132 response = self.read(length) 133 # multi-bulk response 134 elif byte == '*': 135 length = int(response) 136 if length == -1: 137 return None 138 response = [self.read_response() for i in xrange(length)] 139 if isinstance(response, bytes) and self.encoding: 140 response = response.decode(self.encoding) 141 return response
142 143
144 -class HiredisParser(object):
145 "Parser class for connections using Hiredis"
146 - def __init__(self):
147 if not hiredis_available: 148 raise RedisError("Hiredis is not installed")
149
150 - def __del__(self):
151 try: 152 self.on_disconnect() 153 except: 154 pass
155
156 - def on_connect(self, connection):
157 self._sock = connection._sock 158 kwargs = { 159 'protocolError': InvalidResponse, 160 'replyError': ResponseError, 161 } 162 if connection.decode_responses: 163 kwargs['encoding'] = connection.encoding 164 self._reader = hiredis.Reader(**kwargs)
165
166 - def on_disconnect(self):
167 self._sock = None 168 self._reader = None
169
170 - def read_response(self):
171 if not self._reader: 172 raise ConnectionError("Socket closed on remote end") 173 response = self._reader.gets() 174 while response is False: 175 try: 176 buffer = self._sock.recv(4096) 177 except (socket.error, socket.timeout): 178 e = sys.exc_info()[1] 179 raise ConnectionError("Error while reading from socket: %s" % 180 (e.args,)) 181 if not buffer: 182 raise ConnectionError("Socket closed on remote end") 183 self._reader.feed(buffer) 184 # proactively, but not conclusively, check if more data is in the 185 # buffer. if the data received doesn't end with \n, there's more. 186 if not buffer.endswith(SYM_LF): 187 continue 188 response = self._reader.gets() 189 return response
190 191 if hiredis_available: 192 DefaultParser = HiredisParser 193 else: 194 DefaultParser = PythonParser 195 196
197 -class Connection(object):
198 "Manages TCP communication to and from a Redis server"
199 - def __init__(self, host='localhost', port=6379, db=0, password=None, 200 socket_timeout=None, encoding='utf-8', 201 encoding_errors='strict', decode_responses=False, 202 parser_class=DefaultParser):
203 self.pid = os.getpid() 204 self.host = host 205 self.port = port 206 self.db = db 207 self.password = password 208 self.socket_timeout = socket_timeout 209 self.encoding = encoding 210 self.encoding_errors = encoding_errors 211 self.decode_responses = decode_responses 212 self._sock = None 213 self._parser = parser_class()
214
215 - def __del__(self):
216 try: 217 self.disconnect() 218 except: 219 pass
220
221 - def connect(self):
222 "Connects to the Redis server if not already connected" 223 if self._sock: 224 return 225 try: 226 sock = self._connect() 227 except socket.error: 228 e = sys.exc_info()[1] 229 raise ConnectionError(self._error_message(e)) 230 231 self._sock = sock 232 self.on_connect()
233
234 - def _connect(self):
235 "Create a TCP socket connection" 236 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 237 sock.settimeout(self.socket_timeout) 238 sock.connect((self.host, self.port)) 239 return sock
240
241 - def _error_message(self, exception):
242 # args for socket.error can either be (errno, "message") 243 # or just "message" 244 if len(exception.args) == 1: 245 return "Error connecting to %s:%s. %s." % \ 246 (self.host, self.port, exception.args[0]) 247 else: 248 return "Error %s connecting %s:%s. %s." % \ 249 (exception.args[0], self.host, self.port, exception.args[1])
250
251 - def on_connect(self):
252 "Initialize the connection, authenticate and select a database" 253 self._parser.on_connect(self) 254 255 # if a password is specified, authenticate 256 if self.password: 257 self.send_command('AUTH', self.password) 258 if nativestr(self.read_response()) != 'OK': 259 raise AuthenticationError('Invalid Password') 260 261 # if a database is specified, switch to it 262 if self.db: 263 self.send_command('SELECT', self.db) 264 if nativestr(self.read_response()) != 'OK': 265 raise ConnectionError('Invalid Database')
266
267 - def disconnect(self):
268 "Disconnects from the Redis server" 269 self._parser.on_disconnect() 270 if self._sock is None: 271 return 272 try: 273 self._sock.close() 274 except socket.error: 275 pass 276 self._sock = None
277
278 - def send_packed_command(self, command):
279 "Send an already packed command to the Redis server" 280 if not self._sock: 281 self.connect() 282 try: 283 self._sock.sendall(command) 284 except socket.error: 285 e = sys.exc_info()[1] 286 self.disconnect() 287 if len(e.args) == 1: 288 _errno, errmsg = 'UNKNOWN', e.args[0] 289 else: 290 _errno, errmsg = e.args 291 raise ConnectionError("Error %s while writing to socket. %s." % 292 (_errno, errmsg)) 293 except: 294 self.disconnect() 295 raise
296
297 - def send_command(self, *args):
298 "Pack and send a command to the Redis server" 299 self.send_packed_command(self.pack_command(*args))
300
301 - def read_response(self):
302 "Read the response from a previously sent command" 303 try: 304 response = self._parser.read_response() 305 except: 306 self.disconnect() 307 raise 308 if isinstance(response, ResponseError): 309 raise response 310 return response
311
312 - def encode(self, value):
313 "Return a bytestring representation of the value" 314 if isinstance(value, bytes): 315 return value 316 if not isinstance(value, unicode): 317 value = str(value) 318 if isinstance(value, unicode): 319 value = value.encode(self.encoding, self.encoding_errors) 320 return value
321
322 - def pack_command(self, *args):
323 "Pack a series of arguments into a value Redis command" 324 output = SYM_STAR + b(str(len(args))) + SYM_CRLF 325 for enc_value in imap(self.encode, args): 326 output += SYM_DOLLAR 327 output += b(str(len(enc_value))) 328 output += SYM_CRLF 329 output += enc_value 330 output += SYM_CRLF 331 return output
332 333
334 -class UnixDomainSocketConnection(Connection):
335 - def __init__(self, path='', db=0, password=None, 336 socket_timeout=None, encoding='utf-8', 337 encoding_errors='strict', decode_responses=False, 338 parser_class=DefaultParser):
339 self.pid = os.getpid() 340 self.path = path 341 self.db = db 342 self.password = password 343 self.socket_timeout = socket_timeout 344 self.encoding = encoding 345 self.encoding_errors = encoding_errors 346 self.decode_responses = decode_responses 347 self._sock = None 348 self._parser = parser_class()
349
350 - def _connect(self):
351 "Create a Unix domain socket connection" 352 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) 353 sock.settimeout(self.socket_timeout) 354 sock.connect(self.path) 355 return sock
356
357 - def _error_message(self, exception):
358 # args for socket.error can either be (errno, "message") 359 # or just "message" 360 if len(exception.args) == 1: 361 return "Error connecting to unix socket: %s. %s." % \ 362 (self.path, exception.args[0]) 363 else: 364 return "Error %s connecting to unix socket: %s. %s." % \ 365 (exception.args[0], self.path, exception.args[1])
366 367 368 # TODO: add ability to block waiting on a connection to be released
369 -class ConnectionPool(object):
370 "Generic connection pool"
371 - def __init__(self, connection_class=Connection, max_connections=None, 372 **connection_kwargs):
373 self.pid = os.getpid() 374 self.connection_class = connection_class 375 self.connection_kwargs = connection_kwargs 376 self.max_connections = max_connections or 2 ** 31 377 self._created_connections = 0 378 self._available_connections = [] 379 self._in_use_connections = set()
380
381 - def _checkpid(self):
382 if self.pid != os.getpid(): 383 self.disconnect() 384 self.__init__(self.connection_class, self.max_connections, 385 **self.connection_kwargs)
386
387 - def get_connection(self, command_name, *keys, **options):
388 "Get a connection from the pool" 389 self._checkpid() 390 try: 391 connection = self._available_connections.pop() 392 except IndexError: 393 connection = self.make_connection() 394 self._in_use_connections.add(connection) 395 return connection
396
397 - def make_connection(self):
398 "Create a new connection" 399 if self._created_connections >= self.max_connections: 400 raise ConnectionError("Too many connections") 401 self._created_connections += 1 402 return self.connection_class(**self.connection_kwargs)
403
404 - def release(self, connection):
405 "Releases the connection back to the pool" 406 self._checkpid() 407 if connection.pid == self.pid: 408 self._in_use_connections.remove(connection) 409 self._available_connections.append(connection)
410
411 - def disconnect(self):
412 "Disconnects all connections in the pool" 413 all_conns = chain(self._available_connections, 414 self._in_use_connections) 415 for connection in all_conns: 416 connection.disconnect()
417