agd.AutomaticDifferentiation.ad_generic
1# Copyright 2020 Jean-Marie Mirebeau, University Paris-Sud, CNRS, University Paris-Saclay 2# Distributed WITHOUT ANY WARRANTY. Licensed under the Apache License, Version 2.0, see http://www.apache.org/licenses/LICENSE-2.0 3 4import itertools 5import numpy as np 6import functools 7 8from . import functional 9from .Base import is_ad,isndarray,array,asarray,_cp_ndarray 10 11""" 12This file implements functions which apply indifferently to several AD types. 13""" 14 15def adtype(data,iterables=tuple()): 16 """ 17 Returns None if no ad variable found, or the adtype if one is found. 18 Also checks consistency of the ad types. 19 """ 20 result = None 21 for value in rec_iter(data,iterables): 22 t=type(x) 23 if is_ad(t): 24 if result is None: result=t 25 else: assert result==t 26 return result 27 28def precision(x): 29 """ 30 Precision of the floating point type of x. 31 """ 32 if not isinstance(x,type): x = array(x).dtype.type 33 return np.finfo(x).precision 34 35def remove_ad(data,iterables=tuple()): 36 def f(a): return a.value if is_ad(a) else a 37 return functional.map_iterables(f,data,iterables) 38 39def as_writable(a): 40 """ 41 Returns a writable array containing the same elements as a. 42 If the array a, or a field of a for an AD type, is flagged as 43 non-writable, then it is copied. 44 """ 45 if isinstance(a,(np.ndarray,_cp_ndarray)): 46 return a if a.flags['WRITEABLE'] else a.copy() 47 return a.new(*tuple(as_writable(e) for e in a.as_tuple() )) 48 49def common_cast(*args): 50 """ 51 If any of the arguments is an AD type, casts all other arguments to that type. 52 Casts to ndarray if no argument is an AD type. 53 Usage : if a and b may or may not b AD arrays, 54 a,b = common_cast(a,b); a[0]=b[0] 55 """ 56 args = tuple(array(x) for x in args) 57 common_type = None 58 for x in args: 59 if is_ad(x): 60 if common_type is None: 61 common_type = type(x) 62 if not isinstance(x,common_type): 63 raise ValueError("Error : several distinct AD types") 64 return args if common_type is None else tuple(common_type(x) for x in args) 65 66 67def min_argmin(array,axis=None): 68 if axis is None: return min_argmin(array.reshape(-1),axis=0) 69 ai = np.argmin(array,axis=axis) 70 return np.squeeze(np.take_along_axis(array,np.expand_dims(ai, 71 axis=axis),axis=axis),axis=axis),ai 72 73def max_argmax(array,axis=None): 74 if axis is None: return max_argmax(array.reshape(-1),axis=0) 75 ai = np.argmax(array,axis=axis) 76 return np.squeeze(np.take_along_axis(array,np.expand_dims(ai, 77 axis=axis),axis=axis),axis=axis),ai 78 79# ------- Linear operators ------ 80 81 82def apply_linear_mapping(matrix,rhs,niter=1): 83 """ 84 Applies the provided linear operator, to a dense AD variable of first or second order. 85 """ 86 def step(x): return np.dot(matrix,x) if isinstance(matrix,np.ndarray) else (matrix*x) 87 operator = functional.recurse(step,niter) 88 return rhs.apply_linear_operator(operator) if is_ad(rhs) else operator(rhs) 89 90def apply_linear_inverse(solver,matrix,rhs,niter=1): 91 """ 92 Applies the provided linear inverse to a dense AD variable of first or second order. 93 """ 94 def step(x): return solver(matrix,x) 95 operator = functional.recurse(step,niter) 96 return rhs.apply_linear_operator(operator) if is_ad(rhs) else operator(rhs) 97 98# ------- Shape manipulation ------- 99 100def squeeze_shape(shape,axis): 101 if axis is None: 102 return shape 103 assert shape[axis]==1 104 if axis==-1: 105 return shape[:-1] 106 else: 107 return shape[:axis]+shape[(axis+1):] 108 109def expand_shape(shape,axis): 110 if axis is None: 111 return shape 112 if axis==-1: 113 return shape+(1,) 114 if axis<0: 115 axis+=1 116 return shape[:axis]+(1,)+shape[axis:] 117 118def _set_shape_free_bound(shape,shape_free,shape_bound): 119 if shape_free is not None: 120 assert shape_free==shape[0:len(shape_free)] 121 if shape_bound is None: 122 shape_bound=shape[len(shape_free):] 123 else: 124 assert shape_bound==shape[len(shape_free):] 125 if shape_bound is None: 126 shape_bound = tuple() 127 assert len(shape_bound)==0 or shape_bound==shape[-len(shape_bound):] 128 if shape_free is None: 129 if len(shape_bound)==0: 130 shape_free = shape 131 else: 132 shape_free = shape[:len(shape)-len(shape_bound)] 133 return shape_free,shape_bound 134 135def disassociate(array,shape_free=None,shape_bound=None, 136 expand_free_dims=-1,expand_bound_dims=-1): 137 """ 138 Turns an array of shape shape_free + shape_bound 139 into an array of shape shape_free whose elements 140 are arrays of shape shape_bound. 141 Typical usage : recursive automatic differentiation. 142 Caveat : by defaut, singleton dimensions are introduced 143 to avoid numpy's "clever" treatment of scalar arrays. 144 145 Arguments: 146 - array : reshaped array 147 - (optional) shape_free, shape_bound : outer and inner array shapes. One is deduced from the other. 148 - (optional) expand_free_dims, expand_bound_dims. 149 """ 150 shape_free,shape_bound = _set_shape_free_bound(array.shape,shape_free,shape_bound) 151 shape_free = expand_shape(shape_free, expand_free_dims) 152 shape_bound = expand_shape(shape_bound,expand_bound_dims) 153 154 size_free = np.prod(shape_free) 155 array = array.reshape((size_free,)+shape_bound) 156 result = np.zeros(size_free,object) 157 for i in range(size_free): result[i] = array[i] 158 return result.reshape(shape_free) 159 160def associate(array,squeeze_free_dims=-1,squeeze_bound_dims=-1): 161 """ 162 Turns an array of shape shape_free, whose elements 163 are arrays of shape shape_bound, into an array 164 of shape shape_free+shape_bound. 165 Inverse opeation to disassociate. 166 """ 167 if is_ad(array): 168 return array.associate(squeeze_free_dims,squeeze_bound_dims) 169 result = np.stack(array.reshape(-1),axis=0) 170 shape_free = squeeze_shape(array.shape,squeeze_free_dims) 171 shape_bound = squeeze_shape(result.shape[1:],squeeze_bound_dims) 172 return result.reshape(shape_free+shape_bound)
16def adtype(data,iterables=tuple()): 17 """ 18 Returns None if no ad variable found, or the adtype if one is found. 19 Also checks consistency of the ad types. 20 """ 21 result = None 22 for value in rec_iter(data,iterables): 23 t=type(x) 24 if is_ad(t): 25 if result is None: result=t 26 else: assert result==t 27 return result
Returns None if no ad variable found, or the adtype if one is found. Also checks consistency of the ad types.
29def precision(x): 30 """ 31 Precision of the floating point type of x. 32 """ 33 if not isinstance(x,type): x = array(x).dtype.type 34 return np.finfo(x).precision
Precision of the floating point type of x.
40def as_writable(a): 41 """ 42 Returns a writable array containing the same elements as a. 43 If the array a, or a field of a for an AD type, is flagged as 44 non-writable, then it is copied. 45 """ 46 if isinstance(a,(np.ndarray,_cp_ndarray)): 47 return a if a.flags['WRITEABLE'] else a.copy() 48 return a.new(*tuple(as_writable(e) for e in a.as_tuple() ))
Returns a writable array containing the same elements as a. If the array a, or a field of a for an AD type, is flagged as non-writable, then it is copied.
50def common_cast(*args): 51 """ 52 If any of the arguments is an AD type, casts all other arguments to that type. 53 Casts to ndarray if no argument is an AD type. 54 Usage : if a and b may or may not b AD arrays, 55 a,b = common_cast(a,b); a[0]=b[0] 56 """ 57 args = tuple(array(x) for x in args) 58 common_type = None 59 for x in args: 60 if is_ad(x): 61 if common_type is None: 62 common_type = type(x) 63 if not isinstance(x,common_type): 64 raise ValueError("Error : several distinct AD types") 65 return args if common_type is None else tuple(common_type(x) for x in args)
If any of the arguments is an AD type, casts all other arguments to that type. Casts to ndarray if no argument is an AD type. Usage : if a and b may or may not b AD arrays, a,b = common_cast(a,b); a[0]=b[0]
83def apply_linear_mapping(matrix,rhs,niter=1): 84 """ 85 Applies the provided linear operator, to a dense AD variable of first or second order. 86 """ 87 def step(x): return np.dot(matrix,x) if isinstance(matrix,np.ndarray) else (matrix*x) 88 operator = functional.recurse(step,niter) 89 return rhs.apply_linear_operator(operator) if is_ad(rhs) else operator(rhs)
Applies the provided linear operator, to a dense AD variable of first or second order.
91def apply_linear_inverse(solver,matrix,rhs,niter=1): 92 """ 93 Applies the provided linear inverse to a dense AD variable of first or second order. 94 """ 95 def step(x): return solver(matrix,x) 96 operator = functional.recurse(step,niter) 97 return rhs.apply_linear_operator(operator) if is_ad(rhs) else operator(rhs)
Applies the provided linear inverse to a dense AD variable of first or second order.
136def disassociate(array,shape_free=None,shape_bound=None, 137 expand_free_dims=-1,expand_bound_dims=-1): 138 """ 139 Turns an array of shape shape_free + shape_bound 140 into an array of shape shape_free whose elements 141 are arrays of shape shape_bound. 142 Typical usage : recursive automatic differentiation. 143 Caveat : by defaut, singleton dimensions are introduced 144 to avoid numpy's "clever" treatment of scalar arrays. 145 146 Arguments: 147 - array : reshaped array 148 - (optional) shape_free, shape_bound : outer and inner array shapes. One is deduced from the other. 149 - (optional) expand_free_dims, expand_bound_dims. 150 """ 151 shape_free,shape_bound = _set_shape_free_bound(array.shape,shape_free,shape_bound) 152 shape_free = expand_shape(shape_free, expand_free_dims) 153 shape_bound = expand_shape(shape_bound,expand_bound_dims) 154 155 size_free = np.prod(shape_free) 156 array = array.reshape((size_free,)+shape_bound) 157 result = np.zeros(size_free,object) 158 for i in range(size_free): result[i] = array[i] 159 return result.reshape(shape_free)
Turns an array of shape shape_free + shape_bound into an array of shape shape_free whose elements are arrays of shape shape_bound. Typical usage : recursive automatic differentiation. Caveat : by defaut, singleton dimensions are introduced to avoid numpy's "clever" treatment of scalar arrays.
Arguments:
- array : reshaped array
- (optional) shape_free, shape_bound : outer and inner array shapes. One is deduced from the other.
- (optional) expand_free_dims, expand_bound_dims.
161def associate(array,squeeze_free_dims=-1,squeeze_bound_dims=-1): 162 """ 163 Turns an array of shape shape_free, whose elements 164 are arrays of shape shape_bound, into an array 165 of shape shape_free+shape_bound. 166 Inverse opeation to disassociate. 167 """ 168 if is_ad(array): 169 return array.associate(squeeze_free_dims,squeeze_bound_dims) 170 result = np.stack(array.reshape(-1),axis=0) 171 shape_free = squeeze_shape(array.shape,squeeze_free_dims) 172 shape_bound = squeeze_shape(result.shape[1:],squeeze_bound_dims) 173 return result.reshape(shape_free+shape_bound)
Turns an array of shape shape_free, whose elements are arrays of shape shape_bound, into an array of shape shape_free+shape_bound. Inverse opeation to disassociate.