8 from pybind11_tests
import numpy_dtypes
as m
10 np = pytest.importorskip(
"numpy")
13 @pytest.fixture(scope=
'module')
15 ld = np.dtype(
'longdouble')
16 return np.dtype({
'names': [
'bool_',
'uint_',
'float_',
'ldbl_'],
17 'formats': [
'?',
'u4',
'f4',
'f{}'.format(ld.itemsize)],
18 'offsets': [0, 4, 8, (16
if ld.alignment > 4
else 12)]})
21 @pytest.fixture(scope=
'module')
23 return np.dtype([(
'bool_',
'?'), (
'uint_',
'u4'), (
'float_',
'f4'), (
'ldbl_',
'g')])
27 from sys
import byteorder
28 e =
'<' if byteorder ==
'little' else '>' 29 return (
"{{'names':['bool_','uint_','float_','ldbl_']," 30 " 'formats':['?','" + e +
"u4','" + e +
"f4','" + e +
"f{}']," 31 " 'offsets':[0,4,8,{}], 'itemsize':{}}}")
35 ld = np.dtype(
'longdouble')
36 simple_ld_off = 12 + 4 * (ld.alignment > 4)
37 return dt_fmt().format(ld.itemsize, simple_ld_off, simple_ld_off + ld.itemsize)
41 from sys
import byteorder
42 return "[('bool_', '?'), ('uint_', '{e}u4'), ('float_', '{e}f4'), ('ldbl_', '{e}f{}')]".format(
43 np.dtype(
'longdouble').itemsize, e=
'<' if byteorder ==
'little' else '>')
47 return 12 + 4 * (np.dtype(
'uint64').alignment > 4) + 8 + 8 * (
48 np.dtype(
'longdouble').alignment > 8)
52 ld = np.dtype(
'longdouble')
54 return dt_fmt().format(ld.itemsize, partial_ld_off, partial_ld_off + ld.itemsize)
58 ld = np.dtype(
'longdouble')
59 partial_nested_off = 8 + 8 * (ld.alignment > 8)
61 partial_nested_size = partial_nested_off * 2 + partial_ld_off + ld.itemsize
62 return "{{'names':['a'], 'formats':[{}], 'offsets':[{}], 'itemsize':{}}}".format(
67 np.testing.assert_equal(actual, np.array(expected_data, dtype=expected_dtype))
71 with pytest.raises(RuntimeError)
as excinfo:
72 m.get_format_unbound()
73 assert re.match(
'^NumPy type info missing for .*UnboundStruct.*$',
str(excinfo.value))
75 ld = np.dtype(
'longdouble')
76 ldbl_fmt = (
'4x' if ld.alignment > 4
else '') + ld.char
77 ss_fmt =
"^T{?:bool_:3xI:uint_:f:float_:" + ldbl_fmt +
":ldbl_:}" 78 dbl = np.dtype(
'double')
79 partial_fmt = (
"^T{?:bool_:3xI:uint_:f:float_:" +
80 str(4 * (dbl.alignment > 4) + dbl.itemsize + 8 * (ld.alignment > 8)) +
82 nested_extra =
str(
max(8, ld.alignment))
83 assert m.print_format_descriptors() == [
85 "^T{?:bool_:I:uint_:f:float_:g:ldbl_:}",
86 "^T{" + ss_fmt +
":a:^T{?:bool_:I:uint_:f:float_:g:ldbl_:}:b:}",
88 "^T{" + nested_extra +
"x" + partial_fmt +
":a:" + nested_extra +
"x}",
90 "^T{(3)4s:a:(2)i:b:(3)B:c:1x(4, 2)f:d:}",
92 '^T{Zf:cflt:Zd:cdbl:}' 97 from sys
import byteorder
98 e =
'<' if byteorder ==
'little' else '>' 100 assert m.print_dtypes() == [
106 "[('a', 'S3'), ('b', 'S3')]",
107 (
"{{'names':['a','b','c','d'], " +
108 "'formats':[('S4', (3,)),('" + e +
"i4', (2,)),('u1', (3,)),('" + e +
"f4', (4, 2))], " +
109 "'offsets':[0,12,20,24], 'itemsize':56}}").format(e=e),
110 "[('e1', '" + e +
"i8'), ('e2', 'u1')]",
111 "[('x', 'i1'), ('y', '" + e +
"u8')]",
112 "[('cflt', '" + e +
"c8'), ('cdbl', '" + e +
"c16')]" 115 d1 = np.dtype({
'names': [
'a',
'b'],
'formats': [
'int32',
'float64'],
116 'offsets': [1, 10],
'itemsize': 20})
117 d2 = np.dtype([(
'a',
'i4'), (
'b',
'f4')])
118 assert m.test_dtype_ctors() == [np.dtype(
'int32'), np.dtype(
'float64'),
119 np.dtype(
'bool'), d1, d1, np.dtype(
'uint32'), d2]
121 assert m.test_dtype_methods() == [np.dtype(
'int32'), simple_dtype,
False,
True,
122 np.dtype(
'int32').itemsize, simple_dtype.itemsize]
124 assert m.trailing_padding_dtype() == m.buffer_to_dtype(np.zeros(1, m.trailing_padding_dtype()))
128 elements = [(
False, 0, 0.0, -0.0), (
True, 1, 1.5, -2.5), (
False, 2, 3.0, -5.0)]
130 for func, dtype
in [(m.create_rec_simple, simple_dtype), (m.create_rec_packed, packed_dtype)]:
132 assert arr.dtype == dtype
137 assert arr.dtype == dtype
141 if dtype == simple_dtype:
142 assert m.print_rec_simple(arr) == [
148 assert m.print_rec_packed(arr) == [
154 nested_dtype = np.dtype([(
'a', simple_dtype), (
'b', packed_dtype)])
156 arr = m.create_rec_nested(0)
157 assert arr.dtype == nested_dtype
160 arr = m.create_rec_nested(3)
161 assert arr.dtype == nested_dtype
162 assert_equal(arr, [((
False, 0, 0.0, -0.0), (
True, 1, 1.5, -2.5)),
163 ((
True, 1, 1.5, -2.5), (
False, 2, 3.0, -5.0)),
164 ((
False, 2, 3.0, -5.0), (
True, 3, 4.5, -7.5))], nested_dtype)
165 assert m.print_rec_nested(arr) == [
166 "n:a=s:0,0,0,-0;b=p:1,1,1.5,-2.5",
167 "n:a=s:1,1,1.5,-2.5;b=p:0,2,3,-5",
168 "n:a=s:0,2,3,-5;b=p:1,3,4.5,-7.5" 171 arr = m.create_rec_partial(3)
173 partial_dtype = arr.dtype
174 assert '' not in arr.dtype.fields
175 assert partial_dtype.itemsize > simple_dtype.itemsize
179 arr = m.create_rec_partial_nested(3)
181 assert '' not in arr.dtype.fields
182 assert '' not in arr.dtype.fields[
'a'][0].fields
183 assert arr.dtype.itemsize > partial_dtype.itemsize
184 np.testing.assert_equal(arr[
'a'], m.create_rec_partial(3))
188 data = np.arange(1, 7, dtype=
'int32')
190 np.testing.assert_array_equal(m.test_array_ctors(10 + i), data.reshape((3, 2)))
191 np.testing.assert_array_equal(m.test_array_ctors(20 + i), data.reshape((3, 2)))
193 np.testing.assert_array_equal(m.test_array_ctors(30 + i), data)
194 np.testing.assert_array_equal(m.test_array_ctors(40 + i), data)
198 arr = m.create_string_array(
True)
199 assert str(arr.dtype) ==
"[('a', 'S3'), ('b', 'S3')]" 200 assert m.print_string_array(arr) == [
207 assert arr[
'a'].tolist() == [b
'', b
'a', b
'ab', b
'abc']
208 assert arr[
'b'].tolist() == [b
'', b
'a', b
'ab', b
'abc']
209 arr = m.create_string_array(
False)
210 assert dtype == arr.dtype
214 from sys
import byteorder
215 e =
'<' if byteorder ==
'little' else '>' 217 arr = m.create_array_array(3)
218 assert str(arr.dtype) == (
219 "{{'names':['a','b','c','d'], " +
220 "'formats':[('S4', (3,)),('" + e +
"i4', (2,)),('u1', (3,)),('{e}f4', (4, 2))], " +
221 "'offsets':[0,12,20,24], 'itemsize':56}}").format(e=e)
222 assert m.print_array_array(arr) == [
223 "a={{A,B,C,D},{K,L,M,N},{U,V,W,X}},b={0,1}," +
224 "c={0,1,2},d={{0,1},{10,11},{20,21},{30,31}}",
225 "a={{W,X,Y,Z},{G,H,I,J},{Q,R,S,T}},b={1000,1001}," +
226 "c={10,11,12},d={{100,101},{110,111},{120,121},{130,131}}",
227 "a={{S,T,U,V},{C,D,E,F},{M,N,O,P}},b={2000,2001}," +
228 "c={20,21,22},d={{200,201},{210,211},{220,221},{230,231}}",
230 assert arr[
'a'].tolist() == [[b
'ABCD', b
'KLMN', b
'UVWX'],
231 [b
'WXYZ', b
'GHIJ', b
'QRST'],
232 [b
'STUV', b
'CDEF', b
'MNOP']]
233 assert arr[
'b'].tolist() == [[0, 1], [1000, 1001], [2000, 2001]]
234 assert m.create_array_array(0).dtype == arr.dtype
238 from sys
import byteorder
239 e =
'<' if byteorder ==
'little' else '>' 241 arr = m.create_enum_array(3)
243 assert dtype == np.dtype([(
'e1', e +
'i8'), (
'e2',
'u1')])
244 assert m.print_enum_array(arr) == [
249 assert arr[
'e1'].tolist() == [-1, 1, -1]
250 assert arr[
'e2'].tolist() == [1, 2, 1]
251 assert m.create_enum_array(0).dtype == dtype
255 from sys
import byteorder
256 e =
'<' if byteorder ==
'little' else '>' 258 arr = m.create_complex_array(3)
260 assert dtype == np.dtype([(
'cflt', e +
'c8'), (
'cdbl', e +
'c16')])
261 assert m.print_complex_array(arr) == [
262 "c:(0,0.25),(0.5,0.75)",
263 "c:(1,1.25),(1.5,1.75)",
264 "c:(2,2.25),(2.5,2.75)" 266 assert arr[
'cflt'].tolist() == [0.0 + 0.25j, 1.0 + 1.25j, 2.0 + 2.25j]
267 assert arr[
'cdbl'].tolist() == [0.5 + 0.75j, 1.5 + 1.75j, 2.5 + 2.75j]
268 assert m.create_complex_array(0).dtype == dtype
272 assert doc(m.create_rec_nested) == \
273 "create_rec_nested(arg0: int) -> numpy.ndarray[NestedStruct]" 278 arrays = [m.create_rec_simple(n), m.create_rec_packed(n),
279 m.create_rec_nested(n), m.create_enum_array(n)]
280 funcs = [m.f_simple, m.f_packed, m.f_nested]
282 for i, func
in enumerate(funcs):
283 for j, arr
in enumerate(arrays):
285 assert [
func(arr[k])
for k
in range(n)] == [k * 10
for k
in range(n)]
287 with pytest.raises(TypeError)
as excinfo:
289 assert 'incompatible function arguments' in str(excinfo.value)
293 with pytest.raises(RuntimeError)
as excinfo:
295 assert 'dtype is already registered' in str(excinfo.value)
298 @pytest.mark.xfail(
"env.PYPY")
300 from sys
import getrefcount
303 start = getrefcount(fmt)
304 d = m.dtype_wrapper(fmt)
305 assert d
is np.dtype(
"f4")
308 assert getrefcount(fmt) == start
312 assert all(m.compare_buffer_info())
def test_compare_buffer_info()
Annotation for documentation.
def test_register_dtype()
def test_dtype(simple_dtype)
def test_format_descriptors()
def test_array_constructors()
def assert_equal(actual, expected_data, expected_dtype)
def test_scalar_conversion()
def test_recarray(simple_dtype, packed_dtype)