agd.AutomaticDifferentiation.cupy_support

 1import numpy as np
 2from .Base import implements_cupy_alt,cp,is_ad
 3
 4"""
 5This file implements a few numpy functions that not well supported by the
 6cupy version (6.0, thus outdated) that is available on windows by conda at the 
 7time of writing.
 8"""
 9
10# -- NOT -- fixed in cupy 8.2
11@implements_cupy_alt(np.max,TypeError)
12def max(a,*args,**kwargs):
13	initial=kwargs.pop('initial') # cupy (old version ?) does not accept initial argument
14	return np.maximum(initial,np.max(a,*args,**kwargs))
15
16def _along_axis(arr,indices,axis):
17	axis%=arr.ndim
18	def indices_(ax):
19		if ax==axis: return indices
20		sax = arr.shape[ax]
21		ind = np.arange(sax).reshape((1,)*ax + (sax,)+(1,)*(arr.ndim-ax-1))
22		return np.broadcast_to(ind,indices.shape)
23	return tuple(indices_(ax) for ax in range(arr.ndim))
24
25# --- NOT --- fixed in cupy 8.2
26@implements_cupy_alt(np.put_along_axis,TypeError)
27def put_along_axis(arr,indices,values,axis):
28	arr[_along_axis(arr,indices,axis)]=values
29
30# -- NOT -- fixed in cupy 8.2
31@implements_cupy_alt(np.packbits,TypeError)
32def packbits(arr,bitorder='big'):
33	"""Implements bitorder option in cupy""" 
34	if bitorder=='little':
35		shape = arr.shape
36		arr = arr.reshape(-1,8)
37		arr = arr[:,::-1]
38		arr = arr.reshape(shape)
39	return cp.packbits(arr)
def packbits(unknown):

packbits(a, /, axis=None, bitorder='big')

Packs the elements of a binary-valued array into bits in a uint8 array.

The result is padded to full bytes by inserting zero bits at the end.

Parameters

a : array_like An array of integers or booleans whose elements should be packed to bits. axis : int, optional The dimension over which bit-packing is done. None implies packing the flattened array. bitorder : {'big', 'little'}, optional The order of the input bits. 'big' will mimic bin(val), [0, 0, 0, 0, 0, 0, 1, 1] => 3 = 0b00000011, 'little' will reverse the order so [1, 1, 0, 0, 0, 0, 0, 0] => 3. Defaults to 'big'.

*New in version 1.17.0.*

Returns

packed : ndarray Array of type uint8 whose elements represent bits corresponding to the logical (0 or nonzero) value of the input elements. The shape of packed has the same number of dimensions as the input (unless axis is None, in which case the output is 1-D).

See Also

unpackbits: Unpacks elements of a uint8 array into a binary-valued output array.

Examples

>>> a = np.array([[[1,0,1],
...                [0,1,0]],
...               [[1,1,0],
...                [0,0,1]]])
>>> b = np.packbits(a, axis=-1)
>>> b
array([[[160],
        [ 64]],
       [[192],
        [ 32]]], dtype=uint8)

Note that in binary 160 = 1010 0000, 64 = 0100 0000, 192 = 1100 0000, and 32 = 0010 0000.