agd.Eikonal.HFM_CUDA.BeckmanOT

  1import os
  2import numpy as np
  3import time
  4import cupy as cp
  5
  6from . import cupy_module_helper
  7from .cupy_module_helper import SetModuleConstant
  8from ... import FiniteDifferences as fd
  9
 10"""
 11This file provides a GPU solver for a very particular instance of the (W1, unbalanced) 
 12optimal transport problem : minimizing the quantity
 13int |σ| + (λ/2)|div(σ)+ν-μ|^2 
 14where μ and ν are (possibly vector valued) measures, and λ is a relaxation parameter.
 15
 16More generally, this file is a basic GPU implementation of the Chambolle-Pock primal-dual
 17optimization algorithm, that can be modified to address other optimization problems.
 18"""
 19
 20def solve_ot(λ,ξ,dx,relax_norm_constraint=0.01,
 21	τ_primal=None,ρ_overrelax=1.8, ϕ0=None, σ0=None,
 22	atol=0,rtol=1e-6,E_rtol=1e-3,maxiter=5000,
 23	shape_i = None,stop_period=10, verbosity=2):
 24	"""
 25	Numerically solves a relaxed Bellman formulation of optimal transport
 26	- λ (positive) : the relaxation parameter
 27	- ξ (array) : the difference of measures ν-μ
 28	- dx (array) : the grid scales (one per dimension) 
 29	"""
 30
 31	# traits and sizes
 32	int_t = np.int32
 33	float_t = np.float32
 34	assert np.ndim(dx)==1 # One grid scale per dimension
 35	vdim = len(dx)
 36	ξ = cp.asarray(ξ,dtype=float_t)
 37	shape_s = ξ.shape[-vdim:] # shape for vector fields
 38	shape_v = tuple(s-1 for s in shape_s)
 39	multichannel_shape = ξ.shape[:-vdim]
 40	nchannels = np.prod(multichannel_shape,dtype=int)
 41	ξ = ξ.reshape((nchannels,*shape_s)) # All channels in first dimension
 42
 43	gradnorm2 = 4.*np.sum(dx**-2) # Squared norm of the gradient operator
 44	if τ_primal is None: τ_primal = 5./np.sqrt(gradnorm2)
 45	assert τ_primal>0
 46	if shape_i is None: shape_i = {1:(64,), 2:(8,8), 3:(4,4,4)}[vdim]
 47	assert len(shape_i)==vdim
 48	size_i = np.prod(shape_i)
 49
 50	# Format suitably for the cupy kernel
 51	ξ = cp.ascontiguousarray(fd.block_expand(cp.asarray(ξ,dtype=float_t),shape_i,
 52		constant_values=np.nan))
 53	shape_o = ξ.shape[1:1+vdim]
 54	size_o = np.prod(shape_o)
 55
 56	if ϕ0 is None: ϕ = cp.zeros((nchannels,)+shape_o+shape_i,dtype=float_t) # primal point
 57	else: ϕ = cp.ascontiguousarray(fd.block_expand(cp.asarray(ϕ0,dtype=float_t).reshape((
 58		nchannels,)+shape_s),shape_i,shape_o))
 59
 60	if σ0 is None:σ=cp.zeros((nchannels,vdim,*shape_o,*shape_i),dtype=float_t)#dual point
 61	else: σ = cp.ascontiguousarray(fd.block_expand(cp.asarray(σ0,dtype=float_t).reshape((
 62		nchannels,vdim)+shape_v),shape_i,shape_o))
 63
 64	ϕ_ext = cp.zeros((nchannels,)+shape_o+shape_i,dtype=float_t) # extrapolated primal point
 65	primal_value = cp.zeros(shape_o,dtype=float_t) # primal objective value, by block
 66	dual_value = cp.zeros(shape_o,dtype=float_t) # dual objective value, by block
 67	stabilized = cp.zeros(shape_o,dtype=np.int8)
 68
 69	# --------------- cuda header construction and compilation ----------------
 70	# Generate the load order for the boundary of shared data
 71	x_top_e = fd.block_neighbors(shape_i,True)
 72	x_bot_e = fd.block_neighbors(shape_i,False)
 73
 74	# Generate the kernels
 75	traits = {
 76		'ndim_macro':vdim,
 77		'Int':int_t,
 78		'Scalar':float_t,
 79		'shape_i':shape_i,
 80		'shape_e':tuple(s+1 for s in shape_i),
 81		'size_bd_e':len(x_top_e),
 82		'x_top_e':x_top_e,
 83		'x_bot_e':1+x_bot_e,
 84		'nchannels':nchannels,
 85	}
 86
 87	cuda_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),"cuda")
 88	date_modified = cupy_module_helper.getmtime_max(cuda_path)
 89	source = cupy_module_helper.traits_header(traits,size_of_shape=True,log2_size=True)
 90
 91	source += [
 92	'#include "Kernel_BeckmanOT.h"',
 93	f"// Date cuda code last modified : {date_modified}"]
 94	cuoptions = ("-default-device", f"-I {cuda_path}") #,"-lineinfo") ,"--device-debug"
 95
 96	source = "\n".join(source)
 97	module = cupy_module_helper.GetModule(
 98		"#define checkstop_macro false\n"+source,cuoptions)
 99	module_checkstop = cupy_module_helper.GetModule(
100		"#define checkstop_macro true\n"+source,cuoptions)
101	# -------------- cuda module generated (may be reused...) ----------------
102
103	def setcst(*args): 
104		cupy_module_helper.SetModuleConstant(module,*args)
105		cupy_module_helper.SetModuleConstant(module_checkstop,*args)
106	setcst('shape_tot_s',shape_s,int_t)
107	setcst('shape_tot_v',shape_v,int_t)
108	setcst('shape_o',shape_o,int_t)
109	setcst('size_io',size_i*size_o,int_t)
110	setcst('tau_primal',τ_primal,float_t)
111	setcst('tau_dual',1./(gradnorm2*τ_primal),float_t)
112	setcst('idx',1./dx,float_t)
113	setcst('lambda',λ,float_t)
114	setcst('ilambda',1./λ,float_t)
115	setcst('irelax_norm_constraint',1./relax_norm_constraint,float_t)
116	setcst('rho_overrelax',ρ_overrelax,float_t)
117	setcst('atol',atol,float_t)
118	setcst('rtol',rtol,float_t)
119
120
121	primal_step = module.get_function("primal_step")
122	dual_step = module.get_function("dual_step")
123	primal_step_checkstop = module_checkstop.get_function("primal_step")
124	dual_step_checkstop = module_checkstop.get_function("dual_step")
125	primal_values,dual_values = [],[]
126
127	# Main loop
128	for arr in (ξ,ϕ,ϕ_ext,σ,primal_value,dual_value,stabilized): 
129		assert arr.flags['C_CONTIGUOUS'] # Just to be sure (Common source of silent bugs.)
130	top = time.time()
131	for niter in range(maxiter):
132		if niter%stop_period!=0:
133			primal_step((size_o,),(size_i,),(ϕ,ϕ_ext,σ,ξ))
134			dual_step((size_o,),(size_i,),(σ,ϕ_ext))
135		else: 
136			primal_step_checkstop((size_o,),(size_i,),(ϕ,ϕ_ext,σ,ξ,
137				primal_value,dual_value,stabilized))
138#			print(f"{niter=},primal={primal_value.sum()},dual={dual_value.sum()}")
139			dual_step_checkstop((size_o,),(size_i,),(σ,ϕ_ext,
140				primal_value,dual_value,stabilized))
141#			print(f"primal={primal_value.sum()},dual={dual_value.sum()}")		
142			e_primal =  float(primal_value.sum())
143			e_dual   = -float(dual_value.sum())
144			primal_values.append(e_primal)
145			dual_values.append(e_dual)
146			if E_rtol>0 and e_primal-e_dual<E_rtol*np.abs(e_primal): break
147			if np.all(stabilized): break
148			stabilized.fill(0)
149#		if niter==10: print(f"First {niter} iterations took {time.time()-top} seconds")
150	else: 
151		if E_rtol>0 and verbosity>=1: 
152			print("Exhausted iteration budget without satisfying convergence criterion") 
153	if verbosity>=2:
154		print(f"GPU primal-dual solver completed {niter+1} steps in {time.time()-top} seconds")
155
156	ϕ = ϕ.reshape(multichannel_shape+ϕ.shape[1:])
157	ϕ = fd.block_squeeze(ϕ,shape_s)
158	σ = np.moveaxis(σ,0,1).reshape((vdim,)+multichannel_shape+σ.shape[2:])
159	σ = fd.block_squeeze(σ,shape_v)
160
161	return {'ϕ':ϕ,'σ':σ, #'ϕ_ext':fd.block_squeeze(ϕ_ext,shape_s),
162	'stabilized':stabilized,'niter':niter+1,
163	'stopping_criterion':'stabilized' if np.all(stabilized) else 'gap',
164	'primal_values':np.array(primal_values),'dual_values':np.array(dual_values)}
def solve_ot( λ, ξ, dx, relax_norm_constraint=0.01, τ_primal=None, ρ_overrelax=1.8, φ0=None, σ0=None, atol=0, rtol=1e-06, E_rtol=0.001, maxiter=5000, shape_i=None, stop_period=10, verbosity=2):
 21def solve_ot(λ,ξ,dx,relax_norm_constraint=0.01,
 22	τ_primal=None,ρ_overrelax=1.8, ϕ0=None, σ0=None,
 23	atol=0,rtol=1e-6,E_rtol=1e-3,maxiter=5000,
 24	shape_i = None,stop_period=10, verbosity=2):
 25	"""
 26	Numerically solves a relaxed Bellman formulation of optimal transport
 27	- λ (positive) : the relaxation parameter
 28	- ξ (array) : the difference of measures ν-μ
 29	- dx (array) : the grid scales (one per dimension) 
 30	"""
 31
 32	# traits and sizes
 33	int_t = np.int32
 34	float_t = np.float32
 35	assert np.ndim(dx)==1 # One grid scale per dimension
 36	vdim = len(dx)
 37	ξ = cp.asarray(ξ,dtype=float_t)
 38	shape_s = ξ.shape[-vdim:] # shape for vector fields
 39	shape_v = tuple(s-1 for s in shape_s)
 40	multichannel_shape = ξ.shape[:-vdim]
 41	nchannels = np.prod(multichannel_shape,dtype=int)
 42	ξ = ξ.reshape((nchannels,*shape_s)) # All channels in first dimension
 43
 44	gradnorm2 = 4.*np.sum(dx**-2) # Squared norm of the gradient operator
 45	if τ_primal is None: τ_primal = 5./np.sqrt(gradnorm2)
 46	assert τ_primal>0
 47	if shape_i is None: shape_i = {1:(64,), 2:(8,8), 3:(4,4,4)}[vdim]
 48	assert len(shape_i)==vdim
 49	size_i = np.prod(shape_i)
 50
 51	# Format suitably for the cupy kernel
 52	ξ = cp.ascontiguousarray(fd.block_expand(cp.asarray(ξ,dtype=float_t),shape_i,
 53		constant_values=np.nan))
 54	shape_o = ξ.shape[1:1+vdim]
 55	size_o = np.prod(shape_o)
 56
 57	if ϕ0 is None: ϕ = cp.zeros((nchannels,)+shape_o+shape_i,dtype=float_t) # primal point
 58	else: ϕ = cp.ascontiguousarray(fd.block_expand(cp.asarray(ϕ0,dtype=float_t).reshape((
 59		nchannels,)+shape_s),shape_i,shape_o))
 60
 61	if σ0 is None:σ=cp.zeros((nchannels,vdim,*shape_o,*shape_i),dtype=float_t)#dual point
 62	else: σ = cp.ascontiguousarray(fd.block_expand(cp.asarray(σ0,dtype=float_t).reshape((
 63		nchannels,vdim)+shape_v),shape_i,shape_o))
 64
 65	ϕ_ext = cp.zeros((nchannels,)+shape_o+shape_i,dtype=float_t) # extrapolated primal point
 66	primal_value = cp.zeros(shape_o,dtype=float_t) # primal objective value, by block
 67	dual_value = cp.zeros(shape_o,dtype=float_t) # dual objective value, by block
 68	stabilized = cp.zeros(shape_o,dtype=np.int8)
 69
 70	# --------------- cuda header construction and compilation ----------------
 71	# Generate the load order for the boundary of shared data
 72	x_top_e = fd.block_neighbors(shape_i,True)
 73	x_bot_e = fd.block_neighbors(shape_i,False)
 74
 75	# Generate the kernels
 76	traits = {
 77		'ndim_macro':vdim,
 78		'Int':int_t,
 79		'Scalar':float_t,
 80		'shape_i':shape_i,
 81		'shape_e':tuple(s+1 for s in shape_i),
 82		'size_bd_e':len(x_top_e),
 83		'x_top_e':x_top_e,
 84		'x_bot_e':1+x_bot_e,
 85		'nchannels':nchannels,
 86	}
 87
 88	cuda_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),"cuda")
 89	date_modified = cupy_module_helper.getmtime_max(cuda_path)
 90	source = cupy_module_helper.traits_header(traits,size_of_shape=True,log2_size=True)
 91
 92	source += [
 93	'#include "Kernel_BeckmanOT.h"',
 94	f"// Date cuda code last modified : {date_modified}"]
 95	cuoptions = ("-default-device", f"-I {cuda_path}") #,"-lineinfo") ,"--device-debug"
 96
 97	source = "\n".join(source)
 98	module = cupy_module_helper.GetModule(
 99		"#define checkstop_macro false\n"+source,cuoptions)
