10 #define EIGEN_USE_THREADS
15 #include <Eigen/CXX11/Tensor>
49 Eigen::ThreadPoolDevice thread_pool_device(&tp, internal::random<int>(3, 11));
50 out.device(thread_pool_device) = (in1 + in2 * 3.14f).cast<double>();
52 for (
int i = 0;
i < 200; ++
i) {
53 for (
int j = 0;
j < 30; ++
j) {
54 for (
int k = 0; k < 70; ++k) {
71 Eigen::ThreadPoolDevice thread_pool_device(&tp, internal::random<int>(3, 11));
74 out.device(thread_pool_device, [&
b]() {
b.Notify(); }) = (in1 + in2 * 3.14
f).cast<
double>();
77 for (
int i = 0;
i < 200; ++
i) {
78 for (
int j = 0;
j < 30; ++
j) {
79 for (
int k = 0; k < 70; ++k) {
96 Eigen::ThreadPoolDevice thread_pool_device(&tp, internal::random<int>(3, 11));
97 out.device(thread_pool_device) = in1;
98 out.device(thread_pool_device) += in2 * 3.14f;
100 for (
int i = 0;
i < 2; ++
i) {
101 for (
int j = 0;
j < 3; ++
j) {
102 for (
int k = 0; k < 7; ++k) {
109 template<
int DataLayout>
124 MapXf m_left(t_left.
data(), 1500, 1147);
125 MapXf m_right(t_right.
data(), 1147, 1400);
129 Eigen::ThreadPoolDevice thread_pool_device(&tp, 4);
132 t_result.
device(thread_pool_device) = t_left.contract(t_right, dims);
133 m_result = m_left * m_right;
135 for (ptrdiff_t
i = 0;
i < t_result.
size();
i++) {
137 if (fabsf(t_result(
i) - m_result(
i)) < 1
e-4
f) {
143 std::cout <<
"mismatch detected at index " <<
i <<
": " << t_result(
i)
144 <<
" vs " << m_result(
i) << std::endl;
149 template<
int DataLayout>
156 t_left = (t_left.constant(-0.5
f) + t_left.random()) * 2.0
f;
157 t_right = (t_right.constant(-0.6
f) + t_right.random()) * 2.0
f;
158 t_result = t_result.constant(NAN);
165 MapXf m_left(t_left.
data(), 32, 500);
166 MapXf m_right(t_right.
data(), 32, 28*28);
170 Eigen::ThreadPoolDevice thread_pool_device(&tp, 12);
173 t_result.
device(thread_pool_device) = t_left.contract(t_right, dims);
174 m_result = m_left.transpose() * m_right;
176 for (ptrdiff_t
i = 0;
i < t_result.
size();
i++) {
178 if (fabsf(t_result.
data()[
i] - m_result.
data()[
i]) >= 1
e-4
f) {
179 std::cout <<
"mismatch detected at index " <<
i <<
" : " << t_result.
data()[
i] <<
" vs " << m_result.
data()[
i] << std::endl;
185 t_left = (t_left.constant(-0.5
f) + t_left.random()) * 2.0
f;
186 t_result.
resize (1, 28*28);
187 t_result = t_result.constant(NAN);
188 t_result.
device(thread_pool_device) = t_left.contract(t_right, dims);
189 new(&m_left) MapXf(t_left.
data(), 32, 1);
190 m_result = m_left.transpose() * m_right;
191 for (ptrdiff_t
i = 0;
i < t_result.
size();
i++) {
193 if (fabsf(t_result.
data()[
i] - m_result.
data()[
i]) >= 1
e-4
f) {
194 std::cout <<
"mismatch detected: " << t_result.
data()[
i] <<
" vs " << m_result.
data()[
i] << std::endl;
201 t_left = (t_left.constant(-0.5
f) + t_left.random()) * 2.0
f;
202 t_right = (t_right.constant(-0.6
f) + t_right.random()) * 2.0
f;
204 t_result = t_result.constant(NAN);
205 t_result.
device(thread_pool_device) = t_left.contract(t_right, dims);
206 new(&m_left) MapXf(t_left.
data(), 32, 500);
207 new(&m_right) MapXf(t_right.
data(), 32, 4);
208 m_result = m_left.transpose() * m_right;
209 for (ptrdiff_t
i = 0;
i < t_result.
size();
i++) {
211 if (fabsf(t_result.
data()[
i] - m_result.
data()[
i]) >= 1
e-4
f) {
212 std::cout <<
"mismatch detected: " << t_result.
data()[
i] <<
" vs " << m_result.
data()[
i] << std::endl;
219 t_left = (t_left.constant(-0.5
f) + t_left.random()) * 2.0
f;
220 t_right = (t_right.constant(-0.6
f) + t_right.random()) * 2.0
f;
222 t_result = t_result.constant(NAN);
223 t_result.
device(thread_pool_device) = t_left.contract(t_right, dims);
224 new(&m_left) MapXf(t_left.
data(), 32, 1);
225 new(&m_right) MapXf(t_right.
data(), 32, 4);
226 m_result = m_left.transpose() * m_right;
227 for (ptrdiff_t
i = 0;
i < t_result.
size();
i++) {
229 if (fabsf(t_result.
data()[
i] - m_result.
data()[
i]) >= 1
e-4
f) {
230 std::cout <<
"mismatch detected: " << t_result.
data()[
i] <<
" vs " << m_result.
data()[
i] << std::endl;
236 template<
int DataLayout>
238 int contract_size = internal::random<int>(1, 5000);
242 internal::random<int>(1, 100));
245 internal::random<int>(1, 37),
247 internal::random<int>(1, 51));
260 Eigen::ThreadPoolDevice thread_pool_device(&tp, internal::random<int>(2, 11));
269 for (ptrdiff_t
i = 0;
i < st_result.
size();
i++) {
280 template <
typename Index,
typename Scalar>
282 const internal::blas_data_mapper<Scalar, Index, ColMajor>& output_mapper,
284 Index num_cols)
const {
285 for (
int i = 0;
i < num_rows; ++
i) {
286 for (
int j = 0;
j < num_cols; ++
j) {
293 template <
int DataLayout>
297 const int num_threads = internal::random<int>(2, 11);
299 Eigen::ThreadPoolDevice device(&threads, num_threads);
311 t_left += t_left.constant(1.0
f);
312 t_right += t_right.constant(1.0
f);
315 MapXf m_left(t_left.
data(), 1500, 248);
316 MapXf m_right(t_right.
data(), 248, 1400);
325 m_result = m_left * m_right;
333 template<
int DataLayout>
336 int contract_size = internal::random<int>(100, 500);
340 internal::random<int>(10, 40));
343 internal::random<int>(1, 20), internal::random<int>(1, 20), contract_size,
344 internal::random<int>(1, 20));
357 Eigen::ThreadPoolDevice thread_pool_device(&tp, internal::random<int>(8, 32));
365 tp_result.
device(thread_pool_device, [&barrier]() { barrier.
Notify(); }) =
370 for (ptrdiff_t
i = 0;
i < st_result.
size();
i++) {
380 template <
int DataLayout>
385 const int num_threads = internal::random<int>(4, 16);
387 Eigen::ThreadPoolDevice device(&threads, num_threads);
399 t_left += t_left.constant(1.0
f);
400 t_right += t_right.constant(1.0
f);
403 MapXf m_left(t_left.
data(), 2, 10000);
404 MapXf m_right(t_right.
data(), 10000, 10);
411 t_result.
device(device) = t_left.contract(t_right, dims);
412 m_result = m_left * m_right;
420 template <
int DataLayout>
425 const int num_threads = internal::random<int>(4, 16);
427 Eigen::ThreadPoolDevice device(&threads, num_threads);
439 t_left += t_left.constant(1.0
f);
440 t_right += t_right.constant(1.0
f);
443 MapXf m_left(t_left.
data(), 2, 10000);
444 MapXf m_right(t_right.
data(), 10000, 10);
452 m_result = m_left * m_right;
460 template <
int DataLayout>
465 const int num_threads = internal::random<int>(4, 16);
467 Eigen::ThreadPoolDevice device(&threads, num_threads);
479 t_left += t_left.constant(1.0
f);
480 t_right += t_right.constant(1.0
f);
483 MapXf m_left(t_left.
data(), 2, 10000);
484 MapXf m_right(t_right.
data(), 10000, 10);
492 t_result.
device(device, [&barrier]() { barrier.Notify(); }) =
493 t_left.contract(t_right, dims);
496 m_result = m_left * m_right;
504 template <
int DataLayout>
509 const int num_threads = internal::random<int>(4, 16);
511 Eigen::ThreadPoolDevice device(&threads, num_threads);
523 t_left += t_left.constant(1.0
f);
524 t_right += t_right.constant(1.0
f);
527 MapXf m_left(t_left.
data(), 2, 10000);
528 MapXf m_right(t_right.
data(), 10000, 10);
536 t_result.
device(device, [&barrier]() { barrier.Notify(); }) =
539 m_result = m_left * m_right;
546 template<
int DataLayout>
548 int contract_size1 = internal::random<int>(1, 500);
549 int contract_size2 = internal::random<int>(1, 500);
566 Eigen::ThreadPoolDevice thread_pool_device(&tp, internal::random<int>(2, 11));
582 template<
int DataLayout>
584 const int num_threads = internal::random<int>(3, 11);
586 Eigen::ThreadPoolDevice thread_pool_device(&thread_pool, num_threads);
588 const int num_rows = internal::random<int>(13, 732);
589 const int num_cols = internal::random<int>(13, 732);
594 full_redux = t1.sum();
597 full_redux_tp.
device(thread_pool_device) = t1.sum();
607 for (
int i = 0;
i < 5; ++
i) {
608 const int num_threads = internal::random<int>(3, 11);
610 Eigen::ThreadPoolDevice thread_pool_device(&tp, num_threads);
612 const int size = internal::random<int>(13, 7632);
616 thread_pool_device.memcpy(&
result[0], t1.
data(),
size*
sizeof(
float));
617 for (
int j = 0;
j <
size;
j++) {
627 Eigen::ThreadPoolDevice device(&tp, 2);
632 template<
int DataLayout>
638 const int num_threads = internal::random<int>(2, 11);
640 Eigen::ThreadPoolDevice device(&threads, num_threads, allocator);
646 for (
int i = 0;
i < 17; ++
i) {
647 for (
int j = 0;
j < 5; ++
j) {
648 for (
int k = 0; k < 7; ++k) {
649 for (
int l = 0;
l < 11; ++
l) {
659 const int num_threads = internal::random<int>(2, 11);
660 const int num_allocs = internal::random<int>(2, 11);
662 Eigen::ThreadPoolDevice device(&threads, num_threads, allocator);
664 for (
int a = 0;
a < num_allocs; ++
a) {
665 void* ptr = device.allocate(512);
666 device.deallocate(ptr);
682 CALL_SUBTEST_3(test_multithread_contraction_agrees_with_singlethread<ColMajor>());
683 CALL_SUBTEST_3(test_multithread_contraction_agrees_with_singlethread<RowMajor>());
684 CALL_SUBTEST_3(test_multithread_contraction_with_output_kernel<ColMajor>());
685 CALL_SUBTEST_3(test_multithread_contraction_with_output_kernel<RowMajor>());
687 CALL_SUBTEST_4(test_async_multithread_contraction_agrees_with_singlethread<ColMajor>());
688 CALL_SUBTEST_4(test_async_multithread_contraction_agrees_with_singlethread<RowMajor>());
691 CALL_SUBTEST_5(test_sharded_by_inner_dim_contraction<ColMajor>());
692 CALL_SUBTEST_5(test_sharded_by_inner_dim_contraction<RowMajor>());
693 CALL_SUBTEST_5(test_sharded_by_inner_dim_contraction_with_output_kernel<ColMajor>());
694 CALL_SUBTEST_5(test_sharded_by_inner_dim_contraction_with_output_kernel<RowMajor>());
696 CALL_SUBTEST_6(test_async_sharded_by_inner_dim_contraction<ColMajor>());
697 CALL_SUBTEST_6(test_async_sharded_by_inner_dim_contraction<RowMajor>());
698 CALL_SUBTEST_6(test_async_sharded_by_inner_dim_contraction_with_output_kernel<ColMajor>());
699 CALL_SUBTEST_6(test_async_sharded_by_inner_dim_contraction_with_output_kernel<RowMajor>());