agd.AutomaticDifferentiation.misc

  1# Copyright 2020 Jean-Marie Mirebeau, University Paris-Sud, CNRS, University Paris-Saclay
  2# Distributed WITHOUT ANY WARRANTY. Licensed under the Apache License, Version 2.0, see http://www.apache.org/licenses/LICENSE-2.0
  3
  4import numpy as np
  5import numbers
  6from .functional import map_iterables,map_iterables2,pair
  7from .cupy_generic import isndarray,from_cupy,cp
  8from .ad_generic import is_ad,remove_ad
  9from . import ad_generic
 10
 11# ------- Ugly utilities -------
 12def normalize_axis(axis,ndim,allow_tuple=True):
 13	if allow_tuple and isinstance(axis,tuple): 
 14		return tuple(normalize_axis(ax,ndim,False) for ax in axis)
 15	if axis<0: return axis+ndim
 16	return axis
 17
 18def add_ndim(arr,n): return np.reshape(arr,arr.shape+(1,)*n)
 19def _add_dim(a):     return np.expand_dims(a,axis=-1)	
 20def _add_dim2(a):    return _add_dim(_add_dim(a))
 21
 22def _to_tuple(a): return tuple(a) if hasattr(a,"__iter__") else (a,)
 23
 24def key_expand(key,depth=1): 
 25	"""Modifies a key to access an array with more dimensions. Needed if ellipsis is used."""
 26	if isinstance(key,tuple):
 27		if any(a is ... for a in key):
 28			return key + (slice(None),)*depth
 29	return key
 30
 31def _pad_last(a,pad_total): # Always makes a deep copy
 32		return np.pad(a, pad_width=((0,0),)*(a.ndim-1)+((0,pad_total-a.shape[-1]),), mode='constant', constant_values=0)
 33def _add_coef(a,b):
 34	if a.shape[-1]==0: return b
 35	elif b.shape[-1]==0: return a
 36	else: return a+b
 37def _prep_nl(s): return "\n"+s if "\n" in s else s
 38
 39def _concatenate(a,b,shape=None):
 40	if shape is not None:
 41		if a.shape[:-1]!=shape: a = np.broadcast_to(a,shape+a.shape[-1:])
 42		if b.shape[:-1]!=shape: b = np.broadcast_to(b,shape+b.shape[-1:])
 43	return np.concatenate((a,b),axis=-1)
 44
 45def _set_shape_constant(shape=None,constant=None):
 46	if isndarray(shape): shape=tuple(shape)
 47	if constant is None:
 48		if shape is None:
 49			raise ValueError("Error : unspecified shape or constant")
 50		constant = np.full(shape,0.)
 51	else:
 52		if not isndarray(constant):
 53			constant = ad_generic.asarray(constant)
 54		if shape is not None and shape!=constant.shape: 
 55			raise ValueError("Error : incompatible shape and constant")
 56		else:
 57			shape=constant.shape
 58	return shape,constant
 59
 60def _test_or_broadcast_ad(array,shape,broadcast,ad_depth=1):
 61	if broadcast:
 62		if array.shape[:-ad_depth]==shape:
 63			return array
 64		else:
 65			return np.broadcast_to(array,shape+array.shape[-ad_depth:])
 66	else:
 67		assert array.shape[:-ad_depth]==shape
 68		return array
 69
 70
 71
 72# -------- For Dense and Dense2 -----
 73
 74def apply_linear_operator(op,rhs,flatten_ndim=0):
 75	"""
 76	Applies a linear operator to an array with more than two dimensions,
 77	by flattening the last dimensions
 78	"""
 79	assert (rhs.ndim-flatten_ndim) in [1,2]
 80	shape_tail = rhs.shape[1:]
 81	op_input = rhs.reshape((rhs.shape[0],np.prod(shape_tail,dtype=int)))
 82	op_output = op(op_input)
 83	return op_output.reshape((op_output.shape[0],)+shape_tail)
 84
 85
 86# -------- Functional iteration, mainly for Reverse and Reverse2 -------
 87
 88def ready_ad(a):
 89	"""
 90	Readies a variable for adding ad information, if possible.
 91	Returns : readied variable, boolean (wether AD extension is possible)
 92	"""
 93	if is_ad(a):
 94		raise ValueError("Variable a already contains AD information")
 95	elif isinstance(a,numbers.Real) and not isinstance(a,numbers.Integral):
 96		return np.array(a),True
 97	elif isndarray(a) and not issubclass(a.dtype.type,numbers.Integral):
 98		return a,True
 99	else:
