10 #ifndef THIRD_PARTY_EIGEN3_EIGEN_SRC_CORE_ARCH_AVX512_MATHFUNCTIONS_H_
11 #define THIRD_PARTY_EIGEN3_EIGEN_SRC_CORE_ARCH_AVX512_MATHFUNCTIONS_H_
18 #if EIGEN_GNUC_AT_LEAST(5, 3)
20 #define _EIGEN_DECLARE_CONST_Packet16f(NAME, X) \
21 const Packet16f p16f_##NAME = pset1<Packet16f>(X)
23 #define _EIGEN_DECLARE_CONST_Packet16f_FROM_INT(NAME, X) \
24 const Packet16f p16f_##NAME = (__m512)pset1<Packet16i>(X)
26 #define _EIGEN_DECLARE_CONST_Packet8d(NAME, X) \
27 const Packet8d p8d_##NAME = pset1<Packet8d>(X)
29 #define _EIGEN_DECLARE_CONST_Packet8d_FROM_INT64(NAME, X) \
30 const Packet8d p8d_##NAME = _mm512_castsi512_pd(_mm512_set1_epi64(X))
36 #if defined(EIGEN_VECTORIZE_AVX512DQ)
41 _EIGEN_DECLARE_CONST_Packet16f(1, 1.0f);
42 _EIGEN_DECLARE_CONST_Packet16f(half, 0.5f);
43 _EIGEN_DECLARE_CONST_Packet16f(126f, 126.0f);
45 _EIGEN_DECLARE_CONST_Packet16f_FROM_INT(inv_mant_mask, ~0x7f800000);
48 _EIGEN_DECLARE_CONST_Packet16f_FROM_INT(min_norm_pos, 0x00800000);
49 _EIGEN_DECLARE_CONST_Packet16f_FROM_INT(minus_inf, 0xff800000);
50 _EIGEN_DECLARE_CONST_Packet16f_FROM_INT(nan, 0x7fc00000);
53 _EIGEN_DECLARE_CONST_Packet16f(cephes_SQRTHF, 0.707106781186547524f);
54 _EIGEN_DECLARE_CONST_Packet16f(cephes_log_p0, 7.0376836292E-2f);
55 _EIGEN_DECLARE_CONST_Packet16f(cephes_log_p1, -1.1514610310E-1f);
56 _EIGEN_DECLARE_CONST_Packet16f(cephes_log_p2, 1.1676998740E-1f);
57 _EIGEN_DECLARE_CONST_Packet16f(cephes_log_p3, -1.2420140846E-1f);
58 _EIGEN_DECLARE_CONST_Packet16f(cephes_log_p4, +1.4249322787E-1f);
59 _EIGEN_DECLARE_CONST_Packet16f(cephes_log_p5, -1.6668057665E-1f);
60 _EIGEN_DECLARE_CONST_Packet16f(cephes_log_p6, +2.0000714765E-1f);
61 _EIGEN_DECLARE_CONST_Packet16f(cephes_log_p7, -2.4999993993E-1f);
62 _EIGEN_DECLARE_CONST_Packet16f(cephes_log_p8, +3.3333331174E-1f);
63 _EIGEN_DECLARE_CONST_Packet16f(cephes_log_q1, -2.12194440e-4f);
64 _EIGEN_DECLARE_CONST_Packet16f(cephes_log_q2, 0.693359375f);
67 __mmask16 invalid_mask =
68 _mm512_cmp_ps_mask(
x, _mm512_setzero_ps(), _CMP_NGE_UQ);
69 __mmask16 iszero_mask =
70 _mm512_cmp_ps_mask(
x, _mm512_setzero_ps(), _CMP_EQ_UQ);
73 x =
pmax(
x, p16f_min_norm_pos);
76 Packet16f emm0 = _mm512_cvtepi32_ps(_mm512_srli_epi32((__m512i)
x, 23));
77 Packet16f e = _mm512_sub_ps(emm0, p16f_126f);
80 x = _mm512_and_ps(
x, p16f_inv_mant_mask);
81 x = _mm512_or_ps(
x, p16f_half);
90 __mmask16 mask = _mm512_cmp_ps_mask(
x, p16f_cephes_SQRTHF, _CMP_LT_OQ);
91 Packet16f tmp = _mm512_mask_blend_ps(mask, _mm512_setzero_ps(),
x);
93 e =
psub(e, _mm512_mask_blend_ps(mask, _mm512_setzero_ps(), p16f_1));
102 y =
pmadd(p16f_cephes_log_p0,
x, p16f_cephes_log_p1);
103 y1 =
pmadd(p16f_cephes_log_p3,
x, p16f_cephes_log_p4);
104 y2 =
pmadd(p16f_cephes_log_p6,
x, p16f_cephes_log_p7);
105 y =
pmadd(
y,
x, p16f_cephes_log_p2);
106 y1 =
pmadd(y1,
x, p16f_cephes_log_p5);
107 y2 =
pmadd(y2,
x, p16f_cephes_log_p8);
113 y1 =
pmul(e, p16f_cephes_log_q1);
114 tmp =
pmul(x2, p16f_half);
117 y2 =
pmul(e, p16f_cephes_log_q2);
122 return _mm512_mask_blend_ps(iszero_mask,
123 _mm512_mask_blend_ps(invalid_mask,
x, p16f_nan),
134 _EIGEN_DECLARE_CONST_Packet16f(1, 1.0f);
135 _EIGEN_DECLARE_CONST_Packet16f(half, 0.5f);
136 _EIGEN_DECLARE_CONST_Packet16f(127, 127.0f);
138 _EIGEN_DECLARE_CONST_Packet16f(exp_hi, 88.3762626647950f);
139 _EIGEN_DECLARE_CONST_Packet16f(exp_lo, -88.3762626647949f);
141 _EIGEN_DECLARE_CONST_Packet16f(cephes_LOG2EF, 1.44269504088896341f);
143 _EIGEN_DECLARE_CONST_Packet16f(cephes_exp_p0, 1.9875691500E-4f);
144 _EIGEN_DECLARE_CONST_Packet16f(cephes_exp_p1, 1.3981999507E-3f);
145 _EIGEN_DECLARE_CONST_Packet16f(cephes_exp_p2, 8.3334519073E-3f);
146 _EIGEN_DECLARE_CONST_Packet16f(cephes_exp_p3, 4.1665795894E-2f);
147 _EIGEN_DECLARE_CONST_Packet16f(cephes_exp_p4, 1.6666665459E-1f);
148 _EIGEN_DECLARE_CONST_Packet16f(cephes_exp_p5, 5.0000001201E-1f);
155 Packet16f m = _mm512_floor_ps(
pmadd(
x, p16f_cephes_LOG2EF, p16f_half));
159 _EIGEN_DECLARE_CONST_Packet16f(nln2, -0.6931471805599453f);
160 Packet16f r = _mm512_fmadd_ps(m, p16f_nln2,
x);
166 y =
pmadd(
y, r, p16f_cephes_exp_p1);
167 y =
pmadd(
y, r, p16f_cephes_exp_p2);
168 y =
pmadd(
y, r, p16f_cephes_exp_p3);
169 y =
pmadd(
y, r, p16f_cephes_exp_p4);
170 y =
pmadd(
y, r, p16f_cephes_exp_p5);
176 emm0 = _mm512_slli_epi32(emm0, 23);
179 return pmax(
pmul(
y, _mm512_castsi512_ps(emm0)), _x);
261 _EIGEN_DECLARE_CONST_Packet16f(one_point_five, 1.5f);
262 _EIGEN_DECLARE_CONST_Packet16f(minus_half, -0.5f);
263 _EIGEN_DECLARE_CONST_Packet16f_FROM_INT(flt_min, 0x00800000);
269 __mmask16 non_zero_mask = _mm512_cmp_ps_mask(_x, p16f_flt_min, _CMP_GE_OQ);
270 Packet16f x = _mm512_mask_blend_ps(non_zero_mask, _mm512_setzero_ps(), _mm512_rsqrt14_ps(_x));
282 psqrt<Packet8d>(
const Packet8d& _x) {
283 _EIGEN_DECLARE_CONST_Packet8d(one_point_five, 1.5);
284 _EIGEN_DECLARE_CONST_Packet8d(minus_half, -0.5);
285 _EIGEN_DECLARE_CONST_Packet8d_FROM_INT64(dbl_min, 0x0010000000000000LL);
291 __mmask8 non_zero_mask = _mm512_cmp_pd_mask(_x, p8d_dbl_min, _CMP_GE_OQ);
292 Packet8d x = _mm512_mask_blend_pd(non_zero_mask, _mm512_setzero_pd(), _mm512_rsqrt14_pd(_x));
307 return _mm512_sqrt_ps(
x);
311 return _mm512_sqrt_pd(
x);
320 #ifdef EIGEN_FAST_MATH
324 _EIGEN_DECLARE_CONST_Packet16f_FROM_INT(inf, 0x7f800000);
325 _EIGEN_DECLARE_CONST_Packet16f_FROM_INT(nan, 0x7fc00000);
326 _EIGEN_DECLARE_CONST_Packet16f(one_point_five, 1.5f);
327 _EIGEN_DECLARE_CONST_Packet16f(minus_half, -0.5f);
328 _EIGEN_DECLARE_CONST_Packet16f_FROM_INT(flt_min, 0x00800000);
334 __mmask16 le_zero_mask = _mm512_cmp_ps_mask(_x, p16f_flt_min, _CMP_LT_OQ);
335 Packet16f x = _mm512_mask_blend_ps(le_zero_mask, _mm512_rsqrt14_ps(_x), _mm512_setzero_ps());
338 __mmask16 neg_mask = _mm512_cmp_ps_mask(_x, _mm512_setzero_ps(), _CMP_LT_OQ);
339 Packet16f infs_and_nans = _mm512_mask_blend_ps(
340 neg_mask, _mm512_mask_blend_ps(le_zero_mask, _mm512_setzero_ps(), p16f_inf), p16f_nan);
346 return _mm512_mask_blend_ps(le_zero_mask,
x, infs_and_nans);
351 prsqrt<Packet8d>(
const Packet8d& _x) {
352 _EIGEN_DECLARE_CONST_Packet8d_FROM_INT64(inf, 0x7ff0000000000000LL);
353 _EIGEN_DECLARE_CONST_Packet8d_FROM_INT64(nan, 0x7ff1000000000000LL);
354 _EIGEN_DECLARE_CONST_Packet8d(one_point_five, 1.5);
355 _EIGEN_DECLARE_CONST_Packet8d(minus_half, -0.5);
356 _EIGEN_DECLARE_CONST_Packet8d_FROM_INT64(dbl_min, 0x0010000000000000LL);
362 __mmask8 le_zero_mask = _mm512_cmp_pd_mask(_x, p8d_dbl_min, _CMP_LT_OQ);
363 Packet8d x = _mm512_mask_blend_pd(le_zero_mask, _mm512_rsqrt14_pd(_x), _mm512_setzero_pd());
366 __mmask8 neg_mask = _mm512_cmp_pd_mask(_x, _mm512_setzero_pd(), _CMP_LT_OQ);
367 Packet8d infs_and_nans = _mm512_mask_blend_pd(
368 neg_mask, _mm512_mask_blend_pd(le_zero_mask, _mm512_setzero_pd(), p8d_inf), p8d_nan);
377 return _mm512_mask_blend_pd(le_zero_mask,
x, infs_and_nans);
379 #elif defined(EIGEN_VECTORIZE_AVX512ER)
382 return _mm512_rsqrt28_ps(
x);
391 #endif // THIRD_PARTY_EIGEN3_EIGEN_SRC_CORE_ARCH_AVX512_MATHFUNCTIONS_H_