agd.AutomaticDifferentiation.Base

  1import numpy as np
  2import numbers
  3import functools
  4import operator
  5
  6from . import functional
  7
  8# ----- implementation note -----
  9# denseAD should not inherit from np.ndarray, otherwise silent casts of scalars
 10# denseAD_cupy must inherit cp.ndarray, otherwise operator overloading won't work
 11
 12
 13# import the cupy module only if available on the system
 14try: 
 15	import cupy as cp
 16	_cp_ndarray = cp.ndarray
 17except (ImportError,ModuleNotFoundError): 
 18	cp=None
 19	class _cp_ndarray: pass
 20
 21class ADCastError(ValueError):
 22	"Raised when attempting to cast between different AD types"
 23
 24# Elementary functions and their derivatives
 25# No implementation of arctan2, or hypot, which have two args
 26class Taylor1: # first order Taylor expansions
 27	def pow(x,n):	return (x**n,n*x**(n-1))
 28	def log(x): 	return (np.log(x),1./x)
 29	def exp(x): 	e=np.exp(x); return (e,e)
 30	def abs(x):	return (np.abs(x),np.sign(x))
 31	def sin(x):	return (np.sin(x),np.cos(x))
 32	def cos(x): 	return (np.cos(x),-np.sin(x))
 33	def tan(x):	t=np.tan(x); return (t,1.+t**2)
 34	def arcsin(x): return (np.arcsin(x),(1.-x**2)**-0.5)
 35	def arccos(c): return (np.arccos(x),-(1.-x**2)**-0.5)
 36	def arctan(x): return (np.arctan(x),1./(1+x**2))
 37	def sinh(x):	return (np.sinh(x),np.cosh(x))
 38	def cosh(x):	return (np.cosh(x),np.sinh(x))
 39	def tanh(x):	t=np.tanh(x); return (t,1.-t**2)
 40	def arcsinh(x): return (np.arcsinh(x),(1.+x**2)**-0.5)
 41	def arccosh(c): return (np.arccosh(x),(x**2-1.)**-0.5)
 42	def arctanh(x): return (np.arctanh(x),1./(1-x**2))
 43
 44class Taylor2: # second order Taylor expansions of classical functions
 45	def pow(x,n):	return (x**n,n*x**(n-1),(n*(n-1))*x**(n-2))
 46	def log(x):	y=1./x; return (np.log(x),y,-y**2)
 47	def exp(x): 	e=np.exp(x); return (e,e,e)
 48	def abs(x):	return (np.abs(x),np.sign(x),np.zeros_like(x))
 49	def sin(x):	s=np.sin(x); return (s,np.cos(x),-s)
 50	def cos(x):	c=np.cos(x); return (c,-np.sin(x),-c)
 51	def tan(x):	t=np.tan(x); u=1.+t**2; return (t,u,2.*u*t)
 52	def arcsin(x): y=1.-x**2; return (np.arcsin(x),y**-0.5,x*y**-1.5)
 53	def arccos(c): y=1.-x**2; return (np.arccos(x),-y**-0.5,-x*y**-1.5)
 54	def arctan(x): y=1./(1.+x**2); return (np.arctan(x),y,-2.*x*y**2)
 55	def sinh(x):	s=np.sinh(x); return (s,np.cosh(x),s)
 56	def cosh(x):	c=np.cosh(x); return (c,np.sinh(x),c)
 57	def tanh(x):	t=np.tanh(x); u=1.-t**2; return (t,u,-2.*u*t)
 58	def arcsinh(x): y=1.+x**2; return (np.arcsinh(x),y**-0.5,-x*y**-1.5)
 59	def arccosh(c): y=x**2-1.; return (np.arccosh(x),y**-0.5,-x*y**-1.5)
 60	def arctanh(x): y=1./(1-x**2); return (np.arctanh(x),y,2.*x*y**2)
 61
 62def _tuple_first(a): 	return a[0] if isinstance(a,tuple) else a
 63def _getitem(a,where):  return a if (where is True and not isndarray(a)) else a[where]
 64
 65def add(a,b,out=None,where=True): 
 66	if out is None: return a+b if is_ad(a) else b+a
 67	else: result=_tuple_first(out); result[where]=a[where]+_getitem(b,where); return result # Common failure : a+=b where b is an AD variable and a is not
 68
 69def subtract(a,b,out=None,where=True):
 70	if out is None: return a-b if is_ad(a) else b.__rsub__(a) 
 71	else: result=_tuple_first(out); result[where]=a[where]-_getitem(b,where); return result # Common failure : a-=b where b is an AD variable and a is not
 72
 73def multiply(a,b,out=None,where=True):
 74	if out is None: return a*b if is_ad(a) else b*a
 75	else: result=_tuple_first(out); result[where]=a[where]*_getitem(b,where); return result # Common failure : a*=b where b is an AD variable and a is not
 76
 77def true_divide(a,b,out=None,where=True): 
 78	if out is None: return a/b if is_ad(a) else b.__rtruediv__(a)
 79	else: result=_tuple_first(out); result[where]=a[where]/_getitem(b,where); return result # Common failure : a/=b where b is an AD variable and a is not
 80
 81def maximum(a,b): return np.where(a>b,a,b)
 82def minimum(a,b): return np.where(a<b,a,b)
 83
 84class baseAD:
 85
 86	@property
 87	def shape(self): return self.value.shape
 88	@property
 89	def ndim(self): return self.value.ndim
 90	@property
 91	def size(self): return self.value.size	
 92	def flatten(self):	return self.reshape( (self.size,) )
 93	def squeeze(self,axis=None): return self.reshape(self.value.squeeze(axis).shape)
 94	@property
 95	def T(self):	return self if self.ndim<2 else self.transpose()
 96
 97	@classmethod
 98	def stack(cls,elems,axis=0):
 99		return cls.concatenate(tuple(np.expand_dims(e,axis=axis) for e in elems),axis)
