00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015 """Test built in connection-pooling."""
00016
00017 import os
00018 import random
00019 import sys
00020 import threading
00021 import time
00022 import unittest
00023 sys.path[0:0] = [""]
00024
00025 from nose.plugins.skip import SkipTest
00026
00027 from pymongo.connection import _Pool
00028 from test_connection import get_connection
00029
00030 N = 50
00031 DB = "pymongo-pooling-tests"
00032
00033 class MongoThread(threading.Thread):
00034
00035 def __init__(self, test_case):
00036 threading.Thread.__init__(self)
00037 self.connection = test_case.c
00038 self.db = self.connection[DB]
00039 self.ut = test_case
00040
00041
00042 class SaveAndFind(MongoThread):
00043
00044 def run(self):
00045 for _ in xrange(N):
00046 rand = random.randint(0, N)
00047 id = self.db.sf.save({"x": rand}, safe=True)
00048 self.ut.assertEqual(rand, self.db.sf.find_one(id)["x"])
00049 self.connection.end_request()
00050
00051
00052 class Unique(MongoThread):
00053
00054 def run(self):
00055 for _ in xrange(N):
00056 self.db.unique.insert({})
00057 self.ut.assertEqual(None, self.db.error())
00058 self.connection.end_request()
00059
00060
00061 class NonUnique(MongoThread):
00062
00063 def run(self):
00064 for _ in xrange(N):
00065 self.db.unique.insert({"_id": "mike"})
00066 self.ut.assertNotEqual(None, self.db.error())
00067 self.connection.end_request()
00068
00069
00070 class Disconnect(MongoThread):
00071
00072 def run(self):
00073 for _ in xrange(N):
00074 self.connection.disconnect()
00075
00076
00077 class NoRequest(MongoThread):
00078
00079 def run(self):
00080 errors = 0
00081 for _ in xrange(N):
00082 self.db.unique.insert({"_id": "mike"})
00083 if self.db.error() is None:
00084 errors += 1
00085
00086 self.ut.assertEqual(0, errors)
00087
00088
00089 def run_cases(ut, cases):
00090 threads = []
00091 for case in cases:
00092 for i in range(10):
00093 thread = case(ut)
00094 thread.start()
00095 threads.append(thread)
00096
00097 for t in threads:
00098 t.join()
00099
00100
00101 class OneOp(threading.Thread):
00102
00103 def __init__(self, connection):
00104 threading.Thread.__init__(self)
00105 self.c = connection
00106
00107 def run(self):
00108 assert len(self.c._Connection__pool.sockets) == 1
00109 self.c.test.test.find_one()
00110 assert len(self.c._Connection__pool.sockets) == 0
00111 self.c.end_request()
00112 assert len(self.c._Connection__pool.sockets) == 1
00113
00114
00115 class CreateAndReleaseSocket(threading.Thread):
00116
00117 def __init__(self, connection):
00118 threading.Thread.__init__(self)
00119 self.c = connection
00120
00121 def run(self):
00122 self.c.test.test.find_one()
00123 time.sleep(1)
00124 self.c.end_request()
00125
00126
00127 class TestPooling(unittest.TestCase):
00128
00129 def setUp(self):
00130 self.c = get_connection()
00131
00132
00133 self.c.drop_database(DB)
00134 self.c[DB].unique.insert({"_id": "mike"})
00135 self.c[DB].unique.find_one()
00136
00137 def test_no_disconnect(self):
00138 run_cases(self, [NoRequest, NonUnique, Unique, SaveAndFind])
00139
00140 def test_disconnect(self):
00141 run_cases(self, [SaveAndFind, Disconnect, Unique])
00142
00143 def test_independent_pools(self):
00144 p = _Pool(None)
00145 self.assertEqual([], p.sockets)
00146 self.c.end_request()
00147 self.assertEqual([], p.sockets)
00148
00149
00150 p1 = _Pool(5)
00151 self.assertEqual(None, p.socket_factory)
00152 self.assertEqual(5, p1.socket_factory)
00153
00154 def test_dependent_pools(self):
00155 c = get_connection()
00156 self.assertEqual(1, len(c._Connection__pool.sockets))
00157 c.test.test.find_one()
00158 self.assertEqual(0, len(c._Connection__pool.sockets))
00159 c.end_request()
00160 self.assertEqual(1, len(c._Connection__pool.sockets))
00161
00162 t = OneOp(c)
00163 t.start()
00164 t.join()
00165
00166 self.assertEqual(1, len(c._Connection__pool.sockets))
00167 c.test.test.find_one()
00168 self.assertEqual(0, len(c._Connection__pool.sockets))
00169
00170 def test_multiple_connections(self):
00171 a = get_connection()
00172 b = get_connection()
00173 self.assertEqual(1, len(a._Connection__pool.sockets))
00174 self.assertEqual(1, len(b._Connection__pool.sockets))
00175
00176 a.test.test.find_one()
00177 a.end_request()
00178 self.assertEqual(1, len(a._Connection__pool.sockets))
00179 self.assertEqual(1, len(b._Connection__pool.sockets))
00180 a_sock = a._Connection__pool.sockets[0]
00181
00182 b.end_request()
00183 self.assertEqual(1, len(a._Connection__pool.sockets))
00184 self.assertEqual(1, len(b._Connection__pool.sockets))
00185
00186 b.test.test.find_one()
00187 self.assertEqual(1, len(a._Connection__pool.sockets))
00188 self.assertEqual(0, len(b._Connection__pool.sockets))
00189
00190 b.end_request()
00191 b_sock = b._Connection__pool.sockets[0]
00192 b.test.test.find_one()
00193 a.test.test.find_one()
00194 self.assertEqual(b_sock, b._Connection__pool.socket())
00195 self.assertEqual(a_sock, a._Connection__pool.socket())
00196
00197 def test_pool_with_fork(self):
00198 if sys.platform == "win32":
00199 raise SkipTest()
00200
00201 try:
00202 from multiprocessing import Process, Pipe
00203 except ImportError:
00204 raise SkipTest()
00205
00206 a = get_connection()
00207 a.test.test.find_one()
00208 a.end_request()
00209 self.assertEqual(1, len(a._Connection__pool.sockets))
00210 a_sock = a._Connection__pool.sockets[0]
00211
00212 def loop(pipe):
00213 c = get_connection()
00214 self.assertEqual(1, len(c._Connection__pool.sockets))
00215 c.test.test.find_one()
00216 self.assertEqual(0, len(c._Connection__pool.sockets))
00217 c.end_request()
00218 self.assertEqual(1, len(c._Connection__pool.sockets))
00219 pipe.send(c._Connection__pool.sockets[0].getsockname())
00220
00221 cp1, cc1 = Pipe()
00222 cp2, cc2 = Pipe()
00223
00224 p1 = Process(target=loop, args=(cc1,))
00225 p2 = Process(target=loop, args=(cc2,))
00226
00227 p1.start()
00228 p2.start()
00229
00230 p1.join(1)
00231 p2.join(1)
00232
00233 p1.terminate()
00234 p2.terminate()
00235
00236 p1.join()
00237 p2.join()
00238
00239 cc1.close()
00240 cc2.close()
00241
00242 b_sock = cp1.recv()
00243 c_sock = cp2.recv()
00244 self.assert_(a_sock.getsockname() != b_sock)
00245 self.assert_(a_sock.getsockname() != c_sock)
00246 self.assert_(b_sock != c_sock)
00247 self.assertEqual(a_sock, a._Connection__pool.socket())
00248
00249 def test_max_pool_size(self):
00250 c = get_connection()
00251
00252 threads = []
00253 for i in range(40):
00254 t = CreateAndReleaseSocket(c)
00255 t.start()
00256 threads.append(t)
00257
00258 for t in threads:
00259 t.join()
00260
00261
00262 self.assert_(abs(10 - len(c._Connection__pool.sockets)) < 10)
00263
00264
00265 if __name__ == "__main__":
00266 unittest.main()