100	module_checkstop = cupy_module_helper.GetModule(
101		"#define checkstop_macro true\n"+source,cuoptions)
102	# -------------- cuda module generated (may be reused...) ----------------
103
104	def setcst(*args): 
105		cupy_module_helper.SetModuleConstant(module,*args)
106		cupy_module_helper.SetModuleConstant(module_checkstop,*args)
107	setcst('shape_tot_s',shape_s,int_t)
108	setcst('shape_tot_v',shape_v,int_t)
109	setcst('shape_o',shape_o,int_t)
110	setcst('size_io',size_i*size_o,int_t)
111	setcst('tau_primal',τ_primal,float_t)
112	setcst('tau_dual',1./(gradnorm2*τ_primal),float_t)
113	setcst('idx',1./dx,float_t)
114	setcst('lambda',λ,float_t)
115	setcst('ilambda',1./λ,float_t)
116	setcst('irelax_norm_constraint',1./relax_norm_constraint,float_t)
117	setcst('rho_overrelax',ρ_overrelax,float_t)
118	setcst('atol',atol,float_t)
119	setcst('rtol',rtol,float_t)
120
121
122	primal_step = module.get_function("primal_step")
123	dual_step = module.get_function("dual_step")
124	primal_step_checkstop = module_checkstop.get_function("primal_step")
125	dual_step_checkstop = module_checkstop.get_function("dual_step")
126	primal_values,dual_values = [],[]
127
128	# Main loop
129	for arr in (ξ,ϕ,ϕ_ext,σ,primal_value,dual_value,stabilized): 
130		assert arr.flags['C_CONTIGUOUS'] # Just to be sure (Common source of silent bugs.)
131	top = time.time()
132	for niter in range(maxiter):
133		if niter%stop_period!=0:
134			primal_step((size_o,),(size_i,),(ϕ,ϕ_ext,σ,ξ))
135			dual_step((size_o,),(size_i,),(σ,ϕ_ext))
136		else: 
137			primal_step_checkstop((size_o,),(size_i,),(ϕ,ϕ_ext,σ,ξ,
138				primal_value,dual_value,stabilized))
139#			print(f"{niter=},primal={primal_value.sum()},dual={dual_value.sum()}")
140			dual_step_checkstop((size_o,),(size_i,),(σ,ϕ_ext,
141				primal_value,dual_value,stabilized))
142#			print(f"primal={primal_value.sum()},dual={dual_value.sum()}")		
143			e_primal =  float(primal_value.sum())
144			e_dual   = -float(dual_value.sum())
145			primal_values.append(e_primal)
146			dual_values.append(e_dual)
147			if E_rtol>0 and e_primal-e_dual<E_rtol*np.abs(e_primal): break
148			if np.all(stabilized): break
149			stabilized.fill(0)
150#		if niter==10: print(f"First {niter} iterations took {time.time()-top} seconds")
151	else: 
152		if E_rtol>0 and verbosity>=1: 
153			print("Exhausted iteration budget without satisfying convergence criterion") 
154	if verbosity>=2:
155		print(f"GPU primal-dual solver completed {niter+1} steps in {time.time()-top} seconds")
156
157	ϕ = ϕ.reshape(multichannel_shape+ϕ.shape[1:])
158	ϕ = fd.block_squeeze(ϕ,shape_s)
159	σ = np.moveaxis(σ,0,1).reshape((vdim,)+multichannel_shape+σ.shape[2:])
160	σ = fd.block_squeeze(σ,shape_v)
161
162	return {'ϕ':ϕ,'σ':σ, #'ϕ_ext':fd.block_squeeze(ϕ_ext,shape_s),
163	'stabilized':stabilized,'niter':niter+1,
164	'stopping_criterion':'stabilized' if np.all(stabilized) else 'gap',
165	'primal_values':np.array(primal_values),'dual_values':np.array(dual_values)}

Numerically solves a relaxed Bellman formulation of optimal transport

  • λ (positive) : the relaxation parameter
  • ξ (array) : the difference of measures ν-μ
  • dx (array) : the grid scales (one per dimension)