100
101	@property
102	def dtype(self): return self.value.dtype
103	def __len__(self): return len(self.value)
104	def _ndarray(self): return type(self.value)
105	def cupy_based(self): return not isinstance(self.value,np.ndarray)
106	def isndarray(self,other): return isinstance(other,self._ndarray()) # same array module
107	@classmethod
108	def is_ad(cls,other): return isinstance(other,cls)
109	@classmethod
110	def new(cls,*args,**kwargs):
111		return cls(*args,**kwargs)	
112
113	@classmethod
114	def Taylor(cls): return Taylor1 if cls.order()==1 else Taylor2
115
116	def sqrt(self):			return self**0.5
117	def __pow__(self,n):	return self._math_helper(self.Taylor().pow(self.value,n))
118	def log(self):			return self._math_helper(self.Taylor().log(self.value))
119	def exp(self):			return self._math_helper(self.Taylor().exp(self.value))
120	def abs(self):			return self._math_helper(self.Taylor().abs(self.value))
121	def sin(self):			return self._math_helper(self.Taylor().sin(self.value))
122	def cos(self):			return self._math_helper(self.Taylor().cos(self.value))
123	def tan(self):			return self._math_helper(self.Taylor().tan(self.value))
124	def arcsin(self):		return self._math_helper(self.Taylor().arcsin(self.value))
125	def arccos(self):		return self._math_helper(self.Taylor().arccos(self.value))
126	def arctan(self):		return self._math_helper(self.Taylor().arctan(self.value))
127	def sinh(self):			return self._math_helper(self.Taylor().sinh(self.value))
128	def cosh(self):			return self._math_helper(self.Taylor().cosh(self.value))
129	def tanh(self):			return self._math_helper(self.Taylor().tanh(self.value))
130	def arcsinh(self):		return self._math_helper(self.Taylor().arcsinh(self.value))
131	def arccosh(self):		return self._math_helper(self.Taylor().arccosh(self.value))
132	def arctanh(self):		return self._math_helper(self.Taylor().arctanh(self.value))
133
134		# See https://docs.scipy.org/doc/numpy/reference/ufuncs.html
135	def __array_ufunc__(self,ufunc,method,*inputs,**kwargs):
136
137		# Return an np.ndarray for piecewise constant functions
138		if ufunc in [
139		# Comparison functions
140		np.greater,np.greater_equal,
141		np.less,np.less_equal,
142		np.equal,np.not_equal,
143
144		# Math
145		np.floor_divide,np.rint,np.sign,np.heaviside,
146
147		# Floating functions
148		np.isfinite,np.isinf,np.isnan,np.isnat,
149		np.signbit,np.floor,np.ceil,np.trunc
150		]:
151			inputs_ = (a.value if self.is_ad(a) else a for a in inputs)
152			return self.value.__array_ufunc__(ufunc,method,*inputs_,**kwargs)
153
154
155		if method=="__call__":
156
157			# Reimplemented
158			if ufunc==np.maximum: return maximum(*inputs,**kwargs)
159			if ufunc==np.minimum: return minimum(*inputs,**kwargs)
160
161			# Math functions
162			if ufunc==np.sqrt:		return self.sqrt()
163			if ufunc==np.log:		return self.log()
164			if ufunc==np.exp:		return self.exp()
165			if ufunc==np.abs:		return self.abs()
166			if ufunc==np.sin:		return self.sin()
167			if ufunc==np.cos:		return self.cos()
168			if ufunc==np.tan:		return self.tan()
169			if ufunc==np.arcsin:	return self.arcsin()
170			if ufunc==np.arccos:	return self.arccos()
171			if ufunc==np.arctan:	return self.arctan()
172			if ufunc==np.sinh:		return self.sinh()
173			if ufunc==np.cosh:		return self.cosh()
174			if ufunc==np.tanh:		return self.tanh()
175			if ufunc==np.arcsinh:	return self.arcsinh()
176			if ufunc==np.arccosh:	return self.arccosh()
177			if ufunc==np.arctanh:	return self.arctanh()
178
179			# Operators
180			if ufunc==np.add: return add(*inputs,**kwargs)
181			if ufunc==np.subtract: return subtract(*inputs,**kwargs)
182			if ufunc==np.multiply: return multiply(*inputs,**kwargs)
183			if ufunc==np.true_divide: return true_divide(*inputs,**kwargs)
184
185
186		return NotImplemented
187
188	def __array_function__(self,func,types,args,kwargs):
189		return _array_function_overload(self,func,types,args,kwargs)
190
191
192	# Support for +=, -=, *=, /=, <, <=, >, >=, ==, !=
193	def __iadd__(self,other): return add(self,other,out=self)
194	def __isub__(self,other): return subtract(self,other,out=self)
195	def __imul__(self,other): return multiply(self,other,out=self)
196	def __itruediv__(self,other): return true_divide(self,other,out=self)
197
198	def __lt__(self,other): return np.less(self,other)
199	def __le__(self,other): return np.less_equal(self,other)
200	def __gt__(self,other): return np.greater(self,other)
201	def __ge__(self,other): return np.greater_equal(self,other)
202	def __eq__(self,other): return np.equal(self,other)
203	def __ne__(self,other): return np.not_equal(self,other)
204
205	def argmin(self,*args,**kwargs): return self.value.argmin(*args,**kwargs)
206	def argmax(self,*args,**kwargs): return self.value.argmax(*args,**kwargs)
207
208	def min(array,axis=None,keepdims=False,out=None):
209		if axis is None: return array.reshape(-1).min(axis=0,out=out)
210		ai = np.expand_dims(np.argmin(array.value, axis=axis), axis=axis)
211		out = np.take_along_axis(array,ai,axis=axis)
212		if not keepdims: out = out.reshape(array.shape[:axis]+array.shape[axis+1:])
213		return out
214
215	def max(array,axis=None,keepdims=False,out=None):
216		if axis is None: return array.reshape(-1).max(axis=0,out=out)
217		ai = np.expand_dims(np.argmax(array.value, axis=axis), axis=axis)
218		out = np.take_along_axis(array,ai,axis=axis)
219		if not keepdims: out = out.reshape(array.shape[:axis]+array.shape[axis+1:])
220		return out
221
222	def prod(arr,axis=None,dtype=None,out=None,keepdims=False,initial=None):
223		"""Attempt to reproduce numpy prod function. (Rather inefficiently, and I presume partially)"""
224
225		shape_orig = arr.shape
226
227		if axis is None:
228			arr = arr.flatten()
229			axis = (0,)
230		elif isinstance(axis,numbers.Number):
231			axis=(axis,)
232
233
234		if axis!=(0,):
235			d = len(axis)
236			rd = tuple(range(len(axis)))
237			arr = np.moveaxis(arr,axis,rd)
238			shape1 = (np.prod(arr.shape[d:],dtype=int),)+arr.shape[d:]
239			arr = arr.reshape(shape1)
240
241		if len(arr)==0:
242			return initial
243
244		if dtype!=arr.dtype and dtype is not None:
245			if initial is None:
246				initial = dtype(1)
247			elif dtype!=initial.dtype:
248				initial = initial*dtype(1)
249
250		out = functools.reduce(operator.mul,arr) if initial is None \
251			else functools.reduce(operator.mul,arr,initial)
252
253		if keepdims:
254			shape_kept = tuple(1 if i in axis else ai for i,ai in enumerate(shape_orig)) \
255				if out.size>1 else (1,)*len(shape_orig)
256			out = out.reshape(shape_kept) 
257
258		return out
259
260# --------- Cupy support ----------
261
262def is_ad(data,iterables=tuple()): 
263	"""Wether the object holds ad information"""
264	return any(isinstance(x,baseAD) for x in functional.rec_iter(data,iterables))
265	# return any(isinstance(x,(baseAD,baseAD_cupy)) for x in functional.rec_iter(data,iterables))
266
267def isndarray(x): 
268	"""Wether the object is a numpy or cupy ndarray, or an adtype"""
269	return isinstance(x,(np.ndarray,baseAD,_cp_ndarray))
270
271def from_cupy(x): 
272	"""Wether the variable is a cupy ndarray, or an AD type based on those"""
273	return isinstance(x,_cp_ndarray) or (isinstance(x,baseAD) and x.cupy_based())
274
275def array(a,copy=True,caster=None):
276	"""
277	Similar to np.array, but does not cast AD subclasses of np.ndarray to the base class.
278	Turns a list or tuple of arrays with the same dimensions. 
279	Turns a scalar into an array scalar.
280	Inputs : 
281	- caster : used to cast a scalar into an array scalar (overrides default)
282	"""
283	if isinstance(a,(list,tuple)) and len(a)>0: 
284		return stack([asarray(e,caster=caster) for e in a],axis=0)
285	elif isndarray(a): return a.copy() if copy else a
286	elif caster is not None: return caster(a)
287	else: return array.caster(a) 
288
289array.caster = np.asarray
290
291def asarray(a,**kwargs): return array(a,copy=False,**kwargs)
292
293def ascontiguousarray(a,**kwargs):
294	a = asarray(a,**kwargs)
295	if isinstance(a,np.ndarray): return np.ascontiguousarray(a)
296	elif isinstance(a,_cp_ndarray): return cp.ascontiguousarray(a)
297	elif is_ad(a):
298		return type(a)(*tuple(ascontiguousarray(e) for e in a.as_tuple()))
299	else: raise ValueError(f"ascontiguousarray not applicable to variable of type {type(a)}")
300
301def array_members(data,iterables=(tuple,list,dict)):
302	"""
303	Returns the list of all arrays in given structure, with their access paths.
304	Usage : for key,value in ad.Base.array_members(self): print(key,value.nbytes/2**20)
305	"""
306	arrays = []
307	def check(path,arr):
308		if isndarray(arr):
309			name = ".".join(map(str,path))
310			for namelist,value in arrays:
311				if arr.data==value.data:
312					namelist.append(name)
313					break
314			else:
315				arrays.append(([name],arr))
316
317	def check_members(prefix,data):
318		for name,value in items(data):
319			name2 = prefix+(name,)
320			if isinstance(value,iterables): check_members(name2,value)
321			else: check(name2,value)
322
323	def items(data):
324		if hasattr(data,'items'): return data.items()
325		if isinstance(data,(list,tuple)): 
326			return [(str(i),value) for i,value in enumerate(data)]
327		else: return data.__dict__.items()
328
329	check_members(tuple(),data)
330	return arrays
331
332# -------- numpy __array_function__ mechanism ---------
333
334"""
335We use the __array_function__ mechanism of numpy to reimplement 
336a number of numpy functions in a way that is compatible with AD information.
337"""
338
339#https://docs.scipy.org/doc/numpy/reference/arrays.classes.html
340numpy_overloads = {}
341cupy_alt_overloads = {} # Used for numpy function unsupported by cupy
342numpy_implementation = {# Use original numpy implementation
343	np.ndim,np.shape,np.moveaxis,np.squeeze,
344	np.amin,np.amax,np.argmin,np.argmax,np.max,np.min,
345	np.sum,np.prod,
346	np.full_like,np.ones_like,np.zeros_like,np.reshape,np.take_along_axis,
347	} 
348
349def implements(numpy_function):
350	"""Register an __array_function__ implementation for MyArray objects."""
351	def decorator(func):
352		numpy_overloads[numpy_function] = func
353		return func
354	return decorator
355
356def implements_cupy_alt(numpy_function,exception):
357	"""Register an alternative to a numpy function only partially supported by cupy"""
358	def decorator(func):
359		cupy_alt_overloads[numpy_function] = (func,exception)
360		return functional.func_except_alt(numpy_function,exception,func)
361	return decorator
362
363def _array_function_overload(self,func,types,args,kwargs,cupy_alt=True):
364	if cupy_alt and self.cupy_based() and func in cupy_alt_overloads:
365		func_alt,exception = cupy_alt_overloads[func]
366		try: return _array_function_overload(self,func,types,args,kwargs,cupy_alt=False)
367		except exception: return func_alt(*args,**kwargs)
368
369	if func in numpy_overloads:
370		return numpy_overloads[func](*args,**kwargs)
371	elif func in numpy_implementation: 
372		return func._implementation(*args,**kwargs)
373	else: return NotImplemented
374
375# ---- overloads ----
376
377@implements(np.stack)
378def stack(elems,axis=0):
379	for e in elems: 
380		if is_ad(e): return type(e).stack(elems,axis)
381	return np.stack(elems,axis)
382
383@implements(np.expand_dims)
384def expand_dims(a,axis):
385	shape = np.expand_dims(a.value,axis).shape
386	return np.reshape(a,shape)
387
388@implements(np.empty_like)
389def empty_like(a,*args,subok=True,**kwargs):
390	if from_cupy(a): subok=None
391	return type(a)(np.empty_like(a.value,*args,subok=subok,**kwargs))
392
393@implements(np.copyto)
394def copy_to(dst,src,*args,**kwargs):
395	if is_ad(src): raise ValueError("copyto is not supported with an AD source")
396	np.copyto(dst.value,src,*args,**kwargs)
397
398@implements(np.broadcast_to)
399def broadcast_to(array,shape):
400	return array.broadcast_to(shape)
401	
402@implements(np.where)
403def where(mask,a,b): 
404	A,B,Mask = (a,b,mask) if is_ad(b) else (b,a,np.logical_not(mask))
405	result = B.copy()
406	result[Mask] = A[Mask] if isndarray(A) else A
407	return result
408
409@implements(np.sort)
410def sort(array,axis=-1,*varargs,**kwargs):
411	ai = np.argsort(array.value,axis=axis,*varargs,**kwargs)
412	return np.take_along_axis(array,ai,axis=axis)
413
414@implements(np.concatenate)
415def concatenate(elems,axis=0):
416	for e in elems:
417		if is_ad(e): return type(e).concatenate(elems,axis)
418	return np.concatenate(elems,axis)	
419
420@implements(np.pad)
421def pad(array, pad_width, *args,**kwargs):
422	if isinstance(pad_width,numbers.Integral):
423		pad_width = (pad_width,)
424	if isinstance(pad_width[0],numbers.Integral) and len(pad_width)==1:
425		pad_width = ((pad_width[0],pad_width[0]),)
426	if len(pad_width)==1:
427		pad_width = pad_width*array.ndim
428	return array.pad(pad_width,*args,**kwargs)
429
430@implements(np.mean)
431def mean(array, *args, **kwargs):
432	out = np.sum(array, *args, **kwargs)
433	out *= out.size / array.size 
434	return out
435
436@implements(np.roll)
437def roll(array,shift,axis=None):
438	if axis is None:
439		if array.ndim>1:
440			raise ValueError("Unsupported None axis when ndim>=2")
441		axis=0
442	elif axis<0:
443		axis+=array.ndim
444
445	if not (0<=axis<array.ndim):
446		raise ValueError(f"Unsupported axis {axis} with ndim {array.ndim}")
447
448	return array.new(*tuple(np.roll(e,shift,axis) for e in array.as_tuple()))
449
450@implements(np.allclose)
451def allclose(a,b,*args,**kwargs): return a.allclose(b,*args,**kwargs)
class ADCastError(builtins.ValueError):
22class ADCastError(ValueError):
23	"Raised when attempting to cast between different AD types"

