agd.AutomaticDifferentiation.Reverse2
1# Copyright 2020 Jean-Marie Mirebeau, University Paris-Sud, CNRS, University Paris-Saclay 2# Distributed WITHOUT ANY WARRANTY. Licensed under the Apache License, Version 2.0, see http://www.apache.org/licenses/LICENSE-2.0 3 4import numpy as np 5import copy 6from . import functional 7from . import misc 8from . import Dense 9from . import Sparse 10from . import Reverse 11from . import Sparse2 12from .cupy_generic import isndarray 13 14class reverseAD2(object): 15 """ 16 A class for reverse second order automatic differentiation 17 """ 18 19 def __init__(self,operator_data=None,input_iterables=None,output_iterables=None): 20 self.operator_data=operator_data 21 self.deepcopy_states = False 22 23 self.input_iterables = (tuple,) if input_iterables is None else input_iterables 24 self.output_iterables = (tuple,) if output_iterables is None else output_iterables 25 assert hasattr(self.input_iterables,'__iter__') and hasattr(self.output_iterables,'__iter__') 26 27 self._size_ad = 0 28 self._size_rev = 0 29 self._states = [] 30 self._shapes_ad = tuple() 31 32 @property 33 def size_ad(self): return self._size_ad 34 @property 35 def size_rev(self): return self._size_rev 36 37 # Variable creation 38 def register(self,a): 39 return misc.register(self.identity,a,self.input_iterables) 40 41 def identity(self,*args,**kwargs): 42 """Creates and registers a new AD variable""" 43 assert (self.operator_data is None) or kwargs.pop("operator_initialization",False) 44 result = Sparse2.identity(*args,**kwargs,shift=self.size_ad) 45 self._shapes_ad += (functional.pair(self.size_ad,result.shape),) 46 self._size_ad += result.size 47 return result 48 49 def _identity_rev(self,*args,**kwargs): 50 """Creates and register an AD variable with negative indices, 51 used as placeholders in reverse AD""" 52 result = Sparse2.identity(*args,**kwargs,shift=self.size_rev) 53 self._size_rev += result.size 54 result.index = -result.index-1 55 return result 56 57 def _index_rev(self,index): 58 """Turns the negative placeholder indices into positive ones, 59 for sparse matrix creation.""" 60 index=index.copy() 61 pos = index<0 62 index[pos] = -index[pos]-1+self.size_ad 63 return index 64 65 def apply(self,func,*args,**kwargs): 66 """ 67 Applies a function on the given args, saving adequate data 68 for reverse AD. 69 """ 70 if self.operator_data == "PassThrough": return func(*args,**kwargs) 71 _args,_kwargs,corresp = misc._apply_input_helper(args,kwargs,Sparse2.spAD2,self.input_iterables) 72 if len(corresp)==0: return f(args,kwargs) 73 _output = func(*_args,**_kwargs) 74 output,shapes = misc._apply_output_helper(self,_output,self.output_iterables) 75 self._states.append((shapes,func, 76 copy.deepcopy(args) if self.deepcopy_states else args, 77 copy.deepcopy(kwargs) if self.deepcopy_states else kwargs)) 78 return output 79 80 def apply_linear_mapping(self,matrix,rhs,niter=1): 81 return self.apply(Reverse.linear_mapping_with_adjoint(matrix,niter=niter),rhs) 82 def apply_linear_inverse(self,matrix,solver,rhs,niter=1): 83 return self.apply(Reverse.linear_inverse_with_adjoint(matrix,solver,niter=niter),rhs) 84 def simplify(self,rhs): 85 return self.apply(Reverse.identity_with_adjoint,rhs) 86 87 # Adjoint evaluation pass 88 def gradient(self,a): 89 """Computes the gradient of the scalar spAD2 variable a""" 90 assert(isinstance(a,Sparse2.spAD2) and a.shape==tuple()) 91 coef = Sparse.spAD(a.value,a.coef1,self._index_rev(a.index)).to_dense().coef 92 for outputshapes,func,args,kwargs in reversed(self._states): 93 co_output_value = misc._to_shapes(coef[self.size_ad:],outputshapes,self.output_iterables) 94 _args,_kwargs,corresp = misc._apply_input_helper(args,kwargs,Sparse2.spAD2,self.input_iterables) 95 co_arg_request = [a for _,a in corresp] 96 co_args = func(*_args,**_kwargs,co_output=functional.pair(co_output_value,co_arg_request)) 97 for a_sparse,a_value2 in corresp: 98 found = False 99 for a_value,a_adjoint in co_args: 100 if a_value is a_value2: 101 val,(row,col) = a_sparse.to_first().triplets() 102 coef_contrib = misc.spapply( 103 (val,(self._index_rev(col),row)), 104 misc.as_flat(a_adjoint)) 105 # Possible improvement : shift by np.min(self._index_rev(col)) to avoid adding zeros 106 coef[:coef_contrib.shape[0]] += coef_contrib 107 found=True 108 break 109 if not found: 110 raise ValueError(f"ReverseAD error : sensitivity not provided for input value {id(a_sparse)} equal to {a_sparse}") 111 112 return coef[:self.size_ad] 113 114 def _hessian_forward_input_helper(self,args,kwargs,dir): 115 """Replaces Sparse AD information with dense one, based on dir_hessian.""" 116 from . import is_ad 117 corresp = [] 118 def _make_arg(a): 119 nonlocal dir,corresp 120 if is_ad(a): 121 assert isinstance(a,Sparse2.spAD2) 122 a1=Sparse.spAD(a.value,a.coef1,self._index_rev(a.index)) 123 coef = misc.spapply(a1.triplets(),dir[:a1.bound_ad()]) 124 a_value = Dense.denseAD(a1.value, coef.reshape(a.shape+(1,))) 125 corresp.append((a,a_value)) 126 return a_value 127 else: 128 return a 129 def make_arg(a): 130 return misc.map_iterables(_make_arg,a,self.input_iterables) 131 _args = tuple(make_arg(a) for a in args) 132 _kwargs = {key:make_arg(val) for key,val in kwargs.items()} 133 return _args,_kwargs,corresp 134 135 def _hessian_forward_make_dir(self,values,shapes,dir): 136 def f(val,s): 137 nonlocal self,dir 138 if s is not None: 139 start,shape = s 140 assert isinstance(val,Dense.denseAD) and val.size_ad==1 141 assert val.shape==shape 142 sstart = self.size_ad+start 143 dir[sstart:(sstart+val.size)] = val.coef.flatten() 144 misc.map_iterables2(f,values,shapes,self.output_iterables) 145 146 def hessian(self,a): 147 """Returns the hessian operator associated with the scalar spAD2 variable a""" 148 assert(isinstance(a,Sparse2.spAD2) and a.shape==tuple()) 149 def hess_operator(dir_hessian,coef2_init=None,with_grad=False): 150 nonlocal self,a 151 # Forward pass : propagate the hessian direction 152 size_total = self.size_ad+self.size_rev 153 dir_hessian_forwarded = np.zeros(size_total) 154 dir_hessian_forwarded[:self.size_ad] = dir_hessian 155 denseArgs = [] 156 for outputshapes,func,args,kwargs in self._states: 157 # Produce denseAD arguments containing the hessian direction 158 _args,_kwargs,corresp = self._hessian_forward_input_helper(args,kwargs,dir_hessian_forwarded) 159 denseArgs.append((_args,_kwargs,corresp)) 160 # Evaluate the function 161 output = func(*_args,**_kwargs) 162 # Collect the forwarded hessian direction 163 self._hessian_forward_make_dir(output,outputshapes,dir_hessian_forwarded) 164 165 # Reverse pass : evaluate the hessian operator 166 # TODO : avoid the recomputation of the gradient 167 coef1 = Sparse.spAD(a.value,a.coef1,self._index_rev(a.index)).to_dense().coef 168 coef2 = misc.spapply((a.coef2,(self._index_rev(a.index_row),self._index_rev(a.index_col))),dir_hessian_forwarded, crop_rhs=True) 169 if coef1.size<size_total: coef1 = misc._pad_last(coef1,size_total) 170 if coef2.size<size_total: coef2 = misc._pad_last(coef2,size_total) 171 if not(coef2_init is None): coef2 += misc._pad_last(coef2_init,size_total) 172 for (outputshapes,func,_,_),(_args,_kwargs,corresp) in zip(reversed(self._states),reversed(denseArgs)): 173 co_output_value1 = misc._to_shapes(coef1[self.size_ad:],outputshapes,self.output_iterables) 174 co_output_value2 = misc._to_shapes(coef2[self.size_ad:],outputshapes,self.output_iterables) 175 co_arg_request = [a for _,a in corresp] 176 co_args = func(*_args,**_kwargs,co_output=functional.pair(functional.pair(co_output_value1,co_output_value2),co_arg_request)) 177 for a_value,a_adjoint1,a_adjoint2 in co_args: 178 for a_sparse,a_value2 in corresp: 179 if a_value is a_value2: 180 # Linear contribution to the gradient 181 val,(row,col) = a_sparse.to_first().triplets() 182 triplets = (val,(self._index_rev(col),row)) 183 coef1_contrib = misc.spapply(triplets,misc.as_flat(a_adjoint1)) 184 coef1[:coef1_contrib.shape[0]] += coef1_contrib 185 186 # Linear contribution to the hessian 187 linear_contrib = misc.spapply(triplets,misc.as_flat(a_adjoint2)) 188 coef2[:linear_contrib.shape[0]] += linear_contrib 189 190 # Quadratic contribution to the hessian 191 obj = (a_adjoint1*a_sparse).sum() 192 quadratic_contrib = misc.spapply((obj.coef2,(self._index_rev(obj.index_row),self._index_rev(obj.index_col))), 193 dir_hessian_forwarded, crop_rhs=True) 194 coef2[:quadratic_contrib.shape[0]] += quadratic_contrib 195 196 break 197 return (coef1[:self.size_ad],coef2[:self.size_ad]) if with_grad else coef2[:self.size_ad] 198 return hess_operator 199 200 201 def to_inputshapes(self,a): 202 return tuple(misc._to_shapes(a,shape,self.input_iterables) for shape in self._shapes_ad) 203 204 def output(self,a): 205 assert not(self.operator_data is None) 206 if self.operator_data == "PassThrough": 207 return a 208 inputs,((co_output_value1,co_output_value2),_),dir_hessian = self.operator_data 209 _a = misc.sumprod(a,co_output_value1,self.output_iterables) 210 _a2 = misc.sumprod(a,co_output_value2,self.output_iterables,to_first=True) 211 coef2_init = Sparse.spAD(_a2.value,_a2.coef,self._index_rev(_a2.index)).to_dense().coef 212 213 hess = self.hessian(_a) 214 coef1,coef2 = hess(dir_hessian,coef2_init=coef2_init,with_grad=True) 215 216 coef1 = self.to_inputshapes(coef1) 217 coef2 = self.to_inputshapes(coef2) 218 co_arg = [] 219 def f(input): 220 nonlocal co_arg 221 if isndarray(input): 222 assert isinstance(input,Dense.denseAD) and input.size_ad==1 223 l = len(co_arg) 224 co_arg.append( (input,coef1[l],coef2[l]) ) 225 misc.map_iterables(f,inputs,self.input_iterables) 226 return co_arg 227 228# End of class reverseAD2 229 230def empty(inputs=None,**kwargs): 231 rev = reverseAD2(**kwargs) 232 return rev if inputs is None else (rev,rev.register(inputs)) 233 234def operator_like(inputs=None,co_output=None,**kwargs): 235 """ 236 Operator_like reverseAD2 (or Reverse depending on reverse mode): 237 - should not register new inputs (conflicts with the way dir_hessian is provided) 238 - fixed co_output 239 - gets dir_hessian from inputs 240 """ 241 mode = misc.reverse_mode(co_output) 242 if mode == "Forward": 243 return reverseAD2(operator_data="PassThrough",**kwargs),inputs 244 elif mode == "Reverse": 245 from . import Reverse 246 return Reverse.operator_like(inputs,co_output,**kwargs) 247 elif mode=="Reverse2": 248 dir_hessian = tuple() 249 def reg_coef(a): 250 nonlocal dir_hessian 251 if isndarray(a): 252 assert isinstance(a,Dense.denseAD) and a.size_ad==1 253 dir_hessian+=(a.coef.flatten(),) 254 input_iterables = kwargs.get('input_iterables',(tuple,)) 255 misc.map_iterables(reg_coef,inputs,input_iterables) 256 dir_hessian = np.concatenate(dir_hessian) 257 rev = reverseAD2(operator_data=(inputs,co_output,dir_hessian),**kwargs) 258 def reg_value(a): 259 nonlocal rev 260 if isinstance(a,Dense.denseAD): 261 return rev.identity(constant=a.value,operator_initialization=True) 262 else: return a 263 return rev,misc.map_iterables(reg_value,inputs,rev.input_iterables)
class
reverseAD2:
15class reverseAD2(object): 16 """ 17 A class for reverse second order automatic differentiation 18 """ 19 20 def __init__(self,operator_data=None,input_iterables=None,output_iterables=None): 21 self.operator_data=operator_data 22 self.deepcopy_states = False 23 24 self.input_iterables = (tuple,) if input_iterables is None else input_iterables 25 self.output_iterables = (tuple,) if output_iterables is None else output_iterables 26 assert hasattr(self.input_iterables,'__iter__') and hasattr(self.output_iterables,'__iter__') 27 28 self._size_ad = 0 29 self._size_rev = 0 30 self._states = [] 31 self._shapes_ad = tuple() 32 33 @property 34 def size_ad(self): return self._size_ad 35 @property 36 def size_rev(self): return self._size_rev 37 38 # Variable creation 39 def register(self,a): 40 return misc.register(self.identity,a,self.input_iterables) 41 42 def identity(self,*args,**kwargs): 43 """Creates and registers a new AD variable""" 44 assert (self.operator_data is None) or kwargs.pop("operator_initialization",False) 45 result = Sparse2.identity(*args,**kwargs,shift=self.size_ad) 46 self._shapes_ad += (functional.pair(self.size_ad,result.shape),) 47 self._size_ad += result.size 48 return result 49 50 def _identity_rev(self,*args,**kwargs): 51 """Creates and register an AD variable with negative indices, 52 used as placeholders in reverse AD""" 53 result = Sparse2.identity(*args,**kwargs,shift=self.size_rev) 54 self._size_rev += result.size 55 result.index = -result.index-1 56 return result 57 58 def _index_rev(self,index): 59 """Turns the negative placeholder indices into positive ones, 60 for sparse matrix creation.""" 61 index=index.copy() 62 pos = index<0 63 index[pos] = -index[pos]-1+self.size_ad 64 return index 65 66 def apply(self,func,*args,**kwargs): 67 """ 68 Applies a function on the given args, saving adequate data 69 for reverse AD. 70 """ 71 if self.operator_data == "PassThrough": return func(*args,**kwargs) 72 _args,_kwargs,corresp = misc._apply_input_helper(args,kwargs,Sparse2.spAD2,self.input_iterables) 73 if len(corresp)==0: return f(args,kwargs) 74 _output = func(*_args,**_kwargs) 75 output,shapes = misc._apply_output_helper(self,_output,self.output_iterables) 76 self._states.append((shapes,func, 77 copy.deepcopy(args) if self.deepcopy_states else args, 78 copy.deepcopy(kwargs) if self.deepcopy_states else kwargs)) 79 return output 80 81 def apply_linear_mapping(self,matrix,rhs,niter=1): 82 return self.apply(Reverse.linear_mapping_with_adjoint(matrix,niter=niter),rhs) 83 def apply_linear_inverse(self,matrix,solver,rhs,niter=1): 84 return self.apply(Reverse.linear_inverse_with_adjoint(matrix,solver,niter=niter),rhs) 85 def simplify(self,rhs): 86 return self.apply(Reverse.identity_with_adjoint,rhs) 87 88 # Adjoint evaluation pass 89 def gradient(self,a): 90 """Computes the gradient of the scalar spAD2 variable a""" 91 assert(isinstance(a,Sparse2.spAD2) and a.shape==tuple()) 92 coef = Sparse.spAD(a.value,a.coef1,self._index_rev(a.index)).to_dense().coef 93 for outputshapes,func,args,kwargs in reversed(self._states): 94 co_output_value = misc._to_shapes(coef[self.size_ad:],outputshapes,self.output_iterables) 95 _args,_kwargs,corresp = misc._apply_input_helper(args,kwargs,Sparse2.spAD2,self.input_iterables) 96 co_arg_request = [a for _,a in corresp] 97 co_args = func(*_args,**_kwargs,co_output=functional.pair(co_output_value,co_arg_request)) 98 for a_sparse,a_value2 in corresp: 99 found = False 100 for a_value,a_adjoint in co_args: 101 if a_value is a_value2: 102 val,(row,col) = a_sparse.to_first().triplets() 103 coef_contrib = misc.spapply( 104 (val,(self._index_rev(col),row)), 105 misc.as_flat(a_adjoint)) 106 # Possible improvement : shift by np.min(self._index_rev(col)) to avoid adding zeros 107 coef[:coef_contrib.shape[0]] += coef_contrib 108 found=True 109 break 110 if not found: 111 raise ValueError(f"ReverseAD error : sensitivity not provided for input value {id(a_sparse)} equal to {a_sparse}") 112 113 return coef[:self.size_ad] 114 115 def _hessian_forward_input_helper(self,args,kwargs,dir): 116 """Replaces Sparse AD information with dense one, based on dir_hessian.""" 117 from . import is_ad 118 corresp = [] 119 def _make_arg(a): 120 nonlocal dir,corresp 121 if is_ad(a): 122 assert isinstance(a,Sparse2.spAD2) 123 a1=Sparse.spAD(a.value,a.coef1,self._index_rev(a.index)) 124 coef = misc.spapply(a1.triplets(),dir[:a1.bound_ad()]) 125 a_value = Dense.denseAD(a1.value, coef.reshape(a.shape+(1,))) 126 corresp.append((a,a_value)) 127 return a_value 128 else: 129 return a 130 def make_arg(a): 131 return misc.map_iterables(_make_arg,a,self.input_iterables) 132 _args = tuple(make_arg(a) for a in args) 133 _kwargs = {key:make_arg(val) for key,val in kwargs.items()} 134 return _args,_kwargs,corresp 135 136 def _hessian_forward_make_dir(self,values,shapes,dir): 137 def f(val,s): 138 nonlocal self,dir 139 if s is not None: 140 start,shape = s 141 assert isinstance(val,Dense.denseAD) and val.size_ad==1 142 assert val.shape==shape 143 sstart = self.size_ad+start 144 dir[sstart:(sstart+val.size)] = val.coef.flatten() 145 misc.map_iterables2(f,values,shapes,self.output_iterables) 146 147 def hessian(self,a): 148 """Returns the hessian operator associated with the scalar spAD2 variable a""" 149 assert(isinstance(a,Sparse2.spAD2) and a.shape==tuple()) 150 def hess_operator(dir_hessian,coef2_init=None,with_grad=False): 151 nonlocal self,a 152 # Forward pass : propagate the hessian direction 153 size_total = self.size_ad+self.size_rev 154 dir_hessian_forwarded = np.zeros(size_total) 155 dir_hessian_forwarded[:self.size_ad] = dir_hessian 156 denseArgs = [] 157 for outputshapes,func,args,kwargs in self._states: 158 # Produce denseAD arguments containing the hessian direction 159 _args,_kwargs,corresp = self._hessian_forward_input_helper(args,kwargs,dir_hessian_forwarded) 160 denseArgs.append((_args,_kwargs,corresp)) 161 # Evaluate the function 162 output = func(*_args,**_kwargs) 163 # Collect the forwarded hessian direction 164 self._hessian_forward_make_dir(output,outputshapes,dir_hessian_forwarded) 165 166 # Reverse pass : evaluate the hessian operator 167 # TODO : avoid the recomputation of the gradient 168 coef1 = Sparse.spAD(a.value,a.coef1,self._index_rev(a.index)).to_dense().coef 169 coef2 = misc.spapply((a.coef2,(self._index_rev(a.index_row),self._index_rev(a.index_col))),dir_hessian_forwarded, crop_rhs=True) 170 if coef1.size<size_total: coef1 = misc._pad_last(coef1,size_total) 171 if coef2.size<size_total: coef2 = misc._pad_last(coef2,size_total) 172 if not(coef2_init is None): coef2 += misc._pad_last(coef2_init,size_total) 173 for (outputshapes,func,_,_),(_args,_kwargs,corresp) in zip(reversed(self._states),reversed(denseArgs)): 174 co_output_value1 = misc._to_shapes(coef1[self.size_ad:],outputshapes,self.output_iterables) 175 co_output_value2 = misc._to_shapes(coef2[self.size_ad:],outputshapes,self.output_iterables) 176 co_arg_request = [a for _,a in corresp] 177 co_args = func(*_args,**_kwargs,co_output=functional.pair(functional.pair(co_output_value1,co_output_value2),co_arg_request)) 178 for a_value,a_adjoint1,a_adjoint2 in co_args: 179 for a_sparse,a_value2 in corresp: 180 if a_value is a_value2: 181 # Linear contribution to the gradient 182 val,(row,col) = a_sparse.to_first().triplets() 183 triplets = (val,(self._index_rev(col),row)) 184 coef1_contrib = misc.spapply(triplets,misc.as_flat(a_adjoint1)) 185 coef1[:coef1_contrib.shape[0]] += coef1_contrib 186 187 # Linear contribution to the hessian 188 linear_contrib = misc.spapply(triplets,misc.as_flat(a_adjoint2)) 189 coef2[:linear_contrib.shape[0]] += linear_contrib 190 191 # Quadratic contribution to the hessian 192 obj = (a_adjoint1*a_sparse).sum() 193 quadratic_contrib = misc.spapply((obj.coef2,(self._index_rev(obj.index_row),self._index_rev(obj.index_col))), 194 dir_hessian_forwarded, crop_rhs=True) 195 coef2[:quadratic_contrib.shape[0]] += quadratic_contrib 196 197 break 198 return (coef1[:self.size_ad],coef2[:self.size_ad]) if with_grad else coef2[:self.size_ad] 199 return hess_operator 200 201 202 def to_inputshapes(self,a): 203 return tuple(misc._to_shapes(a,shape,self.input_iterables) for shape in self._shapes_ad) 204 205 def output(self,a): 206 assert not(self.operator_data is None) 207 if self.operator_data == "PassThrough": 208 return a 209 inputs,((co_output_value1,co_output_value2),_),dir_hessian = self.operator_data 210 _a = misc.sumprod(a,co_output_value1,self.output_iterables) 211 _a2 = misc.sumprod(a,co_output_value2,self.output_iterables,to_first=True) 212 coef2_init = Sparse.spAD(_a2.value,_a2.coef,self._index_rev(_a2.index)).to_dense().coef 213 214 hess = self.hessian(_a) 215 coef1,coef2 = hess(dir_hessian,coef2_init=coef2_init,with_grad=True) 216 217 coef1 = self.to_inputshapes(coef1) 218 coef2 = self.to_inputshapes(coef2) 219 co_arg = [] 220 def f(input): 221 nonlocal co_arg 222 if isndarray(input): 223 assert isinstance(input,Dense.denseAD) and input.size_ad==1 224 l = len(co_arg) 225 co_arg.append( (input,coef1[l],coef2[l]) ) 226 misc.map_iterables(f,inputs,self.input_iterables) 227 return co_arg
A class for reverse second order automatic differentiation
reverseAD2(operator_data=None, input_iterables=None, output_iterables=None)
20 def __init__(self,operator_data=None,input_iterables=None,output_iterables=None): 21 self.operator_data=operator_data 22 self.deepcopy_states = False 23 24 self.input_iterables = (tuple,) if input_iterables is None else input_iterables 25 self.output_iterables = (tuple,) if output_iterables is None else output_iterables 26 assert hasattr(self.input_iterables,'__iter__') and hasattr(self.output_iterables,'__iter__') 27 28 self._size_ad = 0 29 self._size_rev = 0 30 self._states = [] 31 self._shapes_ad = tuple()
def
identity(self, *args, **kwargs):
42 def identity(self,*args,**kwargs): 43 """Creates and registers a new AD variable""" 44 assert (self.operator_data is None) or kwargs.pop("operator_initialization",False) 45 result = Sparse2.identity(*args,**kwargs,shift=self.size_ad) 46 self._shapes_ad += (functional.pair(self.size_ad,result.shape),) 47 self._size_ad += result.size 48 return result
Creates and registers a new AD variable
def
apply(self, func, *args, **kwargs):
66 def apply(self,func,*args,**kwargs): 67 """ 68 Applies a function on the given args, saving adequate data 69 for reverse AD. 70 """ 71 if self.operator_data == "PassThrough": return func(*args,**kwargs) 72 _args,_kwargs,corresp = misc._apply_input_helper(args,kwargs,Sparse2.spAD2,self.input_iterables) 73 if len(corresp)==0: return f(args,kwargs) 74 _output = func(*_args,**_kwargs) 75 output,shapes = misc._apply_output_helper(self,_output,self.output_iterables) 76 self._states.append((shapes,func, 77 copy.deepcopy(args) if self.deepcopy_states else args, 78 copy.deepcopy(kwargs) if self.deepcopy_states else kwargs)) 79 return output
Applies a function on the given args, saving adequate data for reverse AD.
def
gradient(self, a):
89 def gradient(self,a): 90 """Computes the gradient of the scalar spAD2 variable a""" 91 assert(isinstance(a,Sparse2.spAD2) and a.shape==tuple()) 92 coef = Sparse.spAD(a.value,a.coef1,self._index_rev(a.index)).to_dense().coef 93 for outputshapes,func,args,kwargs in reversed(self._states): 94 co_output_value = misc._to_shapes(coef[self.size_ad:],outputshapes,self.output_iterables) 95 _args,_kwargs,corresp = misc._apply_input_helper(args,kwargs,Sparse2.spAD2,self.input_iterables) 96 co_arg_request = [a for _,a in corresp] 97 co_args = func(*_args,**_kwargs,co_output=functional.pair(co_output_value,co_arg_request)) 98 for a_sparse,a_value2 in corresp: 99 found = False 100 for a_value,a_adjoint in co_args: 101 if a_value is a_value2: 102 val,(row,col) = a_sparse.to_first().triplets() 103 coef_contrib = misc.spapply( 104 (val,(self._index_rev(col),row)), 105 misc.as_flat(a_adjoint)) 106 # Possible improvement : shift by np.min(self._index_rev(col)) to avoid adding zeros 107 coef[:coef_contrib.shape[0]] += coef_contrib 108 found=True 109 break 110 if not found: 111 raise ValueError(f"ReverseAD error : sensitivity not provided for input value {id(a_sparse)} equal to {a_sparse}") 112 113 return coef[:self.size_ad]
Computes the gradient of the scalar spAD2 variable a
def
hessian(self, a):
147 def hessian(self,a): 148 """Returns the hessian operator associated with the scalar spAD2 variable a""" 149 assert(isinstance(a,Sparse2.spAD2) and a.shape==tuple()) 150 def hess_operator(dir_hessian,coef2_init=None,with_grad=False): 151 nonlocal self,a 152 # Forward pass : propagate the hessian direction 153 size_total = self.size_ad+self.size_rev 154 dir_hessian_forwarded = np.zeros(size_total) 155 dir_hessian_forwarded[:self.size_ad] = dir_hessian 156 denseArgs = [] 157 for outputshapes,func,args,kwargs in self._states: 158 # Produce denseAD arguments containing the hessian direction 159 _args,_kwargs,corresp = self._hessian_forward_input_helper(args,kwargs,dir_hessian_forwarded) 160 denseArgs.append((_args,_kwargs,corresp)) 161 # Evaluate the function 162 output = func(*_args,**_kwargs) 163 # Collect the forwarded hessian direction 164 self._hessian_forward_make_dir(output,outputshapes,dir_hessian_forwarded) 165 166 # Reverse pass : evaluate the hessian operator 167 # TODO : avoid the recomputation of the gradient 168 coef1 = Sparse.spAD(a.value,a.coef1,self._index_rev(a.index)).to_dense().coef 169 coef2 = misc.spapply((a.coef2,(self._index_rev(a.index_row),self._index_rev(a.index_col))),dir_hessian_forwarded, crop_rhs=True) 170 if coef1.size<size_total: coef1 = misc._pad_last(coef1,size_total) 171 if coef2.size<size_total: coef2 = misc._pad_last(coef2,size_total) 172 if not(coef2_init is None): coef2 += misc._pad_last(coef2_init,size_total) 173 for (outputshapes,func,_,_),(_args,_kwargs,corresp) in zip(reversed(self._states),reversed(denseArgs)): 174 co_output_value1 = misc._to_shapes(coef1[self.size_ad:],outputshapes,self.output_iterables) 175 co_output_value2 = misc._to_shapes(coef2[self.size_ad:],outputshapes,self.output_iterables) 176 co_arg_request = [a for _,a in corresp] 177 co_args = func(*_args,**_kwargs,co_output=functional.pair(functional.pair(co_output_value1,co_output_value2),co_arg_request)) 178 for a_value,a_adjoint1,a_adjoint2 in co_args: 179 for a_sparse,a_value2 in corresp: 180 if a_value is a_value2: 181 # Linear contribution to the gradient 182 val,(row,col) = a_sparse.to_first().triplets() 183 triplets = (val,(self._index_rev(col),row)) 184 coef1_contrib = misc.spapply(triplets,misc.as_flat(a_adjoint1)) 185 coef1[:coef1_contrib.shape[0]] += coef1_contrib 186 187 # Linear contribution to the hessian 188 linear_contrib = misc.spapply(triplets,misc.as_flat(a_adjoint2)) 189 coef2[:linear_contrib.shape[0]] += linear_contrib 190 191 # Quadratic contribution to the hessian 192 obj = (a_adjoint1*a_sparse).sum() 193 quadratic_contrib = misc.spapply((obj.coef2,(self._index_rev(obj.index_row),self._index_rev(obj.index_col))), 194 dir_hessian_forwarded, crop_rhs=True) 195 coef2[:quadratic_contrib.shape[0]] += quadratic_contrib 196 197 break 198 return (coef1[:self.size_ad],coef2[:self.size_ad]) if with_grad else coef2[:self.size_ad] 199 return hess_operator
Returns the hessian operator associated with the scalar spAD2 variable a
def
output(self, a):
205 def output(self,a): 206 assert not(self.operator_data is None) 207 if self.operator_data == "PassThrough": 208 return a 209 inputs,((co_output_value1,co_output_value2),_),dir_hessian = self.operator_data 210 _a = misc.sumprod(a,co_output_value1,self.output_iterables) 211 _a2 = misc.sumprod(a,co_output_value2,self.output_iterables,to_first=True) 212 coef2_init = Sparse.spAD(_a2.value,_a2.coef,self._index_rev(_a2.index)).to_dense().coef 213 214 hess = self.hessian(_a) 215 coef1,coef2 = hess(dir_hessian,coef2_init=coef2_init,with_grad=True) 216 217 coef1 = self.to_inputshapes(coef1) 218 coef2 = self.to_inputshapes(coef2) 219 co_arg = [] 220 def f(input): 221 nonlocal co_arg 222 if isndarray(input): 223 assert isinstance(input,Dense.denseAD) and input.size_ad==1 224 l = len(co_arg) 225 co_arg.append( (input,coef1[l],coef2[l]) ) 226 misc.map_iterables(f,inputs,self.input_iterables) 227 return co_arg
def
empty(inputs=None, **kwargs):
def
operator_like(inputs=None, co_output=None, **kwargs):
235def operator_like(inputs=None,co_output=None,**kwargs): 236 """ 237 Operator_like reverseAD2 (or Reverse depending on reverse mode): 238 - should not register new inputs (conflicts with the way dir_hessian is provided) 239 - fixed co_output 240 - gets dir_hessian from inputs 241 """ 242 mode = misc.reverse_mode(co_output) 243 if mode == "Forward": 244 return reverseAD2(operator_data="PassThrough",**kwargs),inputs 245 elif mode == "Reverse": 246 from . import Reverse 247 return Reverse.operator_like(inputs,co_output,**kwargs) 248 elif mode=="Reverse2": 249 dir_hessian = tuple() 250 def reg_coef(a): 251 nonlocal dir_hessian 252 if isndarray(a): 253 assert isinstance(a,Dense.denseAD) and a.size_ad==1 254 dir_hessian+=(a.coef.flatten(),) 255 input_iterables = kwargs.get('input_iterables',(tuple,)) 256 misc.map_iterables(reg_coef,inputs,input_iterables) 257 dir_hessian = np.concatenate(dir_hessian) 258 rev = reverseAD2(operator_data=(inputs,co_output,dir_hessian),**kwargs) 259 def reg_value(a): 260 nonlocal rev 261 if isinstance(a,Dense.denseAD): 262 return rev.identity(constant=a.value,operator_initialization=True) 263 else: return a 264 return rev,misc.map_iterables(reg_value,inputs,rev.input_iterables)
Operator_like reverseAD2 (or Reverse depending on reverse mode):
- should not register new inputs (conflicts with the way dir_hessian is provided)
- fixed co_output
- gets dir_hessian from inputs