agd.AutomaticDifferentiation.ad_specific

 1# Copyright 2020 Jean-Marie Mirebeau, University Paris-Sud, CNRS, University Paris-Saclay
 2# Distributed WITHOUT ANY WARRANTY. Licensed under the Apache License, Version 2.0, see http://www.apache.org/licenses/LICENSE-2.0
 3
 4import itertools
 5import numpy as np
 6from . import ad_generic
 7from . import misc
 8from . import Dense
 9from . import Sparse
10from . import Dense2
11from . import Sparse2
12
13def simplify_ad(a,*args,**kwargs):
14	"""
15	Simplifies, if possible, the sparsity pattern of a sparse AD variable.
16	See Sparse.spAD.simplify_ad for detailed help.
17	"""
18	if type(a) in (Sparse.spAD,Sparse2.spAD2): 
19		return a.simplify_ad(*args,**kwargs)
20	return a
21
22def apply(f,*args,**kwargs):
23	"""
24	Applies the function to the given arguments, with special treatment if the following 
25	keywords : 
26	- envelope : take advantage of the envelope theorem, to differentiate a min or max.
27	  The function is called twice, first without AD, then with AD and the oracle parameter.
28	- shape_bound : take advantage of dense-sparse (or dense-dense) AD composition to 
29	 differentiate the function efficiently. The function is called with dense AD, and 
30	 the dimensions in shape_bound are regarded as a simple scalar.
31	- reverse_history : use the provided reverse AD trace.
32	"""
33	envelope,shape_bound,reverse_history = (kwargs.pop(s,None) 
34		for s in ('envelope','shape_bound','reverse_history'))
35	if not any(ad_generic.is_ad(a) for a in itertools.chain(args,kwargs.values())):
36		return f(*args,**kwargs)
37	if envelope:
38		def to_np(a): return a.value if ad_generic.is_ad(a) else a
39		_,oracle = f(*[to_np(a) for a in args],**{key:to_np(val) for key,val in kwargs.items()})
40		result,_ = apply(f,*args,**kwargs,oracle=oracle,envelope=False,shape_bound=shape_bound)
41		return result,oracle
42	if shape_bound is not None:
43		size_factor = np.prod(shape_bound,dtype=int)
44		t = tuple(b.reshape((b.size//size_factor,)+shape_bound) 
45			for b in itertools.chain(args,kwargs.values()) if ad_generic.is_ad(b)) # Tuple containing the original AD vars
46		lens = tuple(len(b) for b in t)
47		def to_dense(b):
48			if not ad_generic.is_ad(b): return b
49			nonlocal i
50			shift = (sum(lens[:i]),sum(lens[(i+1):]))
51			i+=1
52			if type(b) in (Sparse.spAD,Dense.denseAD): 
53				return Dense.identity(constant=b.value,shape_bound=shape_bound,shift=shift)
54			elif type(b) in (Sparse2.spAD2,Dense2.denseAD2):
55				return Dense2.identity(constant=b.value,shape_bound=shape_bound,shift=shift)
56		i=0
57		args2 = [to_dense(b) for b in args]
58		kwargs2 = {key:to_dense(val) for key,val in kwargs.items()}
59		result2 = f(*args2,**kwargs2)
60		return compose(result2,t,shape_bound=shape_bound)
61	if reverse_history:
62		return reverse_history.apply(f,*args,**kwargs)
63	return f(*args,**kwargs)
64
65def compose(a,t,shape_bound):
66	"""Compose ad types, mostly intended for dense a and sparse b"""
67	if not isinstance(t,tuple): t=(t,)
68	if isinstance(a,tuple):
69		return tuple(compose(ai,t,shape_bound) for ai in a)
70	if not(type(a) in (Dense.denseAD,Dense2.denseAD2)) or len(t)==0:
71		return a
72	return type(t[0]).compose(a,t)
def simplify_ad(a, *args, **kwargs):
14def simplify_ad(a,*args,**kwargs):
15	"""
16	Simplifies, if possible, the sparsity pattern of a sparse AD variable.
17	See Sparse.spAD.simplify_ad for detailed help.
18	"""
19	if type(a) in (Sparse.spAD,Sparse2.spAD2): 
20		return a.simplify_ad(*args,**kwargs)
21	return a

Simplifies, if possible, the sparsity pattern of a sparse AD variable. See Sparse.spAD.simplify_ad for detailed help.

def apply(f, *args, **kwargs):
23def apply(f,*args,**kwargs):
24	"""
25	Applies the function to the given arguments, with special treatment if the following 
26	keywords : 
27	- envelope : take advantage of the envelope theorem, to differentiate a min or max.
28	  The function is called twice, first without AD, then with AD and the oracle parameter.
29	- shape_bound : take advantage of dense-sparse (or dense-dense) AD composition to 
30	 differentiate the function efficiently. The function is called with dense AD, and 
31	 the dimensions in shape_bound are regarded as a simple scalar.
32	- reverse_history : use the provided reverse AD trace.
33	"""
34	envelope,shape_bound,reverse_history = (kwargs.pop(s,None) 
35		for s in ('envelope','shape_bound','reverse_history'))
36	if not any(ad_generic.is_ad(a) for a in itertools.chain(args,kwargs.values())):
37		return f(*args,**kwargs)
38	if envelope:
39		def to_np(a): return a.value if ad_generic.is_ad(a) else a
40		_,oracle = f(*[to_np(a) for a in args],**{key:to_np(val) for key,val in kwargs.items()})
41		result,_ = apply(f,*args,**kwargs,oracle=oracle,envelope=False,shape_bound=shape_bound)
42		return result,oracle
43	if shape_bound is not None:
44		size_factor = np.prod(shape_bound,dtype=int)
45		t = tuple(b.reshape((b.size//size_factor,)+shape_bound) 
46			for b in itertools.chain(args,kwargs.values()) if ad_generic.is_ad(b)) # Tuple containing the original AD vars
47		lens = tuple(len(b) for b in t)
48		def to_dense(b):
49			if not ad_generic.is_ad(b): return b
50			nonlocal i
51			shift = (sum(lens[:i]),sum(lens[(i+1):]))
52			i+=1
53			if type(b) in (Sparse.spAD,Dense.denseAD): 
54				return Dense.identity(constant=b.value,shape_bound=shape_bound,shift=shift)
55			elif type(b) in (Sparse2.spAD2,Dense2.denseAD2):
56				return Dense2.identity(constant=b.value,shape_bound=shape_bound,shift=shift)
57		i=0
58		args2 = [to_dense(b) for b in args]
59		kwargs2 = {key:to_dense(val) for key,val in kwargs.items()}
60		result2 = f(*args2,**kwargs2)
61		return compose(result2,t,shape_bound=shape_bound)
62	if reverse_history:
63		return reverse_history.apply(f,*args,**kwargs)
64	return f(*args,**kwargs)

Applies the function to the given arguments, with special treatment if the following keywords :

  • envelope : take advantage of the envelope theorem, to differentiate a min or max. The function is called twice, first without AD, then with AD and the oracle parameter.
  • shape_bound : take advantage of dense-sparse (or dense-dense) AD composition to differentiate the function efficiently. The function is called with dense AD, and the dimensions in shape_bound are regarded as a simple scalar.
  • reverse_history : use the provided reverse AD trace.
def compose(a, t, shape_bound):
66def compose(a,t,shape_bound):
67	"""Compose ad types, mostly intended for dense a and sparse b"""
68	if not isinstance(t,tuple): t=(t,)
69	if isinstance(a,tuple):
70		return tuple(compose(ai,t,shape_bound) for ai in a)
71	if not(type(a) in (Dense.denseAD,Dense2.denseAD2)) or len(t)==0:
72		return a
73	return type(t[0]).compose(a,t)

Compose ad types, mostly intended for dense a and sparse b