Raised when attempting to cast between different AD types

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
class Taylor1:
27class Taylor1: # first order Taylor expansions
28	def pow(x,n):	return (x**n,n*x**(n-1))
29	def log(x): 	return (np.log(x),1./x)
30	def exp(x): 	e=np.exp(x); return (e,e)
31	def abs(x):	return (np.abs(x),np.sign(x))
32	def sin(x):	return (np.sin(x),np.cos(x))
33	def cos(x): 	return (np.cos(x),-np.sin(x))
34	def tan(x):	t=np.tan(x); return (t,1.+t**2)
35	def arcsin(x): return (np.arcsin(x),(1.-x**2)**-0.5)
36	def arccos(c): return (np.arccos(x),-(1.-x**2)**-0.5)
37	def arctan(x): return (np.arctan(x),1./(1+x**2))
38	def sinh(x):	return (np.sinh(x),np.cosh(x))
39	def cosh(x):	return (np.cosh(x),np.sinh(x))
40	def tanh(x):	t=np.tanh(x); return (t,1.-t**2)
41	def arcsinh(x): return (np.arcsinh(x),(1.+x**2)**-0.5)
42	def arccosh(c): return (np.arccosh(x),(x**2-1.)**-0.5)
43	def arctanh(x): return (np.arctanh(x),1./(1-x**2))
def pow(x, n):
28	def pow(x,n):	return (x**n,n*x**(n-1))
def log(x):
29	def log(x): 	return (np.log(x),1./x)
def exp(x):
30	def exp(x): 	e=np.exp(x); return (e,e)
def abs(x):
31	def abs(x):	return (np.abs(x),np.sign(x))
def sin(x):
32	def sin(x):	return (np.sin(x),np.cos(x))
def cos(x):
33	def cos(x): 	return (np.cos(x),-np.sin(x))
def tan(x):
34	def tan(x):	t=np.tan(x); return (t,1.+t**2)
def arcsin(x):
35	def arcsin(x): return (np.arcsin(x),(1.-x**2)**-0.5)
def arccos(c):
36	def arccos(c): return (np.arccos(x),-(1.-x**2)**-0.5)
def arctan(x):
37	def arctan(x): return (np.arctan(x),1./(1+x**2))
def sinh(x):
38	def sinh(x):	return (np.sinh(x),np.cosh(x))
def cosh(x):
39	def cosh(x):	return (np.cosh(x),np.sinh(x))
def tanh(x):
40	def tanh(x):	t=np.tanh(x); return (t,1.-t**2)
def arcsinh(x):
41	def arcsinh(x): return (np.arcsinh(x),(1.+x**2)**-0.5)
def arccosh(c):
42	def arccosh(c): return (np.arccosh(x),(x**2-1.)**-0.5)
def arctanh(x):
43	def arctanh(x): return (np.arctanh(x),1./(1-x**2))
class Taylor2:
45class Taylor2: # second order Taylor expansions of classical functions
46	def pow(x,n):	return (x**n,n*x**(n-1),(n*(n-1))*x**(n-2))
47	def log(x):	y=1./x; return (np.log(x),y,-y**2)
48	def exp(x): 	e=np.exp(x); return (e,e,e)
49	def abs(x):	return (np.abs(x),np.sign(x),np.zeros_like(x))
50	def sin(x):	s=np.sin(x); return (s,np.cos(x),-s)
51	def cos(x):	c=np.cos(x); return (c,-np.sin(x),-c)
52	def tan(x):	t=np.tan(x); u=1.+t**2; return (t,u,2.*u*t)
53	def arcsin(x): y=1.-x**2; return (np.arcsin(x),y**-0.5,x*y**-1.5)
54	def arccos(c): y=1.-x**2; return (np.arccos(x),-y**-0.5,-x*y**-1.5)
55	def arctan(x): y=1./(1.+x**2); return (np.arctan(x),y,-2.*x*y**2)
56	def sinh(x):	s=np.sinh(x); return (s,np.cosh(x),s)
57	def cosh(x):	c=np.cosh(x); return (c,np.sinh(x),c)
58	def tanh(x):	t=np.tanh(x); u=1.-t**2; return (t,u,-2.*u*t)
59	def arcsinh(x): y=1.+x**2; return (np.arcsinh(x),y**-0.5,-x*y**-1.5)
60	def arccosh(c): y=x**2-1.; return (np.arccosh(x),y**-0.5,-x*y**-1.5)
61	def arctanh(x): y=1./(1-x**2); return (np.arctanh(x),y,2.*x*y**2)
def pow(x, n):
46	def pow(x,n):	return (x**n,n*x**(n-1),(n*(n-1))*x**(n-2))
def log(x):
47	def log(x):	y=1./x; return (np.log(x),y,-y**2)
def exp(x):
48	def exp(x): 	e=np.exp(x); return (e,e,e)
def abs(x):
49	def abs(x):	return (np.abs(x),np.sign(x),np.zeros_like(x))
def sin(x):
50	def sin(x):	s=np.sin(x); return (s,np.cos(x),-s)
def cos(x):
51	def cos(x):	c=np.cos(x); return (c,-np.sin(x),-c)
def tan(x):
52	def tan(x):	t=np.tan(x); u=1.+t**2; return (t,u,2.*u*t)
def arcsin(x):
53	def arcsin(x): y=1.-x**2; return (np.arcsin(x),y**-0.5,x*y**-1.5)
def arccos(c):
54	def arccos(c): y=1.-x**2; return (np.arccos(x),-y**-0.5,-x*y**-1.5)
def arctan(x):
55	def arctan(x): y=1./(1.+x**2); return (np.arctan(x),y,-2.*x*y**2)
def sinh(x):
56	def sinh(x):	s=np.sinh(x); return (s,np.cosh(x),s)
def cosh(x):
57	def cosh(x):	c=np.cosh(x); return (c,np.sinh(x),c)
def tanh(x):
58	def tanh(x):	t=np.tanh(x); u=1.-t**2; return (t,u,-2.*u*t)
def arcsinh(x):
59	def arcsinh(x): y=1.+x**2; return (np.arcsinh(x),y**-0.5,-x*y**-1.5)
def arccosh(c):
60	def arccosh(c): y=x**2-1.; return (np.arccosh(x),y**-0.5,-x*y**-1.5)
def arctanh(x):
61	def arctanh(x): y=1./(1-x**2); return (np.arctanh(x),y,2.*x*y**2)
def add(a, b, out=None, where=True):
66def add(a,b,out=None,where=True): 
67	if out is None: return a+b if is_ad(a) else b+a
68	else: result=_tuple_first(out); result[where]=a[where]+_getitem(b,where); return result # Common failure : a+=b where b is an AD variable and a is not
def subtract(a, b, out=None, where=True):
70def subtract(a,b,out=None,where=True):
71	if out is None: return a-b if is_ad(a) else b.__rsub__(a) 
72	else: result=_tuple_first(out); result[where]=a[where]-_getitem(b,where); return result # Common failure : a-=b where b is an AD variable and a is not
def multiply(a, b, out=None, where=True):
74def multiply(a,b,out=None,where=True):
75	if out is None: return a*b if is_ad(a) else b*a
76	else: result=_tuple_first(out); result[where]=a[where]*_getitem(b,where); return result # Common failure : a*=b where b is an AD variable and a is not
def true_divide(a, b, out=None, where=True):
78def true_divide(a,b,out=None,where=True): 
79	if out is None: return a/b if is_ad(a) else b.__rtruediv__(a)
80	else: result=_tuple_first(out); result[where]=a[where]/_getitem(b,where); return result # Common failure : a/=b where b is an AD variable and a is not
def maximum(a, b):
82def maximum(a,b): return np.where(a>b,a,b)
def minimum(a, b):
83def minimum(a,b): return np.where(a<b,a,b)
class baseAD:
 85class baseAD:
 86
 87	@property
 88	def shape(self): return self.value.shape
 89	@property
 90	def ndim(self): return self.value.ndim
 91	@property
 92	def size(self): return self.value.size	
 93	def flatten(self):	return self.reshape( (self.size,) )
 94	def squeeze(self,axis=None): return self.reshape(self.value.squeeze(axis).shape)
 95	@property
 96	def T(self):	return self if self.ndim<2 else self.transpose()
 97
 98	@classmethod
 99	def stack(cls,elems,axis=0):
