diff --git a/map2loop/contact_extractor.py b/map2loop/contact_extractor.py new file mode 100644 index 00000000..14226948 --- /dev/null +++ b/map2loop/contact_extractor.py @@ -0,0 +1,133 @@ +"""Utility class for extracting geological contacts.""" + +from __future__ import annotations + +from typing import List, Optional + +import geopandas +import pandas +import shapely + +from .logging import getLogger + +logger = getLogger(__name__) + + +class ContactExtractor: + """Encapsulates contact extraction logic.""" + + def __init__(self, + geology_data: geopandas.GeoDataFrame, + fault_data: Optional[geopandas.GeoDataFrame] = None) -> None: + + self.geology_data = geology_data + self.fault_data = fault_data + self.contacts = None + self.basal_contacts = None + + + def extract_all_contacts(self) -> geopandas.GeoDataFrame: + """Extract all contacts between units in ``geology_data``.""" + + logger.info("Extracting contacts") + if self.fault_data is not None: + faults = self.fault_data.copy() + else: + faults = None + geology = self.geology_data.copy() + geology = geology.dissolve(by="UNITNAME", as_index=False) + + # Remove intrusions + geology = geology[~geology["INTRUSIVE"]] + geology = geology[~geology["SILL"]] + + # Remove faults from contact geometry + if faults is not None: + faults = faults.copy() + faults["geometry"] = faults.buffer(50) + geology = geopandas.overlay(geology, faults, how="difference", keep_geom_type=False) + + units = geology["UNITNAME"].unique() + column_names = ["UNITNAME_1", "UNITNAME_2", "geometry"] + contacts = geopandas.GeoDataFrame(crs=geology.crs, columns=column_names, data=None) + + while len(units) > 1: + unit1 = units[0] + units = units[1:] + for unit2 in units: + if unit1 != unit2: + join = geopandas.overlay( + geology[geology["UNITNAME"] == unit1], + geology[geology["UNITNAME"] == unit2], + keep_geom_type=False, + )[column_names] + join["geometry"] = join.buffer(1) + buffered = geology[geology["UNITNAME"] == unit2][["geometry"]].copy() + buffered["geometry"] = buffered.boundary + end = geopandas.overlay(buffered, join, keep_geom_type=False) + if len(end): + contacts = pandas.concat([contacts, end], ignore_index=True) + + contacts["length"] = [row.length for row in contacts["geometry"]] + self.contacts = contacts + + return contacts + + # ------------------------------------------------------------------ + def extract_basal_contacts( + self, + stratigraphic_column: List[str], + ) -> geopandas.GeoDataFrame: + """Identify the basal unit of ``contacts`` based on ``stratigraphic_column``.""" + + logger.info("Extracting basal contacts") + + units = stratigraphic_column + basal_contacts = self.contacts.copy() + + # verify units exist in the geology dataset + if any(unit not in units for unit in basal_contacts["UNITNAME_1"].unique()): + missing_units = ( + basal_contacts[~basal_contacts["UNITNAME_1"].isin(units)]["UNITNAME_1"] + .unique() + .tolist() + ) + logger.error( + "There are units in the Geology dataset, but not in the stratigraphic column: " + + ", ".join(missing_units) + + ". Please readjust the stratigraphic column if this is a user defined column." + ) + raise ValueError( + "There are units in stratigraphic column, but not in the Geology dataset: " + + ", ".join(missing_units) + + ". Please readjust the stratigraphic column if this is a user defined column." + ) + + # apply minimum lithological id between the two units + basal_contacts["ID"] = basal_contacts.apply( + lambda row: min(units.index(row["UNITNAME_1"]), units.index(row["UNITNAME_2"])), axis=1 + ) + + # match the name of the unit with the minimum id + basal_contacts["basal_unit"] = basal_contacts.apply(lambda row: units[row["ID"]], axis=1) + + # how many units apart are the two units? + basal_contacts["stratigraphic_distance"] = basal_contacts.apply( + lambda row: abs(units.index(row["UNITNAME_1"]) - units.index(row["UNITNAME_2"])), + axis=1, + ) + + # if the units are more than 1 unit apart, the contact is abnormal + basal_contacts["type"] = basal_contacts.apply( + lambda row: "ABNORMAL" if abs(row["stratigraphic_distance"]) > 1 else "BASAL", + axis=1, + ) + + basal_contacts = basal_contacts[["ID", "basal_unit", "type", "geometry"]] + + basal_contacts["geometry"] = [ + shapely.line_merge(shapely.snap(geo, geo, 1)) for geo in basal_contacts["geometry"] + ] + + return basal_contacts + diff --git a/map2loop/mapdata.py b/map2loop/mapdata.py index 4137af27..1e20232c 100644 --- a/map2loop/mapdata.py +++ b/map2loop/mapdata.py @@ -3,7 +3,13 @@ from .config import Config from .aus_state_urls import AustraliaStateUrls from .utils import generate_random_hex_colors, calculate_minimum_fault_length -from .data_checks import check_geology_fields_validity, check_structure_fields_validity, check_fault_fields_validity, check_fold_fields_validity +from .data_checks import ( + check_geology_fields_validity, + check_structure_fields_validity, + check_fault_fields_validity, + check_fold_fields_validity, +) +from .contact_extractor import ContactExtractor # external imports import geopandas @@ -91,6 +97,7 @@ def __init__(self, verbose_level: VerboseLevel = VerboseLevel.ALL): self.colour_filename = None self.verbose_level = verbose_level self.config = Config() + self.contact_extractor = ContactExtractor() @property @beartype.beartype @@ -1508,107 +1515,30 @@ def get_value_from_raster_df(self, datatype: Datatype, df: pandas.DataFrame): @beartype.beartype def extract_all_contacts(self, save_contacts=True): - """ - Extract the contacts between units in the geology GeoDataFrame - """ - logger.info("Extracting contacts") - geology = self.get_map_data(Datatype.GEOLOGY).copy() - geology = geology.dissolve(by="UNITNAME", as_index=False) - # Remove intrusions - geology = geology[~geology["INTRUSIVE"]] - geology = geology[~geology["SILL"]] - # Remove faults from contact geomety - if self.get_map_data(Datatype.FAULT) is not None: - faults = self.get_map_data(Datatype.FAULT).copy() - faults["geometry"] = faults.buffer(50) - geology = geopandas.overlay(geology, faults, how="difference", keep_geom_type=False) - units = geology["UNITNAME"].unique() - column_names = ["UNITNAME_1", "UNITNAME_2", "geometry"] - contacts = geopandas.GeoDataFrame(crs=geology.crs, columns=column_names, data=None) - while len(units) > 1: - unit1 = units[0] - units = units[1:] - for unit2 in units: - if unit1 != unit2: - # print(f'contact: {unit1} and {unit2}') - join = geopandas.overlay( - geology[geology["UNITNAME"] == unit1], - geology[geology["UNITNAME"] == unit2], - keep_geom_type=False, - )[column_names] - join["geometry"] = join.buffer(1) - buffered = geology[geology["UNITNAME"] == unit2][["geometry"]].copy() - buffered["geometry"] = buffered.boundary - end = geopandas.overlay(buffered, join, keep_geom_type=False) - if len(end): - contacts = pandas.concat([contacts, end], ignore_index=True) - # contacts["TYPE"] = "UNKNOWN" - contacts["length"] = [row.length for row in contacts["geometry"]] - # print('finished extracting contacts') + """Extract contacts from the loaded geology and fault data.""" + + geology = self.get_map_data(Datatype.GEOLOGY) + faults = self.get_map_data(Datatype.FAULT) + + contacts = self.contact_extractor.extract_all_contacts(geology, faults) + if save_contacts: self.contacts = contacts return contacts @beartype.beartype def extract_basal_contacts(self, stratigraphic_column: list, save_contacts=True): - """ - Identify the basal unit of the contacts based on the stratigraphic column + """Identify basal contacts using the loaded contacts.""" - Args: - stratigraphic_column (list): - The stratigraphic column to use - """ - logger.info("Extracting basal contacts") - - units = stratigraphic_column - basal_contacts = self.contacts.copy() - - # check if the units in the strati colum are in the geology dataset, so that basal contacts can be built - # if not, stop the project - if any(unit not in units for unit in basal_contacts["UNITNAME_1"].unique()): - missing_units = ( - basal_contacts[~basal_contacts["UNITNAME_1"].isin(units)]["UNITNAME_1"] - .unique() - .tolist() - ) - logger.error( - "There are units in the Geology dataset, but not in the stratigraphic column: " - + ", ".join(missing_units) - + ". Please readjust the stratigraphic column if this is a user defined column." - ) - raise ValueError( - "There are units in stratigraphic column, but not in the Geology dataset: " - + ", ".join(missing_units) - + ". Please readjust the stratigraphic column if this is a user defined column." - ) + if self.contacts is None: + raise ValueError("Contacts must be extracted before extracting basal contacts") - # apply minimum lithological id between the two units - basal_contacts["ID"] = basal_contacts.apply( - lambda row: min(units.index(row["UNITNAME_1"]), units.index(row["UNITNAME_2"])), axis=1 + basal_contacts = self.contact_extractor.extract_basal_contacts( + self.contacts, stratigraphic_column ) - # match the name of the unit with the minimum id - basal_contacts["basal_unit"] = basal_contacts.apply(lambda row: units[row["ID"]], axis=1) - # how many units apart are the two units? - basal_contacts["stratigraphic_distance"] = basal_contacts.apply( - lambda row: abs(units.index(row["UNITNAME_1"]) - units.index(row["UNITNAME_2"])), axis=1 - ) - # if the units are more than 1 unit apart, the contact is abnormal (meaning that there is one (or more) unit(s) missing in between the two) - basal_contacts["type"] = basal_contacts.apply( - lambda row: "ABNORMAL" if abs(row["stratigraphic_distance"]) > 1 else "BASAL", axis=1 - ) - - basal_contacts = basal_contacts[["ID", "basal_unit", "type", "geometry"]] - - # added code to make sure that multi-line that touch each other are snapped and merged. - # necessary for the reconstruction based on featureId - basal_contacts["geometry"] = [ - shapely.line_merge(shapely.snap(geo, geo, 1)) for geo in basal_contacts["geometry"] - ] if save_contacts: - # keep abnormal contacts as all_basal_contacts self.all_basal_contacts = basal_contacts - # remove the abnormal contacts from basal contacts self.basal_contacts = basal_contacts[basal_contacts["type"] == "BASAL"] return basal_contacts diff --git a/tests/contact_extractor/test_contact_extractor.py b/tests/contact_extractor/test_contact_extractor.py new file mode 100644 index 00000000..eb7d511c --- /dev/null +++ b/tests/contact_extractor/test_contact_extractor.py @@ -0,0 +1,61 @@ +import geopandas as gpd +import shapely.geometry +import pytest +import importlib.util +import pathlib +import types +import sys + +ROOT = pathlib.Path(__file__).resolve().parents[2] +PACKAGE_NAME = "map2loop" + +if PACKAGE_NAME not in sys.modules: + pkg = types.ModuleType(PACKAGE_NAME) + pkg.__path__ = [str(ROOT / PACKAGE_NAME)] + import logging + pkg.loggers = {} + pkg.ch = logging.StreamHandler() + pkg.ch.setLevel(logging.WARNING) + sys.modules[PACKAGE_NAME] = pkg + +spec = importlib.util.spec_from_file_location( + f"{PACKAGE_NAME}.contact_extractor", + ROOT / PACKAGE_NAME / "contact_extractor.py", +) +module = importlib.util.module_from_spec(spec) +sys.modules[f"{PACKAGE_NAME}.contact_extractor"] = module +spec.loader.exec_module(module) +ContactExtractor = module.ContactExtractor + + +@pytest.fixture +def simple_geology(): + """Create a minimal geology dataset for testing.""" + + poly1 = shapely.geometry.Polygon([(0, 0), (2, 0), (2, 2), (0, 2)]) + poly2 = shapely.geometry.Polygon([(2, 0), (4, 0), (4, 2), (2, 2)]) + + return gpd.GeoDataFrame( + { + "UNITNAME": ["unit1", "unit2"], + "INTRUSIVE": [False, False], + "SILL": [False, False], + "geometry": [poly1, poly2], + }, + crs="EPSG:4326", + ) + + +def test_extract_all_contacts(simple_geology): + extractor = ContactExtractor(simple_geology) + result = extractor.extract_all_contacts() + assert len(result) == 1 + + +def test_extract_basal_contacts(simple_geology): + extractor = ContactExtractor(simple_geology) + extractor.extract_all_contacts() + basal = extractor.extract_basal_contacts(["unit1", "unit2"]) + assert list(basal["basal_unit"]) == ["unit1"] + +