00001
00002
00003 from __future__ import with_statement
00004
00005 import thread
00006 import threading
00007 import weakref
00008
00009
00010
00011
00012
00013 class DeadlockException(Exception):
00014 pass
00015
00016 class MultipleUnsubscribe(Exception):
00017 pass
00018
00019 class ReentrantDetectingLockEnterHelper:
00020 def __init__(self, parent):
00021 self.parent = parent
00022
00023 def __enter__(self):
00024 pass
00025
00026 def __exit__(self, *args):
00027 return self.parent.release()
00028
00029 class ReentrantDetectingLock():
00030 def __init__(self, *args, **kwargs):
00031 self._lock = threading.Lock(*args, **kwargs)
00032 self._owner = None
00033
00034 def acquire(self, blocking=1):
00035 ti = thread.get_ident()
00036 if self._owner == ti and blocking:
00037 raise DeadlockException("Tried to recursively lock a non recursive lock.")
00038 if self._lock.acquire(blocking):
00039 self._owner = ti
00040 return True
00041 else:
00042 return False
00043
00044 def release(self):
00045 self._owner = None
00046 self._lock.release()
00047
00048 def __enter__(self):
00049 self.acquire()
00050
00051 def __exit__(self, *args):
00052 self.release()
00053
00054 def __call__(self, msg):
00055 try:
00056 self.acquire()
00057 return ReentrantDetectingLockEnterHelper(self)
00058 except DeadlockException:
00059 raise DeadlockException(msg)
00060
00061 class EventCallbackHandle:
00062 def __init__(self, event, cb, args, kwargs, repeating):
00063
00064 self._event = weakref.ref(event)
00065 self._cb = cb
00066 self._args = args
00067 self._kwargs = kwargs
00068 self._call_lock = ReentrantDetectingLock()
00069 self._repeating = repeating
00070
00071 def __enter__(self):
00072 return self
00073
00074 def __exit__(self, *args):
00075 self.unsubscribe()
00076
00077 def _trigger(self, args, kwargs):
00078
00079 allargs = self._args + args
00080
00081
00082
00083 allkwargs = dict(kwargs)
00084 allkwargs.update(self._kwargs)
00085
00086
00087 cb = self._cb
00088 with self._call_lock("Callback recursively triggered itself."):
00089 if cb is not None:
00090 if not self._repeating:
00091 self.unsubscribe(blocking = False, _not_called_from_trigger = False)
00092 cb(*allargs, **allkwargs)
00093
00094 def unsubscribe(self, blocking = True, _not_called_from_trigger = True):
00095
00096
00097 self._cb = None
00098 self._args = ()
00099 self._kwargs = ()
00100 event = self._event()
00101 if event is not None:
00102 try:
00103 del event._subscribers[self]
00104 event._subscription_change()
00105 except KeyError:
00106
00107
00108
00109
00110 if self._repeating and _not_called_from_trigger:
00111 raise MultipleUnsubscribe("Repeating callback unsubscribed multiple times.")
00112
00113
00114
00115 if _not_called_from_trigger:
00116 self._repeating = True
00117 if blocking:
00118
00119 with self._call_lock("Callback tried to blocking unsubscribe itself."):
00120 pass
00121
00122 class Event:
00123 def __init__(self, name = "Unnamed Event"):
00124 self._name = name
00125 self._subscribers = {}
00126 self._subscription_change()
00127
00128 def subscribe(self, cb, args, kwargs, repeating = True):
00129 """Subscribes to an event.
00130
00131 Can be called at any time and from any thread. Subscriptions that
00132 occur while an event is being triggered will not be called until
00133 the next time the event is triggered."""
00134
00135 h = EventCallbackHandle(self, cb, args, kwargs, repeating)
00136 self._subscribers[h] = None
00137 self._subscription_change()
00138 return h
00139
00140 def subscribe_once(*args, **kwargs):
00141
00142
00143 return args[0].subscribe(args[1], args[2:], kwargs, repeating = False)
00144
00145 def subscribe_repeating(*args, **kwargs):
00146
00147
00148 return args[0].subscribe(args[1], args[2:], kwargs, repeating = True)
00149
00150 def trigger(*args, **kwargs):
00151 """Triggers an event.
00152
00153 Concurrent triggers of a given callback are serialized using a lock, so
00154 triggering from a callback will cause a deadlock."""
00155
00156 self = args[0]
00157 args = args[1:]
00158
00159 for h in self._subscribers.keys():
00160 h._trigger(args, kwargs)
00161
00162 def _subscription_change(self):
00163 """Called at the end of each subscribe/unsubscribe. Can be
00164 overloaded in derived classes."""
00165 pass
00166
00167 def unsubscribe_all(self, blocking = True):
00168 """Unsubscribes all subscribers that were present at the start of
00169 the call."""
00170 subs = self._subscribers.keys()
00171 for s in subs:
00172 s._repeating = False
00173 s.unsubscribe(blocking, _not_called_from_trigger = False)
00174
00175 if __name__ == "__main__":
00176 import unittest
00177 import sys
00178
00179 def append_cb(l, *args, **kwargs):
00180 l.append((args, kwargs))
00181
00182 class Unsubscriber:
00183 def __init__(self, l, blocking):
00184 self.l = l
00185 self.blocking = blocking
00186
00187 def cb(self, *args, **kwargs):
00188 append_cb(self.l, *args, **kwargs)
00189 self.h.unsubscribe(blocking = self.blocking)
00190
00191 class BasicTest(unittest.TestCase):
00192 def test_basic(self):
00193 """Tests basic functionality.
00194
00195 Adds a couple of callbacks. Makes sure they are called the
00196 right number of times. Checks that parameters are correct,
00197 including keyword arguments giving priority to subscribe over
00198 trigger."""
00199 e = Event()
00200 l1 = []
00201 l2 = []
00202 h1 = e.subscribe_repeating(append_cb, l1, 'd', e = 'f')
00203 e.subscribe_once(append_cb, l2, 'a', b = 'c')
00204 e.trigger('1', g = 'h')
00205 e.trigger('2', e = 'x')
00206 h1.unsubscribe()
00207 e.trigger('3')
00208 sys.stdout.flush()
00209 self.assertEqual(l1, [
00210 (('d', '1'), { 'e' : 'f', 'g' : 'h'}),
00211 (('d', '2'), { 'e' : 'f'}),
00212 ])
00213 self.assertEqual(l2, [
00214 (('a', '1'), { 'b' : 'c', 'g': 'h'}),
00215 ])
00216
00217 def test_subscription_change(self):
00218 """Test that the _subscription_change is called appropriately."""
00219 l = []
00220 class SubChangeEvent(Event):
00221 def _subscription_change(self):
00222 l.append(len(self._subscribers))
00223 e = SubChangeEvent()
00224 h1 = e.subscribe_repeating(None)
00225 h2 = e.subscribe_repeating(None)
00226 h1.unsubscribe()
00227 h3 = e.subscribe_repeating(None)
00228 h2.unsubscribe()
00229 h3.unsubscribe()
00230 self.assertEqual(l, [0, 1, 2, 1, 2, 1, 0])
00231
00232 def test_unsub_myself(self):
00233 """Tests that a callback can unsubscribe itself."""
00234 e = Event()
00235 l = []
00236 u = Unsubscriber(l, False)
00237 u.h = e.subscribe_repeating(u.cb)
00238 e.trigger('t1')
00239 e.trigger('t2')
00240 self.assertEqual(l, [
00241 (('t1',), {}),
00242 ])
00243
00244 def test_multiple_unsubscribe_repeating(self):
00245 """Tests exceptoin on multiple unsubscribe for repeating subscribers."""
00246 e = Event()
00247 h = e.subscribe_repeating(None)
00248 h.unsubscribe()
00249 self.assertRaises(MultipleUnsubscribe, h.unsubscribe)
00250
00251 def test_multiple_unsubscribe_once(self):
00252 """Tests exceptoin on multiple unsubscribe for non-repeating subscribers."""
00253 e = Event()
00254 h = e.subscribe_repeating(None)
00255 h.unsubscribe()
00256 self.assertRaises(MultipleUnsubscribe, h.unsubscribe)
00257
00258 def test_unsubscribe_all(self):
00259 """Tests basic unsubscribe_all functionality."""
00260 e = Event()
00261 e.subscribe_repeating(None)
00262 e.subscribe_repeating(None)
00263 e.subscribe_repeating(None)
00264 e.subscribe_repeating(None)
00265 e.subscribe_repeating(None)
00266 e.unsubscribe_all()
00267 self.assertEqual(len(e._subscribers), 0)
00268
00269 def test_unsub_myself_blocking(self):
00270 """Tests that a blocking unsubscribe on myself raises exception."""
00271 e = Event()
00272 l = []
00273 u = Unsubscriber(l, True)
00274 u.h = e.subscribe_repeating(u.cb)
00275 self.assertRaises(DeadlockException, e.trigger, ['t1'])
00276
00277 def test_unsub_myself_nonblocking(self):
00278 """Tests that a nonblocking unsubscribe on myself does not raise."""
00279 e = Event()
00280 l = []
00281 u = Unsubscriber(l, False)
00282 u.h = e.subscribe_repeating(u.cb)
00283 e.trigger('t1')
00284
00285 def wait_cv(cv, l, cb, trig):
00286 with cv:
00287 l.append((cb, trig, 'pre'))
00288 cv.notify_all()
00289 cv.wait()
00290 l.append((cb, trig, 'post'))
00291
00292 class ThreadTest(unittest.TestCase):
00293 def setUp(self):
00294 self.e = Event()
00295 self.cv = threading.Condition()
00296 self.l = []
00297 self.h1 = self.e.subscribe_once(wait_cv, self.cv, self.l, 'cb1')
00298 self.h2 = self.e.subscribe_once(wait_cv, self.cv, self.l, 'cb1')
00299 self.t = threading.Thread(target = self.e.trigger, args=['t1'])
00300 self.t.start()
00301
00302 def test_norun_sub_during_trig(self):
00303 """Tests that a callback that gets added during a trigger is
00304 not run."""
00305
00306
00307 with self.cv:
00308
00309 while not self.l:
00310 self.cv.wait()
00311
00312 self.l.append('main')
00313 self.e.subscribe_repeating(append_cb, self.l, 'cb2')
00314 self.cv.notify_all()
00315
00316
00317 while len(self.l) != 4:
00318 self.cv.wait()
00319 self.l.append('main2')
00320 self.cv.notify_all()
00321
00322
00323 self.t.join()
00324
00325 self.expected = [
00326 ('cb1', 't1', 'pre'),
00327 'main',
00328 ('cb1', 't1', 'post'),
00329 ('cb1', 't1', 'pre'),
00330 'main2',
00331 ('cb1', 't1', 'post'),
00332 (('cb2', 't2'), {}),
00333 ]
00334
00335
00336 self.e.trigger('t2')
00337
00338 self.assertEqual(self.l, self.expected)
00339
00340 def test_norun_unsub_during_trig(self):
00341 """Tests that a callback that gets deleted during a trigger is
00342 not run."""
00343
00344
00345 with self.cv:
00346
00347 while not self.l:
00348 self.cv.wait()
00349
00350 self.l.append('main')
00351 unsubed = 0
00352 self.h1.unsubscribe(blocking = False)
00353 self.h2.unsubscribe(blocking = False)
00354 self.cv.notify_all()
00355
00356
00357 self.t.join()
00358
00359 self.expected = [
00360 ('cb1', 't1', 'pre'),
00361 'main',
00362 ('cb1', 't1', 'post'),
00363 ]
00364
00365
00366 self.e.trigger('t2')
00367
00368 self.assertEqual(self.l, self.expected)
00369
00370 if len(sys.argv) > 1 and sys.argv[1].startswith("--gtest_output="):
00371 import roslib; roslib.load_manifest('multi_interface_roam')
00372 import rostest
00373 rostest.unitrun('multi_interface_roam', 'event_basic', BasicTest)
00374 rostest.unitrun('multi_interface_roam', 'event_thread', ThreadTest)
00375 else:
00376 unittest.main()