agd.ODE.proximal

The proximal operator of a (usually convex) function $f$ is defined as prox_f(x0,τ) := argmin_x |x-x0|^2/2 + τ*f(x) which is equivalent to an implicit time step of size τ for the ODE dx/dt = - grad f(x).

When f is a characteristic function, only taking the values 0 and +infty, prox_f is the projection onto the domain of f, independently of the value of τ.

This file provides implementations of a few proximal operators, and of the ADMM algorithm. A classical reference for proximal operators is : 1.Combettes, P. L. & Pesquet, J.-C. Proximal splitting methods in signal processing. in Fixed-point algorithms for inverse problems in science and engineering 185–212 (Fixed-point algorithms for inverse problems in science and engineering, 2011).

  1# Copyright 2022 Jean-Marie Mirebeau, University Paris-Sud, CNRS, University Paris-Saclay
  2# Distributed WITHOUT ANY WARRANTY. Licensed under the Apache License, Version 2.0, see http://www.apache.org/licenses/LICENSE-2.0
  3
  4"""
  5The proximal operator of a (usually convex) function $f$ is defined as 
  6prox_f(x0,τ) := argmin_x |x-x0|^2/2 + τ*f(x)
  7which is equivalent to an implicit time step of size τ for the ODE
  8dx/dt = - grad f(x).
  9
 10When f is a characteristic function, only taking the values 0 and +infty, 
 11prox_f is the projection onto the domain of f, independently of the value of τ.
 12
 13This file provides implementations of a few proximal operators, 
 14and of the ADMM algorithm. A classical reference for proximal operators is : 
 151.Combettes, P. L. & Pesquet, J.-C. Proximal splitting methods in signal processing. 
 16in Fixed-point algorithms for inverse problems in science and engineering 185–212 
 17(Fixed-point algorithms for inverse problems in science and engineering, 2011).
 18"""
 19
 20import numpy as np
 21from .. import AutomaticDifferentiation as ad
 22from scipy import fft
 23import time
 24
 25def make_prox_dual(prox):
 26	"""
 27	The proximal operator for f^*(x) = sup_y <x,y> - f(x).
 28	(Moreau formula)
 29	"""
 30	return lambda x,τ: x - τ*prox(x/τ,1/τ)
 31
 32def chambolle_pock(impl_f,impl_gs,τ_f,τ_gs,K,x,y=None,
 33	KT=None,E_rtol=1e-3,maxiter=1000,
 34	ρ_overrelax=None,cvx_f=0.,callback=None,verbosity=2):
 35	"""
 36	The chambolle_pock primal-dual proximal algorithm, for solving : 
 37	inf_x sup_y <K*x,y> + f(x) - g^*(y).
 38	This algorithm is equivalent to the (linearized) ADMM, but provides explicit 
 39	dual points, duality gap, and has a number of useful variants.
 40	Inputs : 
 41	- impl_f : possibilities
 42	  - a tuple (f,f_star,prox_f) implementing the function f, the 
 43	Legendre-Fenchel dual f_star, and the proximal operator prox_f. (f and f_star 
 44	are used to construct the primal and dual energies, which are involved in 
 45	the stopping criterion).
 46	  - an implementation of f, in which case E_rtol must provide E_primal_dual
 47
 48	- impl_gs : similar to impl_f above, but for the function g^*.
 49	- K : possibilities
 50	  - a linear operator, called either as K(x) or K*x. 
 51	  - the string 'Id'
 52
 53	- x : initial guess for the primal point
 54	- y (optional, default : K(x)) : intial guess for the dual point
 55	- KT (optional, default : K.T): transposed linear operator. 
 56
 57	- E_rtol (optional) : possibilities
 58		- (positive float) : algorithm stops when (E_primal-E_dual) < E_rtol *abs(E_primal),
 59		which is checked every 10 iterations.
 60		- a tuple (callable, positive float) : the callable implements E_primal_dual(x,y) 
 61		(returns the pair of primal and dual energies). Same stopping criterion as above.
 62		- (callable) : algorithm stops when E_rtol(niter,x,y) is True
 63
 64	- maxiter (optional) : maximum number of iterations
 65	- ρ_overrelax (optional, use value in [0,2], typically 1.8): over-relaxed variant 
 66	- cvx_f : coercivity constant of f (used in ALG2 variant)
 67	- callback (optional) : each iteration begins with callback(niter,x,y)
 68	"""
 69	if K=='Id': K = KT = lambda x:x 
 70	if KT is None: KT = K.T
 71	if not callable(K): # Handle both call and mutliplication syntax for the linear operator
 72		K_=K; KT_=KT
 73		K  = lambda x : K_*x
 74		KT = lambda y : KT_*y
 75	if y is None: y = K(x)
 76
 77	if isinstance(E_rtol,tuple): # Directly provide the primal and dual energies
 78		E_primal_dual,E_rtol = E_rtol; E_primal,E_dual=None,None
 79		prox_f,prox_gs = impl_f,impl_gs
 80	else: # Construct the primal and dual energies
 81		f,fs,prox_f  = impl_f 
 82		gs,g,prox_gs = impl_gs
 83		def E_primal(x): return f(x)+g(K(x))
 84		def E_dual(y):   return -(fs(-KT(y)) + gs(y))
 85		def E_primal_dual(x,y): return E_primal(x),E_dual(y)
 86	primal_values=[]
 87	dual_values=[]
 88	τs_f=[]
 89	if callback is None: callback=lambda niter,x,y: None
 90	if callable(E_rtol): stopping_criterion=E_rtol
 91	else:
 92		def stopping_criterion(niter,x,y):
 93			if niter%10 != 0: return False # Check every 10 iterations
 94			e_p,e_d = E_primal_dual(x,y)
 95			primal_values.append(e_p); dual_values.append(e_d)
 96			return E_rtol>0 and (e_p-e_d)<E_rtol*abs(e_p)
 97
 98	top = time.time()
 99	θ=1.; x_,xold,yold=None,None,None;
