from __future__ import annotations
import re
from collections import abc
from itertools import cycle
from typing import (TYPE_CHECKING, Any, Callable, Dict, Iterable, List,
Optional, Sequence, Tuple)
import matplotlib.pyplot as plt
import numpy as np
from .._doc import copy_func, doc
if TYPE_CHECKING:
from nanomesh import LineMesh, MeshContainer, TriangleMesh
def _get_color_cycle(colors) -> cycle:
"""Get default matplotlib color cycle.
Returns
-------
itertools.cycle
Cycles through color strings in hex format (#XXXXXX).
"""
if not colors:
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
return cycle(colors)
def _get_point(mesh, label: int, method: str = 'mean') -> Tuple[float, float]:
"""Pick middle point from mesh matching default key.
Parameters
----------
mesh : Mesh
Input mesh
label : int
Input label
Returns
-------
Tuple[float, float]
(x, y) point
"""
idx = (mesh.cell_data[mesh.default_key] == label)
cells = mesh.cells[idx]
points = mesh.points[np.unique(cells)]
if method == 'middle':
return points[len(points) // 2] # take middle point as anchor
else:
return points.mean(axis=0)
def _annotate(ax: plt.Axes,
name: str,
xy: Tuple[float, float],
flip_xy: bool = True):
"""Annotate point on axis.
Parameters
----------
ax : matplotlib.axes.Axes
Matplotlib axes
name : str
Annotation text
xy : Tuple[float, float]
The point (x, y) to annotate
flip_xy : bool, optional
If true, (x,y) -> (y,x)
Returns
-------
matplotlib.text.Annotation
"""
if flip_xy:
xy = xy[::-1]
ax.annotate(name,
xy,
textcoords='offset pixels',
xytext=(4, 4),
color='red',
va='bottom')
def _deduplicate_labels(
handles_labels: Tuple[List[Any],
List[str]]) -> Tuple[List[Any], List[str]]:
"""Deduplicate legend handles and labels.
Parameters
----------
handles_labels : Tuple[List[Any], List[str]]
Legend handles and labels.
Returns
-------
(new_handles, new_labels) : Tuple[List[Any], List[str]]
Deduplicated legend handles and labels
"""
new_handles = []
new_labels = []
for handle, label in zip(*handles_labels):
if label not in new_labels:
new_handles.append(handle)
new_labels.append(label)
return (new_handles, new_labels)
def _legend_with_triplot_fix(ax: plt.Axes, **kwargs):
"""Add legend for triplot with fix that avoids duplicate labels.
Parameters
----------
ax : matplotlib.axes.Axes
Matplotlib axes to apply legend to.
**kwargs
These parameters are passed to :func:`matplotlib.pyplot.legend`.
Returns
-------
matplotlib.legend.Legend
"""
handles_labels = ax.get_legend_handles_labels()
new_handles_labels = _deduplicate_labels(handles_labels)
return ax.legend(*new_handles_labels, **kwargs)
def _legend_with_field_names_only(ax: plt.Axes,
triplot_fix: bool = False,
**kwargs):
"""Add legend with named fields only.
Parameters
----------
ax : matplotlib.axes.Axes
Matplotlib axes to apply legend to.
**kwargs
These parameters are passed to :func:`matplotlib.pyplot.legend`.
Returns
-------
matplotlib.legend.Legend
"""
handles_labels = ax.get_legend_handles_labels()
new_handles = []
new_labels = []
for handle, label in zip(*handles_labels):
try:
float(label)
except ValueError:
new_handles.append(handle)
new_labels.append(label)
new_handles_labels = (new_handles, new_labels)
if triplot_fix:
new_handles_labels = _deduplicate_labels(new_handles_labels)
return ax.legend(*new_handles_labels, **kwargs)
def _legend(ax: plt.Axes, title: str, triplot_fix: bool = False):
"""Wrapper around ax.legend with dispatch for fix.
Parameters
----------
ax : matplotlib.axes.Axes
Matplotlib axes object
title : str
Legend title
triplot_fix : bool, optional
If true, apply fix for triplot.
"""
if triplot_fix:
return _legend_with_triplot_fix(ax, title=title)
else:
return ax.legend(title=title)
@doc(cell_type='triangle',
cell_dim=3,
description='Shallow interface to :func:`matplotlib.pyplot.triplot`')
def triplot(ax: plt.Axes, **kwargs):
"""Plot collection of {cell_type}s.
{description}.
Parameters
----------
ax : matplotlib.axes.Axes
Description
x : (n, 1) numpy.ndarray
x-coordinates of points.
y : (n, 1) numpy.ndarray
y-coordinates of points.
cells : (m, {cell_dim}) numpy.ndarray
Integer array describing the {cell_type}s.
mask : (m, 1) numpy.ndarray, optional
Mask for {cell_type}s.
label : str, optional
Label for legend.
**kwargs
Extra keywords arguments passed to :func:`matplotlib.pyplot.plot`
Returns
-------
list of :class:`matplotlib.lines.Line2D`
A list of lines representing the {cell_type}s and nodes.
"""
x = kwargs.pop('x')
y = kwargs.pop('y')
kwargs['triangles'] = kwargs.pop('cells')
return ax.triplot(x, y, **kwargs)
[docs]@doc(triplot,
cell_type='line segment',
cell_dim=3,
description='API mimicks :func:`triplot`')
def lineplot(ax: plt.Axes,
*,
x: np.ndarray,
y: np.ndarray,
cells: np.ndarray,
mask: np.ndarray = None,
label: str = None,
**kwargs):
kwargs.setdefault('marker', '.')
kwargs.setdefault('markersize', 1)
if mask is not None:
cells = cells[~mask.squeeze()]
if np.issubdtype(x.dtype, np.integer):
x = x.astype(float)
if np.issubdtype(y.dtype, np.integer):
y = y.astype(float)
lines_x = np.insert(x[cells], 2, np.nan, axis=1)
lines_y = np.insert(y[cells], 2, np.nan, axis=1)
return ax.plot(lines_x.ravel(), lines_y.ravel(), label=label, **kwargs)
[docs]def meshplot(mesh: LineMesh | TriangleMesh,
ax: plt.Axes = None,
key: str = None,
legend: str = 'fields',
show_labels: Optional[Iterable | str | int] = None,
hide_labels: Optional[Iterable | str | int] = None,
show_region_markers: bool = True,
colors: Sequence[str] = None,
color_map: Dict[str | int, str] = None,
flip_xy: bool = True,
**kwargs) -> plt.Axes:
"""Plot a :class:`nanomesh.TriangleMesh` or :class:`nanomesh.LineMesh`
using :mod:`matplotlib`.
Parameters
----------
ax : matplotlib.axes.Axes, optional
Axes to use for plotting.
key : str, optional
Label of cell data item to plot, defaults to
:attr:`nanomesh.LineMesh.default_key` or
:attr:`nanomesh.TriangleMEsh.default_key`.
legend : str
Style for the legend.
- off : No legend
- all : Create legend with all labels
- fields : Create legend with field names only
- floating : Add floating labels to plot
show_labels : Iterable | str | int
List of labels or field names of cell data to show. A single label to
show can also be specified directly by its name or number.
hide_labels : Iterable | str | int
List of labels or field names of cell data to hide. A single label to
hide can also be specified directly by its name or number.
show_region_markers : bool, default True
If True, show region markers on the plot
colors : Sequence[str]
List of colors to cycle through
color_map : dict
Mapping of labels or field names to colors.
flip_xy : bool, optional
Flip x/y coordinates. This is sometimes necessary to combine the
plot with other plots.
**kwargs
These parameters are passed to :func:`lineplot` or :func:`triplot`.
Returns
-------
matplotlib.axes.Axes
"""
dispatch: Dict[str, Callable] = {
'line': lineplot,
'triangle': triplot,
}
plotting_func = dispatch[mesh.cell_type]
key = key if key else mesh.default_key
# https://github.com/python/mypy/issues/9430
cell_data = mesh.cell_data.get(key, mesh.zero_labels) # type: ignore
cell_data_vals = np.unique(cell_data).astype(int)
if show_labels:
if isinstance(show_labels, int):
labels_to_show = {show_labels}
elif isinstance(show_labels, str):
pattern = show_labels
fields_to_show = {
field
for field in mesh.fields if re.match(pattern, field)
}
labels_to_show = {mesh.fields[field] for field in fields_to_show}
elif isinstance(show_labels, abc.Iterable):
labels_to_show = {
mesh.fields.get(label, label) # type: ignore
for label in show_labels
}
else:
labels_to_show = {*cell_data_vals}
if hide_labels:
if isinstance(hide_labels, int):
labels_to_show -= {hide_labels}
elif isinstance(hide_labels, str):
pattern = hide_labels
fields_to_hide = {
field
for field in mesh.fields if re.match(pattern, field)
}
labels_to_show -= {mesh.fields[field] for field in fields_to_hide}
elif isinstance(hide_labels, abc.Iterable):
labels_to_show -= {
mesh.fields.get(label, label) # type: ignore
for label in hide_labels
}
if not ax:
fig, ax = plt.subplots()
color_cycle = _get_color_cycle(colors)
color_map = color_map if color_map else {}
vert_x, vert_y = mesh.points.T
if flip_xy:
vert_x, vert_y = vert_y, vert_x
for cell_data_val in np.unique(cell_data):
name = mesh.number_to_field.get(cell_data_val, cell_data_val)
if cell_data_val not in labels_to_show:
continue
color = color_map.get(name, next(color_cycle))
plotting_func(
ax=ax,
x=vert_x,
y=vert_y,
cells=mesh.cells,
mask=cell_data != cell_data_val,
label=name,
color=color,
**kwargs,
)
if legend == 'floating':
method = 'middle' if mesh.cell_type == 'triangle' else 'mean'
xy = _get_point(mesh, cell_data_val, method=method)
_annotate(ax, name, xy, flip_xy=flip_xy)
if show_region_markers and mesh.region_markers:
mark_x, mark_y = np.array([m.point for m in mesh.region_markers]).T
ax.scatter(mark_y,
mark_x,
marker='*',
color='red',
label='Region markers')
for marker in mesh.region_markers:
label = marker.name if marker.name else marker.label
ax.annotate(label,
marker.point[::-1],
textcoords='offset pixels',
xytext=(4, -4),
color='red',
va='top')
ax.set_title(f'{mesh.cell_type} mesh')
ax.axis('equal')
# prevent double entries in legend for triangles
triplot_fix = (mesh.cell_type == 'triangle')
if legend == 'all':
_legend(ax=ax, title=key, triplot_fix=triplot_fix)
elif legend == 'fields':
_legend_with_field_names_only(ax=ax,
title=key,
triplot_fix=triplot_fix)
# force consistent orientation
y0, y1 = ax.get_ylim()
if y0 < y1:
ax.set_ylim(y1, y0)
return ax
trianglemeshplot = copy_func(meshplot)
trianglemeshplot.__doc__ = """Alias for :func:`meshplot`."""
linemeshplot = copy_func(meshplot)
linemeshplot.__doc__ = """Alias for :func:`meshplot`."""
[docs]def linetrianglemeshplot(mesh: MeshContainer,
**kwargs) -> Tuple[plt.Axes, plt.Axes]:
"""Plot line/triangle mesh together.
Parameters
----------
mesh : MeshContainer
Input mesh containing line and triangle cells.
**kwargs
Extra keyword arguments passed to
- :func:`linemeshplot`
- :func:`trianglemeshplot`
Returns
-------
Tuple(matplotlib.axes.Axes, matplotlib.axes.Axes)
Tuple of matplotlib axes
"""
fig, (ax1, ax2) = plt.subplots(ncols=2)
line_mesh = mesh.get('line')
linemeshplot(line_mesh, ax=ax1, **kwargs)
triangle_mesh = mesh.get('triangle')
trianglemeshplot(triangle_mesh, ax=ax2, **kwargs)
return ax1, ax2
[docs]def pointsplot(mesh: MeshContainer,
key: str = None,
ax: plt.Axes = None,
**kwargs) -> plt.Axes:
"""Plot mesh points.
Parameters
----------
mesh : MeshContainer
Input mesh
key : str, optional
Key of the point data to use for coloring
ax : matplotlib.axes.Axes, optional
Axes to use for plotting.
**kwargs :
These parameters are passed to :func:`matplotlib.pyplot.scatter`
Returns
-------
matplotlib.axes.Axes
"""
if key:
colors = mesh.point_data[key]
else:
colors = None
if not ax:
fig, ax = plt.subplots()
points = mesh.points
x, y = points.T
ax.scatter(y, x, c=colors, **kwargs)
ax.set_title('Points plot')
ax.axis('equal')
return ax