agd.AutomaticDifferentiation.ad_specific
1# Copyright 2020 Jean-Marie Mirebeau, University Paris-Sud, CNRS, University Paris-Saclay 2# Distributed WITHOUT ANY WARRANTY. Licensed under the Apache License, Version 2.0, see http://www.apache.org/licenses/LICENSE-2.0 3 4import itertools 5import numpy as np 6from . import ad_generic 7from . import misc 8from . import Dense 9from . import Sparse 10from . import Dense2 11from . import Sparse2 12 13def simplify_ad(a,*args,**kwargs): 14 """ 15 Simplifies, if possible, the sparsity pattern of a sparse AD variable. 16 See Sparse.spAD.simplify_ad for detailed help. 17 """ 18 if type(a) in (Sparse.spAD,Sparse2.spAD2): 19 return a.simplify_ad(*args,**kwargs) 20 return a 21 22def apply(f,*args,**kwargs): 23 """ 24 Applies the function to the given arguments, with special treatment if the following 25 keywords : 26 - envelope : take advantage of the envelope theorem, to differentiate a min or max. 27 The function is called twice, first without AD, then with AD and the oracle parameter. 28 - shape_bound : take advantage of dense-sparse (or dense-dense) AD composition to 29 differentiate the function efficiently. The function is called with dense AD, and 30 the dimensions in shape_bound are regarded as a simple scalar. 31 - reverse_history : use the provided reverse AD trace. 32 """ 33 envelope,shape_bound,reverse_history = (kwargs.pop(s,None) 34 for s in ('envelope','shape_bound','reverse_history')) 35 if not any(ad_generic.is_ad(a) for a in itertools.chain(args,kwargs.values())): 36 return f(*args,**kwargs) 37 if envelope: 38 def to_np(a): return a.value if ad_generic.is_ad(a) else a 39 _,oracle = f(*[to_np(a) for a in args],**{key:to_np(val) for key,val in kwargs.items()}) 40 result,_ = apply(f,*args,**kwargs,oracle=oracle,envelope=False,shape_bound=shape_bound) 41 return result,oracle 42 if shape_bound is not None: 43 size_factor = np.prod(shape_bound,dtype=int) 44 t = tuple(b.reshape((b.size//size_factor,)+shape_bound) 45 for b in itertools.chain(args,kwargs.values()) if ad_generic.is_ad(b)) # Tuple containing the original AD vars 46 lens = tuple(len(b) for b in t) 47 def to_dense(b): 48 if not ad_generic.is_ad(b): return b 49 nonlocal i 50 shift = (sum(lens[:i]),sum(lens[(i+1):])) 51 i+=1 52 if type(b) in (Sparse.spAD,Dense.denseAD): 53 return Dense.identity(constant=b.value,shape_bound=shape_bound,shift=shift) 54 elif type(b) in (Sparse2.spAD2,Dense2.denseAD2): 55 return Dense2.identity(constant=b.value,shape_bound=shape_bound,shift=shift) 56 i=0 57 args2 = [to_dense(b) for b in args] 58 kwargs2 = {key:to_dense(val) for key,val in kwargs.items()} 59 result2 = f(*args2,**kwargs2) 60 return compose(result2,t,shape_bound=shape_bound) 61 if reverse_history: 62 return reverse_history.apply(f,*args,**kwargs) 63 return f(*args,**kwargs) 64 65def compose(a,t,shape_bound): 66 """Compose ad types, mostly intended for dense a and sparse b""" 67 if not isinstance(t,tuple): t=(t,) 68 if isinstance(a,tuple): 69 return tuple(compose(ai,t,shape_bound) for ai in a) 70 if not(type(a) in (Dense.denseAD,Dense2.denseAD2)) or len(t)==0: 71 return a 72 return type(t[0]).compose(a,t)
def
simplify_ad(a, *args, **kwargs):
14def simplify_ad(a,*args,**kwargs): 15 """ 16 Simplifies, if possible, the sparsity pattern of a sparse AD variable. 17 See Sparse.spAD.simplify_ad for detailed help. 18 """ 19 if type(a) in (Sparse.spAD,Sparse2.spAD2): 20 return a.simplify_ad(*args,**kwargs) 21 return a
Simplifies, if possible, the sparsity pattern of a sparse AD variable. See Sparse.spAD.simplify_ad for detailed help.
def
apply(f, *args, **kwargs):
23def apply(f,*args,**kwargs): 24 """ 25 Applies the function to the given arguments, with special treatment if the following 26 keywords : 27 - envelope : take advantage of the envelope theorem, to differentiate a min or max. 28 The function is called twice, first without AD, then with AD and the oracle parameter. 29 - shape_bound : take advantage of dense-sparse (or dense-dense) AD composition to 30 differentiate the function efficiently. The function is called with dense AD, and 31 the dimensions in shape_bound are regarded as a simple scalar. 32 - reverse_history : use the provided reverse AD trace. 33 """ 34 envelope,shape_bound,reverse_history = (kwargs.pop(s,None) 35 for s in ('envelope','shape_bound','reverse_history')) 36 if not any(ad_generic.is_ad(a) for a in itertools.chain(args,kwargs.values())): 37 return f(*args,**kwargs) 38 if envelope: 39 def to_np(a): return a.value if ad_generic.is_ad(a) else a 40 _,oracle = f(*[to_np(a) for a in args],**{key:to_np(val) for key,val in kwargs.items()}) 41 result,_ = apply(f,*args,**kwargs,oracle=oracle,envelope=False,shape_bound=shape_bound) 42 return result,oracle 43 if shape_bound is not None: 44 size_factor = np.prod(shape_bound,dtype=int) 45 t = tuple(b.reshape((b.size//size_factor,)+shape_bound) 46 for b in itertools.chain(args,kwargs.values()) if ad_generic.is_ad(b)) # Tuple containing the original AD vars 47 lens = tuple(len(b) for b in t) 48 def to_dense(b): 49 if not ad_generic.is_ad(b): return b 50 nonlocal i 51 shift = (sum(lens[:i]),sum(lens[(i+1):])) 52 i+=1 53 if type(b) in (Sparse.spAD,Dense.denseAD): 54 return Dense.identity(constant=b.value,shape_bound=shape_bound,shift=shift) 55 elif type(b) in (Sparse2.spAD2,Dense2.denseAD2): 56 return Dense2.identity(constant=b.value,shape_bound=shape_bound,shift=shift) 57 i=0 58 args2 = [to_dense(b) for b in args] 59 kwargs2 = {key:to_dense(val) for key,val in kwargs.items()} 60 result2 = f(*args2,**kwargs2) 61 return compose(result2,t,shape_bound=shape_bound) 62 if reverse_history: 63 return reverse_history.apply(f,*args,**kwargs) 64 return f(*args,**kwargs)
Applies the function to the given arguments, with special treatment if the following keywords :
- envelope : take advantage of the envelope theorem, to differentiate a min or max. The function is called twice, first without AD, then with AD and the oracle parameter.
- shape_bound : take advantage of dense-sparse (or dense-dense) AD composition to differentiate the function efficiently. The function is called with dense AD, and the dimensions in shape_bound are regarded as a simple scalar.
- reverse_history : use the provided reverse AD trace.
def
compose(a, t, shape_bound):
66def compose(a,t,shape_bound): 67 """Compose ad types, mostly intended for dense a and sparse b""" 68 if not isinstance(t,tuple): t=(t,) 69 if isinstance(a,tuple): 70 return tuple(compose(ai,t,shape_bound) for ai in a) 71 if not(type(a) in (Dense.denseAD,Dense2.denseAD2)) or len(t)==0: 72 return a 73 return type(t[0]).compose(a,t)
Compose ad types, mostly intended for dense a and sparse b