agd.ODE.proximal
The proximal operator of a (usually convex) function $f$ is defined as prox_f(x0,τ) := argmin_x |x-x0|^2/2 + τ*f(x) which is equivalent to an implicit time step of size τ for the ODE dx/dt = - grad f(x).
When f is a characteristic function, only taking the values 0 and +infty, prox_f is the projection onto the domain of f, independently of the value of τ.
This file provides implementations of a few proximal operators, and of the ADMM algorithm. A classical reference for proximal operators is : 1.Combettes, P. L. & Pesquet, J.-C. Proximal splitting methods in signal processing. in Fixed-point algorithms for inverse problems in science and engineering 185–212 (Fixed-point algorithms for inverse problems in science and engineering, 2011).
1# Copyright 2022 Jean-Marie Mirebeau, University Paris-Sud, CNRS, University Paris-Saclay 2# Distributed WITHOUT ANY WARRANTY. Licensed under the Apache License, Version 2.0, see http://www.apache.org/licenses/LICENSE-2.0 3 4""" 5The proximal operator of a (usually convex) function $f$ is defined as 6prox_f(x0,τ) := argmin_x |x-x0|^2/2 + τ*f(x) 7which is equivalent to an implicit time step of size τ for the ODE 8dx/dt = - grad f(x). 9 10When f is a characteristic function, only taking the values 0 and +infty, 11prox_f is the projection onto the domain of f, independently of the value of τ. 12 13This file provides implementations of a few proximal operators, 14and of the ADMM algorithm. A classical reference for proximal operators is : 151.Combettes, P. L. & Pesquet, J.-C. Proximal splitting methods in signal processing. 16in Fixed-point algorithms for inverse problems in science and engineering 185–212 17(Fixed-point algorithms for inverse problems in science and engineering, 2011). 18""" 19 20import numpy as np 21from .. import AutomaticDifferentiation as ad 22from scipy import fft 23import time 24 25def make_prox_dual(prox): 26 """ 27 The proximal operator for f^*(x) = sup_y <x,y> - f(x). 28 (Moreau formula) 29 """ 30 return lambda x,τ: x - τ*prox(x/τ,1/τ) 31 32def chambolle_pock(impl_f,impl_gs,τ_f,τ_gs,K,x,y=None, 33 KT=None,E_rtol=1e-3,maxiter=1000, 34 ρ_overrelax=None,cvx_f=0.,callback=None,verbosity=2): 35 """ 36 The chambolle_pock primal-dual proximal algorithm, for solving : 37 inf_x sup_y <K*x,y> + f(x) - g^*(y). 38 This algorithm is equivalent to the (linearized) ADMM, but provides explicit 39 dual points, duality gap, and has a number of useful variants. 40 Inputs : 41 - impl_f : possibilities 42 - a tuple (f,f_star,prox_f) implementing the function f, the 43 Legendre-Fenchel dual f_star, and the proximal operator prox_f. (f and f_star 44 are used to construct the primal and dual energies, which are involved in 45 the stopping criterion). 46 - an implementation of f, in which case E_rtol must provide E_primal_dual 47 48 - impl_gs : similar to impl_f above, but for the function g^*. 49 - K : possibilities 50 - a linear operator, called either as K(x) or K*x. 51 - the string 'Id' 52 53 - x : initial guess for the primal point 54 - y (optional, default : K(x)) : intial guess for the dual point 55 - KT (optional, default : K.T): transposed linear operator. 56 57 - E_rtol (optional) : possibilities 58 - (positive float) : algorithm stops when (E_primal-E_dual) < E_rtol *abs(E_primal), 59 which is checked every 10 iterations. 60 - a tuple (callable, positive float) : the callable implements E_primal_dual(x,y) 61 (returns the pair of primal and dual energies). Same stopping criterion as above. 62 - (callable) : algorithm stops when E_rtol(niter,x,y) is True 63 64 - maxiter (optional) : maximum number of iterations 65 - ρ_overrelax (optional, use value in [0,2], typically 1.8): over-relaxed variant 66 - cvx_f : coercivity constant of f (used in ALG2 variant) 67 - callback (optional) : each iteration begins with callback(niter,x,y) 68 """ 69 if K=='Id': K = KT = lambda x:x 70 if KT is None: KT = K.T 71 if not callable(K): # Handle both call and mutliplication syntax for the linear operator 72 K_=K; KT_=KT 73 K = lambda x : K_*x 74 KT = lambda y : KT_*y 75 if y is None: y = K(x) 76 77 if isinstance(E_rtol,tuple): # Directly provide the primal and dual energies 78 E_primal_dual,E_rtol = E_rtol; E_primal,E_dual=None,None 79 prox_f,prox_gs = impl_f,impl_gs 80 else: # Construct the primal and dual energies 81 f,fs,prox_f = impl_f 82 gs,g,prox_gs = impl_gs 83 def E_primal(x): return f(x)+g(K(x)) 84 def E_dual(y): return -(fs(-KT(y)) + gs(y)) 85 def E_primal_dual(x,y): return E_primal(x),E_dual(y) 86 primal_values=[] 87 dual_values=[] 88 τs_f=[] 89 if callback is None: callback=lambda niter,x,y: None 90 if callable(E_rtol): stopping_criterion=E_rtol 91 else: 92 def stopping_criterion(niter,x,y): 93 if niter%10 != 0: return False # Check every 10 iterations 94 e_p,e_d = E_primal_dual(x,y) 95 primal_values.append(e_p); dual_values.append(e_d) 96 return E_rtol>0 and (e_p-e_d)<E_rtol*abs(e_p) 97 98 top = time.time() 99 θ=1.; x_,xold,yold=None,None,None; 100 for niter in range(maxiter): 101 callback(niter,x,y) 102 if stopping_criterion(niter,x,y): break 103 xold,yold = x,y 104 prox_f_arg = x-τ_f*KT(y) 105# If one introduces a smooth term, then a different stopping criterion is needed. 106#-dfsmooth (optionnal) : gradient of an additional smooth term fsmooth(x) in the objective 107# if dfsmooth is not None: prox_f_arg -= τ_f*dfmooth(x) 108 x = prox_f(prox_f_arg,τ_f) 109 110 x_ = 2*x - xold if cvx_f==0 else x+θ*(x-xold) 111 y = prox_gs(y+τ_gs*K(x_),τ_gs) 112 if ρ_overrelax is not None: 113 x = (1-ρ_overrelax)*xold+ρ_overrelax*x 114 y = (1-ρ_overrelax)*yold+ρ_overrelax*y 115 if cvx_f>0: 116 τs_f.append(τ_f) 117 θ = 1/np.sqrt(1+cvx_f*τ_f) 118 τ_f *= θ 119 τ_gs /= θ 120 else: 121 if E_rtol>0 and verbosity>=1: 122 print("Warning : duality gap not reduced to target within iteration budget") 123 if verbosity>=2: 124 print(f"Primal-dual solver completed {niter+1} steps in {time.time()-top} seconds") 125 126 primal_values = np.array(primal_values); dual_values = np.array(dual_values) 127 return {'x':x,'y':y,'niter':niter+1, 128 'primal_values':primal_values,'dual_values':dual_values, 129# 'rgap':2*(primal_values-dual_values)/(np.abs(primal_values)+np.abs(dual_values)), 130 'ops':{'K':K,'KT':KT,'E_primal':E_primal,'E_dual':E_dual,'E_primal_dual':E_primal_dual}, 131 'tmp':{'x_':x_,'xold':xold,'yold':yold,'θ':θ,'τ_f':τ_f,'τ_gs':τ_gs,'τs_f':τs_f}} 132 133 134# ------------------------- Helpers for proximal operators -------------------------- 135 136def make_prox_multivar(prox): 137 """ 138 Defines Prox((x1,...,xn),τ) := prox(x1,...,xn,τ), 139 or Prox((x1,...,xn),τ) := (prox1(x1,τ),...,proxn(xn,τ)). 140 Result is cast to np.array(dtype=object) 141 """ 142 if isinstance(prox,tuple): 143 return lambda x,τ: np.array( tuple(proxi(xi,τ) for (proxi,xi) in zip(prox,x)), dtype=object) 144 return lambda x,τ: np.array( prox(*x,τ), dtype=object) 145 146 147def impl_inmult(impl,λ): 148 """ 149 Implements the Lengendre-Fenchel dual and proximal operator of F(x) := f(λ*x) 150 Input : 151 - impl : f,fs,prox_f 152 - λ : a scalar 153 Output : 154 - F, Fs, prox_F where F(x) = f(λ x) 155 """ 156 primal0,dual0,prox0 = impl 157 iλ = 1/λ; λ2 = λ**2 158 def primal(x): return primal0(λ*x) 159 def dual(x): return dual0(iλ*x) 160 def prox(x,τ): return iλ*prox0(λ*x,λ2*τ) 161 return primal,dual,prox 162 163def impl_exmult(impl,λ): 164 """ 165 Implements the Lengendre-Fenchel dual and proximal operator of F(x) := λ*f(x) 166 Input : 167 - impl : f,fs,prox_f 168 - λ : a scalar 169 Output : 170 - F, Fs, prox_F where F(x) = λ f(x) 171 """ 172 primal0,dual0,prox0 = impl 173 iλ = 1/λ 174 def primal(x): return λ*primal0(x) 175 def dual(x): return λ*dual0(iλ*x) 176 def prox(x,τ): return prox0(x,λ*τ) 177 return primal,dual,prox 178 179 180def impl_sub(impl,x0): 181 """ 182 Implements the Legendre-Fenchel dual and proximal operator of F(x) := F(x-x0) 183 Input : 184 - impl : f,fs,prox_f 185 - x0 : the shift parameter 186 Output : 187 - F, Fs, prox_F where F(x) = f(x-x0) 188 """ 189 primal0,dual0,prox0 = impl 190 def primal(x): return primal0(x-x0) 191 def dual(x): return np.sum(x*x0)+dual0(x) 192 def prox(x,τ): return x0+prox0(λ*τ,λ2*x) 193 return primal,dual,prox 194 195 196# def make_prox_addlin(prox,w): 197# """The proximal operator for F(x) := f(x) + w.x""" 198# return lambda x,τ: prox(x-τ*w,τ) 199 200# def make_prox_addquad(prox,a): 201# """The proximal operator for F(x) := f(x)+a x^2, where a>=0""" 202# b = 1./(τ*a+1) 203# return lambda x,τ: prox(b*x,b*τ) 204 205# def norm_multivar(arrs): 206# """Euclidean norm of an np.array(dtype=object)""" 207# return np.sqrt(sum(np.sum(arr**2) for arr in arrs))
26def make_prox_dual(prox): 27 """ 28 The proximal operator for f^*(x) = sup_y <x,y> - f(x). 29 (Moreau formula) 30 """ 31 return lambda x,τ: x - τ*prox(x/τ,1/τ)
The proximal operator for f^*(x) = sup_y
33def chambolle_pock(impl_f,impl_gs,τ_f,τ_gs,K,x,y=None, 34 KT=None,E_rtol=1e-3,maxiter=1000, 35 ρ_overrelax=None,cvx_f=0.,callback=None,verbosity=2): 36 """ 37 The chambolle_pock primal-dual proximal algorithm, for solving : 38 inf_x sup_y <K*x,y> + f(x) - g^*(y). 39 This algorithm is equivalent to the (linearized) ADMM, but provides explicit 40 dual points, duality gap, and has a number of useful variants. 41 Inputs : 42 - impl_f : possibilities 43 - a tuple (f,f_star,prox_f) implementing the function f, the 44 Legendre-Fenchel dual f_star, and the proximal operator prox_f. (f and f_star 45 are used to construct the primal and dual energies, which are involved in 46 the stopping criterion). 47 - an implementation of f, in which case E_rtol must provide E_primal_dual 48 49 - impl_gs : similar to impl_f above, but for the function g^*. 50 - K : possibilities 51 - a linear operator, called either as K(x) or K*x. 52 - the string 'Id' 53 54 - x : initial guess for the primal point 55 - y (optional, default : K(x)) : intial guess for the dual point 56 - KT (optional, default : K.T): transposed linear operator. 57 58 - E_rtol (optional) : possibilities 59 - (positive float) : algorithm stops when (E_primal-E_dual) < E_rtol *abs(E_primal), 60 which is checked every 10 iterations. 61 - a tuple (callable, positive float) : the callable implements E_primal_dual(x,y) 62 (returns the pair of primal and dual energies). Same stopping criterion as above. 63 - (callable) : algorithm stops when E_rtol(niter,x,y) is True 64 65 - maxiter (optional) : maximum number of iterations 66 - ρ_overrelax (optional, use value in [0,2], typically 1.8): over-relaxed variant 67 - cvx_f : coercivity constant of f (used in ALG2 variant) 68 - callback (optional) : each iteration begins with callback(niter,x,y) 69 """ 70 if K=='Id': K = KT = lambda x:x 71 if KT is None: KT = K.T 72 if not callable(K): # Handle both call and mutliplication syntax for the linear operator 73 K_=K; KT_=KT 74 K = lambda x : K_*x 75 KT = lambda y : KT_*y 76 if y is None: y = K(x) 77 78 if isinstance(E_rtol,tuple): # Directly provide the primal and dual energies 79 E_primal_dual,E_rtol = E_rtol; E_primal,E_dual=None,None 80 prox_f,prox_gs = impl_f,impl_gs 81 else: # Construct the primal and dual energies 82 f,fs,prox_f = impl_f 83 gs,g,prox_gs = impl_gs 84 def E_primal(x): return f(x)+g(K(x)) 85 def E_dual(y): return -(fs(-KT(y)) + gs(y)) 86 def E_primal_dual(x,y): return E_primal(x),E_dual(y) 87 primal_values=[] 88 dual_values=[] 89 τs_f=[] 90 if callback is None: callback=lambda niter,x,y: None 91 if callable(E_rtol): stopping_criterion=E_rtol 92 else: 93 def stopping_criterion(niter,x,y): 94 if niter%10 != 0: return False # Check every 10 iterations 95 e_p,e_d = E_primal_dual(x,y) 96 primal_values.append(e_p); dual_values.append(e_d) 97 return E_rtol>0 and (e_p-e_d)<E_rtol*abs(e_p) 98 99 top = time.time() 100 θ=1.; x_,xold,yold=None,None,None; 101 for niter in range(maxiter): 102 callback(niter,x,y) 103 if stopping_criterion(niter,x,y): break 104 xold,yold = x,y 105 prox_f_arg = x-τ_f*KT(y) 106# If one introduces a smooth term, then a different stopping criterion is needed. 107#-dfsmooth (optionnal) : gradient of an additional smooth term fsmooth(x) in the objective 108# if dfsmooth is not None: prox_f_arg -= τ_f*dfmooth(x) 109 x = prox_f(prox_f_arg,τ_f) 110 111 x_ = 2*x - xold if cvx_f==0 else x+θ*(x-xold) 112 y = prox_gs(y+τ_gs*K(x_),τ_gs) 113 if ρ_overrelax is not None: 114 x = (1-ρ_overrelax)*xold+ρ_overrelax*x 115 y = (1-ρ_overrelax)*yold+ρ_overrelax*y 116 if cvx_f>0: 117 τs_f.append(τ_f) 118 θ = 1/np.sqrt(1+cvx_f*τ_f) 119 τ_f *= θ 120 τ_gs /= θ 121 else: 122 if E_rtol>0 and verbosity>=1: 123 print("Warning : duality gap not reduced to target within iteration budget") 124 if verbosity>=2: 125 print(f"Primal-dual solver completed {niter+1} steps in {time.time()-top} seconds") 126 127 primal_values = np.array(primal_values); dual_values = np.array(dual_values) 128 return {'x':x,'y':y,'niter':niter+1, 129 'primal_values':primal_values,'dual_values':dual_values, 130# 'rgap':2*(primal_values-dual_values)/(np.abs(primal_values)+np.abs(dual_values)), 131 'ops':{'K':K,'KT':KT,'E_primal':E_primal,'E_dual':E_dual,'E_primal_dual':E_primal_dual}, 132 'tmp':{'x_':x_,'xold':xold,'yold':yold,'θ':θ,'τ_f':τ_f,'τ_gs':τ_gs,'τs_f':τs_f}}
The chambolle_pock primal-dual proximal algorithm, for solving :
inf_x sup_y
- impl_f : possibilities
- a tuple (f,f_star,prox_f) implementing the function f, the Legendre-Fenchel dual f_star, and the proximal operator prox_f. (f and f_star are used to construct the primal and dual energies, which are involved in the stopping criterion).
- an implementation of f, in which case E_rtol must provide E_primal_dual
- impl_gs : similar to impl_f above, but for the function g^*.
K : possibilities
- a linear operator, called either as K(x) or K*x.
- the string 'Id'
x : initial guess for the primal point
- y (optional, default : K(x)) : intial guess for the dual point
KT (optional, default : K.T): transposed linear operator.
E_rtol (optional) : possibilities - (positive float) : algorithm stops when (E_primal-E_dual) < E_rtol *abs(E_primal), which is checked every 10 iterations. - a tuple (callable, positive float) : the callable implements E_primal_dual(x,y) (returns the pair of primal and dual energies). Same stopping criterion as above. - (callable) : algorithm stops when E_rtol(niter,x,y) is True
maxiter (optional) : maximum number of iterations
- ρ_overrelax (optional, use value in [0,2], typically 1.8): over-relaxed variant
- cvx_f : coercivity constant of f (used in ALG2 variant)
- callback (optional) : each iteration begins with callback(niter,x,y)
137def make_prox_multivar(prox): 138 """ 139 Defines Prox((x1,...,xn),τ) := prox(x1,...,xn,τ), 140 or Prox((x1,...,xn),τ) := (prox1(x1,τ),...,proxn(xn,τ)). 141 Result is cast to np.array(dtype=object) 142 """ 143 if isinstance(prox,tuple): 144 return lambda x,τ: np.array( tuple(proxi(xi,τ) for (proxi,xi) in zip(prox,x)), dtype=object) 145 return lambda x,τ: np.array( prox(*x,τ), dtype=object)
Defines Prox((x1,...,xn),τ) := prox(x1,...,xn,τ), or Prox((x1,...,xn),τ) := (prox1(x1,τ),...,proxn(xn,τ)). Result is cast to np.array(dtype=object)
148def impl_inmult(impl,λ): 149 """ 150 Implements the Lengendre-Fenchel dual and proximal operator of F(x) := f(λ*x) 151 Input : 152 - impl : f,fs,prox_f 153 - λ : a scalar 154 Output : 155 - F, Fs, prox_F where F(x) = f(λ x) 156 """ 157 primal0,dual0,prox0 = impl 158 iλ = 1/λ; λ2 = λ**2 159 def primal(x): return primal0(λ*x) 160 def dual(x): return dual0(iλ*x) 161 def prox(x,τ): return iλ*prox0(λ*x,λ2*τ) 162 return primal,dual,prox
Implements the Lengendre-Fenchel dual and proximal operator of F(x) := f(λ*x) Input :
- impl : f,fs,prox_f
- λ : a scalar Output :
- F, Fs, prox_F where F(x) = f(λ x)
164def impl_exmult(impl,λ): 165 """ 166 Implements the Lengendre-Fenchel dual and proximal operator of F(x) := λ*f(x) 167 Input : 168 - impl : f,fs,prox_f 169 - λ : a scalar 170 Output : 171 - F, Fs, prox_F where F(x) = λ f(x) 172 """ 173 primal0,dual0,prox0 = impl 174 iλ = 1/λ 175 def primal(x): return λ*primal0(x) 176 def dual(x): return λ*dual0(iλ*x) 177 def prox(x,τ): return prox0(x,λ*τ) 178 return primal,dual,prox
Implements the Lengendre-Fenchel dual and proximal operator of F(x) := λ*f(x) Input :
- impl : f,fs,prox_f
- λ : a scalar Output :
- F, Fs, prox_F where F(x) = λ f(x)
181def impl_sub(impl,x0): 182 """ 183 Implements the Legendre-Fenchel dual and proximal operator of F(x) := F(x-x0) 184 Input : 185 - impl : f,fs,prox_f 186 - x0 : the shift parameter 187 Output : 188 - F, Fs, prox_F where F(x) = f(x-x0) 189 """ 190 primal0,dual0,prox0 = impl 191 def primal(x): return primal0(x-x0) 192 def dual(x): return np.sum(x*x0)+dual0(x) 193 def prox(x,τ): return x0+prox0(λ*τ,λ2*x) 194 return primal,dual,prox
Implements the Legendre-Fenchel dual and proximal operator of F(x) := F(x-x0) Input :
- impl : f,fs,prox_f
- x0 : the shift parameter Output :
- F, Fs, prox_F where F(x) = f(x-x0)