100		return cls.concatenate(tuple(np.expand_dims(e,axis=axis) for e in elems),axis)
101
102	@property
103	def dtype(self): return self.value.dtype
104	def __len__(self): return len(self.value)
105	def _ndarray(self): return type(self.value)
106	def cupy_based(self): return not isinstance(self.value,np.ndarray)
107	def isndarray(self,other): return isinstance(other,self._ndarray()) # same array module
108	@classmethod
109	def is_ad(cls,other): return isinstance(other,cls)
110	@classmethod
111	def new(cls,*args,**kwargs):
112		return cls(*args,**kwargs)	
113
114	@classmethod
115	def Taylor(cls): return Taylor1 if cls.order()==1 else Taylor2
116
117	def sqrt(self):			return self**0.5
118	def __pow__(self,n):	return self._math_helper(self.Taylor().pow(self.value,n))
119	def log(self):			return self._math_helper(self.Taylor().log(self.value))
120	def exp(self):			return self._math_helper(self.Taylor().exp(self.value))
121	def abs(self):			return self._math_helper(self.Taylor().abs(self.value))
122	def sin(self):			return self._math_helper(self.Taylor().sin(self.value))
123	def cos(self):			return self._math_helper(self.Taylor().cos(self.value))
124	def tan(self):			return self._math_helper(self.Taylor().tan(self.value))
125	def arcsin(self):		return self._math_helper(self.Taylor().arcsin(self.value))
126	def arccos(self):		return self._math_helper(self.Taylor().arccos(self.value))
127	def arctan(self):		return self._math_helper(self.Taylor().arctan(self.value))
128	def sinh(self):			return self._math_helper(self.Taylor().sinh(self.value))
129	def cosh(self):			return self._math_helper(self.Taylor().cosh(self.value))
130	def tanh(self):			return self._math_helper(self.Taylor().tanh(self.value))
131	def arcsinh(self):		return self._math_helper(self.Taylor().arcsinh(self.value))
132	def arccosh(self):		return self._math_helper(self.Taylor().arccosh(self.value))
133	def arctanh(self):		return self._math_helper(self.Taylor().arctanh(self.value))
134
135		# See https://docs.scipy.org/doc/numpy/reference/ufuncs.html
136	def __array_ufunc__(self,ufunc,method,*inputs,**kwargs):
137
138		# Return an np.ndarray for piecewise constant functions
139		if ufunc in [
140		# Comparison functions
141		np.greater,np.greater_equal,
142		np.less,np.less_equal,
143		np.equal,np.not_equal,
144
145		# Math
146		np.floor_divide,np.rint,np.sign,np.heaviside,
147
148		# Floating functions
149		np.isfinite,np.isinf,np.isnan,np.isnat,
150		np.signbit,np.floor,np.ceil,np.trunc
151		]:
152			inputs_ = (a.value if self.is_ad(a) else a for a in inputs)
153			return self.value.__array_ufunc__(ufunc,method,*inputs_,**kwargs)
154
155
156		if method=="__call__":
157
158			# Reimplemented
159			if ufunc==np.maximum: return maximum(*inputs,**kwargs)
160			if ufunc==np.minimum: return minimum(*inputs,**kwargs)
161
162			# Math functions
163			if ufunc==np.sqrt:		return self.sqrt()
164			if ufunc==np.log:		return self.log()
165			if ufunc==np.exp:		return self.exp()
166			if ufunc==np.abs:		return self.abs()
167			if ufunc==np.sin:		return self.sin()
168			if ufunc==np.cos:		return self.cos()
169			if ufunc==np.tan:		return self.tan()
170			if ufunc==np.arcsin:	return self.arcsin()
171			if ufunc==np.arccos:	return self.arccos()
172			if ufunc==np.arctan:	return self.arctan()
173			if ufunc==np.sinh:		return self.sinh()
174			if ufunc==np.cosh:		return self.cosh()
175			if ufunc==np.tanh:		return self.tanh()
176			if ufunc==np.arcsinh:	return self.arcsinh()
177			if ufunc==np.arccosh:	return self.arccosh()
178			if ufunc==np.arctanh:	return self.arctanh()
179
180			# Operators
181			if ufunc==np.add: return add(*inputs,**kwargs)
182			if ufunc==np.subtract: return subtract(*inputs,**kwargs)
183			if ufunc==np.multiply: return multiply(*inputs,**kwargs)
184			if ufunc==np.true_divide: return true_divide(*inputs,**kwargs)
185
186
187		return NotImplemented
188
189	def __array_function__(self,func,types,args,kwargs):
190		return _array_function_overload(self,func,types,args,kwargs)
191
192
193	# Support for +=, -=, *=, /=, <, <=, >, >=, ==, !=
194	def __iadd__(self,other): return add(self,other,out=self)
195	def __isub__(self,other): return subtract(self,other,out=self)
196	def __imul__(self,other): return multiply(self,other,out=self)
197	def __itruediv__(self,other): return true_divide(self,other,out=self)
198
199	def __lt__(self,other): return np.less(self,other)
200	def __le__(self,other): return np.less_equal(self,other)
201	def __gt__(self,other): return np.greater(self,other)
202	def __ge__(self,other): return np.greater_equal(self,other)
203	def __eq__(self,other): return np.equal(self,other)
204	def __ne__(self,other): return np.not_equal(self,other)
205
206	def argmin(self,*args,**kwargs): return self.value.argmin(*args,**kwargs)
207	def argmax(self,*args,**kwargs): return self.value.argmax(*args,**kwargs)
208
209	def min(array,axis=None,keepdims=False,out=None):
210		if axis is None: return array.reshape(-1).min(axis=0,out=out)
211		ai = np.expand_dims(np.argmin(array.value, axis=axis), axis=axis)
212		out = np.take_along_axis(array,ai,axis=axis)
213		if not keepdims: out = out.reshape(array.shape[:axis]+array.shape[axis+1:])
214		return out
215
216	def max(array,axis=None,keepdims=False,out=None):
217		if axis is None: return array.reshape(-1).max(axis=0,out=out)
218		ai = np.expand_dims(np.argmax(array.value, axis=axis), axis=axis)
219		out = np.take_along_axis(array,ai,axis=axis)
220		if not keepdims: out = out.reshape(array.shape[:axis]+array.shape[axis+1:])
221		return out
222
223	def prod(arr,axis=None,dtype=None,out=None,keepdims=False,initial=None):
224		"""Attempt to reproduce numpy prod function. (Rather inefficiently, and I presume partially)"""
225
226		shape_orig = arr.shape
227
228		if axis is None:
229			arr = arr.flatten()
230			axis = (0,)
231		elif isinstance(axis,numbers.Number):
232			axis=(axis,)
233
234
235		if axis!=(0,):
236			d = len(axis)
237			rd = tuple(range(len(axis)))
238			arr = np.moveaxis(arr,axis,rd)
239			shape1 = (np.prod(arr.shape[d:],dtype=int),)+arr.shape[d:]
240			arr = arr.reshape(shape1)
241
242		if len(arr)==0:
243			return initial
244
245		if dtype!=arr.dtype and dtype is not None:
246			if initial is None:
247				initial = dtype(1)
248			elif dtype!=initial.dtype:
249				initial = initial*dtype(1)
250
251		out = functools.reduce(operator.mul,arr) if initial is None \
252			else functools.reduce(operator.mul,arr,initial)
253
254		if keepdims:
255			shape_kept = tuple(1 if i in axis else ai for i,ai in enumerate(shape_orig)) \
256				if out.size>1 else (1,)*len(shape_orig)
257			out = out.reshape(shape_kept) 
258
259		return out
shape
87	@property
88	def shape(self): return self.value.shape
ndim
89	@property
90	def ndim(self): return self.value.ndim
size
91	@property
92	def size(self): return self.value.size	
def flatten(self):
93	def flatten(self):	return self.reshape( (self.size,) )
def squeeze(self, axis=None):
94	def squeeze(self,axis=None): return self.reshape(self.value.squeeze(axis).shape)
T
95	@property
96	def T(self):	return self if self.ndim<2 else self.transpose()
@classmethod
def stack(cls, elems, axis=0):
 98	@classmethod
 99	def stack(cls,elems,axis=0):
