
This module gathers a few helper functions for plotting data, that are used throughout the illustrative notebooks.

  1# Copyright 2020 Jean-Marie Mirebeau, University Paris-Sud, CNRS, University Paris-Saclay
  2# Distributed WITHOUT ANY WARRANTY. Licensed under the Apache License, Version 2.0, see
  5This module gathers a few helper functions for plotting data, that are used throughout the 
  6illustrative notebooks.
  9from os import path
 10import matplotlib.pyplot as plt
 11from matplotlib import animation
 12import numpy as np
 13from . import LinearParallel as lp
 14import io
 18def SetTitle3D(ax,title):
 19	ax.text2D(0.5,0.95,title,transform=ax.transAxes,horizontalalignment='center')
 21def savefig(fig,fileName,dirName=None,ax=None,**kwargs):
 22	"""Save a figure:
 23	- in a given directory, possibly set in the properties of the function. 
 24	 Silently fails if dirName is None
 25	- with defaulted arguments, possibly set in the properties of the function
 26	"""
 27	# Choose the subplot to be saved 
 28	if ax is not None:
 29		kwargs['bbox_inches'] = ax.get_tightbbox(
 30			fig.canvas.get_renderer()).transformed(fig.dpi_scale_trans.inverted())
 32	# Set arguments to be passed
 33	for key,value in vars(savefig).items():
 34		if key not in kwargs and key!='dirName':
 35			kwargs[key]=value
 37	# Set directory
 38	if dirName is None: 
 39		if savefig.dirName is None: return 
 40		else: dirName=savefig.dirName
 42	# Save figure
 43	if path.isdir(dirName):
 44		fig.savefig(path.join(dirName,fileName),**kwargs) 
 45	else:
 46		print("savefig error: No such directory", dirName)
 47#		raise OSError(2, 'No such directory', dirName)
 49savefig.dirName = None
 50savefig.bbox_inches = 'tight'
 51savefig.pad_inches = 0
 52savefig.dpi = 300
 54def open_local_or_web(func,filepath,local_prefix="../",
 55	# Using data stored on the gh-pages (github pages) branch
 56	web_prefix='',
 57	web_suffix=''):
 58	try: return func(local_prefix+filepath)
 59	except FileNotFoundError:
 60		try : return func(filepath)  # Retry without the local prefix...
 61		except FileNotFoundError:
 62			import urllib
 63			return func(urllib.request.urlopen(web_prefix+filepath+web_suffix))
 66def imread(*args,**kwargs):
 67	"""
 68	Reads the image into a numpy array. Tries to find it locally and on the web.
 69	- *args,**args : passed to open_local_or_web
 70	"""
 71	import PIL
 72	return np.array(open_local_or_web(,*args,**kwargs))
 74def animation_curve(X,Y,**kwargs):
 75	"""Animates a sequence of curves Y[0],Y[1],... with X as horizontal axis"""
 76	fig, ax = plt.subplots(); plt.close()
 77	ax.set_xlim(( X[0], X[-1]))
 78	ax.set_ylim(( np.min(Y), np.max(Y)))
 79	line, = ax.plot([], [])
 80	def func(i,Y): line.set_data(X,Y[i])
 81	kwargs.setdefault('interval',20)
 82	kwargs.setdefault('repeat',False)
 83	return animation.FuncAnimation(fig,func,fargs=(Y,),frames=len(Y),**kwargs)
 85# ---- Vectors fields, metrics ----
 87def quiver(X,Y,U,V,subsampling=tuple(),**kwargs):
 88	"""
 89	Pyplot quiver with additional arg:
 90	- subsampling (tuple or int). Subsample X,Y,U,V	
 91	"""
 92	if np.ndim(subsampling)==0: subsampling = (subsampling,)*2
 93	where = tuple(slice(None,None,s) for s in subsampling)
 94	def f(Z): return Z.__getitem__(where)
 95	return plt.quiver(f(X),f(Y),f(U),f(V),**kwargs)
 97def Tissot(metric,X,=100,subsampling=5,scale=-1):
 98	"""
 99	Display the collection of unit balls of a two dimensional metric, also known as the 
100	Tissot indicatrix.
101	Inputs : 
102	- metric : the metric to display
103	- X : the geometric domain
104	- nθ : number of angular directions
105	- subsampling (integer or pair of integers): only display a subset of the unit balls
106	- scale : scaling factor for the unit balls (if negative, then relative to auto scale)
107	"""
108	if subsampling is not None:
109		if np.ndim(subsampling)==0: subsampling=[subsampling,subsampling]
110		metric.set_interpolation(X)
111		X = X[:,(subsampling[0]//2)::subsampling[0],(subsampling[1]//2)::subsampling[1]]
112		metric =
114	dx = X[:,1,1]-X[:,0,0]
115	θ = np.linspace(0,2*np.pi,)
116	U = np.array([np.cos(θ),np.sin(θ)]) # unit vectors
117	bd = np.array([u[:,None,None] / metric.norm(u) for u in U.T])
119	if scale<0:
120		default_scale = 0.4*min(dx[0]/np.max(bd[:,0]), dx[1]/np.max(bd[:,1]))
121		scale = np.abs(scale)*default_scale
122	bd = (bd*scale + X).reshape((,2,-1))
123	plt.plot(bd[:,0],bd[:,1],color='red')
124	plt.scatter(*X,s=1,color='black')
125	return scale
128# -------------- Array to image conversion ----------
130def imshow_ij(image,**kwargs): 
131	"""Show an image, using Cartesian array coordinates, 
132	as with the option indexing='ij' of np.mesgrid."""
133	return plt.imshow(np.moveaxis(image,0,1),origin='lower',**kwargs)
135def arr2fig(image,xsize=None,**kwargs):
136	"""
137	Create a figure displaying the given image, 
138	and nothing else. Uses Cartesian array coordinates.
139	"""
140	xshape,yshape = image.shape[:2]
141	if xsize is None: xsize = min(6.4,4.8*xshape/yshape)
142	ysize = xsize*yshape/xshape
143	fig = plt.figure(figsize=[xsize,ysize])
144	plt.axis('off')
145	fig.tight_layout(pad=0)
146	imshow_ij(image,**kwargs)
147	return fig
149def fig2arr(fig,shape,noalpha=True):
150	"""
151	Save the figure as an array with the given shape, 
152	which must be proportional to its size. Uses Cartesian array coords.
154	Approximate inverse of arr2fig.
155	"""
156	size = fig.get_size_inches()
157	assert np.allclose(shape[0]*size[1],shape[1]*size[0])
159	io_buf = io.BytesIO()
160	fig.savefig(io_buf, format='raw',dpi=shape[0]/size[0])
162	img_arr = np.reshape(np.frombuffer(io_buf.getvalue(), dtype=np.uint8),
163					newshape=(shape[1],shape[0], -1))
164	io_buf.close()
165	return np.moveaxis(img_arr,0,1)[:,::-1,:(3 if noalpha else 4)]/255
168# ---------- Interactive picking of points in image ----------
170def pick_lines(n=np.inf,broken=True,arrow=False):
171	"""Interactively pick some (broken) lines."""
172	# Set default : [np.round(line).astype(int).tolist() for line in lines]
173	plt.title(f"Pick points, type enter once\n to make a line, twice to terminate") 
176	lines = []
177	while len(lines)<n:
178		pts = np.array(plt.ginput(-1 if broken else 2))
179		if len(pts)==0: break # Empty input means end
180		elif len(pts)==1: continue # Invalid input
181		pts = pts.T
182		lines.append(pts)
183		plt.plot(*pts,color='r')
184		if arrow: plt.arrow(*pts[:,-1],*0.001*(pts[:,-1]-pts[:,-2]),color='r',head_width=10)
186	plt.close()
187	return lines
189def pick_points(n=np.inf):
190	"""Interactively pick some point coordinates."""
191	plt.title(f"Pick {n} point(s)" if n<np.inf else 
192		"Pick points, middle click or type enter to terminate") 
194	pts = np.array(plt.ginput(-1 if n==np.inf else n)).T
195	plt.close()
196	return pts
198def input_default(prompt='',default=''):
199	if default!='': prompt += f" (default={default}) "
200	return input(prompt) or default
203# ----- Convex body 3D display -----
205def plotly_primal_dual_bodies(V):
206	"""	
207	Output : facet indices, facet measures. 
208	"""
209	import scipy.spatial
210	import plotly.graph_objects as go
212	# Use a convex hull routine to triangulate the primal convex body
213	primal_body = scipy.spatial.ConvexHull(V.T) 
215	Vsize = V.shape[1]
216	if primal_body.vertices.size!=Vsize:  # Then primal_body.vertices == np.arange(Xsize)
217		raise ValueError("Non-convex primal body ! See",set(range(Vsize))-set(primal_body.vertices))
219	# Use counter-clockwise orientation for all triangles
220	S = primal_body.simplices.T
221	N = primal_body.neighbors.T
222	cw = np.sign(lp.det(V[:,S]))==-1 
223	S[1,cw],S[2,cw] = S[2,cw],S[1,cw] 
224	N[1,cw],N[2,cw] = N[2,cw],N[1,cw] 
226	# --- Plotly primal mesh object
227	x,y,z = V; i,j,k = S;
228	primal_mesh = go.Mesh3d(x=x,y=y,z=z, i=i,j=j,k=k)
229	# ---
231	# --- Plotly primal edges object
232	V0=V[:,S]; V1=V[:,np.roll(S,1,axis=0)];
233	xe,ye,ze = np.moveaxis([V0,V1,np.full_like(V0,np.nan)],0,1)
234	primal_edges = go.Scatter3d(x=xe.T.flat,y=ye.T.flat,z=ze.T.flat,mode='lines')
235	# ---
237	Sg = lp.solve_AV(lp.transpose(V[:,S]),np.ones(S.shape)) # The vertices of the dual convex body
238	Ng = Sg[:,N] # Gradient of the neighbor cells
240	# --- Plotly dual edges object
241	xe,ye,ze = np.moveaxis(np.array([Sg[:,None]+0*Ng,Ng,np.full_like(Ng,np.nan)]),0,1)
242	dual_edges = go.Scatter3d(x=xe.T.flat,y=ye.T.flat,z=ze.T.flat,mode='lines')
243	# ---
245	# Choose a reference simplex for each vertex, hence a reference point for each dual facet
246	XS = np.full(Vsize,-1,S.dtype)
247	XS[S] = np.arange(S.shape[1],dtype=S.dtype)[None]
249	# --- Plotly dual mesh object, using triangulated faces
250	x,y,z = Sg
251	i = XS[S]; j=np.roll(N,-1,axis=0); k=np.tile(np.arange(Sg.shape[1]),(3,1))
252	dual_mesh=go.Mesh3d(x=x,y=y,z=z, i=i.flat,j=j.flat,k=k.flat)
253	# ---
255	return (primal_mesh,primal_edges),(dual_mesh,dual_edges)
