agd.Eikonal.HFM_CUDA.BeckmanOT
1import os 2import numpy as np 3import time 4import cupy as cp 5 6from . import cupy_module_helper 7from .cupy_module_helper import SetModuleConstant 8from ... import FiniteDifferences as fd 9 10""" 11This file provides a GPU solver for a very particular instance of the (W1, unbalanced) 12optimal transport problem : minimizing the quantity 13int |σ| + (λ/2)|div(σ)+ν-μ|^2 14where μ and ν are (possibly vector valued) measures, and λ is a relaxation parameter. 15 16More generally, this file is a basic GPU implementation of the Chambolle-Pock primal-dual 17optimization algorithm, that can be modified to address other optimization problems. 18""" 19 20def solve_ot(λ,ξ,dx,relax_norm_constraint=0.01, 21 τ_primal=None,ρ_overrelax=1.8, ϕ0=None, σ0=None, 22 atol=0,rtol=1e-6,E_rtol=1e-3,maxiter=5000, 23 shape_i = None,stop_period=10, verbosity=2): 24 """ 25 Numerically solves a relaxed Bellman formulation of optimal transport 26 - λ (positive) : the relaxation parameter 27 - ξ (array) : the difference of measures ν-μ 28 - dx (array) : the grid scales (one per dimension) 29 """ 30 31 # traits and sizes 32 int_t = np.int32 33 float_t = np.float32 34 assert np.ndim(dx)==1 # One grid scale per dimension 35 vdim = len(dx) 36 ξ = cp.asarray(ξ,dtype=float_t) 37 shape_s = ξ.shape[-vdim:] # shape for vector fields 38 shape_v = tuple(s-1 for s in shape_s) 39 multichannel_shape = ξ.shape[:-vdim] 40 nchannels = np.prod(multichannel_shape,dtype=int) 41 ξ = ξ.reshape((nchannels,*shape_s)) # All channels in first dimension 42 43 gradnorm2 = 4.*np.sum(dx**-2) # Squared norm of the gradient operator 44 if τ_primal is None: τ_primal = 5./np.sqrt(gradnorm2) 45 assert τ_primal>0 46 if shape_i is None: shape_i = {1:(64,), 2:(8,8), 3:(4,4,4)}[vdim] 47 assert len(shape_i)==vdim 48 size_i = np.prod(shape_i) 49 50 # Format suitably for the cupy kernel 51 ξ = cp.ascontiguousarray(fd.block_expand(cp.asarray(ξ,dtype=float_t),shape_i, 52 constant_values=np.nan)) 53 shape_o = ξ.shape[1:1+vdim] 54 size_o = np.prod(shape_o) 55 56 if ϕ0 is None: ϕ = cp.zeros((nchannels,)+shape_o+shape_i,dtype=float_t) # primal point 57 else: ϕ = cp.ascontiguousarray(fd.block_expand(cp.asarray(ϕ0,dtype=float_t).reshape(( 58 nchannels,)+shape_s),shape_i,shape_o)) 59 60 if σ0 is None:σ=cp.zeros((nchannels,vdim,*shape_o,*shape_i),dtype=float_t)#dual point 61 else: σ = cp.ascontiguousarray(fd.block_expand(cp.asarray(σ0,dtype=float_t).reshape(( 62 nchannels,vdim)+shape_v),shape_i,shape_o)) 63 64 ϕ_ext = cp.zeros((nchannels,)+shape_o+shape_i,dtype=float_t) # extrapolated primal point 65 primal_value = cp.zeros(shape_o,dtype=float_t) # primal objective value, by block 66 dual_value = cp.zeros(shape_o,dtype=float_t) # dual objective value, by block 67 stabilized = cp.zeros(shape_o,dtype=np.int8) 68 69 # --------------- cuda header construction and compilation ---------------- 70 # Generate the load order for the boundary of shared data 71 x_top_e = fd.block_neighbors(shape_i,True) 72 x_bot_e = fd.block_neighbors(shape_i,False) 73 74 # Generate the kernels 75 traits = { 76 'ndim_macro':vdim, 77 'Int':int_t, 78 'Scalar':float_t, 79 'shape_i':shape_i, 80 'shape_e':tuple(s+1 for s in shape_i), 81 'size_bd_e':len(x_top_e), 82 'x_top_e':x_top_e, 83 'x_bot_e':1+x_bot_e, 84 'nchannels':nchannels, 85 } 86 87 cuda_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),"cuda") 88 date_modified = cupy_module_helper.getmtime_max(cuda_path) 89 source = cupy_module_helper.traits_header(traits,size_of_shape=True,log2_size=True) 90 91 source += [ 92 '#include "Kernel_BeckmanOT.h"', 93 f"// Date cuda code last modified : {date_modified}"] 94 cuoptions = ("-default-device", f"-I {cuda_path}") #,"-lineinfo") ,"--device-debug" 95 96 source = "\n".join(source) 97 module = cupy_module_helper.GetModule( 98 "#define checkstop_macro false\n"+source,cuoptions) 99 module_checkstop = cupy_module_helper.GetModule( 100 "#define checkstop_macro true\n"+source,cuoptions) 101 # -------------- cuda module generated (may be reused...) ---------------- 102 103 def setcst(*args): 104 cupy_module_helper.SetModuleConstant(module,*args) 105 cupy_module_helper.SetModuleConstant(module_checkstop,*args) 106 setcst('shape_tot_s',shape_s,int_t) 107 setcst('shape_tot_v',shape_v,int_t) 108 setcst('shape_o',shape_o,int_t) 109 setcst('size_io',size_i*size_o,int_t) 110 setcst('tau_primal',τ_primal,float_t) 111 setcst('tau_dual',1./(gradnorm2*τ_primal),float_t) 112 setcst('idx',1./dx,float_t) 113 setcst('lambda',λ,float_t) 114 setcst('ilambda',1./λ,float_t) 115 setcst('irelax_norm_constraint',1./relax_norm_constraint,float_t) 116 setcst('rho_overrelax',ρ_overrelax,float_t) 117 setcst('atol',atol,float_t) 118 setcst('rtol',rtol,float_t) 119 120 121 primal_step = module.get_function("primal_step") 122 dual_step = module.get_function("dual_step") 123 primal_step_checkstop = module_checkstop.get_function("primal_step") 124 dual_step_checkstop = module_checkstop.get_function("dual_step") 125 primal_values,dual_values = [],[] 126 127 # Main loop 128 for arr in (ξ,ϕ,ϕ_ext,σ,primal_value,dual_value,stabilized): 129 assert arr.flags['C_CONTIGUOUS'] # Just to be sure (Common source of silent bugs.) 130 top = time.time() 131 for niter in range(maxiter): 132 if niter%stop_period!=0: 133 primal_step((size_o,),(size_i,),(ϕ,ϕ_ext,σ,ξ)) 134 dual_step((size_o,),(size_i,),(σ,ϕ_ext)) 135 else: 136 primal_step_checkstop((size_o,),(size_i,),(ϕ,ϕ_ext,σ,ξ, 137 primal_value,dual_value,stabilized)) 138# print(f"{niter=},primal={primal_value.sum()},dual={dual_value.sum()}") 139 dual_step_checkstop((size_o,),(size_i,),(σ,ϕ_ext, 140 primal_value,dual_value,stabilized)) 141# print(f"primal={primal_value.sum()},dual={dual_value.sum()}") 142 e_primal = float(primal_value.sum()) 143 e_dual = -float(dual_value.sum()) 144 primal_values.append(e_primal) 145 dual_values.append(e_dual) 146 if E_rtol>0 and e_primal-e_dual<E_rtol*np.abs(e_primal): break 147 if np.all(stabilized): break 148 stabilized.fill(0) 149# if niter==10: print(f"First {niter} iterations took {time.time()-top} seconds") 150 else: 151 if E_rtol>0 and verbosity>=1: 152 print("Exhausted iteration budget without satisfying convergence criterion") 153 if verbosity>=2: 154 print(f"GPU primal-dual solver completed {niter+1} steps in {time.time()-top} seconds") 155 156 ϕ = ϕ.reshape(multichannel_shape+ϕ.shape[1:]) 157 ϕ = fd.block_squeeze(ϕ,shape_s) 158 σ = np.moveaxis(σ,0,1).reshape((vdim,)+multichannel_shape+σ.shape[2:]) 159 σ = fd.block_squeeze(σ,shape_v) 160 161 return {'ϕ':ϕ,'σ':σ, #'ϕ_ext':fd.block_squeeze(ϕ_ext,shape_s), 162 'stabilized':stabilized,'niter':niter+1, 163 'stopping_criterion':'stabilized' if np.all(stabilized) else 'gap', 164 'primal_values':np.array(primal_values),'dual_values':np.array(dual_values)}
def
solve_ot( λ, ξ, dx, relax_norm_constraint=0.01, τ_primal=None, ρ_overrelax=1.8, φ0=None, σ0=None, atol=0, rtol=1e-06, E_rtol=0.001, maxiter=5000, shape_i=None, stop_period=10, verbosity=2):
21def solve_ot(λ,ξ,dx,relax_norm_constraint=0.01, 22 τ_primal=None,ρ_overrelax=1.8, ϕ0=None, σ0=None, 23 atol=0,rtol=1e-6,E_rtol=1e-3,maxiter=5000, 24 shape_i = None,stop_period=10, verbosity=2): 25 """ 26 Numerically solves a relaxed Bellman formulation of optimal transport 27 - λ (positive) : the relaxation parameter 28 - ξ (array) : the difference of measures ν-μ 29 - dx (array) : the grid scales (one per dimension) 30 """ 31 32 # traits and sizes 33 int_t = np.int32 34 float_t = np.float32 35 assert np.ndim(dx)==1 # One grid scale per dimension 36 vdim = len(dx) 37 ξ = cp.asarray(ξ,dtype=float_t) 38 shape_s = ξ.shape[-vdim:] # shape for vector fields 39 shape_v = tuple(s-1 for s in shape_s) 40 multichannel_shape = ξ.shape[:-vdim] 41 nchannels = np.prod(multichannel_shape,dtype=int) 42 ξ = ξ.reshape((nchannels,*shape_s)) # All channels in first dimension 43 44 gradnorm2 = 4.*np.sum(dx**-2) # Squared norm of the gradient operator 45 if τ_primal is None: τ_primal = 5./np.sqrt(gradnorm2) 46 assert τ_primal>0 47 if shape_i is None: shape_i = {1:(64,), 2:(8,8), 3:(4,4,4)}[vdim] 48 assert len(shape_i)==vdim 49 size_i = np.prod(shape_i) 50 51 # Format suitably for the cupy kernel 52 ξ = cp.ascontiguousarray(fd.block_expand(cp.asarray(ξ,dtype=float_t),shape_i, 53 constant_values=np.nan)) 54 shape_o = ξ.shape[1:1+vdim] 55 size_o = np.prod(shape_o) 56 57 if ϕ0 is None: ϕ = cp.zeros((nchannels,)+shape_o+shape_i,dtype=float_t) # primal point 58 else: ϕ = cp.ascontiguousarray(fd.block_expand(cp.asarray(ϕ0,dtype=float_t).reshape(( 59 nchannels,)+shape_s),shape_i,shape_o)) 60 61 if σ0 is None:σ=cp.zeros((nchannels,vdim,*shape_o,*shape_i),dtype=float_t)#dual point 62 else: σ = cp.ascontiguousarray(fd.block_expand(cp.asarray(σ0,dtype=float_t).reshape(( 63 nchannels,vdim)+shape_v),shape_i,shape_o)) 64 65 ϕ_ext = cp.zeros((nchannels,)+shape_o+shape_i,dtype=float_t) # extrapolated primal point 66 primal_value = cp.zeros(shape_o,dtype=float_t) # primal objective value, by block 67 dual_value = cp.zeros(shape_o,dtype=float_t) # dual objective value, by block 68 stabilized = cp.zeros(shape_o,dtype=np.int8) 69 70 # --------------- cuda header construction and compilation ---------------- 71 # Generate the load order for the boundary of shared data 72 x_top_e = fd.block_neighbors(shape_i,True) 73 x_bot_e = fd.block_neighbors(shape_i,False) 74 75 # Generate the kernels 76 traits = { 77 'ndim_macro':vdim, 78 'Int':int_t, 79 'Scalar':float_t, 80 'shape_i':shape_i, 81 'shape_e':tuple(s+1 for s in shape_i), 82 'size_bd_e':len(x_top_e), 83 'x_top_e':x_top_e, 84 'x_bot_e':1+x_bot_e, 85 'nchannels':nchannels, 86 } 87 88 cuda_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),"cuda") 89 date_modified = cupy_module_helper.getmtime_max(cuda_path) 90 source = cupy_module_helper.traits_header(traits,size_of_shape=True,log2_size=True) 91 92 source += [ 93 '#include "Kernel_BeckmanOT.h"', 94 f"// Date cuda code last modified : {date_modified}"] 95 cuoptions = ("-default-device", f"-I {cuda_path}") #,"-lineinfo") ,"--device-debug" 96 97 source = "\n".join(source) 98 module = cupy_module_helper.GetModule( 99 "#define checkstop_macro false\n"+source,cuoptions) 100 module_checkstop = cupy_module_helper.GetModule( 101 "#define checkstop_macro true\n"+source,cuoptions) 102 # -------------- cuda module generated (may be reused...) ---------------- 103 104 def setcst(*args): 105 cupy_module_helper.SetModuleConstant(module,*args) 106 cupy_module_helper.SetModuleConstant(module_checkstop,*args) 107 setcst('shape_tot_s',shape_s,int_t) 108 setcst('shape_tot_v',shape_v,int_t) 109 setcst('shape_o',shape_o,int_t) 110 setcst('size_io',size_i*size_o,int_t) 111 setcst('tau_primal',τ_primal,float_t) 112 setcst('tau_dual',1./(gradnorm2*τ_primal),float_t) 113 setcst('idx',1./dx,float_t) 114 setcst('lambda',λ,float_t) 115 setcst('ilambda',1./λ,float_t) 116 setcst('irelax_norm_constraint',1./relax_norm_constraint,float_t) 117 setcst('rho_overrelax',ρ_overrelax,float_t) 118 setcst('atol',atol,float_t) 119 setcst('rtol',rtol,float_t) 120 121 122 primal_step = module.get_function("primal_step") 123 dual_step = module.get_function("dual_step") 124 primal_step_checkstop = module_checkstop.get_function("primal_step") 125 dual_step_checkstop = module_checkstop.get_function("dual_step") 126 primal_values,dual_values = [],[] 127 128 # Main loop 129 for arr in (ξ,ϕ,ϕ_ext,σ,primal_value,dual_value,stabilized): 130 assert arr.flags['C_CONTIGUOUS'] # Just to be sure (Common source of silent bugs.) 131 top = time.time() 132 for niter in range(maxiter): 133 if niter%stop_period!=0: 134 primal_step((size_o,),(size_i,),(ϕ,ϕ_ext,σ,ξ)) 135 dual_step((size_o,),(size_i,),(σ,ϕ_ext)) 136 else: 137 primal_step_checkstop((size_o,),(size_i,),(ϕ,ϕ_ext,σ,ξ, 138 primal_value,dual_value,stabilized)) 139# print(f"{niter=},primal={primal_value.sum()},dual={dual_value.sum()}") 140 dual_step_checkstop((size_o,),(size_i,),(σ,ϕ_ext, 141 primal_value,dual_value,stabilized)) 142# print(f"primal={primal_value.sum()},dual={dual_value.sum()}") 143 e_primal = float(primal_value.sum()) 144 e_dual = -float(dual_value.sum()) 145 primal_values.append(e_primal) 146 dual_values.append(e_dual) 147 if E_rtol>0 and e_primal-e_dual<E_rtol*np.abs(e_primal): break 148 if np.all(stabilized): break 149 stabilized.fill(0) 150# if niter==10: print(f"First {niter} iterations took {time.time()-top} seconds") 151 else: 152 if E_rtol>0 and verbosity>=1: 153 print("Exhausted iteration budget without satisfying convergence criterion") 154 if verbosity>=2: 155 print(f"GPU primal-dual solver completed {niter+1} steps in {time.time()-top} seconds") 156 157 ϕ = ϕ.reshape(multichannel_shape+ϕ.shape[1:]) 158 ϕ = fd.block_squeeze(ϕ,shape_s) 159 σ = np.moveaxis(σ,0,1).reshape((vdim,)+multichannel_shape+σ.shape[2:]) 160 σ = fd.block_squeeze(σ,shape_v) 161 162 return {'ϕ':ϕ,'σ':σ, #'ϕ_ext':fd.block_squeeze(ϕ_ext,shape_s), 163 'stabilized':stabilized,'niter':niter+1, 164 'stopping_criterion':'stabilized' if np.all(stabilized) else 'gap', 165 'primal_values':np.array(primal_values),'dual_values':np.array(dual_values)}
Numerically solves a relaxed Bellman formulation of optimal transport
- λ (positive) : the relaxation parameter
- ξ (array) : the difference of measures ν-μ
- dx (array) : the grid scales (one per dimension)