100		return cls.concatenate(tuple(np.expand_dims(e,axis=axis) for e in elems),axis)
dtype
102	@property
103	def dtype(self): return self.value.dtype
def cupy_based(self):
106	def cupy_based(self): return not isinstance(self.value,np.ndarray)
def isndarray(self, other):
107	def isndarray(self,other): return isinstance(other,self._ndarray()) # same array module
@classmethod
def is_ad(cls, other):
108	@classmethod
109	def is_ad(cls,other): return isinstance(other,cls)
@classmethod
def new(cls, *args, **kwargs):
110	@classmethod
111	def new(cls,*args,**kwargs):
112		return cls(*args,**kwargs)	
@classmethod
def Taylor(cls):
114	@classmethod
115	def Taylor(cls): return Taylor1 if cls.order()==1 else Taylor2
def sqrt(self):
117	def sqrt(self):			return self**0.5
def log(self):
119	def log(self):			return self._math_helper(self.Taylor().log(self.value))
def exp(self):
120	def exp(self):			return self._math_helper(self.Taylor().exp(self.value))
def abs(self):
121	def abs(self):			return self._math_helper(self.Taylor().abs(self.value))
def sin(self):
122	def sin(self):			return self._math_helper(self.Taylor().sin(self.value))
def cos(self):
123	def cos(self):			return self._math_helper(self.Taylor().cos(self.value))
def tan(self):
124	def tan(self):			return self._math_helper(self.Taylor().tan(self.value))
def arcsin(self):
125	def arcsin(self):		return self._math_helper(self.Taylor().arcsin(self.value))
def arccos(self):
126	def arccos(self):		return self._math_helper(self.Taylor().arccos(self.value))
def arctan(self):
127	def arctan(self):		return self._math_helper(self.Taylor().arctan(self.value))
def sinh(self):
128	def sinh(self):			return self._math_helper(self.Taylor().sinh(self.value))
def cosh(self):
129	def cosh(self):			return self._math_helper(self.Taylor().cosh(self.value))
def tanh(self):
130	def tanh(self):			return self._math_helper(self.Taylor().tanh(self.value))
def arcsinh(self):
131	def arcsinh(self):		return self._math_helper(self.Taylor().arcsinh(self.value))
def arccosh(self):
132	def arccosh(self):		return self._math_helper(self.Taylor().arccosh(self.value))
def arctanh(self):
133	def arctanh(self):		return self._math_helper(self.Taylor().arctanh(self.value))
def argmin(self, *args, **kwargs):
206	def argmin(self,*args,**kwargs): return self.value.argmin(*args,**kwargs)
def argmax(self, *args, **kwargs):
207	def argmax(self,*args,**kwargs): return self.value.argmax(*args,**kwargs)
def min(array, axis=None, keepdims=False, out=None):
209	def min(array,axis=None,keepdims=False,out=None):
210		if axis is None: return array.reshape(-1).min(axis=0,out=out)
211		ai = np.expand_dims(np.argmin(array.value, axis=axis), axis=axis)
212		out = np.take_along_axis(array,ai,axis=axis)
213		if not keepdims: out = out.reshape(array.shape[:axis]+array.shape[axis+1:])
214		return out
def max(array, axis=None, keepdims=False, out=None):
216	def max(array,axis=None,keepdims=False,out=None):
217		if axis is None: return array.reshape(-1).max(axis=0,out=out)
218		ai = np.expand_dims(np.argmax(array.value, axis=axis), axis=axis)
219		out = np.take_along_axis(array,ai,axis=axis)
220		if not keepdims: out = out.reshape(array.shape[:axis]+array.shape[axis+1:])
221		return out
def prod(arr, axis=None, dtype=None, out=None, keepdims=False, initial=None):
223	def prod(arr,axis=None,dtype=None,out=None,keepdims=False,initial=None):
224		"""Attempt to reproduce numpy prod function. (Rather inefficiently, and I presume partially)"""
225
226		shape_orig = arr.shape
227
228		if axis is None:
229			arr = arr.flatten()
230			axis = (0,)
231		elif isinstance(axis,numbers.Number):
232			axis=(axis,)
233
234
235		if axis!=(0,):
236			d = len(axis)
237			rd = tuple(range(len(axis)))
238			arr = np.moveaxis(arr,axis,rd)
239			shape1 = (np.prod(arr.shape[d:],dtype=int),)+arr.shape[d:]
240			arr = arr.reshape(shape1)
241
242		if len(arr)==0:
243			return initial
244
245		if dtype!=arr.dtype and dtype is not None:
246			if initial is None:
247				initial = dtype(1)
248			elif dtype!=initial.dtype:
249				initial = initial*dtype(1)
250
251		out = functools.reduce(operator.mul,arr) if initial is None \
252			else functools.reduce(operator.mul,arr,initial)
253
254		if keepdims:
255			shape_kept = tuple(1 if i in axis else ai for i,ai in enumerate(shape_orig)) \
256				if out.size>1 else (1,)*len(shape_orig)
257			out = out.reshape(shape_kept) 
258
259		return out

