agd.AutomaticDifferentiation.ad_generic

  1# Copyright 2020 Jean-Marie Mirebeau, University Paris-Sud, CNRS, University Paris-Saclay
  2# Distributed WITHOUT ANY WARRANTY. Licensed under the Apache License, Version 2.0, see http://www.apache.org/licenses/LICENSE-2.0
  3
  4import itertools
  5import numpy as np
  6import functools
  7
  8from . import functional
  9from .Base import is_ad,isndarray,array,asarray,_cp_ndarray
 10
 11"""
 12This file implements functions which apply indifferently to several AD types.
 13"""
 14
 15def adtype(data,iterables=tuple()):
 16	"""
 17	Returns None if no ad variable found, or the adtype if one is found.
 18	Also checks consistency of the ad types.
 19	"""
 20	result = None
 21	for value in rec_iter(data,iterables):
 22		t=type(x)
 23		if is_ad(t): 
 24			if result is None: result=t
 25			else: assert result==t
 26	return result
 27
 28def precision(x):
 29	"""
 30	Precision of the floating point type of x.
 31	"""
 32	if not isinstance(x,type): x = array(x).dtype.type
 33	return np.finfo(x).precision	
 34
 35def remove_ad(data,iterables=tuple()):
 36	def f(a): return a.value if is_ad(a) else a
 37	return functional.map_iterables(f,data,iterables)
 38
 39def as_writable(a):
 40	"""
 41	Returns a writable array containing the same elements as a.
 42	If the array a, or a field of a for an AD type, is flagged as 
 43	non-writable, then it is copied.
 44	"""
 45	if isinstance(a,(np.ndarray,_cp_ndarray)):
 46		return a if a.flags['WRITEABLE'] else a.copy()
 47	return a.new(*tuple(as_writable(e) for e in a.as_tuple() ))
 48
 49def common_cast(*args):
 50	"""
 51	If any of the arguments is an AD type, casts all other arguments to that type.
 52	Casts to ndarray if no argument is an AD type. 
 53	Usage : if a and b may or may not b AD arrays, 
 54	a,b = common_cast(a,b); a[0]=b[0]
 55	"""
 56	args = tuple(array(x) for x in args)
 57	common_type = None
 58	for x in args: 
 59		if is_ad(x):
 60			if common_type is None:
 61				common_type = type(x)
 62			if not isinstance(x,common_type):
 63				raise ValueError("Error : several distinct AD types")
 64	return args if common_type is None else tuple(common_type(x) for x in args)
 65
 66
 67def min_argmin(array,axis=None):
 68	if axis is None: return min_argmin(array.reshape(-1),axis=0)
 69	ai = np.argmin(array,axis=axis)
 70	return np.squeeze(np.take_along_axis(array,np.expand_dims(ai,
 71		axis=axis),axis=axis),axis=axis),ai
 72
 73def max_argmax(array,axis=None):
 74	if axis is None: return max_argmax(array.reshape(-1),axis=0)
 75	ai = np.argmax(array,axis=axis)
 76	return np.squeeze(np.take_along_axis(array,np.expand_dims(ai,
 77		axis=axis),axis=axis),axis=axis),ai
 78
 79# ------- Linear operators ------
 80
 81
 82def apply_linear_mapping(matrix,rhs,niter=1):
 83	"""
 84	Applies the provided linear operator, to a dense AD variable of first or second order.
 85	"""
 86	def step(x): return np.dot(matrix,x) if isinstance(matrix,np.ndarray) else (matrix*x)
 87	operator = functional.recurse(step,niter)
 88	return rhs.apply_linear_operator(operator) if is_ad(rhs) else operator(rhs)
 89
 90def apply_linear_inverse(solver,matrix,rhs,niter=1):
 91	"""
 92	Applies the provided linear inverse to a dense AD variable of first or second order.
 93	"""
 94	def step(x): return solver(matrix,x)
 95	operator = functional.recurse(step,niter)
 96	return rhs.apply_linear_operator(operator) if is_ad(rhs) else operator(rhs)
 97
 98# ------- Shape manipulation -------
 99
