agd.Plotting

This module gathers a few helper functions for plotting data, that are used throughout the illustrative notebooks.

  1# Copyright 2020 Jean-Marie Mirebeau, University Paris-Sud, CNRS, University Paris-Saclay
  2# Distributed WITHOUT ANY WARRANTY. Licensed under the Apache License, Version 2.0, see http://www.apache.org/licenses/LICENSE-2.0
  3
  4"""
  5This module gathers a few helper functions for plotting data, that are used throughout the 
  6illustrative notebooks.
  7"""
  8
  9from os import path
 10import matplotlib.pyplot as plt
 11from matplotlib import animation
 12import numpy as np
 13from . import LinearParallel as lp
 14import io
 15
 16
 17
 18def SetTitle3D(ax,title):
 19	ax.text2D(0.5,0.95,title,transform=ax.transAxes,horizontalalignment='center')
 20
 21def savefig(fig,fileName,dirName=None,ax=None,**kwargs):
 22	"""Save a figure:
 23	- in a given directory, possibly set in the properties of the function. 
 24	 Silently fails if dirName is None
 25	- with defaulted arguments, possibly set in the properties of the function
 26	"""
 27	# Choose the subplot to be saved 
 28	if ax is not None:
 29		kwargs['bbox_inches'] = ax.get_tightbbox(
 30			fig.canvas.get_renderer()).transformed(fig.dpi_scale_trans.inverted())
 31
 32	# Set arguments to be passed
 33	for key,value in vars(savefig).items():
 34		if key not in kwargs and key!='dirName':
 35			kwargs[key]=value
 36
 37	# Set directory
 38	if dirName is None: 
 39		if savefig.dirName is None: return 
 40		else: dirName=savefig.dirName
 41
 42	# Save figure
 43	if path.isdir(dirName):
 44		fig.savefig(path.join(dirName,fileName),**kwargs) 
 45	else:
 46		print("savefig error: No such directory", dirName)
 47#		raise OSError(2, 'No such directory', dirName)
 48
 49savefig.dirName = None
 50savefig.bbox_inches = 'tight'
 51savefig.pad_inches = 0
 52savefig.dpi = 300
 53
 54def open_local_or_web(func,filepath,local_prefix="../",
 55	# Using data stored on the gh-pages (github pages) branch
 56	web_prefix='https://mirebeau.github.io/AdaptiveGridDiscretizations',
 57	web_suffix=''):
 58	try: return func(local_prefix+filepath)
 59	except FileNotFoundError:
 60		try : return func(filepath)  # Retry without the local prefix...
 61		except FileNotFoundError:
 62			import urllib
 63			return func(urllib.request.urlopen(web_prefix+filepath+web_suffix))
 64
 65
 66def imread(*args,**kwargs):
 67	"""
 68	Reads the image into a numpy array. Tries to find it locally and on the web.
 69	- *args,**args : passed to open_local_or_web
 70	"""
 71	import PIL
 72	return np.array(open_local_or_web(PIL.Image.open,*args,**kwargs))
 73
 74def animation_curve(X,Y,**kwargs):
 75	"""Animates a sequence of curves Y[0],Y[1],... with X as horizontal axis"""
 76	fig, ax = plt.subplots(); plt.close()
 77	ax.set_xlim(( X[0], X[-1]))
 78	ax.set_ylim(( np.min(Y), np.max(Y)))
 79	line, = ax.plot([], [])
 80	def func(i,Y): line.set_data(X,Y[i])
 81	kwargs.setdefault('interval',20)
 82	kwargs.setdefault('repeat',False)
 83	return animation.FuncAnimation(fig,func,fargs=(Y,),frames=len(Y),**kwargs)
 84
 85# ---- Vectors fields, metrics ----
 86
 87def quiver(X,Y,U,V,subsampling=tuple(),**kwargs):
 88	"""
 89	Pyplot quiver with additional arg:
 90	- subsampling (tuple or int). Subsample X,Y,U,V	
 91	"""
 92	if np.ndim(subsampling)==0: subsampling = (subsampling,)*2
 93	where = tuple(slice(None,None,s) for s in subsampling)
 94	def f(Z): return Z.__getitem__(where)
 95	return plt.quiver(f(X),f(Y),f(U),f(V),**kwargs)
 96
 97def Tissot(metric,X,=100,subsampling=5,scale=-1):
 98	"""
 99	Display the collection of unit balls of a two dimensional metric, also known as the 
100	Tissot indicatrix.
101	Inputs : 
102	- metric : the metric to display
103	- X : the geometric domain
104	- nθ : number of angular directions
105	- subsampling (integer or pair of integers): only display a subset of the unit balls
106	- scale : scaling factor for the unit balls (if negative, then relative to auto scale)
107	"""
108	if subsampling is not None:
109		if np.ndim(subsampling)==0: subsampling=[subsampling,subsampling]
110		metric.set_interpolation(X)
111		X = X[:,(subsampling[0]//2)::subsampling[0],(subsampling[1]//2)::subsampling[1]]
112		metric = metric.at(X)
113
114	dx = X[:,1,1]-X[:,0,0]
115	θ = np.linspace(0,2*np.pi,)
116	U = np.array([np.cos(θ),np.sin(θ)]) # unit vectors
117	bd = np.array([u[:,None,None] / metric.norm(u) for u in U.T])
118	
119	if scale<0:
120		default_scale = 0.4*min(dx[0]/np.max(bd[:,0]), dx[1]/np.max(bd[:,1]))
121		scale = np.abs(scale)*default_scale
122	bd = (bd*scale + X).reshape((,2,-1))
123	plt.plot(bd[:,0],bd[:,1],color='red')
124	plt.scatter(*X,s=1,color='black')
125	return scale
126
127
128# -------------- Array to image conversion ----------
129
130def imshow_ij(image,**kwargs): 
131	"""Show an image, using Cartesian array coordinates, 
132	as with the option indexing='ij' of np.mesgrid."""
133	return plt.imshow(np.moveaxis(image,0,1),origin='lower',**kwargs)
134	
135def arr2fig(image,xsize=None,**kwargs):
136	"""
137	Create a figure displaying the given image, 
138	and nothing else. Uses Cartesian array coordinates.
139	"""
140	xshape,yshape = image.shape[:2]
141	if xsize is None: xsize = min(6.4,4.8*xshape/yshape)
142	ysize = xsize*yshape/xshape
143	fig = plt.figure(figsize=[xsize,ysize])
144	plt.axis('off')
145	fig.tight_layout(pad=0)
146	imshow_ij(image,**kwargs)
147	return fig
148
149def fig2arr(fig,shape,noalpha=True):
150	"""
151	Save the figure as an array with the given shape, 
152	which must be proportional to its size. Uses Cartesian array coords.
153
154	Approximate inverse of arr2fig.
155	"""
156	size = fig.get_size_inches()
157	assert np.allclose(shape[0]*size[1],shape[1]*size[0])
158		
159	io_buf = io.BytesIO()
160	fig.savefig(io_buf, format='raw',dpi=shape[0]/size[0])
161	io_buf.seek(0)
162	img_arr = np.reshape(np.frombuffer(io_buf.getvalue(), dtype=np.uint8),
163					newshape=(shape[1],shape[0], -1))
164	io_buf.close()
165	return np.moveaxis(img_arr,0,1)[:,::-1,:(3 if noalpha else 4)]/255
166
167
168# ---------- Interactive picking of points in image ----------
169
170def pick_lines(n=np.inf,broken=True,arrow=False):
171	"""Interactively pick some (broken) lines."""
172	# Set default : [np.round(line).astype(int).tolist() for line in lines]
173	plt.title(f"Pick points, type enter once\n to make a line, twice to terminate") 
174	plt.show()
175	
176	lines = []
177	while len(lines)<n:
178		pts = np.array(plt.ginput(-1 if broken else 2))
179		if len(pts)==0: break # Empty input means end
180		elif len(pts)==1: continue # Invalid input
181		pts = pts.T
182		lines.append(pts)
183		plt.plot(*pts,color='r')
184		if arrow: plt.arrow(*pts[:,-1],*0.001*(pts[:,-1]-pts[:,-2]),color='r',head_width=10)
185	
186	plt.close()
187	return lines
188
189def pick_points(n=np.inf):
190	"""Interactively pick some point coordinates."""
191	plt.title(f"Pick {n} point(s)" if n<np.inf else 
192		"Pick points, middle click or type enter to terminate") 
193	plt.show()
194	pts = np.array(plt.ginput(-1 if n==np.inf else n)).T
195	plt.close()
196	return pts
197
198def input_default(prompt='',default=''):
199	if default!='': prompt += f" (default={default}) "
200	return input(prompt) or default
201
202
203# ----- Convex body 3D display -----
204
205def plotly_primal_dual_bodies(V):
206	"""	
207	Output : facet indices, facet measures. 
208	"""
209	import scipy.spatial
210	import plotly.graph_objects as go
211
212	# Use a convex hull routine to triangulate the primal convex body
213	primal_body = scipy.spatial.ConvexHull(V.T) 
214
215	Vsize = V.shape[1]
216	if primal_body.vertices.size!=Vsize:  # Then primal_body.vertices == np.arange(Xsize)
217		raise ValueError("Non-convex primal body ! See",set(range(Vsize))-set(primal_body.vertices))
218
219	# Use counter-clockwise orientation for all triangles
220	S = primal_body.simplices.T
221	N = primal_body.neighbors.T
222	cw = np.sign(lp.det(V[:,S]))==-1 
223	S[1,cw],S[2,cw] = S[2,cw],S[1,cw] 
224	N[1,cw],N[2,cw] = N[2,cw],N[1,cw] 
225
226	# --- Plotly primal mesh object
227	x,y,z = V; i,j,k = S;
228	primal_mesh = go.Mesh3d(x=x,y=y,z=z, i=i,j=j,k=k)
229	# ---
230
231	# --- Plotly primal edges object
232	V0=V[:,S]; V1=V[:,np.roll(S,1,axis=0)];
233	xe,ye,ze = np.moveaxis([V0,V1,np.full_like(V0,np.nan)],0,1)
234	primal_edges = go.Scatter3d(x=xe.T.flat,y=ye.T.flat,z=ze.T.flat,mode='lines')
235	# ---
236	
237	Sg = lp.solve_AV(lp.transpose(V[:,S]),np.ones(S.shape)) # The vertices of the dual convex body
238	Ng = Sg[:,N] # Gradient of the neighbor cells
239	
240	# --- Plotly dual edges object
241	xe,ye,ze = np.moveaxis(np.array([Sg[:,None]+0*Ng,Ng,np.full_like(Ng,np.nan)]),0,1)
242	dual_edges = go.Scatter3d(x=xe.T.flat,y=ye.T.flat,z=ze.T.flat,mode='lines')
243	# ---
244	
245	# Choose a reference simplex for each vertex, hence a reference point for each dual facet
246	XS = np.full(Vsize,-1,S.dtype)
247	XS[S] = np.arange(S.shape[1],dtype=S.dtype)[None]
248		
249	# --- Plotly dual mesh object, using triangulated faces
250	x,y,z = Sg
251	i = XS[S]; j=np.roll(N,-1,axis=0); k=np.tile(np.arange(Sg.shape[1]),(3,1))
252	dual_mesh=go.Mesh3d(x=x,y=y,z=z, i=i.flat,j=j.flat,k=k.flat)
253	# ---
254	
255	return (primal_mesh,primal_edges),(dual_mesh,dual_edges)
def SetTitle3D(ax, title):
19def SetTitle3D(ax,title):
20	ax.text2D(0.5,0.95,title,transform=ax.transAxes,horizontalalignment='center')
def savefig(fig, fileName, dirName=None, ax=None, **kwargs):
22def savefig(fig,fileName,dirName=None,ax=None,**kwargs):
23	"""Save a figure:
24	- in a given directory, possibly set in the properties of the function. 
25	 Silently fails if dirName is None
26	- with defaulted arguments, possibly set in the properties of the function
27	"""
28	# Choose the subplot to be saved 
29	if ax is not None:
30		kwargs['bbox_inches'] = ax.get_tightbbox(
31			fig.canvas.get_renderer()).transformed(fig.dpi_scale_trans.inverted())
32
33	# Set arguments to be passed
34	for key,value in vars(savefig).items():
35		if key not in kwargs and key!='dirName':
36			kwargs[key]=value
37
38	# Set directory
39	if dirName is None: 
40		if savefig.dirName is None: return 
41		else: dirName=savefig.dirName
42
43	# Save figure
44	if path.isdir(dirName):
45		fig.savefig(path.join(dirName,fileName),**kwargs) 
46	else:
47		print("savefig error: No such directory", dirName)

Save a figure:

  • in a given directory, possibly set in the properties of the function. Silently fails if dirName is None
  • with defaulted arguments, possibly set in the properties of the function
def open_local_or_web( func, filepath, local_prefix='../', web_prefix='https://mirebeau.github.io/AdaptiveGridDiscretizations', web_suffix=''):
55def open_local_or_web(func,filepath,local_prefix="../",
56	# Using data stored on the gh-pages (github pages) branch
57	web_prefix='https://mirebeau.github.io/AdaptiveGridDiscretizations',
58	web_suffix=''):
59	try: return func(local_prefix+filepath)
60	except FileNotFoundError:
61		try : return func(filepath)  # Retry without the local prefix...
62		except FileNotFoundError:
63			import urllib
64			return func(urllib.request.urlopen(web_prefix+filepath+web_suffix))
def imread(*args, **kwargs):
67def imread(*args,**kwargs):
68	"""
69	Reads the image into a numpy array. Tries to find it locally and on the web.
70	- *args,**args : passed to open_local_or_web
71	"""
72	import PIL
73	return np.array(open_local_or_web(PIL.Image.open,*args,**kwargs))

Reads the image into a numpy array. Tries to find it locally and on the web.

  • args,*args : passed to open_local_or_web
def animation_curve(X, Y, **kwargs):
75def animation_curve(X,Y,**kwargs):
76	"""Animates a sequence of curves Y[0],Y[1],... with X as horizontal axis"""
77	fig, ax = plt.subplots(); plt.close()
78	ax.set_xlim(( X[0], X[-1]))
79	ax.set_ylim(( np.min(Y), np.max(Y)))
80	line, = ax.plot([], [])
81	def func(i,Y): line.set_data(X,Y[i])
82	kwargs.setdefault('interval',20)
83	kwargs.setdefault('repeat',False)
84	return animation.FuncAnimation(fig,func,fargs=(Y,),frames=len(Y),**kwargs)

Animates a sequence of curves Y[0],Y[1],... with X as horizontal axis

def quiver(X, Y, U, V, subsampling=(), **kwargs):
88def quiver(X,Y,U,V,subsampling=tuple(),**kwargs):
89	"""
90	Pyplot quiver with additional arg:
91	- subsampling (tuple or int). Subsample X,Y,U,V	
92	"""
93	if np.ndim(subsampling)==0: subsampling = (subsampling,)*2
94	where = tuple(slice(None,None,s) for s in subsampling)
95	def f(Z): return Z.__getitem__(where)
96	return plt.quiver(f(X),f(Y),f(U),f(V),**kwargs)

Pyplot quiver with additional arg:

  • subsampling (tuple or int). Subsample X,Y,U,V
def Tissot(metric, X, =100, subsampling=5, scale=-1):
 98def Tissot(metric,X,=100,subsampling=5,scale=-1):
 99	"""
100	Display the collection of unit balls of a two dimensional metric, also known as the 
101	Tissot indicatrix.
102	Inputs : 
103	- metric : the metric to display
104	- X : the geometric domain
105	- nθ : number of angular directions
106	- subsampling (integer or pair of integers): only display a subset of the unit balls
107	- scale : scaling factor for the unit balls (if negative, then relative to auto scale)
108	"""
109	if subsampling is not None:
110		if np.ndim(subsampling)==0: subsampling=[subsampling,subsampling]
111		metric.set_interpolation(X)
112		X = X[:,(subsampling[0]//2)::subsampling[0],(subsampling[1]//2)::subsampling[1]]
113		metric = metric.at(X)
114
115	dx = X[:,1,1]-X[:,0,0]
116	θ = np.linspace(0,2*np.pi,)
117	U = np.array([np.cos(θ),np.sin(θ)]) # unit vectors
118	bd = np.array([u[:,None,None] / metric.norm(u) for u in U.T])
119	
120	if scale<0:
121		default_scale = 0.4*min(dx[0]/np.max(bd[:,0]), dx[1]/np.max(bd[:,1]))
122		scale = np.abs(scale)*default_scale
123	bd = (bd*scale + X).reshape((,2,-1))
124	plt.plot(bd[:,0],bd[:,1],color='red')
125	plt.scatter(*X,s=1,color='black')
126	return scale

Display the collection of unit balls of a two dimensional metric, also known as the Tissot indicatrix. Inputs :

  • metric : the metric to display
  • X : the geometric domain
  • nθ : number of angular directions
  • subsampling (integer or pair of integers): only display a subset of the unit balls
  • scale : scaling factor for the unit balls (if negative, then relative to auto scale)
def imshow_ij(image, **kwargs):
131def imshow_ij(image,**kwargs): 
132	"""Show an image, using Cartesian array coordinates, 
133	as with the option indexing='ij' of np.mesgrid."""
134	return plt.imshow(np.moveaxis(image,0,1),origin='lower',**kwargs)

Show an image, using Cartesian array coordinates, as with the option indexing='ij' of np.mesgrid.

def arr2fig(image, xsize=None, **kwargs):
136def arr2fig(image,xsize=None,**kwargs):
137	"""
138	Create a figure displaying the given image, 
139	and nothing else. Uses Cartesian array coordinates.
140	"""
141	xshape,yshape = image.shape[:2]
142	if xsize is None: xsize = min(6.4,4.8*xshape/yshape)
143	ysize = xsize*yshape/xshape
144	fig = plt.figure(figsize=[xsize,ysize])
145	plt.axis('off')
146	fig.tight_layout(pad=0)
147	imshow_ij(image,**kwargs)
148	return fig

Create a figure displaying the given image, and nothing else. Uses Cartesian array coordinates.

def fig2arr(fig, shape, noalpha=True):
150def fig2arr(fig,shape,noalpha=True):
151	"""
152	Save the figure as an array with the given shape, 
153	which must be proportional to its size. Uses Cartesian array coords.
154
155	Approximate inverse of arr2fig.
156	"""
157	size = fig.get_size_inches()
158	assert np.allclose(shape[0]*size[1],shape[1]*size[0])
159		
160	io_buf = io.BytesIO()
161	fig.savefig(io_buf, format='raw',dpi=shape[0]/size[0])
162	io_buf.seek(0)
163	img_arr = np.reshape(np.frombuffer(io_buf.getvalue(), dtype=np.uint8),
164					newshape=(shape[1],shape[0], -1))
165	io_buf.close()
166	return np.moveaxis(img_arr,0,1)[:,::-1,:(3 if noalpha else 4)]/255

Save the figure as an array with the given shape, which must be proportional to its size. Uses Cartesian array coords.

Approximate inverse of arr2fig.

def pick_lines(n=inf, broken=True, arrow=False):
171def pick_lines(n=np.inf,broken=True,arrow=False):
172	"""Interactively pick some (broken) lines."""
173	# Set default : [np.round(line).astype(int).tolist() for line in lines]
174	plt.title(f"Pick points, type enter once\n to make a line, twice to terminate") 
175	plt.show()
176	
177	lines = []
178	while len(lines)<n:
179		pts = np.array(plt.ginput(-1 if broken else 2))
180		if len(pts)==0: break # Empty input means end
181		elif len(pts)==1: continue # Invalid input
182		pts = pts.T
183		lines.append(pts)
184		plt.plot(*pts,color='r')
185		if arrow: plt.arrow(*pts[:,-1],*0.001*(pts[:,-1]-pts[:,-2]),color='r',head_width=10)
186	
187	plt.close()
188	return lines

Interactively pick some (broken) lines.

def pick_points(n=inf):
190def pick_points(n=np.inf):
191	"""Interactively pick some point coordinates."""
192	plt.title(f"Pick {n} point(s)" if n<np.inf else 
193		"Pick points, middle click or type enter to terminate") 
194	plt.show()
195	pts = np.array(plt.ginput(-1 if n==np.inf else n)).T
196	plt.close()
197	return pts

Interactively pick some point coordinates.

def input_default(prompt='', default=''):
199def input_default(prompt='',default=''):
200	if default!='': prompt += f" (default={default}) "
201	return input(prompt) or default
def plotly_primal_dual_bodies(V):
206def plotly_primal_dual_bodies(V):
207	"""	
208	Output : facet indices, facet measures. 
209	"""
210	import scipy.spatial
211	import plotly.graph_objects as go
212
213	# Use a convex hull routine to triangulate the primal convex body
214	primal_body = scipy.spatial.ConvexHull(V.T) 
215
216	Vsize = V.shape[1]
217	if primal_body.vertices.size!=Vsize:  # Then primal_body.vertices == np.arange(Xsize)
218		raise ValueError("Non-convex primal body ! See",set(range(Vsize))-set(primal_body.vertices))
219
220	# Use counter-clockwise orientation for all triangles
221	S = primal_body.simplices.T
222	N = primal_body.neighbors.T
223	cw = np.sign(lp.det(V[:,S]))==-1 
224	S[1,cw],S[2,cw] = S[2,cw],S[1,cw] 
225	N[1,cw],N[2,cw] = N[2,cw],N[1,cw] 
226
227	# --- Plotly primal mesh object
228	x,y,z = V; i,j,k = S;
229	primal_mesh = go.Mesh3d(x=x,y=y,z=z, i=i,j=j,k=k)
230	# ---
231
232	# --- Plotly primal edges object
233	V0=V[:,S]; V1=V[:,np.roll(S,1,axis=0)];
234	xe,ye,ze = np.moveaxis([V0,V1,np.full_like(V0,np.nan)],0,1)
235	primal_edges = go.Scatter3d(x=xe.T.flat,y=ye.T.flat,z=ze.T.flat,mode='lines')
236	# ---
237	
238	Sg = lp.solve_AV(lp.transpose(V[:,S]),np.ones(S.shape)) # The vertices of the dual convex body
239	Ng = Sg[:,N] # Gradient of the neighbor cells
240	
241	# --- Plotly dual edges object
242	xe,ye,ze = np.moveaxis(np.array([Sg[:,None]+0*Ng,Ng,np.full_like(Ng,np.nan)]),0,1)
243	dual_edges = go.Scatter3d(x=xe.T.flat,y=ye.T.flat,z=ze.T.flat,mode='lines')
244	# ---
245	
246	# Choose a reference simplex for each vertex, hence a reference point for each dual facet
247	XS = np.full(Vsize,-1,S.dtype)
248	XS[S] = np.arange(S.shape[1],dtype=S.dtype)[None]
249		
250	# --- Plotly dual mesh object, using triangulated faces
251	x,y,z = Sg
252	i = XS[S]; j=np.roll(N,-1,axis=0); k=np.tile(np.arange(Sg.shape[1]),(3,1))
253	dual_mesh=go.Mesh3d(x=x,y=y,z=z, i=i.flat,j=j.flat,k=k.flat)
254	# ---
255	
256	return (primal_mesh,primal_edges),(dual_mesh,dual_edges)

Output : facet indices, facet measures.