Attempt to reproduce numpy prod function. (Rather inefficiently, and I presume partially)

def is_ad(data, iterables=()):
263def is_ad(data,iterables=tuple()): 
264	"""Wether the object holds ad information"""
265	return any(isinstance(x,baseAD) for x in functional.rec_iter(data,iterables))
266	# return any(isinstance(x,(baseAD,baseAD_cupy)) for x in functional.rec_iter(data,iterables))

Wether the object holds ad information

def isndarray(x):
268def isndarray(x): 
269	"""Wether the object is a numpy or cupy ndarray, or an adtype"""
270	return isinstance(x,(np.ndarray,baseAD,_cp_ndarray))

Wether the object is a numpy or cupy ndarray, or an adtype

def from_cupy(x):
272def from_cupy(x): 
273	"""Wether the variable is a cupy ndarray, or an AD type based on those"""
274	return isinstance(x,_cp_ndarray) or (isinstance(x,baseAD) and x.cupy_based())

Wether the variable is a cupy ndarray, or an AD type based on those

def array(a, copy=True, caster=None):
276def array(a,copy=True,caster=None):
277	"""
278	Similar to np.array, but does not cast AD subclasses of np.ndarray to the base class.
279	Turns a list or tuple of arrays with the same dimensions. 
280	Turns a scalar into an array scalar.
281	Inputs : 
282	- caster : used to cast a scalar into an array scalar (overrides default)
283	"""
284	if isinstance(a,(list,tuple)) and len(a)>0: 
285		return stack([asarray(e,caster=caster) for e in a],axis=0)
286	elif isndarray(a): return a.copy() if copy else a
287	elif caster is not None: return caster(a)
288	else: return array.caster(a) 