100def squeeze_shape(shape,axis):
101	if axis is None:
102		return shape
103	assert shape[axis]==1
104	if axis==-1:
105		return shape[:-1]
106	else:
107		return shape[:axis]+shape[(axis+1):]
108
109def expand_shape(shape,axis):
110	if axis is None:
111		return shape
112	if axis==-1:
113		return shape+(1,)
114	if axis<0:
115		axis+=1
116	return shape[:axis]+(1,)+shape[axis:]
117
118def _set_shape_free_bound(shape,shape_free,shape_bound):
119	if shape_free is not None:
120		assert shape_free==shape[0:len(shape_free)]
121		if shape_bound is None: 
122			shape_bound=shape[len(shape_free):]
123		else: 
124			assert shape_bound==shape[len(shape_free):]
125	if shape_bound is None: 
126		shape_bound = tuple()
127	assert len(shape_bound)==0 or shape_bound==shape[-len(shape_bound):]
128	if shape_free is None:
129		if len(shape_bound)==0:
130			shape_free = shape
131		else:
132			shape_free = shape[:len(shape)-len(shape_bound)]
133	return shape_free,shape_bound
134
135def disassociate(array,shape_free=None,shape_bound=None,
136	expand_free_dims=-1,expand_bound_dims=-1):
137	"""
138	Turns an array of shape shape_free + shape_bound 
139	into an array of shape shape_free whose elements 
140	are arrays of shape shape_bound.
141	Typical usage : recursive automatic differentiation.
142	Caveat : by defaut, singleton dimensions are introduced 
143	to avoid numpy's "clever" treatment of scalar arrays.
144
145	Arguments: 
146	- array : reshaped array
147	- (optional) shape_free, shape_bound : outer and inner array shapes. One is deduced from the other.
148	- (optional) expand_free_dims, expand_bound_dims. 
149	"""
150	shape_free,shape_bound = _set_shape_free_bound(array.shape,shape_free,shape_bound)
151	shape_free  = expand_shape(shape_free, expand_free_dims)
152	shape_bound = expand_shape(shape_bound,expand_bound_dims)
153	
154	size_free = np.prod(shape_free)
155	array = array.reshape((size_free,)+shape_bound)
156	result = np.zeros(size_free,object)
157	for i in range(size_free): result[i] = array[i]
158	return result.reshape(shape_free)
159
160def associate(array,squeeze_free_dims=-1,squeeze_bound_dims=-1):
161	"""
162	Turns an array of shape shape_free, whose elements 
163	are arrays of shape shape_bound, into an array 
164	of shape shape_free+shape_bound.
165	Inverse opeation to disassociate.
166	"""
167	if is_ad(array): 
168		return array.associate(squeeze_free_dims,squeeze_bound_dims)
169	result = np.stack(array.reshape(-1),axis=0)
170	shape_free  = squeeze_shape(array.shape,squeeze_free_dims)
171	shape_bound = squeeze_shape(result.shape[1:],squeeze_bound_dims) 
172	return result.reshape(shape_free+shape_bound)
def adtype(data, iterables=()):
16def adtype(data,iterables=tuple()):
17	"""
18	Returns None if no ad variable found, or the adtype if one is found.
19	Also checks consistency of the ad types.
20	"""
21	result = None
22	for value in rec_iter(data,iterables):
23		t=type(x)
24		if is_ad(t): 
25			if result is None: result=t
26			else: assert result==t
27	return result

Returns None if no ad variable found, or the adtype if one is found. Also checks consistency of the ad types.

def precision(x):
29def precision(x):
30	"""
31	Precision of the floating point type of x.
32	"""
33	if not isinstance(x,type): x = array(x).dtype.type
34	return np.finfo(x).precision	

Precision of the floating point type of x.

def remove_ad(data, iterables=()):
36def remove_ad(data,iterables=tuple()):
37	def f(a): return a.value if is_ad(a) else a
38	return functional.map_iterables(f,data,iterables)
def as_writable(a):
40def as_writable(a):
41	"""
42	Returns a writable array containing the same elements as a.
43	If the array a, or a field of a for an AD type, is flagged as 
44	non-writable, then it is copied.
45	"""
46	if isinstance(a,(np.ndarray,_cp_ndarray)):
47		return a if a.flags['WRITEABLE'] else a.copy()
48	return a.new(*tuple(as_writable(e) for e in a.as_tuple() ))

Returns a writable array containing the same elements as a. If the array a, or a field of a for an AD type, is flagged as non-writable, then it is copied.

def common_cast(*args):
50def common_cast(*args):
51	"""
52	If any of the arguments is an AD type, casts all other arguments to that type.
53	Casts to ndarray if no argument is an AD type. 
54	Usage : if a and b may or may not b AD arrays, 
55	a,b = common_cast(a,b); a[0]=b[0]
56	"""
57	args = tuple(array(x) for x in args)
58	common_type = None
59	for x in args: 
60		if is_ad(x):
61			if common_type is None:
62				common_type = type(x)
63			if not isinstance(x,common_type):
64				raise ValueError("Error : several distinct AD types")
65	return args if common_type is None else tuple(common_type(x) for x in args)

If any of the arguments is an AD type, casts all other arguments to that type. Casts to ndarray if no argument is an AD type. Usage : if a and b may or may not b AD arrays, a,b = common_cast(a,b); a[0]=b[0]

