00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017 """A lightweight wrapper around MySQLdb."""
00018
00019 from __future__ import absolute_import, division, with_statement
00020
00021 import copy
00022 import itertools
00023 import logging
00024 import time
00025
00026 try:
00027 import MySQLdb.constants
00028 import MySQLdb.converters
00029 import MySQLdb.cursors
00030 except ImportError:
00031
00032
00033
00034 MySQLdb = None
00035
00036
00037 class Connection(object):
00038 """A lightweight wrapper around MySQLdb DB-API connections.
00039
00040 The main value we provide is wrapping rows in a dict/object so that
00041 columns can be accessed by name. Typical usage::
00042
00043 db = database.Connection("localhost", "mydatabase")
00044 for article in db.query("SELECT * FROM articles"):
00045 print article.title
00046
00047 Cursors are hidden by the implementation, but other than that, the methods
00048 are very similar to the DB-API.
00049
00050 We explicitly set the timezone to UTC and the character encoding to
00051 UTF-8 on all connections to avoid time zone and encoding errors.
00052 """
00053 def __init__(self, host, database, user=None, password=None,
00054 max_idle_time=7 * 3600):
00055 self.host = host
00056 self.database = database
00057 self.max_idle_time = max_idle_time
00058
00059 args = dict(conv=CONVERSIONS, use_unicode=True, charset="utf8",
00060 db=database, init_command='SET time_zone = "+0:00"',
00061 sql_mode="TRADITIONAL")
00062 if user is not None:
00063 args["user"] = user
00064 if password is not None:
00065 args["passwd"] = password
00066
00067
00068 if "/" in host:
00069 args["unix_socket"] = host
00070 else:
00071 self.socket = None
00072 pair = host.split(":")
00073 if len(pair) == 2:
00074 args["host"] = pair[0]
00075 args["port"] = int(pair[1])
00076 else:
00077 args["host"] = host
00078 args["port"] = 3306
00079
00080 self._db = None
00081 self._db_args = args
00082 self._last_use_time = time.time()
00083 try:
00084 self.reconnect()
00085 except Exception:
00086 logging.error("Cannot connect to MySQL on %s", self.host,
00087 exc_info=True)
00088
00089 def __del__(self):
00090 self.close()
00091
00092 def close(self):
00093 """Closes this database connection."""
00094 if getattr(self, "_db", None) is not None:
00095 self._db.close()
00096 self._db = None
00097
00098 def reconnect(self):
00099 """Closes the existing database connection and re-opens it."""
00100 self.close()
00101 self._db = MySQLdb.connect(**self._db_args)
00102 self._db.autocommit(True)
00103
00104 def iter(self, query, *parameters):
00105 """Returns an iterator for the given query and parameters."""
00106 self._ensure_connected()
00107 cursor = MySQLdb.cursors.SSCursor(self._db)
00108 try:
00109 self._execute(cursor, query, parameters)
00110 column_names = [d[0] for d in cursor.description]
00111 for row in cursor:
00112 yield Row(zip(column_names, row))
00113 finally:
00114 cursor.close()
00115
00116 def query(self, query, *parameters):
00117 """Returns a row list for the given query and parameters."""
00118 cursor = self._cursor()
00119 try:
00120 self._execute(cursor, query, parameters)
00121 column_names = [d[0] for d in cursor.description]
00122 return [Row(itertools.izip(column_names, row)) for row in cursor]
00123 finally:
00124 cursor.close()
00125
00126 def get(self, query, *parameters):
00127 """Returns the first row returned for the given query."""
00128 rows = self.query(query, *parameters)
00129 if not rows:
00130 return None
00131 elif len(rows) > 1:
00132 raise Exception("Multiple rows returned for Database.get() query")
00133 else:
00134 return rows[0]
00135
00136
00137
00138 def execute(self, query, *parameters):
00139 """Executes the given query, returning the lastrowid from the query."""
00140 return self.execute_lastrowid(query, *parameters)
00141
00142 def execute_lastrowid(self, query, *parameters):
00143 """Executes the given query, returning the lastrowid from the query."""
00144 cursor = self._cursor()
00145 try:
00146 self._execute(cursor, query, parameters)
00147 return cursor.lastrowid
00148 finally:
00149 cursor.close()
00150
00151 def execute_rowcount(self, query, *parameters):
00152 """Executes the given query, returning the rowcount from the query."""
00153 cursor = self._cursor()
00154 try:
00155 self._execute(cursor, query, parameters)
00156 return cursor.rowcount
00157 finally:
00158 cursor.close()
00159
00160 def executemany(self, query, parameters):
00161 """Executes the given query against all the given param sequences.
00162
00163 We return the lastrowid from the query.
00164 """
00165 return self.executemany_lastrowid(query, parameters)
00166
00167 def executemany_lastrowid(self, query, parameters):
00168 """Executes the given query against all the given param sequences.
00169
00170 We return the lastrowid from the query.
00171 """
00172 cursor = self._cursor()
00173 try:
00174 cursor.executemany(query, parameters)
00175 return cursor.lastrowid
00176 finally:
00177 cursor.close()
00178
00179 def executemany_rowcount(self, query, parameters):
00180 """Executes the given query against all the given param sequences.
00181
00182 We return the rowcount from the query.
00183 """
00184 cursor = self._cursor()
00185 try:
00186 cursor.executemany(query, parameters)
00187 return cursor.rowcount
00188 finally:
00189 cursor.close()
00190
00191 def _ensure_connected(self):
00192
00193
00194
00195
00196
00197 if (self._db is None or
00198 (time.time() - self._last_use_time > self.max_idle_time)):
00199 self.reconnect()
00200 self._last_use_time = time.time()
00201
00202 def _cursor(self):
00203 self._ensure_connected()
00204 return self._db.cursor()
00205
00206 def _execute(self, cursor, query, parameters):
00207 try:
00208 return cursor.execute(query, parameters)
00209 except OperationalError:
00210 logging.error("Error connecting to MySQL on %s", self.host)
00211 self.close()
00212 raise
00213
00214
00215 class Row(dict):
00216 """A dict that allows for object-like property access syntax."""
00217 def __getattr__(self, name):
00218 try:
00219 return self[name]
00220 except KeyError:
00221 raise AttributeError(name)
00222
00223 if MySQLdb is not None:
00224
00225 FIELD_TYPE = MySQLdb.constants.FIELD_TYPE
00226 FLAG = MySQLdb.constants.FLAG
00227 CONVERSIONS = copy.copy(MySQLdb.converters.conversions)
00228
00229 field_types = [FIELD_TYPE.BLOB, FIELD_TYPE.STRING, FIELD_TYPE.VAR_STRING]
00230 if 'VARCHAR' in vars(FIELD_TYPE):
00231 field_types.append(FIELD_TYPE.VARCHAR)
00232
00233 for field_type in field_types:
00234 CONVERSIONS[field_type] = [(FLAG.BINARY, str)] + CONVERSIONS[field_type]
00235
00236
00237 IntegrityError = MySQLdb.IntegrityError
00238 OperationalError = MySQLdb.OperationalError