agd.AutomaticDifferentiation.cupy_generic
This file implements functionalities needed to make the agd library generic to cupy/numpy usage. It does not import cupy, unless absolutely required.
1# Copyright 2020 Jean-Marie Mirebeau, University Paris-Sud, CNRS, University Paris-Saclay 2# Distributed WITHOUT ANY WARRANTY. Licensed under the Apache License, Version 2.0, see http://www.apache.org/licenses/LICENSE-2.0 3 4""" 5This file implements functionalities needed to make the agd library generic to cupy/numpy usage. 6It does not import cupy, unless absolutely required. 7""" 8 9import itertools 10import numpy as np 11import sys 12import functools 13import types 14from copy import copy 15 16from . import functional 17from .Base import cp,isndarray,from_cupy,is_ad,array,cupy_alt_overloads 18 19def get_array_module(data,iterables=tuple()): 20 """Returns the cupy module or the numpy module, depending on data""" 21 if cp is None: return np 22 return cp if any(from_cupy(x) for x in functional.rec_iter(data,iterables)) else np 23 24def samesize_int_t(float_t): 25 """ 26 Returns an integer type of the same size (32 or 64 bits) as a given float type 27 """ 28 float_t = np.dtype(float_t).type 29 float_name = str(float_t) 30 if float_t==np.float32: return np.int32 31 elif float_t==np.float64: return np.int64 32 else: raise ValueError( 33 f"Type {float_t} is not a float type, or has no default matching int type") 34 35# ----------- Retrieving data from a cupy array ------------ 36 37dtype32to64 = {np.float32:np.float64,np.int32:np.int64,np.uint32:np.uint64,} 38dtype64to32 = {np.float64:np.float32,np.int64:np.int32,np.uint64:np.uint32,} 39 40def cupy_get(x,dtype64=False,iterables=tuple()): 41 """ 42 If argument is a cupy ndarray, returns output of 'get' member function, 43 which is a numpy ndarray. Likewise for AD types. Returns unchanged argument otherwise. 44 - dtype64 : convert 32 bit floats and ints to their 64 bit counterparts 45 """ 46 def caster(x): 47 if from_cupy(x): 48 if is_ad(x): return type(x)(*(caster(z) for z in x.as_tuple())) 49 x = x.get() 50 if x.dtype.type in dtype32to64: x=x.astype(dtype32to64[x.dtype.type]) 51 return x 52 return functional.map_iterables(caster,x,iterables) 53 54def cupy_set(x,dtype32=True,iterables=tuple()): 55 """ 56 If argument is a numpy ndarray, converts it to a cupy ndarray. Applies to AD Types. 57 - dtype32 : convert 64 bit floats and ints to their 32 bit counterparts 58 """ 59 def caster(x): 60 if isndarray(x) and not from_cupy(x): 61 if is_ad(x): return type(x)(*(caster(z) for z in x.as_tuple())) 62 dtype = dtype64to32.get(x.dtype.type,x.dtype.type) 63 return cp.asarray(x,dtype=dtype) 64 return x 65 return functional.map_iterables(caster,x,iterables) 66 67@functional.decorator_with_arguments 68def cupy_get_args(f,*args,**kwargs): 69 """ 70 Decorator applying cupy_get to all arguments of the given function. 71 - *args, **kwargs : passed to cupy_get 72 """ 73 @functools.wraps(f) 74 def wrapper(*fargs,**fkwargs): 75 fargs = tuple(cupy_get(arg,*args,**kwargs) for arg in fargs) 76 fkwargs = {key:cupy_get(value,*args,**kwargs) for key,value in fkwargs.items()} 77 return f(*fargs,**fkwargs) 78 return wrapper 79 80# ----- Casting data to appropriate floating point and integer types ------ 81 82def has_dtype(arg,dtype="dtype",iterables=(tuple)): 83 """ 84 Wether one member of args is an ndarray with the provided dtype. 85 """ 86 dtype = np.dtype(dtype) 87 has_dtype_ = False 88 def find_dtype(x): 89 nonlocal has_dtype_ 90 has_dtype_ = has_dtype_ or (isndarray(x) and x.dtype==dtype) 91 for x in functional.rec_iter(arg,iterables=iterables): find_dtype(x) 92 return has_dtype_ 93 94def get_float_t(arg,**kwargs): 95 """ 96 Returns float32 if found in any argument, else float64. 97 - kwargs : passed to has_dtype 98 """ 99 return np.float32 if has_dtype(arg,dtype=np.float32,**kwargs) else np.float64 100 101def array_float_caster(arg,**kwargs): 102 """ 103 returns lambda arr : xp.asarray(arr,dtype=float_t) 104 where xp and float_t are in consistency with the arguments. 105 """ 106 xp = get_array_module(arg,**kwargs) 107 float_t = get_float_t(arg,**kwargs) 108 return lambda arr:xp.asarray(arr,dtype=float_t) 109 110@functional.decorator_with_arguments 111def set_output_dtype32(f,silent=True,iterables=(tuple,)): 112 """ 113 If the output of the given funtion contains ndarrays with 64bit dtype, 114 int or float, they are converted to 32 bit dtype. 115 """ 116 def caster(a): 117 if isndarray(a) and a.dtype in (np.float64,np.int64): 118 xp = get_array_module(a) 119 dtype = np.float32 if a.dtype==np.float64 else np.int32 120 if not silent: print( 121 f"Casting output of function {f.__name__} " 122 f"from {a.dtype} to {np.dtype(dtype)}") 123 return xp.asarray(a,dtype=dtype) 124 return a 125 126 @functools.wraps(f) 127 def wrapper(*args,**kwargs): 128 output = f(*args,**kwargs) 129 return functional.map_iterables(caster,output,iterables=iterables) 130 131 return wrapper 132 133# ------------ A helper function for cupy/numpy notebooks ------------- 134 135def cupy_friendly(arg): 136 """ 137 Returns a "cupy-friendly" copy of the input module, function, or object, 138 following arbitrary and ad-hoc rules. 139 """ 140 141 if isinstance(arg,types.ModuleType): 142 # Special cases 143 if arg is np: 144 print("Replacing numpy with cupy, set to output 32bit ints and floats by default.") 145 if cp is None: 146 raise ValueError("cupy module not found.\n" 147 "If your are using Google Colab, go to modify->notebook parameters and activate GPU acceleration.") 148 cp32 = functional.decorate_module_functions(cp,set_output_dtype32) 149 print("Using cp.asarray(*,dtype=np.float32) as the default caster in ad.array.") 150 array.caster = lambda x: cp.asarray(x,dtype=np.float32) 151 return cp32 152 if arg.__name__ == 'scipy.ndimage': 153 print("Replacing module scipy.ndimage with cupyx.scipy.ndimage .") 154 from cupyx.scipy import ndimage 155 return ndimage 156 if arg.__name__ == 'agd.Eikonal': 157 print("Setting dictIn.default_mode = 'gpu' in module agd.Eikonal .") 158 arg.dictIn.default_mode = 'gpu' 159 arg.VoronoiDecomposition.default_mode = 'gpu' 160 return arg 161 162 # Default behavior 163 print(f"Returning a copy of module {arg.__name__} whose functions accept cupy arrays as input.") 164 return functional.decorate_module_functions(arg,cupy_get_args) 165 166 167 if arg is np.allclose: 168 print("Setting float32 compatible default values atol=rtol=1e-5 in np.allclose") 169 def allclose(*args,**kwargs): 170 kwargs.setdefault('atol',1e-5) 171 kwargs.setdefault('rtol',1e-5) 172 return np.allclose(*args,**kwargs) 173 return allclose 174 175 if isinstance(arg,types.FunctionType): 176 177 if arg in cupy_alt_overloads: 178 alt,exception = cupy_alt_overloads[arg] 179 print("Adding (partial) support for (old versions of) cupy" 180 f" versions to function {arg.__name__}") 181 return alt 182 183 # Default behavior 184 print(f"Returning a copy of function {arg.__name__} which accepts cupy arrays as input.") 185 return cupy_get_args(arg) 186 187 if isndarray(arg): 188 print(f"Replacing ndarray object with cupy variant, for object of type {type(arg)}") 189 return cupy_set(arg) 190 else: 191 print("Replacing ndarray members with their cupy variants, " 192 f"for object of type {type(arg)}") 193 return cupy_set(arg,iterables=(type(arg),))
20def get_array_module(data,iterables=tuple()): 21 """Returns the cupy module or the numpy module, depending on data""" 22 if cp is None: return np 23 return cp if any(from_cupy(x) for x in functional.rec_iter(data,iterables)) else np
Returns the cupy module or the numpy module, depending on data
25def samesize_int_t(float_t): 26 """ 27 Returns an integer type of the same size (32 or 64 bits) as a given float type 28 """ 29 float_t = np.dtype(float_t).type 30 float_name = str(float_t) 31 if float_t==np.float32: return np.int32 32 elif float_t==np.float64: return np.int64 33 else: raise ValueError( 34 f"Type {float_t} is not a float type, or has no default matching int type")
Returns an integer type of the same size (32 or 64 bits) as a given float type
41def cupy_get(x,dtype64=False,iterables=tuple()): 42 """ 43 If argument is a cupy ndarray, returns output of 'get' member function, 44 which is a numpy ndarray. Likewise for AD types. Returns unchanged argument otherwise. 45 - dtype64 : convert 32 bit floats and ints to their 64 bit counterparts 46 """ 47 def caster(x): 48 if from_cupy(x): 49 if is_ad(x): return type(x)(*(caster(z) for z in x.as_tuple())) 50 x = x.get() 51 if x.dtype.type in dtype32to64: x=x.astype(dtype32to64[x.dtype.type]) 52 return x 53 return functional.map_iterables(caster,x,iterables)
If argument is a cupy ndarray, returns output of 'get' member function, which is a numpy ndarray. Likewise for AD types. Returns unchanged argument otherwise.
- dtype64 : convert 32 bit floats and ints to their 64 bit counterparts
55def cupy_set(x,dtype32=True,iterables=tuple()): 56 """ 57 If argument is a numpy ndarray, converts it to a cupy ndarray. Applies to AD Types. 58 - dtype32 : convert 64 bit floats and ints to their 32 bit counterparts 59 """ 60 def caster(x): 61 if isndarray(x) and not from_cupy(x): 62 if is_ad(x): return type(x)(*(caster(z) for z in x.as_tuple())) 63 dtype = dtype64to32.get(x.dtype.type,x.dtype.type) 64 return cp.asarray(x,dtype=dtype) 65 return x 66 return functional.map_iterables(caster,x,iterables)
If argument is a numpy ndarray, converts it to a cupy ndarray. Applies to AD Types.
- dtype32 : convert 64 bit floats and ints to their 32 bit counterparts
68@functional.decorator_with_arguments 69def cupy_get_args(f,*args,**kwargs): 70 """ 71 Decorator applying cupy_get to all arguments of the given function. 72 - *args, **kwargs : passed to cupy_get 73 """ 74 @functools.wraps(f) 75 def wrapper(*fargs,**fkwargs): 76 fargs = tuple(cupy_get(arg,*args,**kwargs) for arg in fargs) 77 fkwargs = {key:cupy_get(value,*args,**kwargs) for key,value in fkwargs.items()} 78 return f(*fargs,**fkwargs) 79 return wrapper
Decorator applying cupy_get to all arguments of the given function.
- args, *kwargs : passed to cupy_get
83def has_dtype(arg,dtype="dtype",iterables=(tuple)): 84 """ 85 Wether one member of args is an ndarray with the provided dtype. 86 """ 87 dtype = np.dtype(dtype) 88 has_dtype_ = False 89 def find_dtype(x): 90 nonlocal has_dtype_ 91 has_dtype_ = has_dtype_ or (isndarray(x) and x.dtype==dtype) 92 for x in functional.rec_iter(arg,iterables=iterables): find_dtype(x) 93 return has_dtype_
Wether one member of args is an ndarray with the provided dtype.
95def get_float_t(arg,**kwargs): 96 """ 97 Returns float32 if found in any argument, else float64. 98 - kwargs : passed to has_dtype 99 """ 100 return np.float32 if has_dtype(arg,dtype=np.float32,**kwargs) else np.float64
Returns float32 if found in any argument, else float64.
- kwargs : passed to has_dtype
102def array_float_caster(arg,**kwargs): 103 """ 104 returns lambda arr : xp.asarray(arr,dtype=float_t) 105 where xp and float_t are in consistency with the arguments. 106 """ 107 xp = get_array_module(arg,**kwargs) 108 float_t = get_float_t(arg,**kwargs) 109 return lambda arr:xp.asarray(arr,dtype=float_t)
returns lambda arr : xp.asarray(arr,dtype=float_t) where xp and float_t are in consistency with the arguments.
111@functional.decorator_with_arguments 112def set_output_dtype32(f,silent=True,iterables=(tuple,)): 113 """ 114 If the output of the given funtion contains ndarrays with 64bit dtype, 115 int or float, they are converted to 32 bit dtype. 116 """ 117 def caster(a): 118 if isndarray(a) and a.dtype in (np.float64,np.int64): 119 xp = get_array_module(a) 120 dtype = np.float32 if a.dtype==np.float64 else np.int32 121 if not silent: print( 122 f"Casting output of function {f.__name__} " 123 f"from {a.dtype} to {np.dtype(dtype)}") 124 return xp.asarray(a,dtype=dtype) 125 return a 126 127 @functools.wraps(f) 128 def wrapper(*args,**kwargs): 129 output = f(*args,**kwargs) 130 return functional.map_iterables(caster,output,iterables=iterables) 131 132 return wrapper
If the output of the given funtion contains ndarrays with 64bit dtype, int or float, they are converted to 32 bit dtype.
136def cupy_friendly(arg): 137 """ 138 Returns a "cupy-friendly" copy of the input module, function, or object, 139 following arbitrary and ad-hoc rules. 140 """ 141 142 if isinstance(arg,types.ModuleType): 143 # Special cases 144 if arg is np: 145 print("Replacing numpy with cupy, set to output 32bit ints and floats by default.") 146 if cp is None: 147 raise ValueError("cupy module not found.\n" 148 "If your are using Google Colab, go to modify->notebook parameters and activate GPU acceleration.") 149 cp32 = functional.decorate_module_functions(cp,set_output_dtype32) 150 print("Using cp.asarray(*,dtype=np.float32) as the default caster in ad.array.") 151 array.caster = lambda x: cp.asarray(x,dtype=np.float32) 152 return cp32 153 if arg.__name__ == 'scipy.ndimage': 154 print("Replacing module scipy.ndimage with cupyx.scipy.ndimage .") 155 from cupyx.scipy import ndimage 156 return ndimage 157 if arg.__name__ == 'agd.Eikonal': 158 print("Setting dictIn.default_mode = 'gpu' in module agd.Eikonal .") 159 arg.dictIn.default_mode = 'gpu' 160 arg.VoronoiDecomposition.default_mode = 'gpu' 161 return arg 162 163 # Default behavior 164 print(f"Returning a copy of module {arg.__name__} whose functions accept cupy arrays as input.") 165 return functional.decorate_module_functions(arg,cupy_get_args) 166 167 168 if arg is np.allclose: 169 print("Setting float32 compatible default values atol=rtol=1e-5 in np.allclose") 170 def allclose(*args,**kwargs): 171 kwargs.setdefault('atol',1e-5) 172 kwargs.setdefault('rtol',1e-5) 173 return np.allclose(*args,**kwargs) 174 return allclose 175 176 if isinstance(arg,types.FunctionType): 177 178 if arg in cupy_alt_overloads: 179 alt,exception = cupy_alt_overloads[arg] 180 print("Adding (partial) support for (old versions of) cupy" 181 f" versions to function {arg.__name__}") 182 return alt 183 184 # Default behavior 185 print(f"Returning a copy of function {arg.__name__} which accepts cupy arrays as input.") 186 return cupy_get_args(arg) 187 188 if isndarray(arg): 189 print(f"Replacing ndarray object with cupy variant, for object of type {type(arg)}") 190 return cupy_set(arg) 191 else: 192 print("Replacing ndarray members with their cupy variants, " 193 f"for object of type {type(arg)}") 194 return cupy_set(arg,iterables=(type(arg),))
Returns a "cupy-friendly" copy of the input module, function, or object, following arbitrary and ad-hoc rules.