00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015 """Test that pymongo is thread safe."""
00016
00017 import unittest
00018 import threading
00019
00020 from nose.plugins.skip import SkipTest
00021
00022 from test_connection import get_connection
00023 from pymongo.errors import AutoReconnect
00024
00025
00026 class SaveAndFind(threading.Thread):
00027
00028 def __init__(self, collection):
00029 threading.Thread.__init__(self)
00030 self.collection = collection
00031
00032 def run(self):
00033 sum = 0
00034 for document in self.collection.find():
00035 sum += document["x"]
00036 assert sum == 499500, "sum was %d not 499500" % sum
00037
00038
00039 class Insert(threading.Thread):
00040
00041 def __init__(self, collection, n, expect_exception):
00042 threading.Thread.__init__(self)
00043 self.collection = collection
00044 self.n = n
00045 self.expect_exception = expect_exception
00046
00047 def run(self):
00048 for _ in xrange(self.n):
00049 error = True
00050
00051 try:
00052 self.collection.insert({"test": "insert"}, safe=True)
00053 error = False
00054 except:
00055 if not self.expect_exception:
00056 raise
00057
00058 if self.expect_exception:
00059 assert error
00060
00061
00062 class Update(threading.Thread):
00063
00064 def __init__(self, collection, n, expect_exception):
00065 threading.Thread.__init__(self)
00066 self.collection = collection
00067 self.n = n
00068 self.expect_exception = expect_exception
00069
00070 def run(self):
00071 for _ in xrange(self.n):
00072 error = True
00073
00074 try:
00075 self.collection.update({"test": "unique"}, {"$set": {"test": "update"}}, safe=True)
00076 error = False
00077 except:
00078 if not self.expect_exception:
00079 raise
00080
00081 if self.expect_exception:
00082 assert error
00083
00084
00085 class IgnoreAutoReconnect(threading.Thread):
00086
00087 def __init__(self, collection, n):
00088 threading.Thread.__init__(self)
00089 self.c = collection
00090 self.n = n
00091
00092 def run(self):
00093 for _ in range(self.n):
00094 try:
00095 self.c.find_one()
00096 except AutoReconnect:
00097 pass
00098
00099
00100 class TestThreads(unittest.TestCase):
00101
00102 def setUp(self):
00103 self.db = get_connection().pymongo_test
00104
00105 def test_threading(self):
00106 self.db.test.remove({})
00107 for i in xrange(1000):
00108 self.db.test.save({"x": i}, safe=True)
00109
00110 threads = []
00111 for i in range(10):
00112 t = SaveAndFind(self.db.test)
00113 t.start()
00114 threads.append(t)
00115
00116 for t in threads:
00117 t.join()
00118
00119 def test_safe_insert(self):
00120 self.db.drop_collection("test1")
00121 self.db.test1.insert({"test": "insert"})
00122 self.db.drop_collection("test2")
00123 self.db.test2.insert({"test": "insert"})
00124
00125 self.db.test2.create_index("test", unique=True)
00126 self.db.test2.find_one()
00127
00128 okay = Insert(self.db.test1, 2000, False)
00129 error = Insert(self.db.test2, 2000, True)
00130
00131 error.start()
00132 okay.start()
00133
00134 error.join()
00135 okay.join()
00136
00137 def test_safe_update(self):
00138 self.db.drop_collection("test1")
00139 self.db.test1.insert({"test": "update"})
00140 self.db.test1.insert({"test": "unique"})
00141 self.db.drop_collection("test2")
00142 self.db.test2.insert({"test": "update"})
00143 self.db.test2.insert({"test": "unique"})
00144
00145 self.db.test2.create_index("test", unique=True)
00146 self.db.test2.find_one()
00147
00148 okay = Update(self.db.test1, 2000, False)
00149 error = Update(self.db.test2, 2000, True)
00150
00151 error.start()
00152 okay.start()
00153
00154 error.join()
00155 okay.join()
00156
00157 def test_low_network_timeout(self):
00158 db = None
00159 i = 0
00160 n = 10
00161 while db is None and i < n:
00162 try:
00163 db = get_connection(network_timeout=0.0001).pymongo_test
00164 except AutoReconnect:
00165 i += 1
00166 if i == n:
00167 raise SkipTest()
00168
00169 threads = []
00170 for _ in range(4):
00171 t = IgnoreAutoReconnect(db.test, 100)
00172 t.start()
00173 threads.append(t)
00174
00175 for t in threads:
00176 t.join()
00177
00178
00179 if __name__ == "__main__":
00180 unittest.main()