100	for niter in range(maxiter):
101		callback(niter,x,y)
102		if stopping_criterion(niter,x,y): break
103		xold,yold = x,y
104		prox_f_arg = x-τ_f*KT(y)
105# 		If one introduces a smooth term, then a different stopping criterion is needed.
106#-dfsmooth (optionnal) : gradient of an additional smooth term fsmooth(x) in the objective
107#		if dfsmooth is not None: prox_f_arg -= τ_f*dfmooth(x) 
108		x = prox_f(prox_f_arg,τ_f)
109
110		x_ = 2*x - xold if cvx_f==0 else x+θ*(x-xold)
111		y = prox_gs(y+τ_gs*K(x_),τ_gs)
112		if ρ_overrelax is not None: 
113			x = (1-ρ_overrelax)*xold+ρ_overrelax*x
114			y = (1-ρ_overrelax)*yold+ρ_overrelax*y
115		if cvx_f>0:
116			τs_f.append(τ_f)
117			θ = 1/np.sqrt(1+cvx_f*τ_f)
118			τ_f  *= θ
119			τ_gs /= θ
120	else:
121		if E_rtol>0 and verbosity>=1: 
122			print("Warning : duality gap not reduced to target within iteration budget")
123	if verbosity>=2: 
124		print(f"Primal-dual solver completed {niter+1} steps in {time.time()-top} seconds")
125
126	primal_values = np.array(primal_values); dual_values = np.array(dual_values)
127	return {'x':x,'y':y,'niter':niter+1,
128	'primal_values':primal_values,'dual_values':dual_values,
129#	'rgap':2*(primal_values-dual_values)/(np.abs(primal_values)+np.abs(dual_values)),
130	'ops':{'K':K,'KT':KT,'E_primal':E_primal,'E_dual':E_dual,'E_primal_dual':E_primal_dual},
131	'tmp':{'x_':x_,'xold':xold,'yold':yold,'θ':θ,'τ_f':τ_f,'τ_gs':τ_gs,'τs_f':τs_f}} 
132	
133
134# ------------------------- Helpers for proximal operators --------------------------
135
136def make_prox_multivar(prox):
137	"""
138	Defines Prox((x1,...,xn),τ) := prox(x1,...,xn,τ), 
139	or  Prox((x1,...,xn),τ) := (prox1(x1,τ),...,proxn(xn,τ)).
140	Result is cast to np.array(dtype=object)
141	"""
142	if isinstance(prox,tuple): 
143		return lambda x,τ: np.array( tuple(proxi(xi,τ) for (proxi,xi) in zip(prox,x)), dtype=object)
144	return lambda x,τ: np.array( prox(*x,τ), dtype=object)	
145
146
147def impl_inmult(impl,λ):
148	"""
149	Implements the Lengendre-Fenchel dual and proximal operator of F(x) := f(λ*x)
150	Input : 
151	- impl : f,fs,prox_f
152	- λ : a scalar
153	Output : 
154	- F, Fs, prox_F where F(x) = f(λ x)
155	"""
156	primal0,dual0,prox0 = impl
157	 = 1/λ; λ2 = λ**2
158	def primal(x): return primal0(λ*x)
159	def dual(x):   return dual0(*x)
160	def prox(x,τ): return *prox0(λ*x,λ2*τ)
161	return primal,dual,prox
162
163def impl_exmult(impl,λ):
164	"""
165	Implements the Lengendre-Fenchel dual and proximal operator of F(x) := λ*f(x)
166	Input : 
167	- impl : f,fs,prox_f
168	- λ : a scalar
169	Output : 
170	- F, Fs, prox_F where F(x) = λ f(x)
171	"""
172	primal0,dual0,prox0 = impl
173	 = 1/λ
174	def primal(x): return λ*primal0(x)
175	def dual(x):   return λ*dual0(*x)
176	def prox(x,τ): return prox0(x,λ*τ)
177	return primal,dual,prox
178
179
180def impl_sub(impl,x0):
181	"""
182	Implements the Legendre-Fenchel dual and proximal operator of F(x) := F(x-x0)
183	Input : 
184	- impl : f,fs,prox_f
185	- x0 : the shift parameter
186	Output : 
187	- F, Fs, prox_F where F(x) = f(x-x0)
188	"""
189	primal0,dual0,prox0 = impl
190	def primal(x): return primal0(x-x0)
191	def dual(x):   return np.sum(x*x0)+dual0(x)
192	def prox(x,τ): return x0+prox0(λ*τ,λ2*x)
193	return primal,dual,prox
194
195
196# def make_prox_addlin(prox,w):
197# 	"""The proximal operator for F(x) := f(x) + w.x"""
198# 	return lambda x,τ: prox(x-τ*w,τ)
199
200# def make_prox_addquad(prox,a):
201# 	"""The proximal operator for F(x) := f(x)+a x^2, where a>=0"""
202# 	b = 1./(τ*a+1)
203# 	return lambda x,τ: prox(b*x,b*τ)
204
205# def norm_multivar(arrs): 
206# 	"""Euclidean norm of an np.array(dtype=object)"""
207# 	return np.sqrt(sum(np.sum(arr**2) for arr in arrs))
def make_prox_dual(prox):
26def make_prox_dual(prox):
27	"""
28	The proximal operator for f^*(x) = sup_y <x,y> - f(x).
29	(Moreau formula)
30	"""
31	return lambda x,τ: x - τ*prox(x/τ,1/τ)