def min_argmin(array, axis=None):
68def min_argmin(array,axis=None):
69	if axis is None: return min_argmin(array.reshape(-1),axis=0)
70	ai = np.argmin(array,axis=axis)
71	return np.squeeze(np.take_along_axis(array,np.expand_dims(ai,
72		axis=axis),axis=axis),axis=axis),ai
def max_argmax(array, axis=None):
74def max_argmax(array,axis=None):
75	if axis is None: return max_argmax(array.reshape(-1),axis=0)
76	ai = np.argmax(array,axis=axis)
77	return np.squeeze(np.take_along_axis(array,np.expand_dims(ai,
78		axis=axis),axis=axis),axis=axis),ai
def apply_linear_mapping(matrix, rhs, niter=1):
83def apply_linear_mapping(matrix,rhs,niter=1):
84	"""
85	Applies the provided linear operator, to a dense AD variable of first or second order.
86	"""
87	def step(x): return np.dot(matrix,x) if isinstance(matrix,np.ndarray) else (matrix*x)
88	operator = functional.recurse(step,niter)
89	return rhs.apply_linear_operator(operator) if is_ad(rhs) else operator(rhs)

Applies the provided linear operator, to a dense AD variable of first or second order.

def apply_linear_inverse(solver, matrix, rhs, niter=1):
91def apply_linear_inverse(solver,matrix,rhs,niter=1):
92	"""
93	Applies the provided linear inverse to a dense AD variable of first or second order.
94	"""
95	def step(x): return solver(matrix,x)
96	operator = functional.recurse(step,niter)
97	return rhs.apply_linear_operator(operator) if is_ad(rhs) else operator(rhs)

Applies the provided linear inverse to a dense AD variable of first or second order.

def squeeze_shape(shape, axis):
101def squeeze_shape(shape,axis):
102	if axis is None:
103		return shape
104	assert shape[axis]==1
105	if axis==-1:
106		return shape[:-1]
107	else:
108		return shape[:axis]+shape[(axis+1):]
def expand_shape(shape, axis):
110def expand_shape(shape,axis):
111	if axis is None:
112		return shape
113	if axis==-1:
114		return shape+(1,)
115	if axis<0:
116		axis+=1
117	return shape[:axis]+(1,)+shape[axis:]
def disassociate( array, shape_free=None, shape_bound=None, expand_free_dims=-1, expand_bound_dims=-1):
136def disassociate(array,shape_free=None,shape_bound=None,
137	expand_free_dims=-1,expand_bound_dims=-1):
138	"""
139	Turns an array of shape shape_free + shape_bound 
140	into an array of shape shape_free whose elements 
141	are arrays of shape shape_bound.
142	Typical usage : recursive automatic differentiation.
143	Caveat : by defaut, singleton dimensions are introduced 
144	to avoid numpy's "clever" treatment of scalar arrays.
145
146	Arguments: 
147	- array : reshaped array
148	- (optional) shape_free, shape_bound : outer and inner array shapes. One is deduced from the other.
149	- (optional) expand_free_dims, expand_bound_dims. 
150	"""
151	shape_free,shape_bound = _set_shape_free_bound(array.shape,shape_free,shape_bound)
152	shape_free  = expand_shape(shape_free, expand_free_dims)
153	shape_bound = expand_shape(shape_bound,expand_bound_dims)
154	
155	size_free = np.prod(shape_free)
156	array = array.reshape((size_free,)+shape_bound)
157	result = np.zeros(size_free,object)
158	for i in range(size_free): result[i] = array[i]
159	return result.reshape(shape_free)

Turns an array of shape shape_free + shape_bound into an array of shape shape_free whose elements are arrays of shape shape_bound. Typical usage : recursive automatic differentiation. Caveat : by defaut, singleton dimensions are introduced to avoid numpy's "clever" treatment of scalar arrays.

Arguments:

  • array : reshaped array
  • (optional) shape_free, shape_bound : outer and inner array shapes. One is deduced from the other.
  • (optional) expand_free_dims, expand_bound_dims.
def associate(array, squeeze_free_dims=-1, squeeze_bound_dims=-1):
161def associate(array,squeeze_free_dims=-1,squeeze_bound_dims=-1):
162	"""
163	Turns an array of shape shape_free, whose elements 
164	are arrays of shape shape_bound, into an array 
165	of shape shape_free+shape_bound.
166	Inverse opeation to disassociate.
167	"""
168	if is_ad(array): 
169		return array.associate(squeeze_free_dims,squeeze_bound_dims)
170	result = np.stack(array.reshape(-1),axis=0)
171	shape_free  = squeeze_shape(array.shape,squeeze_free_dims)
172	shape_bound = squeeze_shape(result.shape[1:],squeeze_bound_dims) 
173	return result.reshape(shape_free+shape_bound)

Turns an array of shape shape_free, whose elements are arrays of shape shape_bound, into an array of shape shape_free+shape_bound. Inverse opeation to disassociate.