agd.AutomaticDifferentiation.Base
1import numpy as np 2import numbers 3import functools 4import operator 5 6from . import functional 7 8# ----- implementation note ----- 9# denseAD should not inherit from np.ndarray, otherwise silent casts of scalars 10# denseAD_cupy must inherit cp.ndarray, otherwise operator overloading won't work 11 12 13# import the cupy module only if available on the system 14try: 15 import cupy as cp 16 _cp_ndarray = cp.ndarray 17except (ImportError,ModuleNotFoundError): 18 cp=None 19 class _cp_ndarray: pass 20 21class ADCastError(ValueError): 22 "Raised when attempting to cast between different AD types" 23 24# Elementary functions and their derivatives 25# No implementation of arctan2, or hypot, which have two args 26class Taylor1: # first order Taylor expansions 27 def pow(x,n): return (x**n,n*x**(n-1)) 28 def log(x): return (np.log(x),1./x) 29 def exp(x): e=np.exp(x); return (e,e) 30 def abs(x): return (np.abs(x),np.sign(x)) 31 def sin(x): return (np.sin(x),np.cos(x)) 32 def cos(x): return (np.cos(x),-np.sin(x)) 33 def tan(x): t=np.tan(x); return (t,1.+t**2) 34 def arcsin(x): return (np.arcsin(x),(1.-x**2)**-0.5) 35 def arccos(c): return (np.arccos(x),-(1.-x**2)**-0.5) 36 def arctan(x): return (np.arctan(x),1./(1+x**2)) 37 def sinh(x): return (np.sinh(x),np.cosh(x)) 38 def cosh(x): return (np.cosh(x),np.sinh(x)) 39 def tanh(x): t=np.tanh(x); return (t,1.-t**2) 40 def arcsinh(x): return (np.arcsinh(x),(1.+x**2)**-0.5) 41 def arccosh(c): return (np.arccosh(x),(x**2-1.)**-0.5) 42 def arctanh(x): return (np.arctanh(x),1./(1-x**2)) 43 44class Taylor2: # second order Taylor expansions of classical functions 45 def pow(x,n): return (x**n,n*x**(n-1),(n*(n-1))*x**(n-2)) 46 def log(x): y=1./x; return (np.log(x),y,-y**2) 47 def exp(x): e=np.exp(x); return (e,e,e) 48 def abs(x): return (np.abs(x),np.sign(x),np.zeros_like(x)) 49 def sin(x): s=np.sin(x); return (s,np.cos(x),-s) 50 def cos(x): c=np.cos(x); return (c,-np.sin(x),-c) 51 def tan(x): t=np.tan(x); u=1.+t**2; return (t,u,2.*u*t) 52 def arcsin(x): y=1.-x**2; return (np.arcsin(x),y**-0.5,x*y**-1.5) 53 def arccos(c): y=1.-x**2; return (np.arccos(x),-y**-0.5,-x*y**-1.5) 54 def arctan(x): y=1./(1.+x**2); return (np.arctan(x),y,-2.*x*y**2) 55 def sinh(x): s=np.sinh(x); return (s,np.cosh(x),s) 56 def cosh(x): c=np.cosh(x); return (c,np.sinh(x),c) 57 def tanh(x): t=np.tanh(x); u=1.-t**2; return (t,u,-2.*u*t) 58 def arcsinh(x): y=1.+x**2; return (np.arcsinh(x),y**-0.5,-x*y**-1.5) 59 def arccosh(c): y=x**2-1.; return (np.arccosh(x),y**-0.5,-x*y**-1.5) 60 def arctanh(x): y=1./(1-x**2); return (np.arctanh(x),y,2.*x*y**2) 61 62def _tuple_first(a): return a[0] if isinstance(a,tuple) else a 63def _getitem(a,where): return a if (where is True and not isndarray(a)) else a[where] 64 65def add(a,b,out=None,where=True): 66 if out is None: return a+b if is_ad(a) else b+a 67 else: result=_tuple_first(out); result[where]=a[where]+_getitem(b,where); return result # Common failure : a+=b where b is an AD variable and a is not 68 69def subtract(a,b,out=None,where=True): 70 if out is None: return a-b if is_ad(a) else b.__rsub__(a) 71 else: result=_tuple_first(out); result[where]=a[where]-_getitem(b,where); return result # Common failure : a-=b where b is an AD variable and a is not 72 73def multiply(a,b,out=None,where=True): 74 if out is None: return a*b if is_ad(a) else b*a 75 else: result=_tuple_first(out); result[where]=a[where]*_getitem(b,where); return result # Common failure : a*=b where b is an AD variable and a is not 76 77def true_divide(a,b,out=None,where=True): 78 if out is None: return a/b if is_ad(a) else b.__rtruediv__(a) 79 else: result=_tuple_first(out); result[where]=a[where]/_getitem(b,where); return result # Common failure : a/=b where b is an AD variable and a is not 80 81def maximum(a,b): return np.where(a>b,a,b) 82def minimum(a,b): return np.where(a<b,a,b) 83 84class baseAD: 85 86 @property 87 def shape(self): return self.value.shape 88 @property 89 def ndim(self): return self.value.ndim 90 @property 91 def size(self): return self.value.size 92 def flatten(self): return self.reshape( (self.size,) ) 93 def squeeze(self,axis=None): return self.reshape(self.value.squeeze(axis).shape) 94 @property 95 def T(self): return self if self.ndim<2 else self.transpose() 96 97 @classmethod 98 def stack(cls,elems,axis=0): 99 return cls.concatenate(tuple(np.expand_dims(e,axis=axis) for e in elems),axis) 100 101 @property 102 def dtype(self): return self.value.dtype 103 def __len__(self): return len(self.value) 104 def _ndarray(self): return type(self.value) 105 def cupy_based(self): return not isinstance(self.value,np.ndarray) 106 def isndarray(self,other): return isinstance(other,self._ndarray()) # same array module 107 @classmethod 108 def is_ad(cls,other): return isinstance(other,cls) 109 @classmethod 110 def new(cls,*args,**kwargs): 111 return cls(*args,**kwargs) 112 113 @classmethod 114 def Taylor(cls): return Taylor1 if cls.order()==1 else Taylor2 115 116 def sqrt(self): return self**0.5 117 def __pow__(self,n): return self._math_helper(self.Taylor().pow(self.value,n)) 118 def log(self): return self._math_helper(self.Taylor().log(self.value)) 119 def exp(self): return self._math_helper(self.Taylor().exp(self.value)) 120 def abs(self): return self._math_helper(self.Taylor().abs(self.value)) 121 def sin(self): return self._math_helper(self.Taylor().sin(self.value)) 122 def cos(self): return self._math_helper(self.Taylor().cos(self.value)) 123 def tan(self): return self._math_helper(self.Taylor().tan(self.value)) 124 def arcsin(self): return self._math_helper(self.Taylor().arcsin(self.value)) 125 def arccos(self): return self._math_helper(self.Taylor().arccos(self.value)) 126 def arctan(self): return self._math_helper(self.Taylor().arctan(self.value)) 127 def sinh(self): return self._math_helper(self.Taylor().sinh(self.value)) 128 def cosh(self): return self._math_helper(self.Taylor().cosh(self.value)) 129 def tanh(self): return self._math_helper(self.Taylor().tanh(self.value)) 130 def arcsinh(self): return self._math_helper(self.Taylor().arcsinh(self.value)) 131 def arccosh(self): return self._math_helper(self.Taylor().arccosh(self.value)) 132 def arctanh(self): return self._math_helper(self.Taylor().arctanh(self.value)) 133 134 # See https://docs.scipy.org/doc/numpy/reference/ufuncs.html 135 def __array_ufunc__(self,ufunc,method,*inputs,**kwargs): 136 137 # Return an np.ndarray for piecewise constant functions 138 if ufunc in [ 139 # Comparison functions 140 np.greater,np.greater_equal, 141 np.less,np.less_equal, 142 np.equal,np.not_equal, 143 144 # Math 145 np.floor_divide,np.rint,np.sign,np.heaviside, 146 147 # Floating functions 148 np.isfinite,np.isinf,np.isnan,np.isnat, 149 np.signbit,np.floor,np.ceil,np.trunc 150 ]: 151 inputs_ = (a.value if self.is_ad(a) else a for a in inputs) 152 return self.value.__array_ufunc__(ufunc,method,*inputs_,**kwargs) 153 154 155 if method=="__call__": 156 157 # Reimplemented 158 if ufunc==np.maximum: return maximum(*inputs,**kwargs) 159 if ufunc==np.minimum: return minimum(*inputs,**kwargs) 160 161 # Math functions 162 if ufunc==np.sqrt: return self.sqrt() 163 if ufunc==np.log: return self.log() 164 if ufunc==np.exp: return self.exp() 165 if ufunc==np.abs: return self.abs() 166 if ufunc==np.sin: return self.sin() 167 if ufunc==np.cos: return self.cos() 168 if ufunc==np.tan: return self.tan() 169 if ufunc==np.arcsin: return self.arcsin() 170 if ufunc==np.arccos: return self.arccos() 171 if ufunc==np.arctan: return self.arctan() 172 if ufunc==np.sinh: return self.sinh() 173 if ufunc==np.cosh: return self.cosh() 174 if ufunc==np.tanh: return self.tanh() 175 if ufunc==np.arcsinh: return self.arcsinh() 176 if ufunc==np.arccosh: return self.arccosh() 177 if ufunc==np.arctanh: return self.arctanh() 178 179 # Operators 180 if ufunc==np.add: return add(*inputs,**kwargs) 181 if ufunc==np.subtract: return subtract(*inputs,**kwargs) 182 if ufunc==np.multiply: return multiply(*inputs,**kwargs) 183 if ufunc==np.true_divide: return true_divide(*inputs,**kwargs) 184 185 186 return NotImplemented 187 188 def __array_function__(self,func,types,args,kwargs): 189 return _array_function_overload(self,func,types,args,kwargs) 190 191 192 # Support for +=, -=, *=, /=, <, <=, >, >=, ==, != 193 def __iadd__(self,other): return add(self,other,out=self) 194 def __isub__(self,other): return subtract(self,other,out=self) 195 def __imul__(self,other): return multiply(self,other,out=self) 196 def __itruediv__(self,other): return true_divide(self,other,out=self) 197 198 def __lt__(self,other): return np.less(self,other) 199 def __le__(self,other): return np.less_equal(self,other) 200 def __gt__(self,other): return np.greater(self,other) 201 def __ge__(self,other): return np.greater_equal(self,other) 202 def __eq__(self,other): return np.equal(self,other) 203 def __ne__(self,other): return np.not_equal(self,other) 204 205 def argmin(self,*args,**kwargs): return self.value.argmin(*args,**kwargs) 206 def argmax(self,*args,**kwargs): return self.value.argmax(*args,**kwargs) 207 208 def min(array,axis=None,keepdims=False,out=None): 209 if axis is None: return array.reshape(-1).min(axis=0,out=out) 210 ai = np.expand_dims(np.argmin(array.value, axis=axis), axis=axis) 211 out = np.take_along_axis(array,ai,axis=axis) 212 if not keepdims: out = out.reshape(array.shape[:axis]+array.shape[axis+1:]) 213 return out 214 215 def max(array,axis=None,keepdims=False,out=None): 216 if axis is None: return array.reshape(-1).max(axis=0,out=out) 217 ai = np.expand_dims(np.argmax(array.value, axis=axis), axis=axis) 218 out = np.take_along_axis(array,ai,axis=axis) 219 if not keepdims: out = out.reshape(array.shape[:axis]+array.shape[axis+1:]) 220 return out 221 222 def prod(arr,axis=None,dtype=None,out=None,keepdims=False,initial=None): 223 """Attempt to reproduce numpy prod function. (Rather inefficiently, and I presume partially)""" 224 225 shape_orig = arr.shape 226 227 if axis is None: 228 arr = arr.flatten() 229 axis = (0,) 230 elif isinstance(axis,numbers.Number): 231 axis=(axis,) 232 233 234 if axis!=(0,): 235 d = len(axis) 236 rd = tuple(range(len(axis))) 237 arr = np.moveaxis(arr,axis,rd) 238 shape1 = (np.prod(arr.shape[d:],dtype=int),)+arr.shape[d:] 239 arr = arr.reshape(shape1) 240 241 if len(arr)==0: 242 return initial 243 244 if dtype!=arr.dtype and dtype is not None: 245 if initial is None: 246 initial = dtype(1) 247 elif dtype!=initial.dtype: 248 initial = initial*dtype(1) 249 250 out = functools.reduce(operator.mul,arr) if initial is None \ 251 else functools.reduce(operator.mul,arr,initial) 252 253 if keepdims: 254 shape_kept = tuple(1 if i in axis else ai for i,ai in enumerate(shape_orig)) \ 255 if out.size>1 else (1,)*len(shape_orig) 256 out = out.reshape(shape_kept) 257 258 return out 259 260# --------- Cupy support ---------- 261 262def is_ad(data,iterables=tuple()): 263 """Wether the object holds ad information""" 264 return any(isinstance(x,baseAD) for x in functional.rec_iter(data,iterables)) 265 # return any(isinstance(x,(baseAD,baseAD_cupy)) for x in functional.rec_iter(data,iterables)) 266 267def isndarray(x): 268 """Wether the object is a numpy or cupy ndarray, or an adtype""" 269 return isinstance(x,(np.ndarray,baseAD,_cp_ndarray)) 270 271def from_cupy(x): 272 """Wether the variable is a cupy ndarray, or an AD type based on those""" 273 return isinstance(x,_cp_ndarray) or (isinstance(x,baseAD) and x.cupy_based()) 274 275def array(a,copy=True,caster=None): 276 """ 277 Similar to np.array, but does not cast AD subclasses of np.ndarray to the base class. 278 Turns a list or tuple of arrays with the same dimensions. 279 Turns a scalar into an array scalar. 280 Inputs : 281 - caster : used to cast a scalar into an array scalar (overrides default) 282 """ 283 if isinstance(a,(list,tuple)) and len(a)>0: 284 return stack([asarray(e,caster=caster) for e in a],axis=0) 285 elif isndarray(a): return a.copy() if copy else a 286 elif caster is not None: return caster(a) 287 else: return array.caster(a) 288 289array.caster = np.asarray 290 291def asarray(a,**kwargs): return array(a,copy=False,**kwargs) 292 293def ascontiguousarray(a,**kwargs): 294 a = asarray(a,**kwargs) 295 if isinstance(a,np.ndarray): return np.ascontiguousarray(a) 296 elif isinstance(a,_cp_ndarray): return cp.ascontiguousarray(a) 297 elif is_ad(a): 298 return type(a)(*tuple(ascontiguousarray(e) for e in a.as_tuple())) 299 else: raise ValueError(f"ascontiguousarray not applicable to variable of type {type(a)}") 300 301def array_members(data,iterables=(tuple,list,dict)): 302 """ 303 Returns the list of all arrays in given structure, with their access paths. 304 Usage : for key,value in ad.Base.array_members(self): print(key,value.nbytes/2**20) 305 """ 306 arrays = [] 307 def check(path,arr): 308 if isndarray(arr): 309 name = ".".join(map(str,path)) 310 for namelist,value in arrays: 311 if arr.data==value.data: 312 namelist.append(name) 313 break 314 else: 315 arrays.append(([name],arr)) 316 317 def check_members(prefix,data): 318 for name,value in items(data): 319 name2 = prefix+(name,) 320 if isinstance(value,iterables): check_members(name2,value) 321 else: check(name2,value) 322 323 def items(data): 324 if hasattr(data,'items'): return data.items() 325 if isinstance(data,(list,tuple)): 326 return [(str(i),value) for i,value in enumerate(data)] 327 else: return data.__dict__.items() 328 329 check_members(tuple(),data) 330 return arrays 331 332# -------- numpy __array_function__ mechanism --------- 333 334""" 335We use the __array_function__ mechanism of numpy to reimplement 336a number of numpy functions in a way that is compatible with AD information. 337""" 338 339#https://docs.scipy.org/doc/numpy/reference/arrays.classes.html 340numpy_overloads = {} 341cupy_alt_overloads = {} # Used for numpy function unsupported by cupy 342numpy_implementation = {# Use original numpy implementation 343 np.ndim,np.shape,np.moveaxis,np.squeeze, 344 np.amin,np.amax,np.argmin,np.argmax,np.max,np.min, 345 np.sum,np.prod, 346 np.full_like,np.ones_like,np.zeros_like,np.reshape,np.take_along_axis, 347 } 348 349def implements(numpy_function): 350 """Register an __array_function__ implementation for MyArray objects.""" 351 def decorator(func): 352 numpy_overloads[numpy_function] = func 353 return func 354 return decorator 355 356def implements_cupy_alt(numpy_function,exception): 357 """Register an alternative to a numpy function only partially supported by cupy""" 358 def decorator(func): 359 cupy_alt_overloads[numpy_function] = (func,exception) 360 return functional.func_except_alt(numpy_function,exception,func) 361 return decorator 362 363def _array_function_overload(self,func,types,args,kwargs,cupy_alt=True): 364 if cupy_alt and self.cupy_based() and func in cupy_alt_overloads: 365 func_alt,exception = cupy_alt_overloads[func] 366 try: return _array_function_overload(self,func,types,args,kwargs,cupy_alt=False) 367 except exception: return func_alt(*args,**kwargs) 368 369 if func in numpy_overloads: 370 return numpy_overloads[func](*args,**kwargs) 371 elif func in numpy_implementation: 372 return func._implementation(*args,**kwargs) 373 else: return NotImplemented 374 375# ---- overloads ---- 376 377@implements(np.stack) 378def stack(elems,axis=0): 379 for e in elems: 380 if is_ad(e): return type(e).stack(elems,axis) 381 return np.stack(elems,axis) 382 383@implements(np.expand_dims) 384def expand_dims(a,axis): 385 shape = np.expand_dims(a.value,axis).shape 386 return np.reshape(a,shape) 387 388@implements(np.empty_like) 389def empty_like(a,*args,subok=True,**kwargs): 390 if from_cupy(a): subok=None 391 return type(a)(np.empty_like(a.value,*args,subok=subok,**kwargs)) 392 393@implements(np.copyto) 394def copy_to(dst,src,*args,**kwargs): 395 if is_ad(src): raise ValueError("copyto is not supported with an AD source") 396 np.copyto(dst.value,src,*args,**kwargs) 397 398@implements(np.broadcast_to) 399def broadcast_to(array,shape): 400 return array.broadcast_to(shape) 401 402@implements(np.where) 403def where(mask,a,b): 404 A,B,Mask = (a,b,mask) if is_ad(b) else (b,a,np.logical_not(mask)) 405 result = B.copy() 406 result[Mask] = A[Mask] if isndarray(A) else A 407 return result 408 409@implements(np.sort) 410def sort(array,axis=-1,*varargs,**kwargs): 411 ai = np.argsort(array.value,axis=axis,*varargs,**kwargs) 412 return np.take_along_axis(array,ai,axis=axis) 413 414@implements(np.concatenate) 415def concatenate(elems,axis=0): 416 for e in elems: 417 if is_ad(e): return type(e).concatenate(elems,axis) 418 return np.concatenate(elems,axis) 419 420@implements(np.pad) 421def pad(array, pad_width, *args,**kwargs): 422 if isinstance(pad_width,numbers.Integral): 423 pad_width = (pad_width,) 424 if isinstance(pad_width[0],numbers.Integral) and len(pad_width)==1: 425 pad_width = ((pad_width[0],pad_width[0]),) 426 if len(pad_width)==1: 427 pad_width = pad_width*array.ndim 428 return array.pad(pad_width,*args,**kwargs) 429 430@implements(np.mean) 431def mean(array, *args, **kwargs): 432 out = np.sum(array, *args, **kwargs) 433 out *= out.size / array.size 434 return out 435 436@implements(np.roll) 437def roll(array,shift,axis=None): 438 if axis is None: 439 if array.ndim>1: 440 raise ValueError("Unsupported None axis when ndim>=2") 441 axis=0 442 elif axis<0: 443 axis+=array.ndim 444 445 if not (0<=axis<array.ndim): 446 raise ValueError(f"Unsupported axis {axis} with ndim {array.ndim}") 447 448 return array.new(*tuple(np.roll(e,shift,axis) for e in array.as_tuple())) 449 450@implements(np.allclose) 451def allclose(a,b,*args,**kwargs): return a.allclose(b,*args,**kwargs)
Raised when attempting to cast between different AD types
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
27class Taylor1: # first order Taylor expansions 28 def pow(x,n): return (x**n,n*x**(n-1)) 29 def log(x): return (np.log(x),1./x) 30 def exp(x): e=np.exp(x); return (e,e) 31 def abs(x): return (np.abs(x),np.sign(x)) 32 def sin(x): return (np.sin(x),np.cos(x)) 33 def cos(x): return (np.cos(x),-np.sin(x)) 34 def tan(x): t=np.tan(x); return (t,1.+t**2) 35 def arcsin(x): return (np.arcsin(x),(1.-x**2)**-0.5) 36 def arccos(c): return (np.arccos(x),-(1.-x**2)**-0.5) 37 def arctan(x): return (np.arctan(x),1./(1+x**2)) 38 def sinh(x): return (np.sinh(x),np.cosh(x)) 39 def cosh(x): return (np.cosh(x),np.sinh(x)) 40 def tanh(x): t=np.tanh(x); return (t,1.-t**2) 41 def arcsinh(x): return (np.arcsinh(x),(1.+x**2)**-0.5) 42 def arccosh(c): return (np.arccosh(x),(x**2-1.)**-0.5) 43 def arctanh(x): return (np.arctanh(x),1./(1-x**2))
45class Taylor2: # second order Taylor expansions of classical functions 46 def pow(x,n): return (x**n,n*x**(n-1),(n*(n-1))*x**(n-2)) 47 def log(x): y=1./x; return (np.log(x),y,-y**2) 48 def exp(x): e=np.exp(x); return (e,e,e) 49 def abs(x): return (np.abs(x),np.sign(x),np.zeros_like(x)) 50 def sin(x): s=np.sin(x); return (s,np.cos(x),-s) 51 def cos(x): c=np.cos(x); return (c,-np.sin(x),-c) 52 def tan(x): t=np.tan(x); u=1.+t**2; return (t,u,2.*u*t) 53 def arcsin(x): y=1.-x**2; return (np.arcsin(x),y**-0.5,x*y**-1.5) 54 def arccos(c): y=1.-x**2; return (np.arccos(x),-y**-0.5,-x*y**-1.5) 55 def arctan(x): y=1./(1.+x**2); return (np.arctan(x),y,-2.*x*y**2) 56 def sinh(x): s=np.sinh(x); return (s,np.cosh(x),s) 57 def cosh(x): c=np.cosh(x); return (c,np.sinh(x),c) 58 def tanh(x): t=np.tanh(x); u=1.-t**2; return (t,u,-2.*u*t) 59 def arcsinh(x): y=1.+x**2; return (np.arcsinh(x),y**-0.5,-x*y**-1.5) 60 def arccosh(c): y=x**2-1.; return (np.arccosh(x),y**-0.5,-x*y**-1.5) 61 def arctanh(x): y=1./(1-x**2); return (np.arctanh(x),y,2.*x*y**2)
82def maximum(a,b): return np.where(a>b,a,b)
83def minimum(a,b): return np.where(a<b,a,b)
85class baseAD: 86 87 @property 88 def shape(self): return self.value.shape 89 @property 90 def ndim(self): return self.value.ndim 91 @property 92 def size(self): return self.value.size 93 def flatten(self): return self.reshape( (self.size,) ) 94 def squeeze(self,axis=None): return self.reshape(self.value.squeeze(axis).shape) 95 @property 96 def T(self): return self if self.ndim<2 else self.transpose() 97 98 @classmethod 99 def stack(cls,elems,axis=0): 100 return cls.concatenate(tuple(np.expand_dims(e,axis=axis) for e in elems),axis) 101 102 @property 103 def dtype(self): return self.value.dtype 104 def __len__(self): return len(self.value) 105 def _ndarray(self): return type(self.value) 106 def cupy_based(self): return not isinstance(self.value,np.ndarray) 107 def isndarray(self,other): return isinstance(other,self._ndarray()) # same array module 108 @classmethod 109 def is_ad(cls,other): return isinstance(other,cls) 110 @classmethod 111 def new(cls,*args,**kwargs): 112 return cls(*args,**kwargs) 113 114 @classmethod 115 def Taylor(cls): return Taylor1 if cls.order()==1 else Taylor2 116 117 def sqrt(self): return self**0.5 118 def __pow__(self,n): return self._math_helper(self.Taylor().pow(self.value,n)) 119 def log(self): return self._math_helper(self.Taylor().log(self.value)) 120 def exp(self): return self._math_helper(self.Taylor().exp(self.value)) 121 def abs(self): return self._math_helper(self.Taylor().abs(self.value)) 122 def sin(self): return self._math_helper(self.Taylor().sin(self.value)) 123 def cos(self): return self._math_helper(self.Taylor().cos(self.value)) 124 def tan(self): return self._math_helper(self.Taylor().tan(self.value)) 125 def arcsin(self): return self._math_helper(self.Taylor().arcsin(self.value)) 126 def arccos(self): return self._math_helper(self.Taylor().arccos(self.value)) 127 def arctan(self): return self._math_helper(self.Taylor().arctan(self.value)) 128 def sinh(self): return self._math_helper(self.Taylor().sinh(self.value)) 129 def cosh(self): return self._math_helper(self.Taylor().cosh(self.value)) 130 def tanh(self): return self._math_helper(self.Taylor().tanh(self.value)) 131 def arcsinh(self): return self._math_helper(self.Taylor().arcsinh(self.value)) 132 def arccosh(self): return self._math_helper(self.Taylor().arccosh(self.value)) 133 def arctanh(self): return self._math_helper(self.Taylor().arctanh(self.value)) 134 135 # See https://docs.scipy.org/doc/numpy/reference/ufuncs.html 136 def __array_ufunc__(self,ufunc,method,*inputs,**kwargs): 137 138 # Return an np.ndarray for piecewise constant functions 139 if ufunc in [ 140 # Comparison functions 141 np.greater,np.greater_equal, 142 np.less,np.less_equal, 143 np.equal,np.not_equal, 144 145 # Math 146 np.floor_divide,np.rint,np.sign,np.heaviside, 147 148 # Floating functions 149 np.isfinite,np.isinf,np.isnan,np.isnat, 150 np.signbit,np.floor,np.ceil,np.trunc 151 ]: 152 inputs_ = (a.value if self.is_ad(a) else a for a in inputs) 153 return self.value.__array_ufunc__(ufunc,method,*inputs_,**kwargs) 154 155 156 if method=="__call__": 157 158 # Reimplemented 159 if ufunc==np.maximum: return maximum(*inputs,**kwargs) 160 if ufunc==np.minimum: return minimum(*inputs,**kwargs) 161 162 # Math functions 163 if ufunc==np.sqrt: return self.sqrt() 164 if ufunc==np.log: return self.log() 165 if ufunc==np.exp: return self.exp() 166 if ufunc==np.abs: return self.abs() 167 if ufunc==np.sin: return self.sin() 168 if ufunc==np.cos: return self.cos() 169 if ufunc==np.tan: return self.tan() 170 if ufunc==np.arcsin: return self.arcsin() 171 if ufunc==np.arccos: return self.arccos() 172 if ufunc==np.arctan: return self.arctan() 173 if ufunc==np.sinh: return self.sinh() 174 if ufunc==np.cosh: return self.cosh() 175 if ufunc==np.tanh: return self.tanh() 176 if ufunc==np.arcsinh: return self.arcsinh() 177 if ufunc==np.arccosh: return self.arccosh() 178 if ufunc==np.arctanh: return self.arctanh() 179 180 # Operators 181 if ufunc==np.add: return add(*inputs,**kwargs) 182 if ufunc==np.subtract: return subtract(*inputs,**kwargs) 183 if ufunc==np.multiply: return multiply(*inputs,**kwargs) 184 if ufunc==np.true_divide: return true_divide(*inputs,**kwargs) 185 186 187 return NotImplemented 188 189 def __array_function__(self,func,types,args,kwargs): 190 return _array_function_overload(self,func,types,args,kwargs) 191 192 193 # Support for +=, -=, *=, /=, <, <=, >, >=, ==, != 194 def __iadd__(self,other): return add(self,other,out=self) 195 def __isub__(self,other): return subtract(self,other,out=self) 196 def __imul__(self,other): return multiply(self,other,out=self) 197 def __itruediv__(self,other): return true_divide(self,other,out=self) 198 199 def __lt__(self,other): return np.less(self,other) 200 def __le__(self,other): return np.less_equal(self,other) 201 def __gt__(self,other): return np.greater(self,other) 202 def __ge__(self,other): return np.greater_equal(self,other) 203 def __eq__(self,other): return np.equal(self,other) 204 def __ne__(self,other): return np.not_equal(self,other) 205 206 def argmin(self,*args,**kwargs): return self.value.argmin(*args,**kwargs) 207 def argmax(self,*args,**kwargs): return self.value.argmax(*args,**kwargs) 208 209 def min(array,axis=None,keepdims=False,out=None): 210 if axis is None: return array.reshape(-1).min(axis=0,out=out) 211 ai = np.expand_dims(np.argmin(array.value, axis=axis), axis=axis) 212 out = np.take_along_axis(array,ai,axis=axis) 213 if not keepdims: out = out.reshape(array.shape[:axis]+array.shape[axis+1:]) 214 return out 215 216 def max(array,axis=None,keepdims=False,out=None): 217 if axis is None: return array.reshape(-1).max(axis=0,out=out) 218 ai = np.expand_dims(np.argmax(array.value, axis=axis), axis=axis) 219 out = np.take_along_axis(array,ai,axis=axis) 220 if not keepdims: out = out.reshape(array.shape[:axis]+array.shape[axis+1:]) 221 return out 222 223 def prod(arr,axis=None,dtype=None,out=None,keepdims=False,initial=None): 224 """Attempt to reproduce numpy prod function. (Rather inefficiently, and I presume partially)""" 225 226 shape_orig = arr.shape 227 228 if axis is None: 229 arr = arr.flatten() 230 axis = (0,) 231 elif isinstance(axis,numbers.Number): 232 axis=(axis,) 233 234 235 if axis!=(0,): 236 d = len(axis) 237 rd = tuple(range(len(axis))) 238 arr = np.moveaxis(arr,axis,rd) 239 shape1 = (np.prod(arr.shape[d:],dtype=int),)+arr.shape[d:] 240 arr = arr.reshape(shape1) 241 242 if len(arr)==0: 243 return initial 244 245 if dtype!=arr.dtype and dtype is not None: 246 if initial is None: 247 initial = dtype(1) 248 elif dtype!=initial.dtype: 249 initial = initial*dtype(1) 250 251 out = functools.reduce(operator.mul,arr) if initial is None \ 252 else functools.reduce(operator.mul,arr,initial) 253 254 if keepdims: 255 shape_kept = tuple(1 if i in axis else ai for i,ai in enumerate(shape_orig)) \ 256 if out.size>1 else (1,)*len(shape_orig) 257 out = out.reshape(shape_kept) 258 259 return out
94 def squeeze(self,axis=None): return self.reshape(self.value.squeeze(axis).shape)
107 def isndarray(self,other): return isinstance(other,self._ndarray()) # same array module
131 def arcsinh(self): return self._math_helper(self.Taylor().arcsinh(self.value))
132 def arccosh(self): return self._math_helper(self.Taylor().arccosh(self.value))
133 def arctanh(self): return self._math_helper(self.Taylor().arctanh(self.value))
206 def argmin(self,*args,**kwargs): return self.value.argmin(*args,**kwargs)
207 def argmax(self,*args,**kwargs): return self.value.argmax(*args,**kwargs)
209 def min(array,axis=None,keepdims=False,out=None): 210 if axis is None: return array.reshape(-1).min(axis=0,out=out) 211 ai = np.expand_dims(np.argmin(array.value, axis=axis), axis=axis) 212 out = np.take_along_axis(array,ai,axis=axis) 213 if not keepdims: out = out.reshape(array.shape[:axis]+array.shape[axis+1:]) 214 return out
216 def max(array,axis=None,keepdims=False,out=None): 217 if axis is None: return array.reshape(-1).max(axis=0,out=out) 218 ai = np.expand_dims(np.argmax(array.value, axis=axis), axis=axis) 219 out = np.take_along_axis(array,ai,axis=axis) 220 if not keepdims: out = out.reshape(array.shape[:axis]+array.shape[axis+1:]) 221 return out
223 def prod(arr,axis=None,dtype=None,out=None,keepdims=False,initial=None): 224 """Attempt to reproduce numpy prod function. (Rather inefficiently, and I presume partially)""" 225 226 shape_orig = arr.shape 227 228 if axis is None: 229 arr = arr.flatten() 230 axis = (0,) 231 elif isinstance(axis,numbers.Number): 232 axis=(axis,) 233 234 235 if axis!=(0,): 236 d = len(axis) 237 rd = tuple(range(len(axis))) 238 arr = np.moveaxis(arr,axis,rd) 239 shape1 = (np.prod(arr.shape[d:],dtype=int),)+arr.shape[d:] 240 arr = arr.reshape(shape1) 241 242 if len(arr)==0: 243 return initial 244 245 if dtype!=arr.dtype and dtype is not None: 246 if initial is None: 247 initial = dtype(1) 248 elif dtype!=initial.dtype: 249 initial = initial*dtype(1) 250 251 out = functools.reduce(operator.mul,arr) if initial is None \ 252 else functools.reduce(operator.mul,arr,initial) 253 254 if keepdims: 255 shape_kept = tuple(1 if i in axis else ai for i,ai in enumerate(shape_orig)) \ 256 if out.size>1 else (1,)*len(shape_orig) 257 out = out.reshape(shape_kept) 258 259 return out
Attempt to reproduce numpy prod function. (Rather inefficiently, and I presume partially)
263def is_ad(data,iterables=tuple()): 264 """Wether the object holds ad information""" 265 return any(isinstance(x,baseAD) for x in functional.rec_iter(data,iterables)) 266 # return any(isinstance(x,(baseAD,baseAD_cupy)) for x in functional.rec_iter(data,iterables))
Wether the object holds ad information
268def isndarray(x): 269 """Wether the object is a numpy or cupy ndarray, or an adtype""" 270 return isinstance(x,(np.ndarray,baseAD,_cp_ndarray))
Wether the object is a numpy or cupy ndarray, or an adtype
272def from_cupy(x): 273 """Wether the variable is a cupy ndarray, or an AD type based on those""" 274 return isinstance(x,_cp_ndarray) or (isinstance(x,baseAD) and x.cupy_based())
Wether the variable is a cupy ndarray, or an AD type based on those
276def array(a,copy=True,caster=None): 277 """ 278 Similar to np.array, but does not cast AD subclasses of np.ndarray to the base class. 279 Turns a list or tuple of arrays with the same dimensions. 280 Turns a scalar into an array scalar. 281 Inputs : 282 - caster : used to cast a scalar into an array scalar (overrides default) 283 """ 284 if isinstance(a,(list,tuple)) and len(a)>0: 285 return stack([asarray(e,caster=caster) for e in a],axis=0) 286 elif isndarray(a): return a.copy() if copy else a 287 elif caster is not None: return caster(a) 288 else: return array.caster(a)
Similar to np.array, but does not cast AD subclasses of np.ndarray to the base class. Turns a list or tuple of arrays with the same dimensions. Turns a scalar into an array scalar. Inputs :
- caster : used to cast a scalar into an array scalar (overrides default)
292def asarray(a,**kwargs): return array(a,copy=False,**kwargs)
294def ascontiguousarray(a,**kwargs): 295 a = asarray(a,**kwargs) 296 if isinstance(a,np.ndarray): return np.ascontiguousarray(a) 297 elif isinstance(a,_cp_ndarray): return cp.ascontiguousarray(a) 298 elif is_ad(a): 299 return type(a)(*tuple(ascontiguousarray(e) for e in a.as_tuple())) 300 else: raise ValueError(f"ascontiguousarray not applicable to variable of type {type(a)}")
302def array_members(data,iterables=(tuple,list,dict)): 303 """ 304 Returns the list of all arrays in given structure, with their access paths. 305 Usage : for key,value in ad.Base.array_members(self): print(key,value.nbytes/2**20) 306 """ 307 arrays = [] 308 def check(path,arr): 309 if isndarray(arr): 310 name = ".".join(map(str,path)) 311 for namelist,value in arrays: 312 if arr.data==value.data: 313 namelist.append(name) 314 break 315 else: 316 arrays.append(([name],arr)) 317 318 def check_members(prefix,data): 319 for name,value in items(data): 320 name2 = prefix+(name,) 321 if isinstance(value,iterables): check_members(name2,value) 322 else: check(name2,value) 323 324 def items(data): 325 if hasattr(data,'items'): return data.items() 326 if isinstance(data,(list,tuple)): 327 return [(str(i),value) for i,value in enumerate(data)] 328 else: return data.__dict__.items() 329 330 check_members(tuple(),data) 331 return arrays
Returns the list of all arrays in given structure, with their access paths. Usage : for key,value in ad.Base.array_members(self): print(key,value.nbytes/2**20)
350def implements(numpy_function): 351 """Register an __array_function__ implementation for MyArray objects.""" 352 def decorator(func): 353 numpy_overloads[numpy_function] = func 354 return func 355 return decorator
Register an __array_function__ implementation for MyArray objects.
357def implements_cupy_alt(numpy_function,exception): 358 """Register an alternative to a numpy function only partially supported by cupy""" 359 def decorator(func): 360 cupy_alt_overloads[numpy_function] = (func,exception) 361 return functional.func_except_alt(numpy_function,exception,func) 362 return decorator
Register an alternative to a numpy function only partially supported by cupy
421@implements(np.pad) 422def pad(array, pad_width, *args,**kwargs): 423 if isinstance(pad_width,numbers.Integral): 424 pad_width = (pad_width,) 425 if isinstance(pad_width[0],numbers.Integral) and len(pad_width)==1: 426 pad_width = ((pad_width[0],pad_width[0]),) 427 if len(pad_width)==1: 428 pad_width = pad_width*array.ndim 429 return array.pad(pad_width,*args,**kwargs)
437@implements(np.roll) 438def roll(array,shift,axis=None): 439 if axis is None: 440 if array.ndim>1: 441 raise ValueError("Unsupported None axis when ndim>=2") 442 axis=0 443 elif axis<0: 444 axis+=array.ndim 445 446 if not (0<=axis<array.ndim): 447 raise ValueError(f"Unsupported axis {axis} with ndim {array.ndim}") 448 449 return array.new(*tuple(np.roll(e,shift,axis) for e in array.as_tuple()))