The proximal operator for f^*(x) = sup_y - f(x). (Moreau formula)

def chambolle_pock( impl_f, impl_gs, τ_f, τ_gs, K, x, y=None, KT=None, E_rtol=0.001, maxiter=1000, ρ_overrelax=None, cvx_f=0.0, callback=None, verbosity=2):
 33def chambolle_pock(impl_f,impl_gs,τ_f,τ_gs,K,x,y=None,
 34	KT=None,E_rtol=1e-3,maxiter=1000,
 35	ρ_overrelax=None,cvx_f=0.,callback=None,verbosity=2):
 36	"""
 37	The chambolle_pock primal-dual proximal algorithm, for solving : 
 38	inf_x sup_y <K*x,y> + f(x) - g^*(y).
 39	This algorithm is equivalent to the (linearized) ADMM, but provides explicit 
 40	dual points, duality gap, and has a number of useful variants.
 41	Inputs : 
 42	- impl_f : possibilities
 43	  - a tuple (f,f_star,prox_f) implementing the function f, the 
 44	Legendre-Fenchel dual f_star, and the proximal operator prox_f. (f and f_star 
 45	are used to construct the primal and dual energies, which are involved in 
 46	the stopping criterion).
 47	  - an implementation of f, in which case E_rtol must provide E_primal_dual
 48
 49	- impl_gs : similar to impl_f above, but for the function g^*.
 50	- K : possibilities
 51	  - a linear operator, called either as K(x) or K*x. 
 52	  - the string 'Id'
 53
 54	- x : initial guess for the primal point
 55	- y (optional, default : K(x)) : intial guess for the dual point
 56	- KT (optional, default : K.T): transposed linear operator. 
 57
 58	- E_rtol (optional) : possibilities
 59		- (positive float) : algorithm stops when (E_primal-E_dual) < E_rtol *abs(E_primal),
 60		which is checked every 10 iterations.
 61		- a tuple (callable, positive float) : the callable implements E_primal_dual(x,y) 
 62		(returns the pair of primal and dual energies). Same stopping criterion as above.
 63		- (callable) : algorithm stops when E_rtol(niter,x,y) is True
 64
 65	- maxiter (optional) : maximum number of iterations
 66	- ρ_overrelax (optional, use value in [0,2], typically 1.8): over-relaxed variant 
 67	- cvx_f : coercivity constant of f (used in ALG2 variant)
 68	- callback (optional) : each iteration begins with callback(niter,x,y)
 69	"""
 70	if K=='Id': K = KT = lambda x:x 
 71	if KT is None: KT = K.T
 72	if not callable(K): # Handle both call and mutliplication syntax for the linear operator
 73		K_=K; KT_=KT
 74		K  = lambda x : K_*x
 75		KT = lambda y : KT_*y
 76	if y is None: y = K(x)
 77
 78	if isinstance(E_rtol,tuple): # Directly provide the primal and dual energies
 79		E_primal_dual,E_rtol = E_rtol; E_primal,E_dual=None,None
 80		prox_f,prox_gs = impl_f,impl_gs
 81	else: # Construct the primal and dual energies
 82		f,fs,prox_f  = impl_f 
 83		gs,g,prox_gs = impl_gs
 84		def E_primal(x): return f(x)+g(K(x))
 85		def E_dual(y):   return -(fs(-KT(y)) + gs(y))
 86		def E_primal_dual(x,y): return E_primal(x),E_dual(y)
 87	primal_values=[]
 88	dual_values=[]
 89	τs_f=[]
 90	if callback is None: callback=lambda niter,x,y: None
 91	if callable(E_rtol): stopping_criterion=E_rtol
 92	else:
 93		def stopping_criterion(niter,x,y):
 94			if niter%10 != 0: return False # Check every 10 iterations
 95			e_p,e_d = E_primal_dual(x,y)
 96			primal_values.append(e_p); dual_values.append(e_d)
 97			return E_rtol>0 and (e_p-e_d)<E_rtol*abs(e_p)
 98
 99	top = time.time()
