special_functions.cpp
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2016 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #include <limits.h>
11 #include "main.h"
12 #include "../Eigen/SpecialFunctions"
13 
14 // Hack to allow "implicit" conversions from double to Scalar via comma-initialization.
15 template<typename Derived>
16 Eigen::CommaInitializer<Derived> operator<<(Eigen::DenseBase<Derived>& dense, double v) {
17  return (dense << static_cast<typename Derived::Scalar>(v));
18 }
19 
20 template<typename XprType>
22  return (ci, static_cast<typename XprType::Scalar>(v));
23 }
24 
25 template<typename X, typename Y>
26 void verify_component_wise(const X& x, const Y& y)
27 {
28  for(Index i=0; i<x.size(); ++i)
29  {
30  if((numext::isfinite)(y(i)))
31  VERIFY_IS_APPROX( x(i), y(i) );
32  else if((numext::isnan)(y(i)))
33  VERIFY((numext::isnan)(x(i)));
34  else
35  VERIFY_IS_EQUAL( x(i), y(i) );
36  }
37 }
38 
39 template<typename ArrayType> void array_special_functions()
40 {
41  using std::abs;
42  using std::sqrt;
43  typedef typename ArrayType::Scalar Scalar;
44  typedef typename NumTraits<Scalar>::Real RealScalar;
45 
46  Scalar plusinf = std::numeric_limits<Scalar>::infinity();
47  Scalar nan = std::numeric_limits<Scalar>::quiet_NaN();
48 
49  Index rows = internal::random<Index>(1,30);
50  Index cols = 1;
51 
52  // API
53  {
54  ArrayType m1 = ArrayType::Random(rows,cols);
55 #if EIGEN_HAS_C99_MATH
56  VERIFY_IS_APPROX(m1.lgamma(), lgamma(m1));
57  VERIFY_IS_APPROX(m1.digamma(), digamma(m1));
58  VERIFY_IS_APPROX(m1.erf(), erf(m1));
59  VERIFY_IS_APPROX(m1.erfc(), erfc(m1));
60 #endif // EIGEN_HAS_C99_MATH
61  }
62 
63 
64 #if EIGEN_HAS_C99_MATH
65  // check special functions (comparing against numpy implementation)
67  {
68 
69  {
70  ArrayType m1 = ArrayType::Random(rows,cols);
71  ArrayType m2 = ArrayType::Random(rows,cols);
72 
73  // Test various propreties of igamma & igammac. These are normalized
74  // gamma integrals where
75  // igammac(a, x) = Gamma(a, x) / Gamma(a)
76  // igamma(a, x) = gamma(a, x) / Gamma(a)
77  // where Gamma and gamma are considered the standard unnormalized
78  // upper and lower incomplete gamma functions, respectively.
79  ArrayType a = m1.abs() + Scalar(2);
80  ArrayType x = m2.abs() + Scalar(2);
81  ArrayType zero = ArrayType::Zero(rows, cols);
82  ArrayType one = ArrayType::Constant(rows, cols, Scalar(1.0));
83  ArrayType a_m1 = a - one;
84  ArrayType Gamma_a_x = Eigen::igammac(a, x) * a.lgamma().exp();
85  ArrayType Gamma_a_m1_x = Eigen::igammac(a_m1, x) * a_m1.lgamma().exp();
86  ArrayType gamma_a_x = Eigen::igamma(a, x) * a.lgamma().exp();
87  ArrayType gamma_a_m1_x = Eigen::igamma(a_m1, x) * a_m1.lgamma().exp();
88 
89 
90  // Gamma(a, 0) == Gamma(a)
91  VERIFY_IS_APPROX(Eigen::igammac(a, zero), one);
92 
93  // Gamma(a, x) + gamma(a, x) == Gamma(a)
94  VERIFY_IS_APPROX(Gamma_a_x + gamma_a_x, a.lgamma().exp());
95 
96  // Gamma(a, x) == (a - 1) * Gamma(a-1, x) + x^(a-1) * exp(-x)
97  VERIFY_IS_APPROX(Gamma_a_x, (a - Scalar(1)) * Gamma_a_m1_x + x.pow(a-Scalar(1)) * (-x).exp());
98 
99  // gamma(a, x) == (a - 1) * gamma(a-1, x) - x^(a-1) * exp(-x)
100  VERIFY_IS_APPROX(gamma_a_x, (a - Scalar(1)) * gamma_a_m1_x - x.pow(a-Scalar(1)) * (-x).exp());
101  }
102  {
103  // Verify for large a and x that values are between 0 and 1.
104  ArrayType m1 = ArrayType::Random(rows,cols);
105  ArrayType m2 = ArrayType::Random(rows,cols);
106  int max_exponent = std::numeric_limits<Scalar>::max_exponent10;
107  ArrayType a = m1.abs() * Scalar(pow(10., max_exponent - 1));
108  ArrayType x = m2.abs() * Scalar(pow(10., max_exponent - 1));
109  for (int i = 0; i < a.size(); ++i) {
110  Scalar igam = numext::igamma(a(i), x(i));
111  VERIFY(0 <= igam);
112  VERIFY(igam <= 1);
113  }
114  }
115 
116  {
117  // Check exact values of igamma and igammac against a third party calculation.
118  Scalar a_s[] = {Scalar(0), Scalar(1), Scalar(1.5), Scalar(4), Scalar(0.0001), Scalar(1000.5)};
119  Scalar x_s[] = {Scalar(0), Scalar(1), Scalar(1.5), Scalar(4), Scalar(0.0001), Scalar(1000.5)};
120 
121  // location i*6+j corresponds to a_s[i], x_s[j].
122  Scalar igamma_s[][6] = {
123  {Scalar(0.0), nan, nan, nan, nan, nan},
124  {Scalar(0.0), Scalar(0.6321205588285578), Scalar(0.7768698398515702),
125  Scalar(0.9816843611112658), Scalar(9.999500016666262e-05),
126  Scalar(1.0)},
127  {Scalar(0.0), Scalar(0.4275932955291202), Scalar(0.608374823728911),
128  Scalar(0.9539882943107686), Scalar(7.522076445089201e-07),
129  Scalar(1.0)},
130  {Scalar(0.0), Scalar(0.01898815687615381),
131  Scalar(0.06564245437845008), Scalar(0.5665298796332909),
132  Scalar(4.166333347221828e-18), Scalar(1.0)},
133  {Scalar(0.0), Scalar(0.9999780593618628), Scalar(0.9999899967080838),
134  Scalar(0.9999996219837988), Scalar(0.9991370418689945), Scalar(1.0)},
135  {Scalar(0.0), Scalar(0.0), Scalar(0.0), Scalar(0.0), Scalar(0.0),
136  Scalar(0.5042041932513908)}};
137  Scalar igammac_s[][6] = {
138  {nan, nan, nan, nan, nan, nan},
139  {Scalar(1.0), Scalar(0.36787944117144233),
140  Scalar(0.22313016014842982), Scalar(0.018315638888734182),
141  Scalar(0.9999000049998333), Scalar(0.0)},
142  {Scalar(1.0), Scalar(0.5724067044708798), Scalar(0.3916251762710878),
143  Scalar(0.04601170568923136), Scalar(0.9999992477923555),
144  Scalar(0.0)},
145  {Scalar(1.0), Scalar(0.9810118431238462), Scalar(0.9343575456215499),
146  Scalar(0.4334701203667089), Scalar(1.0), Scalar(0.0)},
147  {Scalar(1.0), Scalar(2.1940638138146658e-05),
148  Scalar(1.0003291916285e-05), Scalar(3.7801620118431334e-07),
149  Scalar(0.0008629581310054535), Scalar(0.0)},
150  {Scalar(1.0), Scalar(1.0), Scalar(1.0), Scalar(1.0), Scalar(1.0),
151  Scalar(0.49579580674813944)}};
152 
153  for (int i = 0; i < 6; ++i) {
154  for (int j = 0; j < 6; ++j) {
155  if ((std::isnan)(igamma_s[i][j])) {
156  VERIFY((std::isnan)(numext::igamma(a_s[i], x_s[j])));
157  } else {
158  VERIFY_IS_APPROX(numext::igamma(a_s[i], x_s[j]), igamma_s[i][j]);
159  }
160 
161  if ((std::isnan)(igammac_s[i][j])) {
162  VERIFY((std::isnan)(numext::igammac(a_s[i], x_s[j])));
163  } else {
164  VERIFY_IS_APPROX(numext::igammac(a_s[i], x_s[j]), igammac_s[i][j]);
165  }
166  }
167  }
168  }
169  }
170 #endif // EIGEN_HAS_C99_MATH
171 
172  // Check the ndtri function against scipy.special.ndtri
173  {
174  ArrayType x(7), res(7), ref(7);
175  x << 0.5, 0.2, 0.8, 0.9, 0.1, 0.99, 0.01;
176  ref << 0., -0.8416212335729142, 0.8416212335729142, 1.2815515655446004, -1.2815515655446004, 2.3263478740408408, -2.3263478740408408;
177  CALL_SUBTEST( verify_component_wise(ref, ref); );
178  CALL_SUBTEST( res = x.ndtri(); verify_component_wise(res, ref); );
180 
181  // ndtri(normal_cdf(x)) ~= x
182  CALL_SUBTEST(
183  ArrayType m1 = ArrayType::Random(32);
184  using std::sqrt;
185 
186  ArrayType cdf_val = (m1 / Scalar(sqrt(2.))).erf();
187  cdf_val = (cdf_val + Scalar(1)) / Scalar(2);
188  verify_component_wise(cdf_val.ndtri(), m1););
189 
190  }
191 
192  // Check the zeta function against scipy.special.zeta
193  {
194  ArrayType x(10), q(10), res(10), ref(10);
195  x << 1.5, 4, 10.5, 10000.5, 3, 1, 0.9, 2, 3, 4;
196  q << 2, 1.5, 3, 1.0001, -2.5, 1.2345, 1.2345, -1, -2, -3;
197  ref << 1.61237534869, 0.234848505667, 1.03086757337e-5, 0.367879440865, 0.054102025820864097, plusinf, nan, plusinf, nan, plusinf;
198  CALL_SUBTEST( verify_component_wise(ref, ref); );
199  CALL_SUBTEST( res = x.zeta(q); verify_component_wise(res, ref); );
201  }
202 
203  // digamma
204  {
205  ArrayType x(9), res(9), ref(9);
206  x << 1, 1.5, 4, -10.5, 10000.5, 0, -1, -2, -3;
207  ref << -0.5772156649015329, 0.03648997397857645, 1.2561176684318, 2.398239129535781, 9.210340372392849, nan, nan, nan, nan;
208  CALL_SUBTEST( verify_component_wise(ref, ref); );
209 
210  CALL_SUBTEST( res = x.digamma(); verify_component_wise(res, ref); );
212  }
213 
214 #if EIGEN_HAS_C99_MATH
215  {
216  ArrayType n(16), x(16), res(16), ref(16);
217  n << 1, 1, 1, 1.5, 17, 31, 28, 8, 42, 147, 170, -1, 0, 1, 2, 3;
218  x << 2, 3, 25.5, 1.5, 4.7, 11.8, 17.7, 30.2, 15.8, 54.1, 64, -1, -2, -3, -4, -5;
219  ref << 0.644934066848, 0.394934066848, 0.0399946696496, nan, 293.334565435, 0.445487887616, -2.47810300902e-07, -8.29668781082e-09, -0.434562276666, 0.567742190178, -0.0108615497927, nan, nan, plusinf, nan, plusinf;
220  CALL_SUBTEST( verify_component_wise(ref, ref); );
221 
222  if(sizeof(RealScalar)>=8) { // double
223  // Reason for commented line: http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1232
224  // CALL_SUBTEST( res = x.polygamma(n); verify_component_wise(res, ref); );
226  }
227  else {
228  // CALL_SUBTEST( res = x.polygamma(n); verify_component_wise(res.head(8), ref.head(8)); );
229  CALL_SUBTEST( res = polygamma(n,x); verify_component_wise(res.head(8), ref.head(8)); );
230  }
231  }
232 #endif
233 
234 #if EIGEN_HAS_C99_MATH
235  {
236  // Inputs and ground truth generated with scipy via:
237  // a = np.logspace(-3, 3, 5) - 1e-3
238  // b = np.logspace(-3, 3, 5) - 1e-3
239  // x = np.linspace(-0.1, 1.1, 5)
240  // (full_a, full_b, full_x) = np.vectorize(lambda a, b, x: (a, b, x))(*np.ix_(a, b, x))
241  // full_a = full_a.flatten().tolist() # same for full_b, full_x
242  // v = scipy.special.betainc(full_a, full_b, full_x).flatten().tolist()
243  //
244  // Note in Eigen, we call betainc with arguments in the order (x, a, b).
245  ArrayType a(125);
246  ArrayType b(125);
247  ArrayType x(125);
248  ArrayType v(125);
249  ArrayType res(125);
250 
251  a << 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
252  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
253  0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
254  0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
255  0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
256  0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
257  0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
258  0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
259  0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
260  0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
261  0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999,
262  0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999,
263  0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999,
264  31.62177660168379, 31.62177660168379, 31.62177660168379,
265  31.62177660168379, 31.62177660168379, 31.62177660168379,
266  31.62177660168379, 31.62177660168379, 31.62177660168379,
267  31.62177660168379, 31.62177660168379, 31.62177660168379,
268  31.62177660168379, 31.62177660168379, 31.62177660168379,
269  31.62177660168379, 31.62177660168379, 31.62177660168379,
270  31.62177660168379, 31.62177660168379, 31.62177660168379,
271  31.62177660168379, 31.62177660168379, 31.62177660168379,
272  31.62177660168379, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999,
273  999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999,
274  999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999,
275  999.999, 999.999, 999.999;
276 
277  b << 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379, 0.03062277660168379,
278  0.03062277660168379, 0.03062277660168379, 0.03062277660168379, 0.999,
279  0.999, 0.999, 0.999, 0.999, 31.62177660168379, 31.62177660168379,
280  31.62177660168379, 31.62177660168379, 31.62177660168379, 999.999,
281  999.999, 999.999, 999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0,
282  0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
283  0.03062277660168379, 0.03062277660168379, 0.999, 0.999, 0.999, 0.999,
284  0.999, 31.62177660168379, 31.62177660168379, 31.62177660168379,
285  31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
286  999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379,
287  0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
288  0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999,
289  31.62177660168379, 31.62177660168379, 31.62177660168379,
290  31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
291  999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379,
292  0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
293  0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999,
294  31.62177660168379, 31.62177660168379, 31.62177660168379,
295  31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
296  999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379,
297  0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
298  0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999,
299  31.62177660168379, 31.62177660168379, 31.62177660168379,
300  31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
301  999.999, 999.999;
302 
303  x << -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5,
304  0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2,
305  0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1,
306  0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1,
307  -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8,
308  1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5,
309  0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2,
310  0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1,
311  0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5,
312  0.8, 1.1;
313 
314  v << nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
315  nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
316  nan, nan, nan, 0.47972119876364683, 0.5, 0.5202788012363533, nan, nan,
317  0.9518683957740043, 0.9789663010413743, 0.9931729188073435, nan, nan,
318  0.999995949033062, 0.9999999999993698, 0.9999999999999999, nan, nan,
319  0.9999999999999999, 0.9999999999999999, 0.9999999999999999, nan, nan,
320  nan, nan, nan, nan, nan, 0.006827081192655869, 0.0210336989586256,
321  0.04813160422599567, nan, nan, 0.20014344256217678, 0.5000000000000001,
322  0.7998565574378232, nan, nan, 0.9991401428435834, 0.999999999698403,
323  0.9999999999999999, nan, nan, 0.9999999999999999, 0.9999999999999999,
324  0.9999999999999999, nan, nan, nan, nan, nan, nan, nan,
325  1.0646600232370887e-25, 6.301722877826246e-13, 4.050966937974938e-06,
326  nan, nan, 7.864342668429763e-23, 3.015969667594166e-10,
327  0.0008598571564165444, nan, nan, 6.031987710123844e-08,
328  0.5000000000000007, 0.9999999396801229, nan, nan, 0.9999999999999999,
329  0.9999999999999999, 0.9999999999999999, nan, nan, nan, nan, nan, nan,
330  nan, 0.0, 7.029920380986636e-306, 2.2450728208591345e-101, nan, nan,
331  0.0, 9.275871147869727e-302, 1.2232913026152827e-97, nan, nan, 0.0,
332  3.0891393081932924e-252, 2.9303043666183996e-60, nan, nan,
333  2.248913486879199e-196, 0.5000000000004947, 0.9999999999999999, nan;
334 
335  CALL_SUBTEST(res = betainc(a, b, x);
336  verify_component_wise(res, v););
337  }
338 
339  // Test various properties of betainc
340  {
341  ArrayType m1 = ArrayType::Random(32);
342  ArrayType m2 = ArrayType::Random(32);
343  ArrayType m3 = ArrayType::Random(32);
344  ArrayType one = ArrayType::Constant(32, Scalar(1.0));
345  const Scalar eps = std::numeric_limits<Scalar>::epsilon();
346  ArrayType a = (m1 * Scalar(4)).exp();
347  ArrayType b = (m2 * Scalar(4)).exp();
348  ArrayType x = m3.abs();
349 
350  // betainc(a, 1, x) == x**a
351  CALL_SUBTEST(
352  ArrayType test = betainc(a, one, x);
353  ArrayType expected = x.pow(a);
354  verify_component_wise(test, expected););
355 
356  // betainc(1, b, x) == 1 - (1 - x)**b
357  CALL_SUBTEST(
358  ArrayType test = betainc(one, b, x);
359  ArrayType expected = one - (one - x).pow(b);
360  verify_component_wise(test, expected););
361 
362  // betainc(a, b, x) == 1 - betainc(b, a, 1-x)
363  CALL_SUBTEST(
364  ArrayType test = betainc(a, b, x) + betainc(b, a, one - x);
365  ArrayType expected = one;
366  verify_component_wise(test, expected););
367 
368  // betainc(a+1, b, x) = betainc(a, b, x) - x**a * (1 - x)**b / (a * beta(a, b))
369  CALL_SUBTEST(
370  ArrayType num = x.pow(a) * (one - x).pow(b);
371  ArrayType denom = a * (a.lgamma() + b.lgamma() - (a + b).lgamma()).exp();
372  // Add eps to rhs and lhs so that component-wise test doesn't result in
373  // nans when both outputs are zeros.
374  ArrayType expected = betainc(a, b, x) - num / denom + eps;
375  ArrayType test = betainc(a + one, b, x) + eps;
376  if (sizeof(Scalar) >= 8) { // double
377  verify_component_wise(test, expected);
378  } else {
379  // Reason for limited test: http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1232
380  verify_component_wise(test.head(8), expected.head(8));
381  });
382 
383  // betainc(a, b+1, x) = betainc(a, b, x) + x**a * (1 - x)**b / (b * beta(a, b))
384  CALL_SUBTEST(
385  // Add eps to rhs and lhs so that component-wise test doesn't result in
386  // nans when both outputs are zeros.
387  ArrayType num = x.pow(a) * (one - x).pow(b);
388  ArrayType denom = b * (a.lgamma() + b.lgamma() - (a + b).lgamma()).exp();
389  ArrayType expected = betainc(a, b, x) + num / denom + eps;
390  ArrayType test = betainc(a, b + one, x) + eps;
391  verify_component_wise(test, expected););
392  }
393 #endif // EIGEN_HAS_C99_MATH
394 
395  /* Code to generate the data for the following two test cases.
396  N = 5
397  np.random.seed(3)
398 
399  a = np.logspace(-2, 3, 6)
400  a = np.ravel(np.tile(np.reshape(a, [-1, 1]), [1, N]))
401  x = np.random.gamma(a, 1.0)
402  x = np.maximum(x, np.finfo(np.float32).tiny)
403 
404  def igamma(a, x):
405  return mpmath.gammainc(a, 0, x, regularized=True)
406 
407  def igamma_der_a(a, x):
408  res = mpmath.diff(lambda a_prime: igamma(a_prime, x), a)
409  return np.float64(res)
410 
411  def gamma_sample_der_alpha(a, x):
412  igamma_x = igamma(a, x)
413  def igammainv_of_igamma(a_prime):
414  return mpmath.findroot(lambda x_prime: igamma(a_prime, x_prime) -
415  igamma_x, x, solver='newton')
416  return np.float64(mpmath.diff(igammainv_of_igamma, a))
417 
418  v_igamma_der_a = np.vectorize(igamma_der_a)(a, x)
419  v_gamma_sample_der_alpha = np.vectorize(gamma_sample_der_alpha)(a, x)
420  */
421 
422 #if EIGEN_HAS_C99_MATH
423  // Test igamma_der_a
424  {
425  ArrayType a(30);
426  ArrayType x(30);
427  ArrayType res(30);
428  ArrayType v(30);
429 
430  a << 0.01, 0.01, 0.01, 0.01, 0.01, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 1.0, 1.0,
431  1.0, 1.0, 10.0, 10.0, 10.0, 10.0, 10.0, 100.0, 100.0, 100.0, 100.0,
432  100.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0;
433 
434  x << 1.25668890405e-26, 1.17549435082e-38, 1.20938905072e-05,
435  1.17549435082e-38, 1.17549435082e-38, 5.66572070696e-16,
436  0.0132865061065, 0.0200034203853, 6.29263709118e-17, 1.37160367764e-06,
437  0.333412038288, 1.18135687766, 0.580629033777, 0.170631439426,
438  0.786686768458, 7.63873279537, 13.1944344379, 11.896042354,
439  10.5830172417, 10.5020942233, 92.8918587747, 95.003720371,
440  86.3715926467, 96.0330217672, 82.6389930677, 968.702906754,
441  969.463546828, 1001.79726022, 955.047416547, 1044.27458568;
442 
443  v << -32.7256441441, -36.4394150514, -9.66467612263, -36.4394150514,
444  -36.4394150514, -1.0891900302, -2.66351229645, -2.48666868596,
445  -0.929700494428, -3.56327722764, -0.455320135314, -0.391437214323,
446  -0.491352055991, -0.350454834292, -0.471773162921, -0.104084440522,
447  -0.0723646747909, -0.0992828975532, -0.121638215446, -0.122619605294,
448  -0.0317670267286, -0.0359974812869, -0.0154359225363, -0.0375775365921,
449  -0.00794899153653, -0.00777303219211, -0.00796085782042,
450  -0.0125850719397, -0.00455500206958, -0.00476436993148;
451 
452  CALL_SUBTEST(res = igamma_der_a(a, x); verify_component_wise(res, v););
453  }
454 
455  // Test gamma_sample_der_alpha
456  {
457  ArrayType alpha(30);
458  ArrayType sample(30);
459  ArrayType res(30);
460  ArrayType v(30);
461 
462  alpha << 0.01, 0.01, 0.01, 0.01, 0.01, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 1.0,
463  1.0, 1.0, 1.0, 10.0, 10.0, 10.0, 10.0, 10.0, 100.0, 100.0, 100.0, 100.0,
464  100.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0;
465 
466  sample << 1.25668890405e-26, 1.17549435082e-38, 1.20938905072e-05,
467  1.17549435082e-38, 1.17549435082e-38, 5.66572070696e-16,
468  0.0132865061065, 0.0200034203853, 6.29263709118e-17, 1.37160367764e-06,
469  0.333412038288, 1.18135687766, 0.580629033777, 0.170631439426,
470  0.786686768458, 7.63873279537, 13.1944344379, 11.896042354,
471  10.5830172417, 10.5020942233, 92.8918587747, 95.003720371,
472  86.3715926467, 96.0330217672, 82.6389930677, 968.702906754,
473  969.463546828, 1001.79726022, 955.047416547, 1044.27458568;
474 
475  v << 7.42424742367e-23, 1.02004297287e-34, 0.0130155240738,
476  1.02004297287e-34, 1.02004297287e-34, 1.96505168277e-13, 0.525575786243,
477  0.713903991771, 2.32077561808e-14, 0.000179348049886, 0.635500453302,
478  1.27561284917, 0.878125852156, 0.41565819538, 1.03606488534,
479  0.885964824887, 1.16424049334, 1.10764479598, 1.04590810812,
480  1.04193666963, 0.965193152414, 0.976217589464, 0.93008035061,
481  0.98153216096, 0.909196397698, 0.98434963993, 0.984738050206,
482  1.00106492525, 0.97734200649, 1.02198794179;
483 
484  CALL_SUBTEST(res = gamma_sample_der_alpha(alpha, sample);
485  verify_component_wise(res, v););
486  }
487 #endif // EIGEN_HAS_C99_MATH
488 }
489 
490 EIGEN_DECLARE_TEST(special_functions)
491 {
492  CALL_SUBTEST_1(array_special_functions<ArrayXf>());
493  CALL_SUBTEST_2(array_special_functions<ArrayXd>());
494  // TODO(cantonios): half/bfloat16 don't have enough precision to reproduce results above.
495  // CALL_SUBTEST_3(array_special_functions<ArrayX<Eigen::half>>());
496  // CALL_SUBTEST_4(array_special_functions<ArrayX<Eigen::bfloat16>>());
497 }
SCALAR Scalar
Definition: bench_gemm.cpp:46
const char Y
Eigen::CommaInitializer< XprType > & operator,(Eigen::CommaInitializer< XprType > &ci, double v)
Scalar * y
Scalar * b
Definition: benchVecAdd.cpp:17
EIGEN_STRONG_INLINE const Eigen::CwiseBinaryOp< Eigen::internal::scalar_igamma_der_a_op< typename Derived::Scalar >, const Derived, const ExponentDerived > igamma_der_a(const Eigen::ArrayBase< Derived > &a, const Eigen::ArrayBase< ExponentDerived > &x)
Matrix expected
Definition: testMatrix.cpp:971
EIGEN_DEVICE_FUNC const NdtriReturnType ndtri() const
EIGEN_DEVICE_FUNC const ErfReturnType erf() const
MatrixType m2(n_dims)
EIGEN_DONT_INLINE Scalar zero()
Definition: svd_common.h:296
Definition: test.py:1
int n
EIGEN_STRONG_INLINE const Eigen::CwiseBinaryOp< Eigen::internal::scalar_gamma_sample_der_alpha_op< typename AlphaDerived::Scalar >, const AlphaDerived, const SampleDerived > gamma_sample_der_alpha(const Eigen::ArrayBase< AlphaDerived > &alpha, const Eigen::ArrayBase< SampleDerived > &sample)
const Eigen::CwiseBinaryOp< Eigen::internal::scalar_pow_op< typename Derived::Scalar, typename ExponentDerived::Scalar >, const Derived, const ExponentDerived > pow(const Eigen::ArrayBase< Derived > &x, const Eigen::ArrayBase< ExponentDerived > &exponents)
Holds information about the various numeric (i.e. scalar) types allowed by Eigen. ...
Definition: NumTraits.h:232
#define isfinite(X)
Definition: main.h:95
EIGEN_DEVICE_FUNC const ErfcReturnType erfc() const
void verify_component_wise(const X &x, const Y &y)
EIGEN_DEVICE_FUNC const LgammaReturnType lgamma() const
static double epsilon
Definition: testRot3.cpp:37
const CwiseBinaryOp< internal::scalar_zeta_op< Scalar >, const Derived, const DerivedQ > zeta(const EIGEN_CURRENT_STORAGE_BASE_CLASS< DerivedQ > &q) const
Helper class used by the comma initializer operator.
EIGEN_DEVICE_FUNC const ExpReturnType exp() const
cout<< "Here is the matrix m:"<< endl<< m<< endl;Matrix< ptrdiff_t, 3, 1 > res
void array_special_functions()
#define VERIFY_IS_APPROX(a, b)
#define VERIFY_IS_EQUAL(a, b)
Definition: main.h:386
#define CALL_SUBTEST_1(FUNC)
Matrix3d m1
Definition: IOFormat.cpp:2
EIGEN_DEVICE_FUNC const DigammaReturnType digamma() const
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:74
Array< int, Dynamic, 1 > v
RealScalar alpha
Array< double, 1, 3 > e(1./3., 0.5, 2.)
EIGEN_DEVICE_FUNC const Scalar & q
NumTraits< Scalar >::Real RealScalar
Definition: bench_gemm.cpp:47
#define CALL_SUBTEST(FUNC)
Definition: main.h:399
#define VERIFY(a)
Definition: main.h:380
EIGEN_STRONG_INLINE const Eigen::CwiseBinaryOp< Eigen::internal::scalar_igamma_op< typename Derived::Scalar >, const Derived, const ExponentDerived > igamma(const Eigen::ArrayBase< Derived > &a, const Eigen::ArrayBase< ExponentDerived > &x)
#define CALL_SUBTEST_2(FUNC)
Jet< T, N > sqrt(const Jet< T, N > &f)
Definition: jet.h:418
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseTernaryOp< internal::scalar_betainc_op< typename XDerived::Scalar >, const ADerived, const BDerived, const XDerived > betainc(const ADerived &a, const BDerived &b, const XDerived &x)
#define X
Definition: icosphere.cpp:20
set noclip points set clip one set noclip two set bar set border lt lw set xdata set ydata set zdata set x2data set y2data set boxwidth set dummy x
#define abs(x)
Definition: datatypes.h:17
EIGEN_STRONG_INLINE const Eigen::CwiseBinaryOp< Eigen::internal::scalar_igammac_op< typename Derived::Scalar >, const Derived, const ExponentDerived > igammac(const Eigen::ArrayBase< Derived > &a, const Eigen::ArrayBase< ExponentDerived > &x)
EIGEN_STRONG_INLINE const Eigen::CwiseBinaryOp< Eigen::internal::scalar_polygamma_op< typename DerivedX::Scalar >, const DerivedN, const DerivedX > polygamma(const Eigen::ArrayBase< DerivedN > &n, const Eigen::ArrayBase< DerivedX > &x)
std::ptrdiff_t j
#define isnan(X)
Definition: main.h:93
EIGEN_DECLARE_TEST(special_functions)


gtsam
Author(s):
autogenerated on Tue Jul 4 2023 02:36:18