agd.AutomaticDifferentiation.Reverse
1# Copyright 2020 Jean-Marie Mirebeau, University Paris-Sud, CNRS, University Paris-Saclay 2# Distributed WITHOUT ANY WARRANTY. Licensed under the Apache License, Version 2.0, see http://www.apache.org/licenses/LICENSE-2.0 3 4import numpy as np 5import copy 6from . import functional 7from . import misc 8from . import Sparse 9 10class reverseAD(object): 11 """ 12 A class for reverse first order automatic differentiation. 13 14 Fields : 15 - input_iterables : tuple, subset of {tuple,list,dict,set}. 16 Which input structures should be explored when looking for AD information 17 - output_iterables : tuple subset of (tuple,list,dict). 18 Which output structures should be explored looking for AD information 19 """ 20 21 def __init__(self,operator_data=None,input_iterables=None,output_iterables=None): 22 self.operator_data=operator_data 23 self.deepcopy_states = False 24 25 self.input_iterables = (tuple,) if input_iterables is None else input_iterables 26 self.output_iterables = (tuple,) if output_iterables is None else output_iterables 27 assert hasattr(self.input_iterables,'__iter__') and hasattr(self.output_iterables,'__iter__') 28 29 self._size_ad = 0 30 self._size_rev = 0 31 self._states = [] 32 self._shapes_ad = tuple() 33 34 @property 35 def size_ad(self): return self._size_ad 36 @property 37 def size_rev(self): return self._size_rev 38 39 # Variable creation 40 def register(self,a): 41 return misc.register(self.identity,a,self.input_iterables) 42 43 def identity(self,*args,**kwargs): 44 """Creates and register a new AD variable""" 45 result = Sparse.identity(*args,**kwargs,shift=self.size_ad) 46 self._shapes_ad += (functional.pair(self.size_ad,result.shape),) 47 self._size_ad += result.size 48 return result 49 50 def _identity_rev(self,*args,**kwargs): 51 """Creates and register an AD variable with negative indices, 52 used as placeholders in reverse AD""" 53 result = Sparse.identity(*args,**kwargs,shift=self.size_rev) 54 self._size_rev += result.size 55 result.index = -result.index-1 56 return result 57 58 def _index_rev(self,index): 59 """Turns the negative placeholder indices into positive ones, 60 for sparse matrix creation.""" 61 index=index.copy() 62 pos = index<0 63 index[pos] = -index[pos]-1+self.size_ad 64 return index 65 66 # Applying a function 67 def apply(self,func,*args,**kwargs): 68 """ 69 Applies a function on the given args, saving adequate data 70 for reverse AD. 71 """ 72 if self.operator_data == "PassThrough": return func(*args,**kwargs) 73 _args,_kwargs,corresp = misc._apply_input_helper(args,kwargs,Sparse.spAD,self.input_iterables) 74 if len(corresp)==0: return func(*args,**kwargs) 75 _output = func(*_args,**_kwargs) 76 output,shapes = misc._apply_output_helper(self,_output,self.output_iterables) 77 self._states.append((shapes,func, 78 copy.deepcopy(args) if self.deepcopy_states else args, 79 copy.deepcopy(kwargs) if self.deepcopy_states else kwargs)) 80 return output 81 82 def apply_linear_mapping(self,matrix,rhs,niter=1): 83 return self.apply(linear_mapping_with_adjoint(matrix,niter=niter),rhs) 84 def apply_linear_inverse(self,solver,matrix,rhs,niter=1): 85 return self.apply(linear_inverse_with_adjoint(solver,matrix,niter=niter),rhs) 86 def simplify(self,rhs): 87 return self.apply(identity_with_adjoint,rhs) 88 89 def iterate(self,func,var,*args,**kwargs): 90 """ 91 Input: function, variable to be updated, niter, nrec, optional args 92 Iterates a function, saving adequate data for reverse AD. 93 If nrec>0, a recursive strategy is used to limit the amount of data saved. 94 """ 95 niter = kwargs.pop('niter') 96 nrec = 0 if niter<=1 else kwargs.pop('nrec',0) 97 assert nrec>=0 98 if nrec==0: 99 for i in range(niter): 100 var = self.apply(func, 101 var if self.deepcopy_states else copy.deepcopy(var), 102 *args,**kwargs) 103 return var 104 else: 105 assert False #TODO. See ODE.RecurseRewind for the strategy. 106 """ 107 def recursive_iterate(): 108 other = reverseAD() 109 return other.iterate(func, 110 niter_top = int(np.ceil(niter**(1./(1+nrec)))) 111 for rec_iter in (niter//niter_top,)*niter_top + (niter%niter_top,) 112 113 var = self.apply(recursive_iterate,var,*args,**kwargs,niter=rec_iter,nrec=nrec-1) 114 115 for 116 """ 117 118 119 # Adjoint evaluation pass 120 121 def to_inputshapes(self,a): 122 return tuple(misc._to_shapes(a,shape,self.input_iterables) for shape in self._shapes_ad) 123 124 def gradient(self,a): 125 """Computes the gradient of the scalar spAD variable a""" 126 assert(isinstance(a,Sparse.spAD) and a.shape==tuple()) 127 coef = Sparse.spAD(a.value,a.coef,self._index_rev(a.index)).to_dense().coef 128 size_total = self.size_ad+self.size_rev 129 if coef.size<size_total: coef = misc._pad_last(coef,size_total) 130 for outputshapes,func,args,kwargs in reversed(self._states): 131 co_output_value = misc._to_shapes(coef[self.size_ad:],outputshapes,self.output_iterables) 132 _args,_kwargs,corresp = misc._apply_input_helper(args,kwargs,Sparse.spAD,self.input_iterables) 133 co_arg_request = [a for _,a in corresp] 134 co_args = func(*_args,**_kwargs,co_output=functional.pair(co_output_value,co_arg_request)) 135 for a_sparse,a_value2 in corresp: 136 found=False 137 for a_value,a_adjoint in co_args: 138 if a_value is a_value2: 139 val,(row,col) = a_sparse.triplets() 140 coef_contrib = misc.spapply( 141 (val,(self._index_rev(col),row)), 142 misc.as_flat(a_adjoint)) 143 # Possible improvement : shift by np.min(self._index_rev(col)) to avoid adding zeros 144 coef[:coef_contrib.shape[0]] += coef_contrib 145 found=True 146 break 147 if not found: 148 raise ValueError(f"ReverseAD error : sensitivity not provided for input value {id(a_sparse)} equal to {a_sparse}") 149 return coef[:self.size_ad] 150 151 def output(self,a): 152 """Computes the gradient of the output a, times the co_state, for an operator_like reverseAD""" 153 assert not(self.operator_data is None) 154 if self.operator_data == "PassThrough": 155 return a 156 inputs,(co_output_value,_) = self.operator_data 157 grad = self.gradient(misc.sumprod(a,co_output_value,self.output_iterables)) 158 grad = self.to_inputshapes(grad) 159 co_arg=[] 160 def f(input): 161 nonlocal co_arg 162 input,to_ad = misc.ready_ad(input) 163 if to_ad: 164 co_arg.append( (input,grad[len(co_arg)]) ) 165 misc.map_iterables(f,inputs,self.input_iterables) 166 return co_arg 167# return [(x,y) for (x,y) in zip(inputs,self.to_inputshapes(grad))] 168 169 170 171# End of class reverseAD 172 173def empty(inputs=None,**kwargs): 174 rev = reverseAD(**kwargs) 175 return rev if inputs is None else (rev,rev.register(inputs)) 176 177# Elementary operators with adjoints 178 179def operator_like(inputs=None,co_output=None,**kwargs): 180 """ 181 Operator_like reverseAD (or reverseAD2 depending on co_output): 182 - has a fixed co_output 183 """ 184 mode = misc.reverse_mode(co_output) 185 if mode == "Forward": 186 return reverseAD(operator_data="PassThrough",**kwargs),inputs 187 elif mode == "Reverse": 188 rev = reverseAD(operator_data=(inputs,co_output),**kwargs) 189 return rev,rev.register(inputs) 190 elif mode == "Reverse2": 191 from . import Reverse2 192 return Reverse2.operator_like(inputs,co_output,**kwargs) 193 194def linear_inverse_with_adjoint(solver,matrix,niter=1): 195 from . import apply_linear_inverse 196 def operator(x): return apply_linear_inverse(solver,matrix, x,niter=niter) 197 def adjoint(x): return apply_linear_inverse(solver,matrix.T,x,niter=niter) 198 def method(u,co_output=None): 199 mode = misc.reverse_mode(co_output) 200 if mode == "Forward": return operator(u) 201 elif mode == "Reverse": c,_ = co_output; return [(u,adjoint(c))] 202 elif mode == "Reverse2":(c1,c2),_ = co_output; return [(u,adjoint(c1),adjoint(c2))] 203 return method 204 205def linear_mapping_with_adjoint(matrix,niter=1): 206 from . import apply_linear_mapping 207 def operator(x): return apply_linear_mapping(matrix, x,niter=niter) 208 def adjoint(x): return apply_linear_mapping(matrix.T,x,niter=niter) 209 def method(u,co_output=None): 210 mode = misc.reverse_mode(co_output) 211 if mode == "Forward": return operator(u) 212 elif mode == "Reverse": c,_ = co_output; return [(u,adjoint(c))] 213 elif mode == "Reverse2":(c1,c2),_ = co_output; return [(u,adjoint(c1),adjoint(c2))] 214 return method 215 216def identity_with_adjoint(u,co_output=None): 217 mode = misc.reverse_mode(co_output) 218 if mode == "Forward": return u 219 elif mode == "Reverse": c,_ = co_output; return [(u,c)] 220 elif mode == "Reverse2":(c1,c2),_ = co_output; return [(u,c1,c2)]
class
reverseAD:
11class reverseAD(object): 12 """ 13 A class for reverse first order automatic differentiation. 14 15 Fields : 16 - input_iterables : tuple, subset of {tuple,list,dict,set}. 17 Which input structures should be explored when looking for AD information 18 - output_iterables : tuple subset of (tuple,list,dict). 19 Which output structures should be explored looking for AD information 20 """ 21 22 def __init__(self,operator_data=None,input_iterables=None,output_iterables=None): 23 self.operator_data=operator_data 24 self.deepcopy_states = False 25 26 self.input_iterables = (tuple,) if input_iterables is None else input_iterables 27 self.output_iterables = (tuple,) if output_iterables is None else output_iterables 28 assert hasattr(self.input_iterables,'__iter__') and hasattr(self.output_iterables,'__iter__') 29 30 self._size_ad = 0 31 self._size_rev = 0 32 self._states = [] 33 self._shapes_ad = tuple() 34 35 @property 36 def size_ad(self): return self._size_ad 37 @property 38 def size_rev(self): return self._size_rev 39 40 # Variable creation 41 def register(self,a): 42 return misc.register(self.identity,a,self.input_iterables) 43 44 def identity(self,*args,**kwargs): 45 """Creates and register a new AD variable""" 46 result = Sparse.identity(*args,**kwargs,shift=self.size_ad) 47 self._shapes_ad += (functional.pair(self.size_ad,result.shape),) 48 self._size_ad += result.size 49 return result 50 51 def _identity_rev(self,*args,**kwargs): 52 """Creates and register an AD variable with negative indices, 53 used as placeholders in reverse AD""" 54 result = Sparse.identity(*args,**kwargs,shift=self.size_rev) 55 self._size_rev += result.size 56 result.index = -result.index-1 57 return result 58 59 def _index_rev(self,index): 60 """Turns the negative placeholder indices into positive ones, 61 for sparse matrix creation.""" 62 index=index.copy() 63 pos = index<0 64 index[pos] = -index[pos]-1+self.size_ad 65 return index 66 67 # Applying a function 68 def apply(self,func,*args,**kwargs): 69 """ 70 Applies a function on the given args, saving adequate data 71 for reverse AD. 72 """ 73 if self.operator_data == "PassThrough": return func(*args,**kwargs) 74 _args,_kwargs,corresp = misc._apply_input_helper(args,kwargs,Sparse.spAD,self.input_iterables) 75 if len(corresp)==0: return func(*args,**kwargs) 76 _output = func(*_args,**_kwargs) 77 output,shapes = misc._apply_output_helper(self,_output,self.output_iterables) 78 self._states.append((shapes,func, 79 copy.deepcopy(args) if self.deepcopy_states else args, 80 copy.deepcopy(kwargs) if self.deepcopy_states else kwargs)) 81 return output 82 83 def apply_linear_mapping(self,matrix,rhs,niter=1): 84 return self.apply(linear_mapping_with_adjoint(matrix,niter=niter),rhs) 85 def apply_linear_inverse(self,solver,matrix,rhs,niter=1): 86 return self.apply(linear_inverse_with_adjoint(solver,matrix,niter=niter),rhs) 87 def simplify(self,rhs): 88 return self.apply(identity_with_adjoint,rhs) 89 90 def iterate(self,func,var,*args,**kwargs): 91 """ 92 Input: function, variable to be updated, niter, nrec, optional args 93 Iterates a function, saving adequate data for reverse AD. 94 If nrec>0, a recursive strategy is used to limit the amount of data saved. 95 """ 96 niter = kwargs.pop('niter') 97 nrec = 0 if niter<=1 else kwargs.pop('nrec',0) 98 assert nrec>=0 99 if nrec==0: 100 for i in range(niter): 101 var = self.apply(func, 102 var if self.deepcopy_states else copy.deepcopy(var), 103 *args,**kwargs) 104 return var 105 else: 106 assert False #TODO. See ODE.RecurseRewind for the strategy. 107 """ 108 def recursive_iterate(): 109 other = reverseAD() 110 return other.iterate(func, 111 niter_top = int(np.ceil(niter**(1./(1+nrec)))) 112 for rec_iter in (niter//niter_top,)*niter_top + (niter%niter_top,) 113 114 var = self.apply(recursive_iterate,var,*args,**kwargs,niter=rec_iter,nrec=nrec-1) 115 116 for 117 """ 118 119 120 # Adjoint evaluation pass 121 122 def to_inputshapes(self,a): 123 return tuple(misc._to_shapes(a,shape,self.input_iterables) for shape in self._shapes_ad) 124 125 def gradient(self,a): 126 """Computes the gradient of the scalar spAD variable a""" 127 assert(isinstance(a,Sparse.spAD) and a.shape==tuple()) 128 coef = Sparse.spAD(a.value,a.coef,self._index_rev(a.index)).to_dense().coef 129 size_total = self.size_ad+self.size_rev 130 if coef.size<size_total: coef = misc._pad_last(coef,size_total) 131 for outputshapes,func,args,kwargs in reversed(self._states): 132 co_output_value = misc._to_shapes(coef[self.size_ad:],outputshapes,self.output_iterables) 133 _args,_kwargs,corresp = misc._apply_input_helper(args,kwargs,Sparse.spAD,self.input_iterables) 134 co_arg_request = [a for _,a in corresp] 135 co_args = func(*_args,**_kwargs,co_output=functional.pair(co_output_value,co_arg_request)) 136 for a_sparse,a_value2 in corresp: 137 found=False 138 for a_value,a_adjoint in co_args: 139 if a_value is a_value2: 140 val,(row,col) = a_sparse.triplets() 141 coef_contrib = misc.spapply( 142 (val,(self._index_rev(col),row)), 143 misc.as_flat(a_adjoint)) 144 # Possible improvement : shift by np.min(self._index_rev(col)) to avoid adding zeros 145 coef[:coef_contrib.shape[0]] += coef_contrib 146 found=True 147 break 148 if not found: 149 raise ValueError(f"ReverseAD error : sensitivity not provided for input value {id(a_sparse)} equal to {a_sparse}") 150 return coef[:self.size_ad] 151 152 def output(self,a): 153 """Computes the gradient of the output a, times the co_state, for an operator_like reverseAD""" 154 assert not(self.operator_data is None) 155 if self.operator_data == "PassThrough": 156 return a 157 inputs,(co_output_value,_) = self.operator_data 158 grad = self.gradient(misc.sumprod(a,co_output_value,self.output_iterables)) 159 grad = self.to_inputshapes(grad) 160 co_arg=[] 161 def f(input): 162 nonlocal co_arg 163 input,to_ad = misc.ready_ad(input) 164 if to_ad: 165 co_arg.append( (input,grad[len(co_arg)]) ) 166 misc.map_iterables(f,inputs,self.input_iterables) 167 return co_arg
A class for reverse first order automatic differentiation.
Fields :
- input_iterables : tuple, subset of {tuple,list,dict,set}. Which input structures should be explored when looking for AD information
- output_iterables : tuple subset of (tuple,list,dict). Which output structures should be explored looking for AD information
reverseAD(operator_data=None, input_iterables=None, output_iterables=None)
22 def __init__(self,operator_data=None,input_iterables=None,output_iterables=None): 23 self.operator_data=operator_data 24 self.deepcopy_states = False 25 26 self.input_iterables = (tuple,) if input_iterables is None else input_iterables 27 self.output_iterables = (tuple,) if output_iterables is None else output_iterables 28 assert hasattr(self.input_iterables,'__iter__') and hasattr(self.output_iterables,'__iter__') 29 30 self._size_ad = 0 31 self._size_rev = 0 32 self._states = [] 33 self._shapes_ad = tuple()
def
identity(self, *args, **kwargs):
44 def identity(self,*args,**kwargs): 45 """Creates and register a new AD variable""" 46 result = Sparse.identity(*args,**kwargs,shift=self.size_ad) 47 self._shapes_ad += (functional.pair(self.size_ad,result.shape),) 48 self._size_ad += result.size 49 return result
Creates and register a new AD variable
def
apply(self, func, *args, **kwargs):
68 def apply(self,func,*args,**kwargs): 69 """ 70 Applies a function on the given args, saving adequate data 71 for reverse AD. 72 """ 73 if self.operator_data == "PassThrough": return func(*args,**kwargs) 74 _args,_kwargs,corresp = misc._apply_input_helper(args,kwargs,Sparse.spAD,self.input_iterables) 75 if len(corresp)==0: return func(*args,**kwargs) 76 _output = func(*_args,**_kwargs) 77 output,shapes = misc._apply_output_helper(self,_output,self.output_iterables) 78 self._states.append((shapes,func, 79 copy.deepcopy(args) if self.deepcopy_states else args, 80 copy.deepcopy(kwargs) if self.deepcopy_states else kwargs)) 81 return output
Applies a function on the given args, saving adequate data for reverse AD.
def
iterate(self, func, var, *args, **kwargs):
90 def iterate(self,func,var,*args,**kwargs): 91 """ 92 Input: function, variable to be updated, niter, nrec, optional args 93 Iterates a function, saving adequate data for reverse AD. 94 If nrec>0, a recursive strategy is used to limit the amount of data saved. 95 """ 96 niter = kwargs.pop('niter') 97 nrec = 0 if niter<=1 else kwargs.pop('nrec',0) 98 assert nrec>=0 99 if nrec==0: 100 for i in range(niter): 101 var = self.apply(func, 102 var if self.deepcopy_states else copy.deepcopy(var), 103 *args,**kwargs) 104 return var 105 else: 106 assert False #TODO. See ODE.RecurseRewind for the strategy. 107 """ 108 def recursive_iterate(): 109 other = reverseAD() 110 return other.iterate(func, 111 niter_top = int(np.ceil(niter**(1./(1+nrec)))) 112 for rec_iter in (niter//niter_top,)*niter_top + (niter%niter_top,) 113 114 var = self.apply(recursive_iterate,var,*args,**kwargs,niter=rec_iter,nrec=nrec-1) 115 116 for 117 """
Input: function, variable to be updated, niter, nrec, optional args Iterates a function, saving adequate data for reverse AD. If nrec>0, a recursive strategy is used to limit the amount of data saved.
def
gradient(self, a):
125 def gradient(self,a): 126 """Computes the gradient of the scalar spAD variable a""" 127 assert(isinstance(a,Sparse.spAD) and a.shape==tuple()) 128 coef = Sparse.spAD(a.value,a.coef,self._index_rev(a.index)).to_dense().coef 129 size_total = self.size_ad+self.size_rev 130 if coef.size<size_total: coef = misc._pad_last(coef,size_total) 131 for outputshapes,func,args,kwargs in reversed(self._states): 132 co_output_value = misc._to_shapes(coef[self.size_ad:],outputshapes,self.output_iterables) 133 _args,_kwargs,corresp = misc._apply_input_helper(args,kwargs,Sparse.spAD,self.input_iterables) 134 co_arg_request = [a for _,a in corresp] 135 co_args = func(*_args,**_kwargs,co_output=functional.pair(co_output_value,co_arg_request)) 136 for a_sparse,a_value2 in corresp: 137 found=False 138 for a_value,a_adjoint in co_args: 139 if a_value is a_value2: 140 val,(row,col) = a_sparse.triplets() 141 coef_contrib = misc.spapply( 142 (val,(self._index_rev(col),row)), 143 misc.as_flat(a_adjoint)) 144 # Possible improvement : shift by np.min(self._index_rev(col)) to avoid adding zeros 145 coef[:coef_contrib.shape[0]] += coef_contrib 146 found=True 147 break 148 if not found: 149 raise ValueError(f"ReverseAD error : sensitivity not provided for input value {id(a_sparse)} equal to {a_sparse}") 150 return coef[:self.size_ad]
Computes the gradient of the scalar spAD variable a
def
output(self, a):
152 def output(self,a): 153 """Computes the gradient of the output a, times the co_state, for an operator_like reverseAD""" 154 assert not(self.operator_data is None) 155 if self.operator_data == "PassThrough": 156 return a 157 inputs,(co_output_value,_) = self.operator_data 158 grad = self.gradient(misc.sumprod(a,co_output_value,self.output_iterables)) 159 grad = self.to_inputshapes(grad) 160 co_arg=[] 161 def f(input): 162 nonlocal co_arg 163 input,to_ad = misc.ready_ad(input) 164 if to_ad: 165 co_arg.append( (input,grad[len(co_arg)]) ) 166 misc.map_iterables(f,inputs,self.input_iterables) 167 return co_arg
Computes the gradient of the output a, times the co_state, for an operator_like reverseAD
def
empty(inputs=None, **kwargs):
def
operator_like(inputs=None, co_output=None, **kwargs):
180def operator_like(inputs=None,co_output=None,**kwargs): 181 """ 182 Operator_like reverseAD (or reverseAD2 depending on co_output): 183 - has a fixed co_output 184 """ 185 mode = misc.reverse_mode(co_output) 186 if mode == "Forward": 187 return reverseAD(operator_data="PassThrough",**kwargs),inputs 188 elif mode == "Reverse": 189 rev = reverseAD(operator_data=(inputs,co_output),**kwargs) 190 return rev,rev.register(inputs) 191 elif mode == "Reverse2": 192 from . import Reverse2 193 return Reverse2.operator_like(inputs,co_output,**kwargs)
Operator_like reverseAD (or reverseAD2 depending on co_output):
- has a fixed co_output
def
linear_inverse_with_adjoint(solver, matrix, niter=1):
195def linear_inverse_with_adjoint(solver,matrix,niter=1): 196 from . import apply_linear_inverse 197 def operator(x): return apply_linear_inverse(solver,matrix, x,niter=niter) 198 def adjoint(x): return apply_linear_inverse(solver,matrix.T,x,niter=niter) 199 def method(u,co_output=None): 200 mode = misc.reverse_mode(co_output) 201 if mode == "Forward": return operator(u) 202 elif mode == "Reverse": c,_ = co_output; return [(u,adjoint(c))] 203 elif mode == "Reverse2":(c1,c2),_ = co_output; return [(u,adjoint(c1),adjoint(c2))] 204 return method
def
linear_mapping_with_adjoint(matrix, niter=1):
206def linear_mapping_with_adjoint(matrix,niter=1): 207 from . import apply_linear_mapping 208 def operator(x): return apply_linear_mapping(matrix, x,niter=niter) 209 def adjoint(x): return apply_linear_mapping(matrix.T,x,niter=niter) 210 def method(u,co_output=None): 211 mode = misc.reverse_mode(co_output) 212 if mode == "Forward": return operator(u) 213 elif mode == "Reverse": c,_ = co_output; return [(u,adjoint(c))] 214 elif mode == "Reverse2":(c1,c2),_ = co_output; return [(u,adjoint(c1),adjoint(c2))] 215 return method
def
identity_with_adjoint(u, co_output=None):