Similar to np.array, but does not cast AD subclasses of np.ndarray to the base class. Turns a list or tuple of arrays with the same dimensions. Turns a scalar into an array scalar. Inputs :

  • caster : used to cast a scalar into an array scalar (overrides default)
def asarray(a, **kwargs):
292def asarray(a,**kwargs): return array(a,copy=False,**kwargs)
def ascontiguousarray(a, **kwargs):
294def ascontiguousarray(a,**kwargs):
295	a = asarray(a,**kwargs)
296	if isinstance(a,np.ndarray): return np.ascontiguousarray(a)
297	elif isinstance(a,_cp_ndarray): return cp.ascontiguousarray(a)
298	elif is_ad(a):
299		return type(a)(*tuple(ascontiguousarray(e) for e in a.as_tuple()))
300	else: raise ValueError(f"ascontiguousarray not applicable to variable of type {type(a)}")
def array_members(data, iterables=(<class 'tuple'>, <class 'list'>, <class 'dict'>)):
302def array_members(data,iterables=(tuple,list,dict)):
303	"""
304	Returns the list of all arrays in given structure, with their access paths.
305	Usage : for key,value in ad.Base.array_members(self): print(key,value.nbytes/2**20)
306	"""
307	arrays = []
308	def check(path,arr):
309		if isndarray(arr):
310			name = ".".join(map(str,path))
311			for namelist,value in arrays:
312				if arr.data==value.data:
313					namelist.append(name)
314					break
315			else:
316				arrays.append(([name],arr))
317
318	def check_members(prefix,data):
319		for name,value in items(data):
320			name2 = prefix+(name,)
321			if isinstance(value,iterables): check_members(name2,value)
322			else: check(name2,value)
323
324	def items(data):
325		if hasattr(data,'items'): return data.items()
326		if isinstance(data,(list,tuple)): 
327			return [(str(i),value) for i,value in enumerate(data)]
328		else: return data.__dict__.items()
329
330	check_members(tuple(),data)
331	return arrays

