agd.AutomaticDifferentiation.misc
1# Copyright 2020 Jean-Marie Mirebeau, University Paris-Sud, CNRS, University Paris-Saclay 2# Distributed WITHOUT ANY WARRANTY. Licensed under the Apache License, Version 2.0, see http://www.apache.org/licenses/LICENSE-2.0 3 4import numpy as np 5import numbers 6from .functional import map_iterables,map_iterables2,pair 7from .cupy_generic import isndarray,from_cupy,cp 8from .ad_generic import is_ad,remove_ad 9from . import ad_generic 10 11# ------- Ugly utilities ------- 12def normalize_axis(axis,ndim,allow_tuple=True): 13 if allow_tuple and isinstance(axis,tuple): 14 return tuple(normalize_axis(ax,ndim,False) for ax in axis) 15 if axis<0: return axis+ndim 16 return axis 17 18def add_ndim(arr,n): return np.reshape(arr,arr.shape+(1,)*n) 19def _add_dim(a): return np.expand_dims(a,axis=-1) 20def _add_dim2(a): return _add_dim(_add_dim(a)) 21 22def _to_tuple(a): return tuple(a) if hasattr(a,"__iter__") else (a,) 23 24def key_expand(key,depth=1): 25 """Modifies a key to access an array with more dimensions. Needed if ellipsis is used.""" 26 if isinstance(key,tuple): 27 if any(a is ... for a in key): 28 return key + (slice(None),)*depth 29 return key 30 31def _pad_last(a,pad_total): # Always makes a deep copy 32 return np.pad(a, pad_width=((0,0),)*(a.ndim-1)+((0,pad_total-a.shape[-1]),), mode='constant', constant_values=0) 33def _add_coef(a,b): 34 if a.shape[-1]==0: return b 35 elif b.shape[-1]==0: return a 36 else: return a+b 37def _prep_nl(s): return "\n"+s if "\n" in s else s 38 39def _concatenate(a,b,shape=None): 40 if shape is not None: 41 if a.shape[:-1]!=shape: a = np.broadcast_to(a,shape+a.shape[-1:]) 42 if b.shape[:-1]!=shape: b = np.broadcast_to(b,shape+b.shape[-1:]) 43 return np.concatenate((a,b),axis=-1) 44 45def _set_shape_constant(shape=None,constant=None): 46 if isndarray(shape): shape=tuple(shape) 47 if constant is None: 48 if shape is None: 49 raise ValueError("Error : unspecified shape or constant") 50 constant = np.full(shape,0.) 51 else: 52 if not isndarray(constant): 53 constant = ad_generic.asarray(constant) 54 if shape is not None and shape!=constant.shape: 55 raise ValueError("Error : incompatible shape and constant") 56 else: 57 shape=constant.shape 58 return shape,constant 59 60def _test_or_broadcast_ad(array,shape,broadcast,ad_depth=1): 61 if broadcast: 62 if array.shape[:-ad_depth]==shape: 63 return array 64 else: 65 return np.broadcast_to(array,shape+array.shape[-ad_depth:]) 66 else: 67 assert array.shape[:-ad_depth]==shape 68 return array 69 70 71 72# -------- For Dense and Dense2 ----- 73 74def apply_linear_operator(op,rhs,flatten_ndim=0): 75 """ 76 Applies a linear operator to an array with more than two dimensions, 77 by flattening the last dimensions 78 """ 79 assert (rhs.ndim-flatten_ndim) in [1,2] 80 shape_tail = rhs.shape[1:] 81 op_input = rhs.reshape((rhs.shape[0],np.prod(shape_tail,dtype=int))) 82 op_output = op(op_input) 83 return op_output.reshape((op_output.shape[0],)+shape_tail) 84 85 86# -------- Functional iteration, mainly for Reverse and Reverse2 ------- 87 88def ready_ad(a): 89 """ 90 Readies a variable for adding ad information, if possible. 91 Returns : readied variable, boolean (wether AD extension is possible) 92 """ 93 if is_ad(a): 94 raise ValueError("Variable a already contains AD information") 95 elif isinstance(a,numbers.Real) and not isinstance(a,numbers.Integral): 96 return np.array(a),True 97 elif isndarray(a) and not issubclass(a.dtype.type,numbers.Integral): 98 return a,True 99 else: 100 return a,False 101 102# Applying a function 103def _apply_output_helper(rev,val,iterables): 104 """ 105 Adds 'virtual' AD information to an output (with negative indices), 106 in selected places. 107 """ 108 def f(a): 109 a,to_ad = ready_ad(a) 110 if to_ad: 111 shape = pair(rev.size_rev,a.shape) 112 return rev._identity_rev(constant=a),shape 113 else: 114 return a,None 115 return map_iterables(f,val,iterables,split=True) 116 117 118def register(identity,data,iterables): 119 def reg(a): 120 a,to_ad = ready_ad(a) 121 if to_ad: return identity(constant=a) 122 else: return a 123 return map_iterables(reg,data,iterables) 124 125 126def _to_shapes(coef,shapes,iterables): 127 """ 128 Reshapes a one dimensional array into the given shapes, 129 given as a tuple of pair(start,shape) 130 """ 131 def f(s): 132 if s is None: 133 return None 134 else: 135 start,shape = s 136 return coef[start : start+np.prod(shape,dtype=int)].reshape(shape) 137 return map_iterables(f,shapes,iterables) 138 139def _apply_input_helper(args,kwargs,cls,iterables): 140 """ 141 Removes the AD information from some function input, and provides the correspondance. 142 """ 143 corresp = [] 144 def _make_arg(a): 145 nonlocal corresp 146 if is_ad(a): 147 assert isinstance(a,cls) 148 a_value = remove_ad(a) 149 corresp.append((a,a_value)) 150 return a_value 151 else: 152 return a 153 _args = tuple(map_iterables(_make_arg,val,iterables) for val in args) 154 _kwargs = {key:map_iterables(_make_arg,val,iterables) for key,val in kwargs.items()} 155 return _args,_kwargs,corresp 156 157 158def sumprod(u,v,iterables,to_first=False): 159 acc=0. 160 def f(u,v): 161 nonlocal acc 162 if u is not None: 163 U = u.to_first() if to_first else u 164 acc=acc+(U*v).sum() 165 map_iterables2(f,u,v,iterables) 166 return acc 167 168def reverse_mode(co_output): 169 if co_output is None: 170 return "Forward" 171 else: 172 assert isinstance(co_output,pair) 173 c,_ = co_output 174 if isinstance(c,pair): 175 return "Reverse2" 176 else: 177 return "Reverse" 178 179# ----- Functionnal ----- 180 181def recurse(step,niter=1): 182 def operator(rhs): 183 nonlocal step,niter 184 for i in range(niter): 185 rhs=step(rhs) 186 return rhs 187 return operator 188 189# ------- Common functions ------- 190 191def as_flat(a): 192 return a.reshape(-1) if isndarray(a) else ad_generic.array([a]) 193 194def tocsr(triplets,shape=None): 195 """Turns sparse matrix given as triplets into a csr (compressed sparse row) matrix""" 196 if from_cupy(triplets[0]): import cupyx; spmod = cupyx.scipy.sparse 197 else: import scipy.sparse as spmod 198 return spmod.coo_matrix(triplets,shape=shape).tocsr() 199 200def spsolve(triplets,rhs): 201 """ 202 Solves a sparse linear system where the matrix is given as triplets. 203 """ 204 if from_cupy(triplets[0]): 205 import cupyx; 206 solver = cupyx.scipy.sparse.linalg.lsqr # Only available solver 207 else: 208 import scipy.sparse.linalg 209 solver = scipy.sparse.linalg.spsolve 210 return solver(tocsr(triplets),rhs) 211 212def spapply(triplets,rhs,crop_rhs=False): 213 """ 214 Applies a sparse matrix, given as triplets, to an rhs. 215 """ 216 if crop_rhs: 217 cols = triplets[1][1] 218 if len(cols)==0: 219 return np.zeros_like(rhs,shape=(0,)) 220 size = 1+np.max(cols) 221 if rhs.shape[0]>size: 222 rhs = rhs[:size] 223 return tocsr(triplets)*rhs
def
normalize_axis(axis, ndim, allow_tuple=True):
def
add_ndim(arr, n):
19def add_ndim(arr,n): return np.reshape(arr,arr.shape+(1,)*n)
def
key_expand(key, depth=1):
25def key_expand(key,depth=1): 26 """Modifies a key to access an array with more dimensions. Needed if ellipsis is used.""" 27 if isinstance(key,tuple): 28 if any(a is ... for a in key): 29 return key + (slice(None),)*depth 30 return key
Modifies a key to access an array with more dimensions. Needed if ellipsis is used.
def
apply_linear_operator(op, rhs, flatten_ndim=0):
75def apply_linear_operator(op,rhs,flatten_ndim=0): 76 """ 77 Applies a linear operator to an array with more than two dimensions, 78 by flattening the last dimensions 79 """ 80 assert (rhs.ndim-flatten_ndim) in [1,2] 81 shape_tail = rhs.shape[1:] 82 op_input = rhs.reshape((rhs.shape[0],np.prod(shape_tail,dtype=int))) 83 op_output = op(op_input) 84 return op_output.reshape((op_output.shape[0],)+shape_tail)
Applies a linear operator to an array with more than two dimensions, by flattening the last dimensions
def
ready_ad(a):
89def ready_ad(a): 90 """ 91 Readies a variable for adding ad information, if possible. 92 Returns : readied variable, boolean (wether AD extension is possible) 93 """ 94 if is_ad(a): 95 raise ValueError("Variable a already contains AD information") 96 elif isinstance(a,numbers.Real) and not isinstance(a,numbers.Integral): 97 return np.array(a),True 98 elif isndarray(a) and not issubclass(a.dtype.type,numbers.Integral): 99 return a,True 100 else: 101 return a,False
Readies a variable for adding ad information, if possible. Returns : readied variable, boolean (wether AD extension is possible)
def
register(identity, data, iterables):
def
sumprod(u, v, iterables, to_first=False):
def
reverse_mode(co_output):
def
recurse(step, niter=1):
def
as_flat(a):
def
tocsr(triplets, shape=None):
195def tocsr(triplets,shape=None): 196 """Turns sparse matrix given as triplets into a csr (compressed sparse row) matrix""" 197 if from_cupy(triplets[0]): import cupyx; spmod = cupyx.scipy.sparse 198 else: import scipy.sparse as spmod 199 return spmod.coo_matrix(triplets,shape=shape).tocsr()
Turns sparse matrix given as triplets into a csr (compressed sparse row) matrix
def
spsolve(triplets, rhs):
201def spsolve(triplets,rhs): 202 """ 203 Solves a sparse linear system where the matrix is given as triplets. 204 """ 205 if from_cupy(triplets[0]): 206 import cupyx; 207 solver = cupyx.scipy.sparse.linalg.lsqr # Only available solver 208 else: 209 import scipy.sparse.linalg 210 solver = scipy.sparse.linalg.spsolve 211 return solver(tocsr(triplets),rhs)
Solves a sparse linear system where the matrix is given as triplets.
def
spapply(triplets, rhs, crop_rhs=False):
213def spapply(triplets,rhs,crop_rhs=False): 214 """ 215 Applies a sparse matrix, given as triplets, to an rhs. 216 """ 217 if crop_rhs: 218 cols = triplets[1][1] 219 if len(cols)==0: 220 return np.zeros_like(rhs,shape=(0,)) 221 size = 1+np.max(cols) 222 if rhs.shape[0]>size: 223 rhs = rhs[:size] 224 return tocsr(triplets)*rhs
Applies a sparse matrix, given as triplets, to an rhs.