100		return a,False
101
102# Applying a function
103def _apply_output_helper(rev,val,iterables):
104	"""
105	Adds 'virtual' AD information to an output (with negative indices), 
106	in selected places.
107	"""
108	def f(a):
109		a,to_ad = ready_ad(a)
110		if to_ad:
111			shape = pair(rev.size_rev,a.shape)
112			return rev._identity_rev(constant=a),shape		
113		else:
114			return a,None
115	return map_iterables(f,val,iterables,split=True)
116
117
118def register(identity,data,iterables):
119	def reg(a):
120		a,to_ad = ready_ad(a)
121		if to_ad: return identity(constant=a)
122		else: return a 
123	return map_iterables(reg,data,iterables)
124
125
126def _to_shapes(coef,shapes,iterables):
127	"""
128	Reshapes a one dimensional array into the given shapes, 
129	given as a tuple of pair(start,shape) 
130	"""
131	def f(s):
132		if s is None:
133			return None
134		else:
135			start,shape = s
136			return coef[start : start+np.prod(shape,dtype=int)].reshape(shape)
137	return map_iterables(f,shapes,iterables)
138
139def _apply_input_helper(args,kwargs,cls,iterables):
140	"""
141	Removes the AD information from some function input, and provides the correspondance.
142	"""
143	corresp = []
144	def _make_arg(a):
145		nonlocal corresp
146		if is_ad(a):
147			assert isinstance(a,cls)
148			a_value = remove_ad(a)
149			corresp.append((a,a_value))
150			return a_value
151		else:
152			return a
153	_args = tuple(map_iterables(_make_arg,val,iterables) for val in args)
154	_kwargs = {key:map_iterables(_make_arg,val,iterables) for key,val in kwargs.items()}
155	return _args,_kwargs,corresp
156
157
158def sumprod(u,v,iterables,to_first=False):
159	acc=0.
160	def f(u,v):
161		nonlocal acc
162		if u is not None: 
163			U = u.to_first() if to_first else u
164			acc=acc+(U*v).sum()
165	map_iterables2(f,u,v,iterables)
166	return acc
167
168def reverse_mode(co_output):
169	if co_output is None: 
170		return "Forward"
171	else:
172		assert isinstance(co_output,pair)
173		c,_ = co_output
174		if isinstance(c,pair):
175			return "Reverse2"
176		else: 
177			return "Reverse"
178
179# ----- Functionnal -----
180
181def recurse(step,niter=1):
182	def operator(rhs):
183		nonlocal step,niter
184		for i in range(niter):
185			rhs=step(rhs)
186		return rhs
187	return operator
188
189# ------- Common functions -------
190
191def as_flat(a):
192	return a.reshape(-1) if isndarray(a) else ad_generic.array([a])
193
194def tocsr(triplets,shape=None):
195	"""Turns sparse matrix given as triplets into a csr (compressed sparse row) matrix"""
196	if from_cupy(triplets[0]): import cupyx; spmod = cupyx.scipy.sparse
197	else: import scipy.sparse as spmod
198	return spmod.coo_matrix(triplets,shape=shape).tocsr()	
199
200def spsolve(triplets,rhs):
201	"""
202	Solves a sparse linear system where the matrix is given as triplets.
203	"""
204	if from_cupy(triplets[0]): 
205		import cupyx; 
206		solver = cupyx.scipy.sparse.linalg.lsqr # Only available solver
207	else:
208		import scipy.sparse.linalg
209		solver = scipy.sparse.linalg.spsolve
210	return solver(tocsr(triplets),rhs)		
211
212def spapply(triplets,rhs,crop_rhs=False):
213	"""
214	Applies a sparse matrix, given as triplets, to an rhs.
215	"""
216	if crop_rhs: 
217		cols = triplets[1][1]
218		if len(cols)==0: 
219			return np.zeros_like(rhs,shape=(0,))
220		size = 1+np.max(cols)
221		if rhs.shape[0]>size:
222			rhs = rhs[:size]
223	return tocsr(triplets)*rhs
def normalize_axis(axis, ndim, allow_tuple=True):
13def normalize_axis(axis,ndim,allow_tuple=True):
14	if allow_tuple and isinstance(axis,tuple): 
15		return tuple(normalize_axis(ax,ndim,False) for ax in axis)
16	if axis<0: return axis+ndim
17	return axis
def add_ndim(arr, n):
19def add_ndim(arr,n): return np.reshape(arr,arr.shape+(1,)*n)
def key_expand(key, depth=1):
25def key_expand(key,depth=1): 
26	"""Modifies a key to access an array with more dimensions. Needed if ellipsis is used."""
27	if isinstance(key,tuple):
28		if any(a is ... for a in key):
29			return key + (slice(None),)*depth
30	return key