Returns the list of all arrays in given structure, with their access paths. Usage : for key,value in ad.Base.array_members(self): print(key,value.nbytes/2**20)

numpy_overloads = {<function stack>: <function stack>, <function expand_dims>: <function expand_dims>, <function empty_like>: <function empty_like>, <function copyto>: <function copy_to>, <function broadcast_to>: <function broadcast_to>, <function where>: <function where>, <function sort>: <function sort>, <function concatenate>: <function concatenate>, <function pad>: <function pad>, <function mean>: <function mean>, <function roll>: <function roll>, <function allclose>: <function allclose>}
cupy_alt_overloads = {<function max>: (<function max>, <class 'TypeError'>), <function put_along_axis>: (<function put_along_axis>, <class 'TypeError'>), <function packbits>: (<function packbits>, <class 'TypeError'>)}
numpy_implementation = {<function take_along_axis>, <function full_like>, <function argmin>, <function ones_like>, <function ndim>, <function max>, <function reshape>, <function amin>, <function moveaxis>, <function min>, <function prod>, <function squeeze>, <function amax>, <function zeros_like>, <function sum>, <function shape>, <function argmax>}
def implements(numpy_function):
350def implements(numpy_function):
351	"""Register an __array_function__ implementation for MyArray objects."""
352	def decorator(func):
353		numpy_overloads[numpy_function] = func
354		return func
355	return decorator

Register an __array_function__ implementation for MyArray objects.

def implements_cupy_alt(numpy_function, exception):
357def implements_cupy_alt(numpy_function,exception):
358	"""Register an alternative to a numpy function only partially supported by cupy"""
359	def decorator(func):
360		cupy_alt_overloads[numpy_function] = (func,exception)
361		return functional.func_except_alt(numpy_function,exception,func)
362	return decorator

Register an alternative to a numpy function only partially supported by cupy

@implements(np.stack)
def stack(elems, axis=0):
378@implements(np.stack)
379def stack(elems,axis=0):
380	for e in elems: 
381		if is_ad(e): return type(e).stack(elems,axis)
382	return np.stack(elems,axis)
@implements(np.expand_dims)
def expand_dims(a, axis):
384@implements(np.expand_dims)
385def expand_dims(a,axis):
386	shape = np.expand_dims(a.value,axis).shape
387	return np.reshape(a,shape)
@implements(np.empty_like)
def empty_like(a, *args, subok=True, **kwargs):
389@implements(np.empty_like)
390def empty_like(a,*args,subok=True,**kwargs):
391	if from_cupy(a): subok=None
392	return type(a)(np.empty_like(a.value,*args,subok=subok,**kwargs))
@implements(np.copyto)
def copy_to(dst, src, *args, **kwargs):
394@implements(np.copyto)
395def copy_to(dst,src,*args,**kwargs):
396	if is_ad(src): raise ValueError("copyto is not supported with an AD source")
397	np.copyto(dst.value,src,*args,**kwargs)
@implements(np.broadcast_to)
def broadcast_to(array, shape):
399@implements(np.broadcast_to)
400def broadcast_to(array,shape):
401	return array.broadcast_to(shape)
@implements(np.where)
def where(mask, a, b):
403@implements(np.where)
404def where(mask,a,b): 
405	A,B,Mask = (a,b,mask) if is_ad(b) else (b,a,np.logical_not(mask))
406	result = B.copy()
407	result[Mask] = A[Mask] if isndarray(A) else A
408	return result
@implements(np.sort)
def sort(array, axis=-1, *varargs, **kwargs):
410@implements(np.sort)
411def sort(array,axis=-1,*varargs,**kwargs):
412	ai = np.argsort(array.value,axis=axis,*varargs,**kwargs)
413	return np.take_along_axis(array,ai,axis=axis)
@implements(np.concatenate)
def concatenate(elems, axis=0):
415@implements(np.concatenate)
416def concatenate(elems,axis=0):
417	for e in elems:
418		if is_ad(e): return type(e).concatenate(elems,axis)
419	return np.concatenate(elems,axis)	
@implements(np.pad)
def pad(array, pad_width, *args, **kwargs):
421@implements(np.pad)
422def pad(array, pad_width, *args,**kwargs):
423	if isinstance(pad_width,numbers.Integral):
424		pad_width = (pad_width,)
425	if isinstance(pad_width[0],numbers.Integral) and len(pad_width)==1:
426		pad_width = ((pad_width[0],pad_width[0]),)
427	if len(pad_width)==1:
428		pad_width = pad_width*array.ndim
429	return array.pad(pad_width,*args,**kwargs)
@implements(np.mean)
def mean(array, *args, **kwargs):
431@implements(np.mean)
432def mean(array, *args, **kwargs):
433	out = np.sum(array, *args, **kwargs)
434	out *= out.size / array.size 
435	return out
@implements(np.roll)
def roll(array, shift, axis=None):
437@implements(np.roll)
438def roll(array,shift,axis=None):
439	if axis is None:
440		if array.ndim>1:
441			raise ValueError("Unsupported None axis when ndim>=2")
442		axis=0
443	elif axis<0:
444		axis+=array.ndim
445
446	if not (0<=axis<array.ndim):
447		raise ValueError(f"Unsupported axis {axis} with ndim {array.ndim}")
448
449	return array.new(*tuple(np.roll(e,shift,axis) for e in array.as_tuple()))
@implements(np.allclose)
def allclose(a, b, *args, **kwargs):
451@implements(np.allclose)
452def allclose(a,b,*args,**kwargs): return a.allclose(b,*args,**kwargs)