10 #ifndef THIRD_PARTY_EIGEN3_EIGEN_SRC_CORE_ARCH_AVX512_MATHFUNCTIONS_H_ 11 #define THIRD_PARTY_EIGEN3_EIGEN_SRC_CORE_ARCH_AVX512_MATHFUNCTIONS_H_ 18 #if EIGEN_GNUC_AT_LEAST(5, 3) 20 #define _EIGEN_DECLARE_CONST_Packet16f(NAME, X) \ 21 const Packet16f p16f_##NAME = pset1<Packet16f>(X) 23 #define _EIGEN_DECLARE_CONST_Packet16f_FROM_INT(NAME, X) \ 24 const Packet16f p16f_##NAME = (__m512)pset1<Packet16i>(X) 26 #define _EIGEN_DECLARE_CONST_Packet8d(NAME, X) \ 27 const Packet8d p8d_##NAME = pset1<Packet8d>(X) 29 #define _EIGEN_DECLARE_CONST_Packet8d_FROM_INT64(NAME, X) \ 30 const Packet8d p8d_##NAME = _mm512_castsi512_pd(_mm512_set1_epi64(X)) 36 #if defined(EIGEN_VECTORIZE_AVX512DQ) 41 _EIGEN_DECLARE_CONST_Packet16f(1, 1.0
f);
42 _EIGEN_DECLARE_CONST_Packet16f(half, 0.5
f);
43 _EIGEN_DECLARE_CONST_Packet16f(126
f, 126.0
f);
45 _EIGEN_DECLARE_CONST_Packet16f_FROM_INT(inv_mant_mask, ~0x7f800000);
48 _EIGEN_DECLARE_CONST_Packet16f_FROM_INT(min_norm_pos, 0x00800000);
49 _EIGEN_DECLARE_CONST_Packet16f_FROM_INT(minus_inf, 0xff800000);
50 _EIGEN_DECLARE_CONST_Packet16f_FROM_INT(nan, 0x7fc00000);
53 _EIGEN_DECLARE_CONST_Packet16f(cephes_SQRTHF, 0.707106781186547524
f);
54 _EIGEN_DECLARE_CONST_Packet16f(cephes_log_p0, 7.0376836292E-2
f);
55 _EIGEN_DECLARE_CONST_Packet16f(cephes_log_p1, -1.1514610310E-1
f);
56 _EIGEN_DECLARE_CONST_Packet16f(cephes_log_p2, 1.1676998740E-1
f);
57 _EIGEN_DECLARE_CONST_Packet16f(cephes_log_p3, -1.2420140846E-1
f);
58 _EIGEN_DECLARE_CONST_Packet16f(cephes_log_p4, +1.4249322787E-1
f);
59 _EIGEN_DECLARE_CONST_Packet16f(cephes_log_p5, -1.6668057665E-1
f);
60 _EIGEN_DECLARE_CONST_Packet16f(cephes_log_p6, +2.0000714765E-1
f);
61 _EIGEN_DECLARE_CONST_Packet16f(cephes_log_p7, -2.4999993993E-1
f);
62 _EIGEN_DECLARE_CONST_Packet16f(cephes_log_p8, +3.3333331174E-1
f);
63 _EIGEN_DECLARE_CONST_Packet16f(cephes_log_q1, -2.12194440e-4
f);
64 _EIGEN_DECLARE_CONST_Packet16f(cephes_log_q2, 0.693359375
f);
67 __mmask16 invalid_mask =
68 _mm512_cmp_ps_mask(x, _mm512_setzero_ps(), _CMP_NGE_UQ);
69 __mmask16 iszero_mask =
70 _mm512_cmp_ps_mask(x, _mm512_setzero_ps(), _CMP_EQ_UQ);
73 x =
pmax(x, p16f_min_norm_pos);
76 Packet16f emm0 = _mm512_cvtepi32_ps(_mm512_srli_epi32((__m512i)x, 23));
77 Packet16f e = _mm512_sub_ps(emm0, p16f_126f);
80 x = _mm512_and_ps(x, p16f_inv_mant_mask);
81 x = _mm512_or_ps(x, p16f_half);
90 __mmask16 mask = _mm512_cmp_ps_mask(x, p16f_cephes_SQRTHF, _CMP_LT_OQ);
91 Packet16f tmp = _mm512_mask_blend_ps(mask, x, _mm512_setzero_ps());
93 e =
psub(e, _mm512_mask_blend_ps(mask, p16f_1, _mm512_setzero_ps()));
102 y =
pmadd(p16f_cephes_log_p0, x, p16f_cephes_log_p1);
103 y1 =
pmadd(p16f_cephes_log_p3, x, p16f_cephes_log_p4);
104 y2 =
pmadd(p16f_cephes_log_p6, x, p16f_cephes_log_p7);
105 y =
pmadd(y, x, p16f_cephes_log_p2);
106 y1 =
pmadd(y1, x, p16f_cephes_log_p5);
107 y2 =
pmadd(y2, x, p16f_cephes_log_p8);
108 y =
pmadd(y, x3, y1);
109 y =
pmadd(y, x3, y2);
113 y1 =
pmul(e, p16f_cephes_log_q1);
114 tmp =
pmul(x2, p16f_half);
117 y2 =
pmul(e, p16f_cephes_log_q2);
122 return _mm512_mask_blend_ps(iszero_mask, p16f_minus_inf,
123 _mm512_mask_blend_ps(invalid_mask, p16f_nan, x));
133 _EIGEN_DECLARE_CONST_Packet16f(1, 1.0
f);
134 _EIGEN_DECLARE_CONST_Packet16f(half, 0.5
f);
135 _EIGEN_DECLARE_CONST_Packet16f(127, 127.0
f);
137 _EIGEN_DECLARE_CONST_Packet16f(exp_hi, 88.3762626647950
f);
138 _EIGEN_DECLARE_CONST_Packet16f(exp_lo, -88.3762626647949
f);
140 _EIGEN_DECLARE_CONST_Packet16f(cephes_LOG2EF, 1.44269504088896341
f);
142 _EIGEN_DECLARE_CONST_Packet16f(cephes_exp_p0, 1.9875691500E-4
f);
143 _EIGEN_DECLARE_CONST_Packet16f(cephes_exp_p1, 1.3981999507E-3
f);
144 _EIGEN_DECLARE_CONST_Packet16f(cephes_exp_p2, 8.3334519073E-3
f);
145 _EIGEN_DECLARE_CONST_Packet16f(cephes_exp_p3, 4.1665795894E-2
f);
146 _EIGEN_DECLARE_CONST_Packet16f(cephes_exp_p4, 1.6666665459E-1
f);
147 _EIGEN_DECLARE_CONST_Packet16f(cephes_exp_p5, 5.0000001201E-1
f);
154 Packet16f m = _mm512_floor_ps(
pmadd(x, p16f_cephes_LOG2EF, p16f_half));
158 _EIGEN_DECLARE_CONST_Packet16f(nln2, -0.6931471805599453
f);
159 Packet16f r = _mm512_fmadd_ps(m, p16f_nln2, x);
165 y =
pmadd(y, r, p16f_cephes_exp_p1);
166 y =
pmadd(y, r, p16f_cephes_exp_p2);
167 y =
pmadd(y, r, p16f_cephes_exp_p3);
168 y =
pmadd(y, r, p16f_cephes_exp_p4);
169 y =
pmadd(y, r, p16f_cephes_exp_p5);
175 emm0 = _mm512_slli_epi32(emm0, 23);
178 return pmax(
pmul(y, _mm512_castsi512_ps(emm0)), _x);
260 _EIGEN_DECLARE_CONST_Packet16f(one_point_five, 1.5
f);
261 _EIGEN_DECLARE_CONST_Packet16f(minus_half, -0.5
f);
262 _EIGEN_DECLARE_CONST_Packet16f_FROM_INT(flt_min, 0x00800000);
268 __mmask16 non_zero_mask = _mm512_cmp_ps_mask(_x, p16f_flt_min, _CMP_GE_OQ);
269 Packet16f x = _mm512_mask_blend_ps(non_zero_mask, _mm512_rsqrt14_ps(_x),
270 _mm512_setzero_ps());
273 x =
pmul(x,
pmadd(neg_half,
pmul(x, x), p16f_one_point_five));
282 psqrt<Packet8d>(
const Packet8d& _x) {
283 _EIGEN_DECLARE_CONST_Packet8d(one_point_five, 1.5);
284 _EIGEN_DECLARE_CONST_Packet8d(minus_half, -0.5);
285 _EIGEN_DECLARE_CONST_Packet8d_FROM_INT64(dbl_min, 0x0010000000000000LL);
291 __mmask8 non_zero_mask = _mm512_cmp_pd_mask(_x, p8d_dbl_min, _CMP_GE_OQ);
292 Packet8d x = _mm512_mask_blend_pd(non_zero_mask, _mm512_rsqrt14_pd(_x),
293 _mm512_setzero_pd());
296 x =
pmul(x,
pmadd(neg_half,
pmul(x, x), p8d_one_point_five));
299 x =
pmul(x,
pmadd(neg_half,
pmul(x, x), p8d_one_point_five));
308 return _mm512_sqrt_ps(x);
312 return _mm512_sqrt_pd(x);
321 #ifdef EIGEN_FAST_MATH 325 _EIGEN_DECLARE_CONST_Packet16f_FROM_INT(inf, 0x7f800000);
326 _EIGEN_DECLARE_CONST_Packet16f_FROM_INT(nan, 0x7fc00000);
327 _EIGEN_DECLARE_CONST_Packet16f(one_point_five, 1.5
f);
328 _EIGEN_DECLARE_CONST_Packet16f(minus_half, -0.5
f);
329 _EIGEN_DECLARE_CONST_Packet16f_FROM_INT(flt_min, 0x00800000);
335 __mmask16 le_zero_mask = _mm512_cmp_ps_mask(_x, p16f_flt_min, _CMP_LT_OQ);
336 Packet16f x = _mm512_mask_blend_ps(le_zero_mask, _mm512_setzero_ps(),
337 _mm512_rsqrt14_ps(_x));
340 __mmask16 neg_mask = _mm512_cmp_ps_mask(_x, _mm512_setzero_ps(), _CMP_LT_OQ);
341 Packet16f infs_and_nans = _mm512_mask_blend_ps(
343 _mm512_mask_blend_ps(le_zero_mask, p16f_inf, _mm512_setzero_ps()));
346 x =
pmul(x,
pmadd(neg_half,
pmul(x, x), p16f_one_point_five));
349 return _mm512_mask_blend_ps(le_zero_mask, infs_and_nans, x);
354 prsqrt<Packet8d>(
const Packet8d& _x) {
355 _EIGEN_DECLARE_CONST_Packet8d_FROM_INT64(inf, 0x7ff0000000000000LL);
356 _EIGEN_DECLARE_CONST_Packet8d_FROM_INT64(nan, 0x7ff1000000000000LL);
357 _EIGEN_DECLARE_CONST_Packet8d(one_point_five, 1.5);
358 _EIGEN_DECLARE_CONST_Packet8d(minus_half, -0.5);
359 _EIGEN_DECLARE_CONST_Packet8d_FROM_INT64(dbl_min, 0x0010000000000000LL);
365 __mmask8 le_zero_mask = _mm512_cmp_pd_mask(_x, p8d_dbl_min, _CMP_LT_OQ);
366 Packet8d x = _mm512_mask_blend_pd(le_zero_mask, _mm512_setzero_pd(),
367 _mm512_rsqrt14_pd(_x));
370 __mmask8 neg_mask = _mm512_cmp_pd_mask(_x, _mm512_setzero_pd(), _CMP_LT_OQ);
371 Packet8d infs_and_nans = _mm512_mask_blend_pd(
373 _mm512_mask_blend_pd(le_zero_mask, p8d_inf, _mm512_setzero_pd()));
376 x =
pmul(x,
pmadd(neg_half,
pmul(x, x), p8d_one_point_five));
379 x =
pmul(x,
pmadd(neg_half,
pmul(x, x), p8d_one_point_five));
382 return _mm512_mask_blend_pd(le_zero_mask, infs_and_nans, x);
387 return _mm512_rsqrt28_ps(x);
396 #endif // THIRD_PARTY_EIGEN3_EIGEN_SRC_CORE_ARCH_AVX512_MATHFUNCTIONS_H_ #define EIGEN_STRONG_INLINE
static int f(const TensorMap< Tensor< int, 3 > > &tensor)
EIGEN_DEVICE_FUNC const Scalar & x
#define EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
EIGEN_DEVICE_FUNC Packet padd(const Packet &a, const Packet &b)
EIGEN_DEVICE_FUNC Packet pmin(const Packet &a, const Packet &b)
EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f &a, const Packet4f &b, const Packet4f &c)
EIGEN_DEVICE_FUNC Packet psub(const Packet &a, const Packet &b)
EIGEN_DEVICE_FUNC Packet pmul(const Packet &a, const Packet &b)
EIGEN_DEVICE_FUNC Packet pmax(const Packet &a, const Packet &b)