Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 111 additions & 55 deletions src/adios4dolfinx/backends/adios2/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def get_default_backend_args(arguments: dict[str, Any] | None) -> dict[str, Any]
args = arguments or {}
if "engine" not in args.keys():
args["engine"] = "BP4"
if "legacy" not in args.keys():
args["legacy"] = False # Only used for legacy HDF5 meshtags
return args


Expand Down Expand Up @@ -71,7 +73,7 @@ def write_attributes(
filename=filename,
mode=adios2.Mode.Append,
io_name="AttributeWriter",
**backend_args,
engine=backend_args["engine"],
) as adios_file:
adios_file.file.BeginStep()

Expand Down Expand Up @@ -105,7 +107,7 @@ def read_attributes(
adios=adios,
filename=filename,
mode=adios2.Mode.Read,
**backend_args,
engine=backend_args["engine"],
io_name="AttributesReader",
) as adios_file:
adios_file.file.BeginStep()
Expand Down Expand Up @@ -143,7 +145,7 @@ def read_timestamps(
adios=adios,
filename=filename,
mode=adios2.Mode.Read,
**backend_args,
engine=backend_args["engine"],
io_name="TimestepReader",
) as adios_file:
time_name = f"{function_name}_time"
Expand Down Expand Up @@ -195,7 +197,8 @@ def write_mesh(
filename=filename,
mode=mode,
comm=comm,
**backend_args,
engine=backend_args["engine"],
io_name=backend_args["io_name"],
) as adios_file:
adios_file.file.BeginStep()
# Write geometry
Expand Down Expand Up @@ -473,73 +476,126 @@ def read_meshtags_data(
"""

adios = adios2.ADIOS(comm)
backend_args = {} if backend_args is None else backend_args
backend_args = get_default_backend_args(backend_args)
io_name = backend_args.get("io_name", "MeshTagsReader")
engine = backend_args.get("engine", "BP4")
engine = backend_args["engine"]
legacy = backend_args["legacy"]
with ADIOSFile(
adios=adios,
filename=filename,
mode=adios2.Mode.Read,
engine=engine,
io_name=io_name,
) as adios_file:
# Get mesh cell type
dim_attr_name = f"{name}_dim"
step = 0
for i in range(adios_file.file.Steps()):
adios_file.file.BeginStep()
if dim_attr_name in adios_file.io.AvailableAttributes().keys():
step = i
break
adios_file.file.EndStep()
if dim_attr_name not in adios_file.io.AvailableAttributes().keys():
raise KeyError(f"{dim_attr_name} not found in {filename}")
if not legacy:
# Get mesh cell type
dim_attr_name = f"{name}_dim"
step = 0
for i in range(adios_file.file.Steps()):
adios_file.file.BeginStep()
if dim_attr_name in adios_file.io.AvailableAttributes().keys():
step = i
break
adios_file.file.EndStep()
if dim_attr_name not in adios_file.io.AvailableAttributes().keys():
raise KeyError(f"{dim_attr_name} not found in {filename}")

m_dim = adios_file.io.InquireAttribute(dim_attr_name)
dim = int(m_dim.Data()[0])
m_dim = adios_file.io.InquireAttribute(dim_attr_name)
dim = int(m_dim.Data()[0])

# Get mesh tags entites
topology_name = f"{name}_topology"
for i in range(step, adios_file.file.Steps()):
if i > step:
adios_file.file.BeginStep()
if topology_name in adios_file.io.AvailableVariables().keys():
break
adios_file.file.EndStep()
if topology_name not in adios_file.io.AvailableVariables().keys():
raise KeyError(f"{topology_name} not found in {filename}")
# Get mesh tags entites
topology_name = f"{name}_topology"
for i in range(step, adios_file.file.Steps()):
if i > step:
adios_file.file.BeginStep()
if topology_name in adios_file.io.AvailableVariables().keys():
break
adios_file.file.EndStep()
if topology_name not in adios_file.io.AvailableVariables().keys():
raise KeyError(f"{topology_name} not found in {filename}")

topology = adios_file.io.InquireVariable(topology_name)
top_shape = topology.Shape()
topology_range = compute_local_range(comm, top_shape[0])

topology.SetSelection(
[
[topology_range[0], 0],
[topology_range[1] - topology_range[0], top_shape[1]],
]
)
mesh_entities = np.empty(
(topology_range[1] - topology_range[0], top_shape[1]), dtype=np.int64
)
adios_file.file.Get(topology, mesh_entities, adios2.Mode.Deferred)

# Get mesh tags values
values_name = f"{name}_values"
if values_name not in adios_file.io.AvailableVariables().keys():
raise KeyError(f"{values_name} not found")

values = adios_file.io.InquireVariable(values_name)
val_shape = values.Shape()
assert val_shape[0] == top_shape[0]
values.SetSelection([[topology_range[0]], [topology_range[1] - topology_range[0]]])
tag_values = np.empty(
(topology_range[1] - topology_range[0]), dtype=values.Type().strip("_t")
)
adios_file.file.Get(values, tag_values, adios2.Mode.Deferred)

topology = adios_file.io.InquireVariable(topology_name)
top_shape = topology.Shape()
topology_range = compute_local_range(comm, top_shape[0])
adios_file.file.PerformGets()
adios_file.file.EndStep()
else:
# Get mesh cell type
dim_attr_name = f"{name}_dim"
assert adios_file.file.Steps() == 0
if (ct_key := f"/{name}/topology/celltype") in adios_file.io.AvailableAttributes():
cell_type = adios_file.io.InquireAttribute(ct_key)
else:
raise ValueError(f"Celltype not found for meshtags {name} in {filename}.")
dim = dolfinx.mesh.cell_dim(dolfinx.mesh.to_type(cell_type.DataString()[0]))

# Get mesh tags entites
if (top_key := f"/{name}/topology") in adios_file.io.AvailableVariables():
topology = adios_file.io.InquireVariable(top_key)
else:
raise ValueError(f"Topology not found for meshtags {name} in {filename}.")

top_shape = topology.Shape()
topology_range = compute_local_range(comm, top_shape[0])

topology.SetSelection(
[
[topology_range[0], 0],
[topology_range[1] - topology_range[0], top_shape[1]],
]
)
mesh_entities = np.empty(
(topology_range[1] - topology_range[0], top_shape[1]), dtype=np.int64
)
adios_file.file.Get(topology, mesh_entities, adios2.Mode.Deferred)

topology.SetSelection(
[
[topology_range[0], 0],
[topology_range[1] - topology_range[0], top_shape[1]],
]
)
mesh_entities = np.empty(
(topology_range[1] - topology_range[0], top_shape[1]), dtype=np.int64
)
adios_file.file.Get(topology, mesh_entities, adios2.Mode.Deferred)
# Get mesh tags values
if (val_key := f"/{name}/values") in adios_file.io.AvailableVariables():
values = adios_file.io.InquireVariable(val_key)
else:
raise ValueError(f"Values not found for meshtags {name} in {filename}.")

# Get mesh tags values
values_name = f"{name}_values"
if values_name not in adios_file.io.AvailableVariables().keys():
raise KeyError(f"{values_name} not found")
val_shape = values.Shape()
assert val_shape[0] == top_shape[0]

values = adios_file.io.InquireVariable(values_name)
val_shape = values.Shape()
assert val_shape[0] == top_shape[0]
values.SetSelection([[topology_range[0]], [topology_range[1] - topology_range[0]]])
tag_values = np.empty((topology_range[1] - topology_range[0]), dtype=np.int32)
adios_file.file.Get(values, tag_values, adios2.Mode.Deferred)
values.SetSelection([[topology_range[0]], [topology_range[1] - topology_range[0]]])
tag_values = np.empty(
(topology_range[1] - topology_range[0]), dtype=values.Type().strip("_t")
)
adios_file.file.Get(values, tag_values, adios2.Mode.Deferred)

adios_file.file.PerformGets()
adios_file.file.EndStep()
adios_file.file.PerformGets()
adios_file.file.EndStep()

return MeshTagsData(name=name, values=tag_values, indices=mesh_entities, dim=dim)
return MeshTagsData(
name=name, values=tag_values.astype(np.int32), indices=mesh_entities, dim=dim
)


def read_dofmap(
Expand Down
53 changes: 31 additions & 22 deletions src/adios4dolfinx/backends/h5py/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from mpi4py import MPI

import dolfinx
import h5py
import numpy as np
import numpy.typing as npt
from dolfinx.graph import adjacencylist
Expand All @@ -23,6 +24,10 @@

read_mode = ReadMode.parallel

# try:
# except ModuleNotFoundError:
# raise ModuleNotFoundError("This backend requires h5py to be installed.")


@contextlib.contextmanager
def h5pyfile(h5name, filemode="r", force_serial: bool = False, comm=None):
Expand All @@ -35,8 +40,6 @@ def h5pyfile(h5name, filemode="r", force_serial: bool = False, comm=None):
comm: The MPI communicator

"""
import h5py

if comm is None:
comm = MPI.COMM_WORLD

Expand All @@ -58,13 +61,7 @@ def h5pyfile(h5name, filemode="r", force_serial: bool = False, comm=None):


def get_default_backend_args(arguments: dict[str, Any] | None) -> dict[str, Any]:
args = arguments or {}

if arguments:
# Currently no default arguments for h5py backend
# TODO: Pehaps we would like to make this into a warning instead?
raise RuntimeError("Unexpected backend arguments to h5py backend")

args = arguments or {"legacy": False} # If meshtags is read from legacy
return args


Expand Down Expand Up @@ -384,20 +381,32 @@ def read_meshtags_data(
Internal data structure for the mesh tags read from file
"""
backend_args = get_default_backend_args(backend_args)
legacy = backend_args["legacy"]
with h5pyfile(filename, filemode="r", comm=comm, force_serial=False) as h5file:
if "mesh" not in h5file.keys():
raise KeyError("No mesh found")
mesh = h5file["mesh"]
if "tags" not in mesh.keys():
raise KeyError("Could not find 'tags' in file, are you sure this is a checkpoint?")
tags = mesh["tags"]
if name not in tags.keys():
raise KeyError(f"Could not find {name} in '/mesh/tags/' in {filename}")
tag = tags[name]

dim = tag.attrs["dim"]
topology = tag["Topology"]
values = tag["Values"]
if legacy:
if name not in h5file.keys():
raise RuntimeError(f"MeshTag {name} not found in {filename}.")
mesh = h5file[name]
topology = mesh["topology"]
cell_type = topology.attrs["celltype"]
if isinstance(cell_type, np.bytes_):
cell_type = cell_type.decode("utf-8")
dim = dolfinx.mesh.cell_dim(dolfinx.mesh.to_type(cell_type))
values = mesh["values"]
else:
if "mesh" not in h5file.keys():
raise KeyError("No mesh found")
mesh = h5file["mesh"]
if "tags" not in mesh.keys():
raise KeyError("Could not find 'tags' in file, are you sure this is a checkpoint?")
tags = mesh["tags"]
if name not in tags.keys():
raise KeyError(f"Could not find {name} in '/mesh/tags/' in {filename}")
tag = tags[name]

dim = tag.attrs["dim"]
topology = tag["Topology"]
values = tag["Values"]
num_entities_global = topology.shape[0]
topology_range = compute_local_range(comm, num_entities_global)
indices = topology[slice(*topology_range), :].astype(np.int64)
Expand Down
30 changes: 23 additions & 7 deletions src/adios4dolfinx/backends/pyvista/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
try:
import pyvista
except ImportError:
raise ModuleNotFoundError("This module requires pyvista")
raise ModuleNotFoundError("The PyVista-backend requires pyvista")
from pathlib import Path

from mpi4py import MPI
Expand Down Expand Up @@ -138,7 +138,11 @@ def read_mesh_data(
grid = in_data
elif isinstance(in_data, pyvista.core.composite.MultiBlock):
# To handle multiblock like pvd
pyvista._VTK_SNAKE_CASE_STATE = "allow"
if hasattr(pyvista, "_VTK_SNAKE_CASE_STATE"):
pyvista._VTK_SNAKE_CASE_STATE = "allow"
else:
# Compatibility with 0.47
pyvista.core.vtk_snake_case._state = "allow"
number_of_blocks = in_data.number_of_blocks
assert number_of_blocks == 1
b0 = in_data.get_block(0)
Expand Down Expand Up @@ -203,7 +207,11 @@ def read_point_data(
grid = in_data
elif isinstance(in_data, pyvista.core.composite.MultiBlock):
# To handle multiblock like pvd
pyvista._VTK_SNAKE_CASE_STATE = "allow"
if hasattr(pyvista, "_VTK_SNAKE_CASE_STATE"):
pyvista._VTK_SNAKE_CASE_STATE = "allow"
else:
# Compatibility with 0.47
pyvista.core.vtk_snake_case._state = "allow"
number_of_blocks = in_data.number_of_blocks
assert number_of_blocks == 1
b0 = in_data.get_block(0)
Expand All @@ -217,10 +225,10 @@ def read_point_data(
else:
num_components = dataset.shape[1]
if np.issubdtype(dataset.dtype, np.integer):
gtype = in_data.points.dtype
gtype = grid.points.dtype
dataset = dataset.astype(gtype)
else:
gtype = in_data.dtype
gtype = dataset.dtype
num_components, gtype = comm.bcast((num_components, gtype), root=0)
local_range_start = 0
else:
Expand All @@ -246,7 +254,11 @@ def read_cell_data(
grid = in_data
elif isinstance(in_data, pyvista.core.composite.MultiBlock):
# To handle multiblock like pvd
pyvista._VTK_SNAKE_CASE_STATE = "allow"
if hasattr(pyvista, "_VTK_SNAKE_CASE_STATE"):
pyvista._VTK_SNAKE_CASE_STATE = "allow"
else:
# Compatibility with 0.47
pyvista.core.vtk_snake_case._state = "allow"
number_of_blocks = in_data.number_of_blocks
assert number_of_blocks == 1
b0 = in_data.get_block(0)
Expand Down Expand Up @@ -351,7 +363,11 @@ def read_function_names(
grid = in_data
elif isinstance(in_data, pyvista.core.composite.MultiBlock):
# To handle multiblock like pvd
pyvista._VTK_SNAKE_CASE_STATE = "allow"
if hasattr(pyvista, "_VTK_SNAKE_CASE_STATE"):
pyvista._VTK_SNAKE_CASE_STATE = "allow"
else:
# Compatibility with 0.47
pyvista.core.vtk_snake_case._state = "allow"
number_of_blocks = in_data.number_of_blocks
assert number_of_blocks == 1
b0 = in_data.get_block(0)
Expand Down
Loading