Modifies a key to access an array with more dimensions. Needed if ellipsis is used.

def apply_linear_operator(op, rhs, flatten_ndim=0):
75def apply_linear_operator(op,rhs,flatten_ndim=0):
76	"""
77	Applies a linear operator to an array with more than two dimensions,
78	by flattening the last dimensions
79	"""
80	assert (rhs.ndim-flatten_ndim) in [1,2]
81	shape_tail = rhs.shape[1:]
82	op_input = rhs.reshape((rhs.shape[0],np.prod(shape_tail,dtype=int)))
83	op_output = op(op_input)
84	return op_output.reshape((op_output.shape[0],)+shape_tail)

Applies a linear operator to an array with more than two dimensions, by flattening the last dimensions

def ready_ad(a):
 89def ready_ad(a):
 90	"""
 91	Readies a variable for adding ad information, if possible.
 92	Returns : readied variable, boolean (wether AD extension is possible)
 93	"""
 94	if is_ad(a):
 95		raise ValueError("Variable a already contains AD information")
 96	elif isinstance(a,numbers.Real) and not isinstance(a,numbers.Integral):
 97		return np.array(a),True
 98	elif isndarray(a) and not issubclass(a.dtype.type,numbers.Integral):
 99		return a,True
100	else:
101		return a,False

Readies a variable for adding ad information, if possible. Returns : readied variable, boolean (wether AD extension is possible)

def register(identity, data, iterables):
119def register(identity,data,iterables):
120	def reg(a):
121		a,to_ad = ready_ad(a)
122		if to_ad: return identity(constant=a)
123		else: return a 
124	return map_iterables(reg,data,iterables)
def sumprod(u, v, iterables, to_first=False):
159def sumprod(u,v,iterables,to_first=False):
160	acc=0.
161	def f(u,v):
162		nonlocal acc
163		if u is not None: 
164			U = u.to_first() if to_first else u
165			acc=acc+(U*v).sum()
166	map_iterables2(f,u,v,iterables)
167	return acc
def reverse_mode(co_output):
169def reverse_mode(co_output):
170	if co_output is None: 
171		return "Forward"
172	else:
173		assert isinstance(co_output,pair)
174		c,_ = co_output
175		if isinstance(c,pair):
176			return "Reverse2"
177		else: 
178			return "Reverse"
def recurse(step, niter=1):
182def recurse(step,niter=1):
183	def operator(rhs):
184		nonlocal step,niter
185		for i in range(niter):
186			rhs=step(rhs)
187		return rhs
188	return operator
def as_flat(a):
192def as_flat(a):
193	return a.reshape(-1) if isndarray(a) else ad_generic.array([a])
def tocsr(triplets, shape=None):
195def tocsr(triplets,shape=None):
196	"""Turns sparse matrix given as triplets into a csr (compressed sparse row) matrix"""
197	if from_cupy(triplets[0]): import cupyx; spmod = cupyx.scipy.sparse
198	else: import scipy.sparse as spmod
199	return spmod.coo_matrix(triplets,shape=shape).tocsr()	

Turns sparse matrix given as triplets into a csr (compressed sparse row) matrix

def spsolve(triplets, rhs):
201def spsolve(triplets,rhs):
202	"""
203	Solves a sparse linear system where the matrix is given as triplets.
204	"""
205	if from_cupy(triplets[0]): 
206		import cupyx; 
207		solver = cupyx.scipy.sparse.linalg.lsqr # Only available solver
208	else:
209		import scipy.sparse.linalg
210		solver = scipy.sparse.linalg.spsolve
211	return solver(tocsr(triplets),rhs)		

Solves a sparse linear system where the matrix is given as triplets.

def spapply(triplets, rhs, crop_rhs=False):
213def spapply(triplets,rhs,crop_rhs=False):
214	"""
215	Applies a sparse matrix, given as triplets, to an rhs.
216	"""
217	if crop_rhs: 
218		cols = triplets[1][1]
219		if len(cols)==0: 
220			return np.zeros_like(rhs,shape=(0,))
221		size = 1+np.max(cols)
222		if rhs.shape[0]>size:
223			rhs = rhs[:size]
224	return tocsr(triplets)*rhs

Applies a sparse matrix, given as triplets, to an rhs.