100	θ=1.; x_,xold,yold=None,None,None;
101	for niter in range(maxiter):
102		callback(niter,x,y)
103		if stopping_criterion(niter,x,y): break
104		xold,yold = x,y
105		prox_f_arg = x-τ_f*KT(y)
106# 		If one introduces a smooth term, then a different stopping criterion is needed.
107#-dfsmooth (optionnal) : gradient of an additional smooth term fsmooth(x) in the objective
108#		if dfsmooth is not None: prox_f_arg -= τ_f*dfmooth(x) 
109		x = prox_f(prox_f_arg,τ_f)
110
111		x_ = 2*x - xold if cvx_f==0 else x+θ*(x-xold)
112		y = prox_gs(y+τ_gs*K(x_),τ_gs)
113		if ρ_overrelax is not None: 
114			x = (1-ρ_overrelax)*xold+ρ_overrelax*x
115			y = (1-ρ_overrelax)*yold+ρ_overrelax*y
116		if cvx_f>0:
117			τs_f.append(τ_f)
118			θ = 1/np.sqrt(1+cvx_f*τ_f)
119			τ_f  *= θ
120			τ_gs /= θ
121	else:
122		if E_rtol>0 and verbosity>=1: 
123			print("Warning : duality gap not reduced to target within iteration budget")
124	if verbosity>=2: 
125		print(f"Primal-dual solver completed {niter+1} steps in {time.time()-top} seconds")
126
127	primal_values = np.array(primal_values); dual_values = np.array(dual_values)
128	return {'x':x,'y':y,'niter':niter+1,
129	'primal_values':primal_values,'dual_values':dual_values,
130#	'rgap':2*(primal_values-dual_values)/(np.abs(primal_values)+np.abs(dual_values)),
131	'ops':{'K':K,'KT':KT,'E_primal':E_primal,'E_dual':E_dual,'E_primal_dual':E_primal_dual},
132	'tmp':{'x_':x_,'xold':xold,'yold':yold,'θ':θ,'τ_f':τ_f,'τ_gs':τ_gs,'τs_f':τs_f}} 

