00001 from itertools import chain
00002 import os
00003 import socket
00004 import sys
00005
00006 from ._compat import (b, xrange, imap, byte_to_chr, unicode, bytes, long,
00007 BytesIO, nativestr)
00008 from .exceptions import (
00009 RedisError,
00010 ConnectionError,
00011 ResponseError,
00012 InvalidResponse,
00013 AuthenticationError,
00014 NoScriptError,
00015 )
00016
00017 try:
00018 import hiredis
00019 hiredis_available = True
00020 except ImportError:
00021 hiredis_available = False
00022
00023
00024 SYM_STAR = b('*')
00025 SYM_DOLLAR = b('$')
00026 SYM_CRLF = b('\r\n')
00027 SYM_LF = b('\n')
00028
00029
00030 class PythonParser(object):
00031 "Plain Python parsing class"
00032 MAX_READ_LENGTH = 1000000
00033 encoding = None
00034
00035 EXCEPTION_CLASSES = {
00036 'ERR': ResponseError,
00037 'NOSCRIPT': NoScriptError,
00038 }
00039
00040 def __init__(self):
00041 self._fp = None
00042
00043 def __del__(self):
00044 try:
00045 self.on_disconnect()
00046 except:
00047 pass
00048
00049 def on_connect(self, connection):
00050 "Called when the socket connects"
00051 self._fp = connection._sock.makefile('rb')
00052 if connection.decode_responses:
00053 self.encoding = connection.encoding
00054
00055 def on_disconnect(self):
00056 "Called when the socket disconnects"
00057 if self._fp is not None:
00058 self._fp.close()
00059 self._fp = None
00060
00061 def read(self, length=None):
00062 """
00063 Read a line from the socket is no length is specified,
00064 otherwise read ``length`` bytes. Always strip away the newlines.
00065 """
00066 try:
00067 if length is not None:
00068 bytes_left = length + 2
00069 if length > self.MAX_READ_LENGTH:
00070
00071
00072
00073
00074 try:
00075 buf = BytesIO()
00076 while bytes_left > 0:
00077 read_len = min(bytes_left, self.MAX_READ_LENGTH)
00078 buf.write(self._fp.read(read_len))
00079 bytes_left -= read_len
00080 buf.seek(0)
00081 return buf.read(length)
00082 finally:
00083 buf.close()
00084 return self._fp.read(bytes_left)[:-2]
00085
00086
00087 return self._fp.readline()[:-2]
00088 except (socket.error, socket.timeout):
00089 e = sys.exc_info()[1]
00090 raise ConnectionError("Error while reading from socket: %s" %
00091 (e.args,))
00092
00093 def parse_error(self, response):
00094 "Parse an error response"
00095 error_code = response.split(' ')[0]
00096 if error_code in self.EXCEPTION_CLASSES:
00097 response = response[len(error_code) + 1:]
00098 return self.EXCEPTION_CLASSES[error_code](response)
00099 return ResponseError(response)
00100
00101 def read_response(self):
00102 response = self.read()
00103 if not response:
00104 raise ConnectionError("Socket closed on remote end")
00105
00106 byte, response = byte_to_chr(response[0]), response[1:]
00107
00108 if byte not in ('-', '+', ':', '$', '*'):
00109 raise InvalidResponse("Protocol Error")
00110
00111
00112 if byte == '-':
00113 response = nativestr(response)
00114 if response.startswith('LOADING '):
00115
00116
00117 raise ConnectionError("Redis is loading data into memory")
00118
00119
00120 return self.parse_error(response)
00121
00122 elif byte == '+':
00123 pass
00124
00125 elif byte == ':':
00126 response = long(response)
00127
00128 elif byte == '$':
00129 length = int(response)
00130 if length == -1:
00131 return None
00132 response = self.read(length)
00133
00134 elif byte == '*':
00135 length = int(response)
00136 if length == -1:
00137 return None
00138 response = [self.read_response() for i in xrange(length)]
00139 if isinstance(response, bytes) and self.encoding:
00140 response = response.decode(self.encoding)
00141 return response
00142
00143
00144 class HiredisParser(object):
00145 "Parser class for connections using Hiredis"
00146 def __init__(self):
00147 if not hiredis_available:
00148 raise RedisError("Hiredis is not installed")
00149
00150 def __del__(self):
00151 try:
00152 self.on_disconnect()
00153 except:
00154 pass
00155
00156 def on_connect(self, connection):
00157 self._sock = connection._sock
00158 kwargs = {
00159 'protocolError': InvalidResponse,
00160 'replyError': ResponseError,
00161 }
00162 if connection.decode_responses:
00163 kwargs['encoding'] = connection.encoding
00164 self._reader = hiredis.Reader(**kwargs)
00165
00166 def on_disconnect(self):
00167 self._sock = None
00168 self._reader = None
00169
00170 def read_response(self):
00171 if not self._reader:
00172 raise ConnectionError("Socket closed on remote end")
00173 response = self._reader.gets()
00174 while response is False:
00175 try:
00176 buffer = self._sock.recv(4096)
00177 except (socket.error, socket.timeout):
00178 e = sys.exc_info()[1]
00179 raise ConnectionError("Error while reading from socket: %s" %
00180 (e.args,))
00181 if not buffer:
00182 raise ConnectionError("Socket closed on remote end")
00183 self._reader.feed(buffer)
00184
00185
00186 if not buffer.endswith(SYM_LF):
00187 continue
00188 response = self._reader.gets()
00189 return response
00190
00191 if hiredis_available:
00192 DefaultParser = HiredisParser
00193 else:
00194 DefaultParser = PythonParser
00195
00196
00197 class Connection(object):
00198 "Manages TCP communication to and from a Redis server"
00199 def __init__(self, host='localhost', port=6379, db=0, password=None,
00200 socket_timeout=None, encoding='utf-8',
00201 encoding_errors='strict', decode_responses=False,
00202 parser_class=DefaultParser):
00203 self.pid = os.getpid()
00204 self.host = host
00205 self.port = port
00206 self.db = db
00207 self.password = password
00208 self.socket_timeout = socket_timeout
00209 self.encoding = encoding
00210 self.encoding_errors = encoding_errors
00211 self.decode_responses = decode_responses
00212 self._sock = None
00213 self._parser = parser_class()
00214
00215 def __del__(self):
00216 try:
00217 self.disconnect()
00218 except:
00219 pass
00220
00221 def connect(self):
00222 "Connects to the Redis server if not already connected"
00223 if self._sock:
00224 return
00225 try:
00226 sock = self._connect()
00227 except socket.error:
00228 e = sys.exc_info()[1]
00229 raise ConnectionError(self._error_message(e))
00230
00231 self._sock = sock
00232 self.on_connect()
00233
00234 def _connect(self):
00235 "Create a TCP socket connection"
00236 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
00237 sock.settimeout(self.socket_timeout)
00238 sock.connect((self.host, self.port))
00239 return sock
00240
00241 def _error_message(self, exception):
00242
00243
00244 if len(exception.args) == 1:
00245 return "Error connecting to %s:%s. %s." % \
00246 (self.host, self.port, exception.args[0])
00247 else:
00248 return "Error %s connecting %s:%s. %s." % \
00249 (exception.args[0], self.host, self.port, exception.args[1])
00250
00251 def on_connect(self):
00252 "Initialize the connection, authenticate and select a database"
00253 self._parser.on_connect(self)
00254
00255
00256 if self.password:
00257 self.send_command('AUTH', self.password)
00258 if nativestr(self.read_response()) != 'OK':
00259 raise AuthenticationError('Invalid Password')
00260
00261
00262 if self.db:
00263 self.send_command('SELECT', self.db)
00264 if nativestr(self.read_response()) != 'OK':
00265 raise ConnectionError('Invalid Database')
00266
00267 def disconnect(self):
00268 "Disconnects from the Redis server"
00269 self._parser.on_disconnect()
00270 if self._sock is None:
00271 return
00272 try:
00273 self._sock.shutdown(socket.SHUT_RDWR)
00274 self._sock.close()
00275 except socket.error:
00276 pass
00277 self._sock = None
00278
00279 def send_packed_command(self, command):
00280 "Send an already packed command to the Redis server"
00281 if not self._sock:
00282 self.connect()
00283 try:
00284 self._sock.sendall(command)
00285 except socket.error:
00286 e = sys.exc_info()[1]
00287 self.disconnect()
00288 if len(e.args) == 1:
00289 _errno, errmsg = 'UNKNOWN', e.args[0]
00290 else:
00291 _errno, errmsg = e.args
00292 raise ConnectionError("Error %s while writing to socket. %s." %
00293 (_errno, errmsg))
00294 except:
00295 self.disconnect()
00296 raise
00297
00298 def send_command(self, *args):
00299 "Pack and send a command to the Redis server"
00300 self.send_packed_command(self.pack_command(*args))
00301
00302 def read_response(self):
00303 "Read the response from a previously sent command"
00304 try:
00305 response = self._parser.read_response()
00306 except:
00307 self.disconnect()
00308 raise
00309 if isinstance(response, ResponseError):
00310 raise response
00311 return response
00312
00313 def encode(self, value):
00314 "Return a bytestring representation of the value"
00315 if isinstance(value, bytes):
00316 return value
00317 if not isinstance(value, unicode):
00318 value = str(value)
00319 if isinstance(value, unicode):
00320 value = value.encode(self.encoding, self.encoding_errors)
00321 return value
00322
00323 def pack_command(self, *args):
00324 "Pack a series of arguments into a value Redis command"
00325 output = SYM_STAR + b(str(len(args))) + SYM_CRLF
00326 for enc_value in imap(self.encode, args):
00327 output += SYM_DOLLAR
00328 output += b(str(len(enc_value)))
00329 output += SYM_CRLF
00330 output += enc_value
00331 output += SYM_CRLF
00332 return output
00333
00334
00335 class UnixDomainSocketConnection(Connection):
00336 def __init__(self, path='', db=0, password=None,
00337 socket_timeout=None, encoding='utf-8',
00338 encoding_errors='strict', decode_responses=False,
00339 parser_class=DefaultParser):
00340 self.pid = os.getpid()
00341 self.path = path
00342 self.db = db
00343 self.password = password
00344 self.socket_timeout = socket_timeout
00345 self.encoding = encoding
00346 self.encoding_errors = encoding_errors
00347 self.decode_responses = decode_responses
00348 self._sock = None
00349 self._parser = parser_class()
00350
00351 def _connect(self):
00352 "Create a Unix domain socket connection"
00353 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
00354 sock.settimeout(self.socket_timeout)
00355 sock.connect(self.path)
00356 return sock
00357
00358 def _error_message(self, exception):
00359
00360
00361 if len(exception.args) == 1:
00362 return "Error connecting to unix socket: %s. %s." % \
00363 (self.path, exception.args[0])
00364 else:
00365 return "Error %s connecting to unix socket: %s. %s." % \
00366 (exception.args[0], self.path, exception.args[1])
00367
00368
00369
00370 class ConnectionPool(object):
00371 "Generic connection pool"
00372 def __init__(self, connection_class=Connection, max_connections=None,
00373 **connection_kwargs):
00374 self.pid = os.getpid()
00375 self.connection_class = connection_class
00376 self.connection_kwargs = connection_kwargs
00377 self.max_connections = max_connections or 2 ** 31
00378 self._created_connections = 0
00379 self._available_connections = []
00380 self._in_use_connections = set()
00381
00382 def _checkpid(self):
00383 if self.pid != os.getpid():
00384 self.disconnect()
00385 self.__init__(self.connection_class, self.max_connections,
00386 **self.connection_kwargs)
00387
00388 def get_connection(self, command_name, *keys, **options):
00389 "Get a connection from the pool"
00390 self._checkpid()
00391 try:
00392 connection = self._available_connections.pop()
00393 except IndexError:
00394 connection = self.make_connection()
00395 self._in_use_connections.add(connection)
00396 return connection
00397
00398 def make_connection(self):
00399 "Create a new connection"
00400 if self._created_connections >= self.max_connections:
00401 raise ConnectionError("Too many connections")
00402 self._created_connections += 1
00403 return self.connection_class(**self.connection_kwargs)
00404
00405 def release(self, connection):
00406 "Releases the connection back to the pool"
00407 self._checkpid()
00408 if connection.pid == self.pid:
00409 self._in_use_connections.remove(connection)
00410 self._available_connections.append(connection)
00411
00412 def disconnect(self):
00413 "Disconnects all connections in the pool"
00414 all_conns = chain(self._available_connections,
00415 self._in_use_connections)
00416 for connection in all_conns:
00417 connection.disconnect()