Source code for nanomesh.mesh_container

from __future__ import annotations

from collections import defaultdict
from enum import Enum
from types import MappingProxyType
from typing import Dict, List

import meshio
import numpy as np

from .mesh import Mesh, PruneZ0Mixin
from .plotting import pointsplot

try:
    # meshio >= 5.3
    from meshio._helpers import extension_to_filetypes
except ImportError:
    # meshio < 5.3
    from meshio._helpers import extension_to_filetype
    extension_to_filetypes = {k: [v] for k, v in extension_to_filetype.items()}


class _CellType(Enum):
    NULL = 0
    LINE = 1
    TRIANGLE = 2
    TETRA = 3


[docs]class MeshContainer(meshio.Mesh, PruneZ0Mixin): """Low-level container for storing mesh data. Can contain multiple cell types sharing a set of points. It can store different types of cells and associated data. :class:`MeshContainer` is based on :class:`meshio.Mesh` (https://github.com/nschloe/meshio). Parameters ---------- points : numpy.ndarray Array storing the mesh points (e.g. vertices) cells : list List of cell arrays point_data : todo cell_data : todo field_data : dict Dictionary mapping field names to cell data values. point_sets : todo cell_sets : todo gmsh_periodic : todo info : todo """ def __repr__(self): """Canonical string representation.""" s = super().__repr__().splitlines() s[0] = f'<{self.__class__.__name__}>' return '\n'.join(s) @property def number_to_field(self): """Mapping from numbers to fields, proxy to :attr:`MeshContainer.field_data`.""" number_to_field = defaultdict(dict) for field, (number, dimension) in self.field_data.items(): dim_name = _CellType(dimension).name.lower() number_to_field[dim_name][number] = field return MappingProxyType( {k: MappingProxyType(v) for k, v in number_to_field.items()}) @property def field_to_number(self): """Mapping from fields to numbers, proxy to :attr:`MeshContainer.field_data`.""" field_to_number = defaultdict(dict) for field, (number, dimension) in self.field_data.items(): dim_name = _CellType(dimension).name.lower() field_to_number[dim_name][field] = number return MappingProxyType( {k: MappingProxyType(v) for k, v in field_to_number.items()})
[docs] def set_field_data(self, cell_type: str, field_data: Dict[int, str]): """Update the values in :attr:`MeshContainer.field_data`. Parameters ---------- cell_type : str Cell type to update the values for. field_data : dict Dictionary with key-to-number mapping, i.e. `field_data={0: 'green', 1: 'blue', 2: 'red'}` maps `0` to `green`, etc. """ try: input_field_data = dict(self.number_to_field[cell_type]) except KeyError: input_field_data = {} input_field_data.update(field_data) new_field_data = self.field_data.copy() remove_me = [] for field, (value, field_cell_type) in new_field_data.items(): if _CellType(field_cell_type) == _CellType[cell_type.upper()]: remove_me.append(field) for field in remove_me: new_field_data.pop(field) for value, field in input_field_data.items(): CELL_TYPE = _CellType[cell_type.upper()].value new_field_data[field] = [value, CELL_TYPE] self.field_data: Dict[str, List[int]] = new_field_data
@property def cell_types(self): """Return cell types in order.""" return tuple(cell.type for cell in self.cells)
[docs] def set_cell_data(self, cell_type: str, key: str, value: np.ndarray): """Set the cell data to the given value. Updates :attr:`MeshContainer.cell_data`. Parameters ---------- cell_type : str Cell type, must be in :attr:`MeshContainer.cell_types` key : str The key of the value in :attr:`MeshContainer.cell_data` value : numpy.ndarray Array of values to set """ index = self.cell_types.index(cell_type) assert len(value) == len(self.cells_dict[cell_type]) try: self.cell_data[key][index] = value except KeyError: new_cell_data = [] # set missing cells to 0 for i, _ in enumerate(self.cell_types): if i == index: new_cell_data.append(value) else: new_cell_data.append( np.zeros(len(self.cells[0].data), dtype=int)) self.cell_data[key] = new_cell_data
[docs] def get_default_type(self) -> str: """Try to return highest dimension type. Default to first type :attr:`MeshContainer.cells_dict`. Returns ------- cell_type : str """ for type_ in ('tetra', 'triangle', 'line'): if type_ in self.cells_dict: return type_ return list(self.cells_dict.keys())[0]
[docs] def get(self, cell_type: str = None): """Extract mesh with points/cells of `cell_type`. Parameters ---------- cell_type : str, optional Element type, such as line, triangle, tetra, etc. Returns ------- Mesh Mesh of the given type """ if not cell_type: cell_type = self.get_default_type() try: cells = self.cells_dict[cell_type] except KeyError as e: msg = (f'No such cell type: {cell_type!r}. ' f'Must be one of {tuple(self.cells_dict.keys())!r}') raise KeyError(msg) from e points = self.points cell_data = self.get_all_cell_data(cell_type) fields = self.field_to_number.get(cell_type, None) return Mesh(cells=cells, points=points, fields=fields, **cell_data)
[docs] def get_all_cell_data(self, cell_type: str = None) -> dict: """Get all cell data for given `cell_type`. Parameters ---------- cell_type : str, optional Element type, such as line, triangle, tetra, etc. Returns ------- data_dict : dict Dictionary with cell data """ if not cell_type: cell_type = self.get_default_type() data_dict = {} for key in self.cell_data: new_key = key.replace(':', '-') data_dict[new_key] = self.get_cell_data(key, cell_type) return data_dict
[docs] def plot(self, cell_type: str = None, **kwargs): """Plot data. Parameters ---------- cell_type : str, optional Cell type to plot. **kwargs These parameters are passed to plotting method. """ cell_types = {cell.type for cell in self.cells} if (not cell_type) and (cell_types == {'line', 'triangle'}): from .plotting import linetrianglemeshplot return linetrianglemeshplot(self, **kwargs) else: mesh = self.get(cell_type) return mesh.plot(**kwargs)
[docs] def plot_mpl(self, cell_type: str = None, **kwargs): """Plot data using :mod:`matplotlib`. Parameters ---------- cell_type : str, optional Cell type to plot. **kwargs These parameters are passed to plotting method. """ mesh = self.get(cell_type) return mesh.plot_mpl(**kwargs)
[docs] def plot_itk(self, cell_type: str = None, **kwargs): """Plot data using :mod:`itk`. Parameters ---------- cell_type : str, optional Cell type to plot. **kwargs These parameters are passed to plotting method. """ mesh = self.get(cell_type) return mesh.plot_itk(**kwargs)
[docs] def plot_pyvista(self, cell_type: str = None, **kwargs): """Plot data using :mod:`pyvista`. Parameters ---------- cell_type : str, optional Cell type to plot. **kwargs These parameters are passed to plotting method. """ mesh = self.get(cell_type) return mesh.plot_pyvista(**kwargs)
[docs] def plot_points(self, **kwargs): """Plot points data using :mod:`matplotlib. Parameters ---------- **kwargs These parameters are passed to the plotting method. """ return pointsplot(self, **kwargs)
[docs] @classmethod def from_mesh(cls, mesh: Mesh): """Convert from :class:`nanomesh.mesh.Mesh` to :class:`MeshContainer`. Parameters ---------- mesh : Mesh Input mesh, must be a subclass of :class:`nanomesh.mesh.Mesh`. Returns ------- MeshContainer """ meshio_mesh = mesh.to_meshio() return cls(points=meshio_mesh.points, cells=meshio_mesh.cells, cell_data=meshio_mesh.cell_data)
[docs] @classmethod def from_triangle_dict(cls, triangle_dict: dict): """Return instance of :class:`MeshContainer` from triangle results dict. Parameters ---------- triangle_dict : dict Triangle triangulate output dictionary. Returns ------- mesh : MeshContainer """ points = triangle_dict['vertices'] cells = {'triangle': triangle_dict['triangles']} cell_data = {} try: cell_data['physical'] = [ triangle_dict['triangle_attributes'].squeeze() ] # Order must match order of cell_data cells['line'] = triangle_dict['edges'] cell_data['physical'].append( triangle_dict['edge_markers'].squeeze()) except KeyError: pass point_data = {} try: point_data['physical'] = triangle_dict['vertex_markers'].squeeze() except KeyError: pass mesh = cls( points=points, cells=cells, cell_data=cell_data, point_data=point_data, ) mesh.triangle_dict = triangle_dict return mesh
[docs] @classmethod def read(cls, *args, **kwargs) -> MeshContainer: """Wrapper for :func:`meshio.read`. For gmsh: - remaps `gmsh:physical` -> `physical` - remaps `gmsh:geometrical` -> `geometrical` Parameters ---------- *args These parameters passed to reader **kwargs These parameters are passed to the reader Returns ------- MeshContainer """ from meshio import read mesh = read(*args, **kwargs) cell_data = {} for key, value in mesh.cell_data.items(): if key in ('gmsh:physical', 'gmsh:geometrical'): key = key.replace('gmsh:', '') cell_data[key] = value point_data = {} for key, value in mesh.point_data.items(): if key in ('gmsh:physical', 'gmsh:geometrical'): key = key.replace('gmsh:', '') point_data[key] = value ret = cls( mesh.points, mesh.cells, point_data=point_data, cell_data=cell_data, field_data=mesh.field_data, point_sets=mesh.point_sets, cell_sets=mesh.cell_sets, gmsh_periodic=mesh.gmsh_periodic, info=mesh.info, ) ret.prune_z_0() return ret
[docs] def write(self, filename, file_format: str = None, **kwargs): """Thin wrapper of `meshio.write` to avoid altering class. For gmsh: - remaps `physical` -> `gmsh:physical` - remaps `geometrical` -> `gmsh:geometrical` Parameters ---------- filename : str File to write to. file_format : str, optional Specify file format. By default, this is guessed from the extension. **kwargs These parameters are passed to :func:`meshio.write`. """ from pathlib import Path from meshio import write if file_format is None: suffix = Path(filename).suffix try: file_types = extension_to_filetypes[suffix] file_format = file_types[0] except KeyError: raise IOError('Unknown extension, specify file format.') except IndexError: raise IOError('Specify file format ({file_types}).') if file_format.startswith('gmsh'): cell_data = {} for key, value in self.cell_data.items(): if key in ('physical', 'geometrical'): key = f'gmsh:{key}' cell_data[key] = value point_data = {} for key, value in self.point_data.items(): if key in ('physical', 'geometrical'): key = f'gmsh:{key}' point_data[key] = value else: cell_data = self.cell_data point_data = self.point_data out_mesh = meshio.Mesh( self.points, self.cells, point_data=point_data, cell_data=cell_data, field_data=self.field_data, point_sets=self.point_sets, cell_sets=self.cell_sets, gmsh_periodic=self.gmsh_periodic, info=self.info, ) write(filename, mesh=out_mesh, file_format=file_format, **kwargs)