special_functions.cpp
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2016 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #include "main.h"
11 #include "../Eigen/SpecialFunctions"
12 
13 template<typename X, typename Y>
14 void verify_component_wise(const X& x, const Y& y)
15 {
16  for(Index i=0; i<x.size(); ++i)
17  {
18  if((numext::isfinite)(y(i)))
19  VERIFY_IS_APPROX( x(i), y(i) );
20  else if((numext::isnan)(y(i)))
21  VERIFY((numext::isnan)(x(i)));
22  else
23  VERIFY_IS_EQUAL( x(i), y(i) );
24  }
25 }
26 
27 template<typename ArrayType> void array_special_functions()
28 {
29  using std::abs;
30  using std::sqrt;
31  typedef typename ArrayType::Scalar Scalar;
32  typedef typename NumTraits<Scalar>::Real RealScalar;
33 
34  Scalar plusinf = std::numeric_limits<Scalar>::infinity();
35  Scalar nan = std::numeric_limits<Scalar>::quiet_NaN();
36 
37  Index rows = internal::random<Index>(1,30);
38  Index cols = 1;
39 
40  // API
41  {
42  ArrayType m1 = ArrayType::Random(rows,cols);
43 #if EIGEN_HAS_C99_MATH
44  VERIFY_IS_APPROX(m1.lgamma(), lgamma(m1));
45  VERIFY_IS_APPROX(m1.digamma(), digamma(m1));
46  VERIFY_IS_APPROX(m1.erf(), erf(m1));
47  VERIFY_IS_APPROX(m1.erfc(), erfc(m1));
48 #endif // EIGEN_HAS_C99_MATH
49  }
50 
51 
52 #if EIGEN_HAS_C99_MATH
53  // check special functions (comparing against numpy implementation)
55  {
56 
57  {
58  ArrayType m1 = ArrayType::Random(rows,cols);
59  ArrayType m2 = ArrayType::Random(rows,cols);
60 
61  // Test various propreties of igamma & igammac. These are normalized
62  // gamma integrals where
63  // igammac(a, x) = Gamma(a, x) / Gamma(a)
64  // igamma(a, x) = gamma(a, x) / Gamma(a)
65  // where Gamma and gamma are considered the standard unnormalized
66  // upper and lower incomplete gamma functions, respectively.
67  ArrayType a = m1.abs() + 2;
68  ArrayType x = m2.abs() + 2;
69  ArrayType zero = ArrayType::Zero(rows, cols);
70  ArrayType one = ArrayType::Constant(rows, cols, Scalar(1.0));
71  ArrayType a_m1 = a - one;
72  ArrayType Gamma_a_x = Eigen::igammac(a, x) * a.lgamma().exp();
73  ArrayType Gamma_a_m1_x = Eigen::igammac(a_m1, x) * a_m1.lgamma().exp();
74  ArrayType gamma_a_x = Eigen::igamma(a, x) * a.lgamma().exp();
75  ArrayType gamma_a_m1_x = Eigen::igamma(a_m1, x) * a_m1.lgamma().exp();
76 
77  // Gamma(a, 0) == Gamma(a)
78  VERIFY_IS_APPROX(Eigen::igammac(a, zero), one);
79 
80  // Gamma(a, x) + gamma(a, x) == Gamma(a)
81  VERIFY_IS_APPROX(Gamma_a_x + gamma_a_x, a.lgamma().exp());
82 
83  // Gamma(a, x) == (a - 1) * Gamma(a-1, x) + x^(a-1) * exp(-x)
84  VERIFY_IS_APPROX(Gamma_a_x, (a - 1) * Gamma_a_m1_x + x.pow(a-1) * (-x).exp());
85 
86  // gamma(a, x) == (a - 1) * gamma(a-1, x) - x^(a-1) * exp(-x)
87  VERIFY_IS_APPROX(gamma_a_x, (a - 1) * gamma_a_m1_x - x.pow(a-1) * (-x).exp());
88  }
89 
90  {
91  // Check exact values of igamma and igammac against a third party calculation.
92  Scalar a_s[] = {Scalar(0), Scalar(1), Scalar(1.5), Scalar(4), Scalar(0.0001), Scalar(1000.5)};
93  Scalar x_s[] = {Scalar(0), Scalar(1), Scalar(1.5), Scalar(4), Scalar(0.0001), Scalar(1000.5)};
94 
95  // location i*6+j corresponds to a_s[i], x_s[j].
96  Scalar igamma_s[][6] = {{0.0, nan, nan, nan, nan, nan},
97  {0.0, 0.6321205588285578, 0.7768698398515702,
98  0.9816843611112658, 9.999500016666262e-05, 1.0},
99  {0.0, 0.4275932955291202, 0.608374823728911,
100  0.9539882943107686, 7.522076445089201e-07, 1.0},
101  {0.0, 0.01898815687615381, 0.06564245437845008,
102  0.5665298796332909, 4.166333347221828e-18, 1.0},
103  {0.0, 0.9999780593618628, 0.9999899967080838,
104  0.9999996219837988, 0.9991370418689945, 1.0},
105  {0.0, 0.0, 0.0, 0.0, 0.0, 0.5042041932513908}};
106  Scalar igammac_s[][6] = {{nan, nan, nan, nan, nan, nan},
107  {1.0, 0.36787944117144233, 0.22313016014842982,
108  0.018315638888734182, 0.9999000049998333, 0.0},
109  {1.0, 0.5724067044708798, 0.3916251762710878,
110  0.04601170568923136, 0.9999992477923555, 0.0},
111  {1.0, 0.9810118431238462, 0.9343575456215499,
112  0.4334701203667089, 1.0, 0.0},
113  {1.0, 2.1940638138146658e-05, 1.0003291916285e-05,
114  3.7801620118431334e-07, 0.0008629581310054535,
115  0.0},
116  {1.0, 1.0, 1.0, 1.0, 1.0, 0.49579580674813944}};
117  for (int i = 0; i < 6; ++i) {
118  for (int j = 0; j < 6; ++j) {
119  if ((std::isnan)(igamma_s[i][j])) {
120  VERIFY((std::isnan)(numext::igamma(a_s[i], x_s[j])));
121  } else {
122  VERIFY_IS_APPROX(numext::igamma(a_s[i], x_s[j]), igamma_s[i][j]);
123  }
124 
125  if ((std::isnan)(igammac_s[i][j])) {
126  VERIFY((std::isnan)(numext::igammac(a_s[i], x_s[j])));
127  } else {
128  VERIFY_IS_APPROX(numext::igammac(a_s[i], x_s[j]), igammac_s[i][j]);
129  }
130  }
131  }
132  }
133  }
134 #endif // EIGEN_HAS_C99_MATH
135 
136  // Check the zeta function against scipy.special.zeta
137  {
138  ArrayType x(7), q(7), res(7), ref(7);
139  x << 1.5, 4, 10.5, 10000.5, 3, 1, 0.9;
140  q << 2, 1.5, 3, 1.0001, -2.5, 1.2345, 1.2345;
141  ref << 1.61237534869, 0.234848505667, 1.03086757337e-5, 0.367879440865, 0.054102025820864097, plusinf, nan;
142  CALL_SUBTEST( verify_component_wise(ref, ref); );
143  CALL_SUBTEST( res = x.zeta(q); verify_component_wise(res, ref); );
144  CALL_SUBTEST( res = zeta(x,q); verify_component_wise(res, ref); );
145  }
146 
147  // digamma
148  {
149  ArrayType x(7), res(7), ref(7);
150  x << 1, 1.5, 4, -10.5, 10000.5, 0, -1;
151  ref << -0.5772156649015329, 0.03648997397857645, 1.2561176684318, 2.398239129535781, 9.210340372392849, plusinf, plusinf;
152  CALL_SUBTEST( verify_component_wise(ref, ref); );
153 
154  CALL_SUBTEST( res = x.digamma(); verify_component_wise(res, ref); );
155  CALL_SUBTEST( res = digamma(x); verify_component_wise(res, ref); );
156  }
157 
158 
159 #if EIGEN_HAS_C99_MATH
160  {
161  ArrayType n(11), x(11), res(11), ref(11);
162  n << 1, 1, 1, 1.5, 17, 31, 28, 8, 42, 147, 170;
163  x << 2, 3, 25.5, 1.5, 4.7, 11.8, 17.7, 30.2, 15.8, 54.1, 64;
164  ref << 0.644934066848, 0.394934066848, 0.0399946696496, nan, 293.334565435, 0.445487887616, -2.47810300902e-07, -8.29668781082e-09, -0.434562276666, 0.567742190178, -0.0108615497927;
165  CALL_SUBTEST( verify_component_wise(ref, ref); );
166 
167  if(sizeof(RealScalar)>=8) { // double
168  // Reason for commented line: http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1232
169  // CALL_SUBTEST( res = x.polygamma(n); verify_component_wise(res, ref); );
170  CALL_SUBTEST( res = polygamma(n,x); verify_component_wise(res, ref); );
171  }
172  else {
173  // CALL_SUBTEST( res = x.polygamma(n); verify_component_wise(res.head(8), ref.head(8)); );
174  CALL_SUBTEST( res = polygamma(n,x); verify_component_wise(res.head(8), ref.head(8)); );
175  }
176  }
177 #endif
178 
179 #if EIGEN_HAS_C99_MATH
180  {
181  // Inputs and ground truth generated with scipy via:
182  // a = np.logspace(-3, 3, 5) - 1e-3
183  // b = np.logspace(-3, 3, 5) - 1e-3
184  // x = np.linspace(-0.1, 1.1, 5)
185  // (full_a, full_b, full_x) = np.vectorize(lambda a, b, x: (a, b, x))(*np.ix_(a, b, x))
186  // full_a = full_a.flatten().tolist() # same for full_b, full_x
187  // v = scipy.special.betainc(full_a, full_b, full_x).flatten().tolist()
188  //
189  // Note in Eigen, we call betainc with arguments in the order (x, a, b).
190  ArrayType a(125);
191  ArrayType b(125);
192  ArrayType x(125);
193  ArrayType v(125);
194  ArrayType res(125);
195 
196  a << 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
197  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
198  0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
199  0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
200  0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
201  0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
202  0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
203  0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
204  0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
205  0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
206  0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999,
207  0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999,
208  0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999,
209  31.62177660168379, 31.62177660168379, 31.62177660168379,
210  31.62177660168379, 31.62177660168379, 31.62177660168379,
211  31.62177660168379, 31.62177660168379, 31.62177660168379,
212  31.62177660168379, 31.62177660168379, 31.62177660168379,
213  31.62177660168379, 31.62177660168379, 31.62177660168379,
214  31.62177660168379, 31.62177660168379, 31.62177660168379,
215  31.62177660168379, 31.62177660168379, 31.62177660168379,
216  31.62177660168379, 31.62177660168379, 31.62177660168379,
217  31.62177660168379, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999,
218  999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999,
219  999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999,
220  999.999, 999.999, 999.999;
221 
222  b << 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379, 0.03062277660168379,
223  0.03062277660168379, 0.03062277660168379, 0.03062277660168379, 0.999,
224  0.999, 0.999, 0.999, 0.999, 31.62177660168379, 31.62177660168379,
225  31.62177660168379, 31.62177660168379, 31.62177660168379, 999.999,
226  999.999, 999.999, 999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0,
227  0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
228  0.03062277660168379, 0.03062277660168379, 0.999, 0.999, 0.999, 0.999,
229  0.999, 31.62177660168379, 31.62177660168379, 31.62177660168379,
230  31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
231  999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379,
232  0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
233  0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999,
234  31.62177660168379, 31.62177660168379, 31.62177660168379,
235  31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
236  999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379,
237  0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
238  0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999,
239  31.62177660168379, 31.62177660168379, 31.62177660168379,
240  31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
241  999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379,
242  0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
243  0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999,
244  31.62177660168379, 31.62177660168379, 31.62177660168379,
245  31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
246  999.999, 999.999;
247 
248  x << -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5,
249  0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2,
250  0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1,
251  0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1,
252  -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8,
253  1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5,
254  0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2,
255  0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1,
256  0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5,
257  0.8, 1.1;
258 
259  v << nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
260  nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
261  nan, nan, nan, 0.47972119876364683, 0.5, 0.5202788012363533, nan, nan,
262  0.9518683957740043, 0.9789663010413743, 0.9931729188073435, nan, nan,
263  0.999995949033062, 0.9999999999993698, 0.9999999999999999, nan, nan,
264  0.9999999999999999, 0.9999999999999999, 0.9999999999999999, nan, nan,
265  nan, nan, nan, nan, nan, 0.006827081192655869, 0.0210336989586256,
266  0.04813160422599567, nan, nan, 0.20014344256217678, 0.5000000000000001,
267  0.7998565574378232, nan, nan, 0.9991401428435834, 0.999999999698403,
268  0.9999999999999999, nan, nan, 0.9999999999999999, 0.9999999999999999,
269  0.9999999999999999, nan, nan, nan, nan, nan, nan, nan,
270  1.0646600232370887e-25, 6.301722877826246e-13, 4.050966937974938e-06,
271  nan, nan, 7.864342668429763e-23, 3.015969667594166e-10,
272  0.0008598571564165444, nan, nan, 6.031987710123844e-08,
273  0.5000000000000007, 0.9999999396801229, nan, nan, 0.9999999999999999,
274  0.9999999999999999, 0.9999999999999999, nan, nan, nan, nan, nan, nan,
275  nan, 0.0, 7.029920380986636e-306, 2.2450728208591345e-101, nan, nan,
276  0.0, 9.275871147869727e-302, 1.2232913026152827e-97, nan, nan, 0.0,
277  3.0891393081932924e-252, 2.9303043666183996e-60, nan, nan,
278  2.248913486879199e-196, 0.5000000000004947, 0.9999999999999999, nan;
279 
280  CALL_SUBTEST(res = betainc(a, b, x);
281  verify_component_wise(res, v););
282  }
283 
284  // Test various properties of betainc
285  {
286  ArrayType m1 = ArrayType::Random(32);
287  ArrayType m2 = ArrayType::Random(32);
288  ArrayType m3 = ArrayType::Random(32);
289  ArrayType one = ArrayType::Constant(32, Scalar(1.0));
290  const Scalar eps = std::numeric_limits<Scalar>::epsilon();
291  ArrayType a = (m1 * 4.0).exp();
292  ArrayType b = (m2 * 4.0).exp();
293  ArrayType x = m3.abs();
294 
295  // betainc(a, 1, x) == x**a
296  CALL_SUBTEST(
297  ArrayType test = betainc(a, one, x);
298  ArrayType expected = x.pow(a);
299  verify_component_wise(test, expected););
300 
301  // betainc(1, b, x) == 1 - (1 - x)**b
302  CALL_SUBTEST(
303  ArrayType test = betainc(one, b, x);
304  ArrayType expected = one - (one - x).pow(b);
305  verify_component_wise(test, expected););
306 
307  // betainc(a, b, x) == 1 - betainc(b, a, 1-x)
308  CALL_SUBTEST(
309  ArrayType test = betainc(a, b, x) + betainc(b, a, one - x);
310  ArrayType expected = one;
311  verify_component_wise(test, expected););
312 
313  // betainc(a+1, b, x) = betainc(a, b, x) - x**a * (1 - x)**b / (a * beta(a, b))
314  CALL_SUBTEST(
315  ArrayType num = x.pow(a) * (one - x).pow(b);
316  ArrayType denom = a * (a.lgamma() + b.lgamma() - (a + b).lgamma()).exp();
317  // Add eps to rhs and lhs so that component-wise test doesn't result in
318  // nans when both outputs are zeros.
319  ArrayType expected = betainc(a, b, x) - num / denom + eps;
320  ArrayType test = betainc(a + one, b, x) + eps;
321  if (sizeof(Scalar) >= 8) { // double
322  verify_component_wise(test, expected);
323  } else {
324  // Reason for limited test: http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1232
325  verify_component_wise(test.head(8), expected.head(8));
326  });
327 
328  // betainc(a, b+1, x) = betainc(a, b, x) + x**a * (1 - x)**b / (b * beta(a, b))
329  CALL_SUBTEST(
330  // Add eps to rhs and lhs so that component-wise test doesn't result in
331  // nans when both outputs are zeros.
332  ArrayType num = x.pow(a) * (one - x).pow(b);
333  ArrayType denom = b * (a.lgamma() + b.lgamma() - (a + b).lgamma()).exp();
334  ArrayType expected = betainc(a, b, x) + num / denom + eps;
335  ArrayType test = betainc(a, b + one, x) + eps;
336  verify_component_wise(test, expected););
337  }
338 #endif
339 }
340 
342 {
343  CALL_SUBTEST_1(array_special_functions<ArrayXf>());
344  CALL_SUBTEST_2(array_special_functions<ArrayXd>());
345 }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool() isfinite(const half &a)
Definition: Half.h:379
const CwiseBinaryOp< internal::scalar_zeta_op< Scalar >, const Derived, const DerivedQ > zeta(const EIGEN_CURRENT_STORAGE_BASE_CLASS< DerivedQ > &q) const
EIGEN_DEVICE_FUNC const ExpReturnType exp() const
EIGEN_DEVICE_FUNC const ErfReturnType erf() const
EIGEN_DEVICE_FUNC const SqrtReturnType sqrt() const
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const AbsReturnType abs() const
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_igammac_op< typename Derived::Scalar >, const Derived, const ExponentDerived > igammac(const Eigen::ArrayBase< Derived > &a, const Eigen::ArrayBase< ExponentDerived > &x)
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_igamma_op< typename Derived::Scalar >, const Derived, const ExponentDerived > igamma(const Eigen::ArrayBase< Derived > &a, const Eigen::ArrayBase< ExponentDerived > &x)
void verify_component_wise(const X &x, const Y &y)
void array_special_functions()
internal::enable_if< !(internal::is_same< typename Derived::Scalar, ScalarExponent >::value)&&EIGEN_SCALAR_BINARY_SUPPORTED(pow, typename Derived::Scalar, ScalarExponent), const EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived, ScalarExponent, pow) >::type pow(const Eigen::ArrayBase< Derived > &x, const ScalarExponent &exponent)
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:33
EIGEN_DEVICE_FUNC const Scalar & q
void test_special_functions()
EIGEN_DEVICE_FUNC const ErfcReturnType erfc() const
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_polygamma_op< typename DerivedX::Scalar >, const DerivedN, const DerivedX > polygamma(const Eigen::ArrayBase< DerivedN > &n, const Eigen::ArrayBase< DerivedX > &x)
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool() isnan(const half &a)
Definition: Half.h:372
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseTernaryOp< internal::scalar_betainc_op< typename XDerived::Scalar >, const ADerived, const BDerived, const XDerived > betainc(const ADerived &a, const BDerived &b, const XDerived &x)
EIGEN_DEVICE_FUNC const Scalar & b
EIGEN_DEVICE_FUNC const DigammaReturnType digamma() const
EIGEN_DEVICE_FUNC const LgammaReturnType lgamma() const


hebiros
Author(s): Xavier Artache , Matthew Tesch
autogenerated on Thu Sep 3 2020 04:09:02