The chambolle_pock primal-dual proximal algorithm, for solving : inf_x sup_y + f(x) - g^*(y). This algorithm is equivalent to the (linearized) ADMM, but provides explicit dual points, duality gap, and has a number of useful variants. Inputs :

  • impl_f : possibilities
    • a tuple (f,f_star,prox_f) implementing the function f, the Legendre-Fenchel dual f_star, and the proximal operator prox_f. (f and f_star are used to construct the primal and dual energies, which are involved in the stopping criterion).
    • an implementation of f, in which case E_rtol must provide E_primal_dual
  • impl_gs : similar to impl_f above, but for the function g^*.
  • K : possibilities

    • a linear operator, called either as K(x) or K*x.
    • the string 'Id'
  • x : initial guess for the primal point

  • y (optional, default : K(x)) : intial guess for the dual point
  • KT (optional, default : K.T): transposed linear operator.

  • E_rtol (optional) : possibilities - (positive float) : algorithm stops when (E_primal-E_dual) < E_rtol *abs(E_primal), which is checked every 10 iterations. - a tuple (callable, positive float) : the callable implements E_primal_dual(x,y) (returns the pair of primal and dual energies). Same stopping criterion as above. - (callable) : algorithm stops when E_rtol(niter,x,y) is True

  • maxiter (optional) : maximum number of iterations

  • ρ_overrelax (optional, use value in [0,2], typically 1.8): over-relaxed variant
  • cvx_f : coercivity constant of f (used in ALG2 variant)
  • callback (optional) : each iteration begins with callback(niter,x,y)
def make_prox_multivar(prox):
137def make_prox_multivar(prox):
138	"""
139	Defines Prox((x1,...,xn),τ) := prox(x1,...,xn,τ), 
140	or  Prox((x1,...,xn),τ) := (prox1(x1,τ),...,proxn(xn,τ)).
141	Result is cast to np.array(dtype=object)
142	"""
143	if isinstance(prox,tuple): 
144		return lambda x,τ: np.array( tuple(proxi(xi,τ) for (proxi,xi) in zip(prox,x)), dtype=object)
145	return lambda x,τ: np.array( prox(*x,τ), dtype=object)	

Defines Prox((x1,...,xn),τ) := prox(x1,...,xn,τ), or Prox((x1,...,xn),τ) := (prox1(x1,τ),...,proxn(xn,τ)). Result is cast to np.array(dtype=object)

def impl_inmult(impl, λ):
148def impl_inmult(impl,λ):
149	"""
150	Implements the Lengendre-Fenchel dual and proximal operator of F(x) := f(λ*x)
151	Input : 
152	- impl : f,fs,prox_f
153	- λ : a scalar
154	Output : 
155	- F, Fs, prox_F where F(x) = f(λ x)
156	"""
157	primal0,dual0,prox0 = impl
158	 = 1/λ; λ2 = λ**2
159	def primal(x): return primal0(λ*x)
160	def dual(x):   return dual0(*x)
161	def prox(x,τ): return *prox0(λ*x,λ2*τ)
162	return primal,dual,prox

Implements the Lengendre-Fenchel dual and proximal operator of F(x) := f(λ*x) Input :

  • impl : f,fs,prox_f
  • λ : a scalar Output :
  • F, Fs, prox_F where F(x) = f(λ x)
def impl_exmult(impl, λ):
164def impl_exmult(impl,λ):
165	"""
166	Implements the Lengendre-Fenchel dual and proximal operator of F(x) := λ*f(x)
167	Input : 
168	- impl : f,fs,prox_f
169	- λ : a scalar
170	Output : 
171	- F, Fs, prox_F where F(x) = λ f(x)
172	"""
173	primal0,dual0,prox0 = impl
174	 = 1/λ
175	def primal(x): return λ*primal0(x)
176	def dual(x):   return λ*dual0(*x)
177	def prox(x,τ): return prox0(x,λ*τ)
178	return primal,dual,prox

Implements the Lengendre-Fenchel dual and proximal operator of F(x) := λ*f(x) Input :

  • impl : f,fs,prox_f
  • λ : a scalar Output :
  • F, Fs, prox_F where F(x) = λ f(x)
def impl_sub(impl, x0):
181def impl_sub(impl,x0):
182	"""
183	Implements the Legendre-Fenchel dual and proximal operator of F(x) := F(x-x0)
184	Input : 
185	- impl : f,fs,prox_f
186	- x0 : the shift parameter
187	Output : 
188	- F, Fs, prox_F where F(x) = f(x-x0)
189	"""
190	primal0,dual0,prox0 = impl
191	def primal(x): return primal0(x-x0)
192	def dual(x):   return np.sum(x*x0)+dual0(x)
193	def prox(x,τ): return x0+prox0(λ*τ,λ2*x)
194	return primal,dual,prox

Implements the Legendre-Fenchel dual and proximal operator of F(x) := F(x-x0) Input :

  • impl : f,fs,prox_f
  • x0 : the shift parameter Output :
  • F, Fs, prox_F where F(x) = f(x-x0)