EventCount.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2016 Dmitry Vyukov <dvyukov@google.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_THREADPOOL_EVENTCOUNT_H_
11 #define EIGEN_CXX11_THREADPOOL_EVENTCOUNT_H_
12 
13 namespace Eigen {
14 
15 // EventCount allows to wait for arbitrary predicates in non-blocking
16 // algorithms. Think of condition variable, but wait predicate does not need to
17 // be protected by a mutex. Usage:
18 // Waiting thread does:
19 //
20 // if (predicate)
21 // return act();
22 // EventCount::Waiter& w = waiters[my_index];
23 // ec.Prewait(&w);
24 // if (predicate) {
25 // ec.CancelWait(&w);
26 // return act();
27 // }
28 // ec.CommitWait(&w);
29 //
30 // Notifying thread does:
31 //
32 // predicate = true;
33 // ec.Notify(true);
34 //
35 // Notify is cheap if there are no waiting threads. Prewait/CommitWait are not
36 // cheap, but they are executed only if the preceding predicate check has
37 // failed.
38 //
39 // Algorithm outline:
40 // There are two main variables: predicate (managed by user) and state_.
41 // Operation closely resembles Dekker mutual algorithm:
42 // https://en.wikipedia.org/wiki/Dekker%27s_algorithm
43 // Waiting thread sets state_ then checks predicate, Notifying thread sets
44 // predicate then checks state_. Due to seq_cst fences in between these
45 // operations it is guaranteed than either waiter will see predicate change
46 // and won't block, or notifying thread will see state_ change and will unblock
47 // the waiter, or both. But it can't happen that both threads don't see each
48 // other changes, which would lead to deadlock.
49 class EventCount {
50  public:
51  class Waiter;
52 
54  : state_(kStackMask), waiters_(waiters) {
55  eigen_plain_assert(waiters.size() < (1 << kWaiterBits) - 1);
56  }
57 
59  // Ensure there are no waiters.
61  }
62 
63  // Prewait prepares for waiting.
64  // After calling Prewait, the thread must re-check the wait predicate
65  // and then call either CancelWait or CommitWait.
66  void Prewait() {
67  uint64_t state = state_.load(std::memory_order_relaxed);
68  for (;;) {
69  CheckState(state);
70  uint64_t newstate = state + kWaiterInc;
71  CheckState(newstate);
72  if (state_.compare_exchange_weak(state, newstate,
73  std::memory_order_seq_cst))
74  return;
75  }
76  }
77 
78  // CommitWait commits waiting after Prewait.
79  void CommitWait(Waiter* w) {
80  eigen_plain_assert((w->epoch & ~kEpochMask) == 0);
82  const uint64_t me = (w - &waiters_[0]) | w->epoch;
83  uint64_t state = state_.load(std::memory_order_seq_cst);
84  for (;;) {
85  CheckState(state, true);
86  uint64_t newstate;
87  if ((state & kSignalMask) != 0) {
88  // Consume the signal and return immidiately.
89  newstate = state - kWaiterInc - kSignalInc;
90  } else {
91  // Remove this thread from pre-wait counter and add to the waiter stack.
92  newstate = ((state & kWaiterMask) - kWaiterInc) | me;
93  w->next.store(state & (kStackMask | kEpochMask),
94  std::memory_order_relaxed);
95  }
96  CheckState(newstate);
97  if (state_.compare_exchange_weak(state, newstate,
98  std::memory_order_acq_rel)) {
99  if ((state & kSignalMask) == 0) {
100  w->epoch += kEpochInc;
101  Park(w);
102  }
103  return;
104  }
105  }
106  }
107 
108  // CancelWait cancels effects of the previous Prewait call.
109  void CancelWait() {
110  uint64_t state = state_.load(std::memory_order_relaxed);
111  for (;;) {
112  CheckState(state, true);
113  uint64_t newstate = state - kWaiterInc;
114  // We don't know if the thread was also notified or not,
115  // so we should not consume a signal unconditionaly.
116  // Only if number of waiters is equal to number of signals,
117  // we know that the thread was notified and we must take away the signal.
118  if (((state & kWaiterMask) >> kWaiterShift) ==
119  ((state & kSignalMask) >> kSignalShift))
120  newstate -= kSignalInc;
121  CheckState(newstate);
122  if (state_.compare_exchange_weak(state, newstate,
123  std::memory_order_acq_rel))
124  return;
125  }
126  }
127 
128  // Notify wakes one or all waiting threads.
129  // Must be called after changing the associated wait predicate.
130  void Notify(bool notifyAll) {
131  std::atomic_thread_fence(std::memory_order_seq_cst);
132  uint64_t state = state_.load(std::memory_order_acquire);
133  for (;;) {
134  CheckState(state);
135  const uint64_t waiters = (state & kWaiterMask) >> kWaiterShift;
136  const uint64_t signals = (state & kSignalMask) >> kSignalShift;
137  // Easy case: no waiters.
138  if ((state & kStackMask) == kStackMask && waiters == signals) return;
139  uint64_t newstate;
140  if (notifyAll) {
141  // Empty wait stack and set signal to number of pre-wait threads.
142  newstate =
143  (state & kWaiterMask) | (waiters << kSignalShift) | kStackMask;
144  } else if (signals < waiters) {
145  // There is a thread in pre-wait state, unblock it.
146  newstate = state + kSignalInc;
147  } else {
148  // Pop a waiter from list and unpark it.
149  Waiter* w = &waiters_[state & kStackMask];
150  uint64_t next = w->next.load(std::memory_order_relaxed);
151  newstate = (state & (kWaiterMask | kSignalMask)) | next;
152  }
153  CheckState(newstate);
154  if (state_.compare_exchange_weak(state, newstate,
155  std::memory_order_acq_rel)) {
156  if (!notifyAll && (signals < waiters))
157  return; // unblocked pre-wait thread
158  if ((state & kStackMask) == kStackMask) return;
159  Waiter* w = &waiters_[state & kStackMask];
160  if (!notifyAll) w->next.store(kStackMask, std::memory_order_relaxed);
161  Unpark(w);
162  return;
163  }
164  }
165  }
166 
167  class Waiter {
168  friend class EventCount;
169  // Align to 128 byte boundary to prevent false sharing with other Waiter
170  // objects in the same vector.
171  EIGEN_ALIGN_TO_BOUNDARY(128) std::atomic<uint64_t> next;
172  std::mutex mu;
173  std::condition_variable cv;
175  unsigned state = kNotSignaled;
176  enum {
180  };
181  };
182 
183  private:
184  // State_ layout:
185  // - low kWaiterBits is a stack of waiters committed wait
186  // (indexes in waiters_ array are used as stack elements,
187  // kStackMask means empty stack).
188  // - next kWaiterBits is count of waiters in prewait state.
189  // - next kWaiterBits is count of pending signals.
190  // - remaining bits are ABA counter for the stack.
191  // (stored in Waiter node and incremented on push).
192  static const uint64_t kWaiterBits = 14;
193  static const uint64_t kStackMask = (1ull << kWaiterBits) - 1;
195  static const uint64_t kWaiterMask = ((1ull << kWaiterBits) - 1)
196  << kWaiterShift;
197  static const uint64_t kWaiterInc = 1ull << kWaiterShift;
198  static const uint64_t kSignalShift = 2 * kWaiterBits;
199  static const uint64_t kSignalMask = ((1ull << kWaiterBits) - 1)
200  << kSignalShift;
201  static const uint64_t kSignalInc = 1ull << kSignalShift;
202  static const uint64_t kEpochShift = 3 * kWaiterBits;
203  static const uint64_t kEpochBits = 64 - kEpochShift;
204  static const uint64_t kEpochMask = ((1ull << kEpochBits) - 1) << kEpochShift;
205  static const uint64_t kEpochInc = 1ull << kEpochShift;
206  std::atomic<uint64_t> state_;
208 
209  static void CheckState(uint64_t state, bool waiter = false) {
210  static_assert(kEpochBits >= 20, "not enough bits to prevent ABA problem");
211  const uint64_t waiters = (state & kWaiterMask) >> kWaiterShift;
212  const uint64_t signals = (state & kSignalMask) >> kSignalShift;
213  eigen_plain_assert(waiters >= signals);
214  eigen_plain_assert(waiters < (1 << kWaiterBits) - 1);
215  eigen_plain_assert(!waiter || waiters > 0);
216  (void)waiters;
217  (void)signals;
218  }
219 
220  void Park(Waiter* w) {
221  std::unique_lock<std::mutex> lock(w->mu);
222  while (w->state != Waiter::kSignaled) {
223  w->state = Waiter::kWaiting;
224  w->cv.wait(lock);
225  }
226  }
227 
228  void Unpark(Waiter* w) {
229  for (Waiter* next; w; w = next) {
230  uint64_t wnext = w->next.load(std::memory_order_relaxed) & kStackMask;
231  next = wnext == kStackMask ? nullptr : &waiters_[wnext];
232  unsigned state;
233  {
234  std::unique_lock<std::mutex> lock(w->mu);
235  state = w->state;
237  }
238  // Avoid notifying if it wasn't waiting.
239  if (state == Waiter::kWaiting) w->cv.notify_one();
240  }
241  }
242 
243  EventCount(const EventCount&) = delete;
244  void operator=(const EventCount&) = delete;
245 };
246 
247 } // namespace Eigen
248 
249 #endif // EIGEN_CXX11_THREADPOOL_EVENTCOUNT_H_
static const uint64_t kSignalShift
Definition: EventCount.h:198
std::condition_variable cv
Definition: EventCount.h:173
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE size_t size() const
static const uint64_t kSignalInc
Definition: EventCount.h:201
EventCount(MaxSizeVector< Waiter > &waiters)
Definition: EventCount.h:53
static const uint64_t kEpochMask
Definition: EventCount.h:204
Namespace containing all symbols from the Eigen library.
Definition: jet.h:637
Definition: BFloat16.h:88
static void CheckState(uint64_t state, bool waiter=false)
Definition: EventCount.h:209
static const uint64_t kEpochShift
Definition: EventCount.h:202
static const uint64_t kWaiterMask
Definition: EventCount.h:195
void operator=(const EventCount &)=delete
MaxSizeVector< Waiter > & waiters_
Definition: EventCount.h:207
static const uint64_t kSignalMask
Definition: EventCount.h:199
static const uint64_t kStackMask
Definition: EventCount.h:193
EIGEN_ALIGN_TO_BOUNDARY(128) std std::mutex mu
Definition: EventCount.h:171
static const uint64_t kWaiterBits
Definition: EventCount.h:192
std::atomic< uint64_t > state_
Definition: EventCount.h:206
unsigned __int64 uint64_t
Definition: ms_stdint.h:95
#define eigen_plain_assert(x)
Definition: Macros.h:1007
RowVector3d w
static const uint64_t kWaiterShift
Definition: EventCount.h:194
void Unpark(Waiter *w)
Definition: EventCount.h:228
The MaxSizeVector class.
Definition: MaxSizeVector.h:31
void Park(Waiter *w)
Definition: EventCount.h:220
void Notify(bool notifyAll)
Definition: EventCount.h:130
static const uint64_t kEpochBits
Definition: EventCount.h:203
static const uint64_t kWaiterInc
Definition: EventCount.h:197
static const uint64_t kEpochInc
Definition: EventCount.h:205
void CommitWait(Waiter *w)
Definition: EventCount.h:79


gtsam
Author(s):
autogenerated on Tue Jul 4 2023 02:34:12