From e4e74f485fbfea9b49d2582153e183334fb53e8b Mon Sep 17 00:00:00 2001 From: Rabii Chaarani <50892556+rabii-chaarani@users.noreply.github.com> Date: Mon, 30 Jun 2025 15:26:24 +0930 Subject: [PATCH 1/6] Refactor contact extraction into new class --- map2loop/contact_extractor.py | 129 ++++++++++++++++++ map2loop/mapdata.py | 114 ++-------------- map2loop/project.py | 6 +- .../test_contact_extractor.py | 44 ++++++ 4 files changed, 188 insertions(+), 105 deletions(-) create mode 100644 map2loop/contact_extractor.py create mode 100644 tests/contact_extractor/test_contact_extractor.py diff --git a/map2loop/contact_extractor.py b/map2loop/contact_extractor.py new file mode 100644 index 00000000..e3575be2 --- /dev/null +++ b/map2loop/contact_extractor.py @@ -0,0 +1,129 @@ +"""Utility class for extracting geological contacts.""" + +from __future__ import annotations + +from typing import List + +import geopandas +import pandas +import shapely + +from .m2l_enums import Datatype +from .logging import getLogger + +logger = getLogger(__name__) + + +class ContactExtractor: + """Encapsulates contact extraction logic used by :class:`MapData`.""" + + def __init__(self, map_data: "MapData") -> None: + self.map_data = map_data + + # ------------------------------------------------------------------ + def extract_all_contacts(self, save_contacts: bool = True) -> geopandas.GeoDataFrame: + """Extract all contacts between units in the geology GeoDataFrame.""" + + logger.info("Extracting contacts") + geology = self.map_data.get_map_data(Datatype.GEOLOGY).copy() + geology = geology.dissolve(by="UNITNAME", as_index=False) + + # Remove intrusions + geology = geology[~geology["INTRUSIVE"]] + geology = geology[~geology["SILL"]] + + # Remove faults from contact geometry + if self.map_data.get_map_data(Datatype.FAULT) is not None: + faults = self.map_data.get_map_data(Datatype.FAULT).copy() + faults["geometry"] = faults.buffer(50) + geology = geopandas.overlay(geology, faults, how="difference", keep_geom_type=False) + + units = geology["UNITNAME"].unique() + column_names = ["UNITNAME_1", "UNITNAME_2", "geometry"] + contacts = geopandas.GeoDataFrame(crs=geology.crs, columns=column_names, data=None) + + while len(units) > 1: + unit1 = units[0] + units = units[1:] + for unit2 in units: + if unit1 != unit2: + join = geopandas.overlay( + geology[geology["UNITNAME"] == unit1], + geology[geology["UNITNAME"] == unit2], + keep_geom_type=False, + )[column_names] + join["geometry"] = join.buffer(1) + buffered = geology[geology["UNITNAME"] == unit2][["geometry"]].copy() + buffered["geometry"] = buffered.boundary + end = geopandas.overlay(buffered, join, keep_geom_type=False) + if len(end): + contacts = pandas.concat([contacts, end], ignore_index=True) + + contacts["length"] = [row.length for row in contacts["geometry"]] + + if save_contacts: + self.map_data.contacts = contacts + + return contacts + + # ------------------------------------------------------------------ + def extract_basal_contacts( + self, stratigraphic_column: List[str], save_contacts: bool = True + ) -> geopandas.GeoDataFrame: + """Identify the basal unit of the contacts based on the stratigraphic column.""" + + logger.info("Extracting basal contacts") + + units = stratigraphic_column + basal_contacts = self.map_data.contacts.copy() + + # verify units exist in the geology dataset + if any(unit not in units for unit in basal_contacts["UNITNAME_1"].unique()): + missing_units = ( + basal_contacts[~basal_contacts["UNITNAME_1"].isin(units)]["UNITNAME_1"] + .unique() + .tolist() + ) + logger.error( + "There are units in the Geology dataset, but not in the stratigraphic column: " + + ", ".join(missing_units) + + ". Please readjust the stratigraphic column if this is a user defined column." + ) + raise ValueError( + "There are units in stratigraphic column, but not in the Geology dataset: " + + ", ".join(missing_units) + + ". Please readjust the stratigraphic column if this is a user defined column." + ) + + # apply minimum lithological id between the two units + basal_contacts["ID"] = basal_contacts.apply( + lambda row: min(units.index(row["UNITNAME_1"]), units.index(row["UNITNAME_2"])), axis=1 + ) + + # match the name of the unit with the minimum id + basal_contacts["basal_unit"] = basal_contacts.apply(lambda row: units[row["ID"]], axis=1) + + # how many units apart are the two units? + basal_contacts["stratigraphic_distance"] = basal_contacts.apply( + lambda row: abs(units.index(row["UNITNAME_1"]) - units.index(row["UNITNAME_2"])), + axis=1, + ) + + # if the units are more than 1 unit apart, the contact is abnormal + basal_contacts["type"] = basal_contacts.apply( + lambda row: "ABNORMAL" if abs(row["stratigraphic_distance"]) > 1 else "BASAL", + axis=1, + ) + + basal_contacts = basal_contacts[["ID", "basal_unit", "type", "geometry"]] + + basal_contacts["geometry"] = [ + shapely.line_merge(shapely.snap(geo, geo, 1)) for geo in basal_contacts["geometry"] + ] + + if save_contacts: + self.map_data.all_basal_contacts = basal_contacts + self.map_data.basal_contacts = basal_contacts[basal_contacts["type"] == "BASAL"] + + return basal_contacts + diff --git a/map2loop/mapdata.py b/map2loop/mapdata.py index 4137af27..342281b4 100644 --- a/map2loop/mapdata.py +++ b/map2loop/mapdata.py @@ -3,7 +3,13 @@ from .config import Config from .aus_state_urls import AustraliaStateUrls from .utils import generate_random_hex_colors, calculate_minimum_fault_length -from .data_checks import check_geology_fields_validity, check_structure_fields_validity, check_fault_fields_validity, check_fold_fields_validity +from .data_checks import ( + check_geology_fields_validity, + check_structure_fields_validity, + check_fault_fields_validity, + check_fold_fields_validity, +) +from .contact_extractor import ContactExtractor # external imports import geopandas @@ -91,6 +97,7 @@ def __init__(self, verbose_level: VerboseLevel = VerboseLevel.ALL): self.colour_filename = None self.verbose_level = verbose_level self.config = Config() + self.contact_extractor = ContactExtractor(self) @property @beartype.beartype @@ -1508,110 +1515,13 @@ def get_value_from_raster_df(self, datatype: Datatype, df: pandas.DataFrame): @beartype.beartype def extract_all_contacts(self, save_contacts=True): - """ - Extract the contacts between units in the geology GeoDataFrame - """ - logger.info("Extracting contacts") - geology = self.get_map_data(Datatype.GEOLOGY).copy() - geology = geology.dissolve(by="UNITNAME", as_index=False) - # Remove intrusions - geology = geology[~geology["INTRUSIVE"]] - geology = geology[~geology["SILL"]] - # Remove faults from contact geomety - if self.get_map_data(Datatype.FAULT) is not None: - faults = self.get_map_data(Datatype.FAULT).copy() - faults["geometry"] = faults.buffer(50) - geology = geopandas.overlay(geology, faults, how="difference", keep_geom_type=False) - units = geology["UNITNAME"].unique() - column_names = ["UNITNAME_1", "UNITNAME_2", "geometry"] - contacts = geopandas.GeoDataFrame(crs=geology.crs, columns=column_names, data=None) - while len(units) > 1: - unit1 = units[0] - units = units[1:] - for unit2 in units: - if unit1 != unit2: - # print(f'contact: {unit1} and {unit2}') - join = geopandas.overlay( - geology[geology["UNITNAME"] == unit1], - geology[geology["UNITNAME"] == unit2], - keep_geom_type=False, - )[column_names] - join["geometry"] = join.buffer(1) - buffered = geology[geology["UNITNAME"] == unit2][["geometry"]].copy() - buffered["geometry"] = buffered.boundary - end = geopandas.overlay(buffered, join, keep_geom_type=False) - if len(end): - contacts = pandas.concat([contacts, end], ignore_index=True) - # contacts["TYPE"] = "UNKNOWN" - contacts["length"] = [row.length for row in contacts["geometry"]] - # print('finished extracting contacts') - if save_contacts: - self.contacts = contacts - return contacts + """Delegate contact extraction to the ContactExtractor.""" + return self.contact_extractor.extract_all_contacts(save_contacts) @beartype.beartype def extract_basal_contacts(self, stratigraphic_column: list, save_contacts=True): - """ - Identify the basal unit of the contacts based on the stratigraphic column - - Args: - stratigraphic_column (list): - The stratigraphic column to use - """ - logger.info("Extracting basal contacts") - - units = stratigraphic_column - basal_contacts = self.contacts.copy() - - # check if the units in the strati colum are in the geology dataset, so that basal contacts can be built - # if not, stop the project - if any(unit not in units for unit in basal_contacts["UNITNAME_1"].unique()): - missing_units = ( - basal_contacts[~basal_contacts["UNITNAME_1"].isin(units)]["UNITNAME_1"] - .unique() - .tolist() - ) - logger.error( - "There are units in the Geology dataset, but not in the stratigraphic column: " - + ", ".join(missing_units) - + ". Please readjust the stratigraphic column if this is a user defined column." - ) - raise ValueError( - "There are units in stratigraphic column, but not in the Geology dataset: " - + ", ".join(missing_units) - + ". Please readjust the stratigraphic column if this is a user defined column." - ) - - # apply minimum lithological id between the two units - basal_contacts["ID"] = basal_contacts.apply( - lambda row: min(units.index(row["UNITNAME_1"]), units.index(row["UNITNAME_2"])), axis=1 - ) - # match the name of the unit with the minimum id - basal_contacts["basal_unit"] = basal_contacts.apply(lambda row: units[row["ID"]], axis=1) - # how many units apart are the two units? - basal_contacts["stratigraphic_distance"] = basal_contacts.apply( - lambda row: abs(units.index(row["UNITNAME_1"]) - units.index(row["UNITNAME_2"])), axis=1 - ) - # if the units are more than 1 unit apart, the contact is abnormal (meaning that there is one (or more) unit(s) missing in between the two) - basal_contacts["type"] = basal_contacts.apply( - lambda row: "ABNORMAL" if abs(row["stratigraphic_distance"]) > 1 else "BASAL", axis=1 - ) - - basal_contacts = basal_contacts[["ID", "basal_unit", "type", "geometry"]] - - # added code to make sure that multi-line that touch each other are snapped and merged. - # necessary for the reconstruction based on featureId - basal_contacts["geometry"] = [ - shapely.line_merge(shapely.snap(geo, geo, 1)) for geo in basal_contacts["geometry"] - ] - - if save_contacts: - # keep abnormal contacts as all_basal_contacts - self.all_basal_contacts = basal_contacts - # remove the abnormal contacts from basal contacts - self.basal_contacts = basal_contacts[basal_contacts["type"] == "BASAL"] - - return basal_contacts + """Delegate basal contact extraction to the ContactExtractor.""" + return self.contact_extractor.extract_basal_contacts(stratigraphic_column, save_contacts) @beartype.beartype def colour_units( diff --git a/map2loop/project.py b/map2loop/project.py index d9cfbb83..0560a764 100644 --- a/map2loop/project.py +++ b/map2loop/project.py @@ -529,7 +529,7 @@ def extract_geology_contacts(self): Use the stratigraphic column, and fault and geology data to extract points along contacts """ # Use stratigraphic column to determine basal contacts - self.map_data.extract_basal_contacts(self.stratigraphic_column.column) + self.map_data.contact_extractor.extract_basal_contacts(self.stratigraphic_column.column) # sample the contacts self.map_data.sampled_contacts = self.samplers[Datatype.GEOLOGY].sample( @@ -558,7 +558,7 @@ def calculate_stratigraphic_order(self, take_best=False): for sorter in sorters ] basal_contacts = [ - self.map_data.extract_basal_contacts(column, save_contacts=False) + self.map_data.contact_extractor.extract_basal_contacts(column, save_contacts=False) for column in columns ] basal_lengths = [ @@ -763,7 +763,7 @@ def run_all(self, user_defined_stratigraphic_column=None, take_best=False): logger.info(f'User defined stratigraphic column: {user_defined_stratigraphic_column}') # Calculate contacts before stratigraphic column - self.map_data.extract_all_contacts() + self.map_data.contact_extractor.extract_all_contacts() # Calculate the stratigraphic column if issubclass(type(user_defined_stratigraphic_column), list): diff --git a/tests/contact_extractor/test_contact_extractor.py b/tests/contact_extractor/test_contact_extractor.py new file mode 100644 index 00000000..51a8e8f7 --- /dev/null +++ b/tests/contact_extractor/test_contact_extractor.py @@ -0,0 +1,44 @@ +import geopandas as gpd +import shapely.geometry +import pytest + +from map2loop.mapdata import MapData + + +@pytest.fixture +def simple_mapdata(): + # Create two adjacent square polygons representing two units + poly1 = shapely.geometry.Polygon([(0, 0), (2, 0), (2, 2), (0, 2)]) + poly2 = shapely.geometry.Polygon([(2, 0), (4, 0), (4, 2), (2, 2)]) + + data = gpd.GeoDataFrame( + { + "UNITNAME": ["unit1", "unit2"], + "INTRUSIVE": [False, False], + "SILL": [False, False], + "geometry": [poly1, poly2], + }, + crs="EPSG:4326", + ) + + md = MapData() + md.data[0] = data # Datatype.GEOLOGY == 0 + md.data_states[0] = 5 # Datastate.COMPLETE + md.dirtyflags[0] = False + return md + + +def test_extract_all_contacts(simple_mapdata): + result = simple_mapdata.contact_extractor.extract_all_contacts() + assert len(result) == 1 + assert simple_mapdata.contacts is not None + + +def test_extract_basal_contacts(simple_mapdata): + simple_mapdata.contact_extractor.extract_all_contacts() + contacts = simple_mapdata.contact_extractor.extract_basal_contacts([ + "unit1", + "unit2", + ]) + assert list(contacts["basal_unit"]) == ["unit1"] + assert simple_mapdata.basal_contacts is not None From 14851a2c59cb282013884d863f66fa7fa8e36db0 Mon Sep 17 00:00:00 2001 From: Rabii Chaarani <50892556+rabii-chaarani@users.noreply.github.com> Date: Mon, 30 Jun 2025 15:37:50 +0930 Subject: [PATCH 2/6] refactor: decouple ContactExtractor from MapData --- map2loop/contact_extractor.py | 36 ++++++----- map2loop/mapdata.py | 30 +++++++-- map2loop/project.py | 6 +- .../test_contact_extractor.py | 61 ++++++++++++------- 4 files changed, 84 insertions(+), 49 deletions(-) diff --git a/map2loop/contact_extractor.py b/map2loop/contact_extractor.py index e3575be2..9c6cd4d6 100644 --- a/map2loop/contact_extractor.py +++ b/map2loop/contact_extractor.py @@ -8,24 +8,27 @@ import pandas import shapely -from .m2l_enums import Datatype from .logging import getLogger logger = getLogger(__name__) class ContactExtractor: - """Encapsulates contact extraction logic used by :class:`MapData`.""" + """Encapsulates contact extraction logic.""" - def __init__(self, map_data: "MapData") -> None: - self.map_data = map_data + def __init__(self) -> None: + pass # ------------------------------------------------------------------ - def extract_all_contacts(self, save_contacts: bool = True) -> geopandas.GeoDataFrame: - """Extract all contacts between units in the geology GeoDataFrame.""" + def extract_all_contacts( + self, + geology: geopandas.GeoDataFrame, + faults: geopandas.GeoDataFrame | None = None, + ) -> geopandas.GeoDataFrame: + """Extract all contacts between units in ``geology``.""" logger.info("Extracting contacts") - geology = self.map_data.get_map_data(Datatype.GEOLOGY).copy() + geology = geology.copy() geology = geology.dissolve(by="UNITNAME", as_index=False) # Remove intrusions @@ -33,8 +36,8 @@ def extract_all_contacts(self, save_contacts: bool = True) -> geopandas.GeoDataF geology = geology[~geology["SILL"]] # Remove faults from contact geometry - if self.map_data.get_map_data(Datatype.FAULT) is not None: - faults = self.map_data.get_map_data(Datatype.FAULT).copy() + if faults is not None: + faults = faults.copy() faults["geometry"] = faults.buffer(50) geology = geopandas.overlay(geology, faults, how="difference", keep_geom_type=False) @@ -61,21 +64,20 @@ def extract_all_contacts(self, save_contacts: bool = True) -> geopandas.GeoDataF contacts["length"] = [row.length for row in contacts["geometry"]] - if save_contacts: - self.map_data.contacts = contacts - return contacts # ------------------------------------------------------------------ def extract_basal_contacts( - self, stratigraphic_column: List[str], save_contacts: bool = True + self, + contacts: geopandas.GeoDataFrame, + stratigraphic_column: List[str], ) -> geopandas.GeoDataFrame: - """Identify the basal unit of the contacts based on the stratigraphic column.""" + """Identify the basal unit of ``contacts`` based on ``stratigraphic_column``.""" logger.info("Extracting basal contacts") units = stratigraphic_column - basal_contacts = self.map_data.contacts.copy() + basal_contacts = contacts.copy() # verify units exist in the geology dataset if any(unit not in units for unit in basal_contacts["UNITNAME_1"].unique()): @@ -121,9 +123,5 @@ def extract_basal_contacts( shapely.line_merge(shapely.snap(geo, geo, 1)) for geo in basal_contacts["geometry"] ] - if save_contacts: - self.map_data.all_basal_contacts = basal_contacts - self.map_data.basal_contacts = basal_contacts[basal_contacts["type"] == "BASAL"] - return basal_contacts diff --git a/map2loop/mapdata.py b/map2loop/mapdata.py index 342281b4..1e20232c 100644 --- a/map2loop/mapdata.py +++ b/map2loop/mapdata.py @@ -97,7 +97,7 @@ def __init__(self, verbose_level: VerboseLevel = VerboseLevel.ALL): self.colour_filename = None self.verbose_level = verbose_level self.config = Config() - self.contact_extractor = ContactExtractor(self) + self.contact_extractor = ContactExtractor() @property @beartype.beartype @@ -1515,13 +1515,33 @@ def get_value_from_raster_df(self, datatype: Datatype, df: pandas.DataFrame): @beartype.beartype def extract_all_contacts(self, save_contacts=True): - """Delegate contact extraction to the ContactExtractor.""" - return self.contact_extractor.extract_all_contacts(save_contacts) + """Extract contacts from the loaded geology and fault data.""" + + geology = self.get_map_data(Datatype.GEOLOGY) + faults = self.get_map_data(Datatype.FAULT) + + contacts = self.contact_extractor.extract_all_contacts(geology, faults) + + if save_contacts: + self.contacts = contacts + return contacts @beartype.beartype def extract_basal_contacts(self, stratigraphic_column: list, save_contacts=True): - """Delegate basal contact extraction to the ContactExtractor.""" - return self.contact_extractor.extract_basal_contacts(stratigraphic_column, save_contacts) + """Identify basal contacts using the loaded contacts.""" + + if self.contacts is None: + raise ValueError("Contacts must be extracted before extracting basal contacts") + + basal_contacts = self.contact_extractor.extract_basal_contacts( + self.contacts, stratigraphic_column + ) + + if save_contacts: + self.all_basal_contacts = basal_contacts + self.basal_contacts = basal_contacts[basal_contacts["type"] == "BASAL"] + + return basal_contacts @beartype.beartype def colour_units( diff --git a/map2loop/project.py b/map2loop/project.py index 0560a764..d9cfbb83 100644 --- a/map2loop/project.py +++ b/map2loop/project.py @@ -529,7 +529,7 @@ def extract_geology_contacts(self): Use the stratigraphic column, and fault and geology data to extract points along contacts """ # Use stratigraphic column to determine basal contacts - self.map_data.contact_extractor.extract_basal_contacts(self.stratigraphic_column.column) + self.map_data.extract_basal_contacts(self.stratigraphic_column.column) # sample the contacts self.map_data.sampled_contacts = self.samplers[Datatype.GEOLOGY].sample( @@ -558,7 +558,7 @@ def calculate_stratigraphic_order(self, take_best=False): for sorter in sorters ] basal_contacts = [ - self.map_data.contact_extractor.extract_basal_contacts(column, save_contacts=False) + self.map_data.extract_basal_contacts(column, save_contacts=False) for column in columns ] basal_lengths = [ @@ -763,7 +763,7 @@ def run_all(self, user_defined_stratigraphic_column=None, take_best=False): logger.info(f'User defined stratigraphic column: {user_defined_stratigraphic_column}') # Calculate contacts before stratigraphic column - self.map_data.contact_extractor.extract_all_contacts() + self.map_data.extract_all_contacts() # Calculate the stratigraphic column if issubclass(type(user_defined_stratigraphic_column), list): diff --git a/tests/contact_extractor/test_contact_extractor.py b/tests/contact_extractor/test_contact_extractor.py index 51a8e8f7..40ba7f33 100644 --- a/tests/contact_extractor/test_contact_extractor.py +++ b/tests/contact_extractor/test_contact_extractor.py @@ -1,17 +1,41 @@ import geopandas as gpd import shapely.geometry import pytest - -from map2loop.mapdata import MapData +import importlib.util +import pathlib +import types +import sys + +ROOT = pathlib.Path(__file__).resolve().parents[2] +PACKAGE_NAME = "map2loop" + +if PACKAGE_NAME not in sys.modules: + pkg = types.ModuleType(PACKAGE_NAME) + pkg.__path__ = [str(ROOT / PACKAGE_NAME)] + import logging + pkg.loggers = {} + pkg.ch = logging.StreamHandler() + pkg.ch.setLevel(logging.WARNING) + sys.modules[PACKAGE_NAME] = pkg + +spec = importlib.util.spec_from_file_location( + f"{PACKAGE_NAME}.contact_extractor", + ROOT / PACKAGE_NAME / "contact_extractor.py", +) +module = importlib.util.module_from_spec(spec) +sys.modules[f"{PACKAGE_NAME}.contact_extractor"] = module +spec.loader.exec_module(module) +ContactExtractor = module.ContactExtractor @pytest.fixture -def simple_mapdata(): - # Create two adjacent square polygons representing two units +def simple_geology(): + """Create a minimal geology dataset for testing.""" + poly1 = shapely.geometry.Polygon([(0, 0), (2, 0), (2, 2), (0, 2)]) poly2 = shapely.geometry.Polygon([(2, 0), (4, 0), (4, 2), (2, 2)]) - data = gpd.GeoDataFrame( + return gpd.GeoDataFrame( { "UNITNAME": ["unit1", "unit2"], "INTRUSIVE": [False, False], @@ -21,24 +45,17 @@ def simple_mapdata(): crs="EPSG:4326", ) - md = MapData() - md.data[0] = data # Datatype.GEOLOGY == 0 - md.data_states[0] = 5 # Datastate.COMPLETE - md.dirtyflags[0] = False - return md - -def test_extract_all_contacts(simple_mapdata): - result = simple_mapdata.contact_extractor.extract_all_contacts() +def test_extract_all_contacts(simple_geology): + extractor = ContactExtractor() + result = extractor.extract_all_contacts(simple_geology) assert len(result) == 1 - assert simple_mapdata.contacts is not None -def test_extract_basal_contacts(simple_mapdata): - simple_mapdata.contact_extractor.extract_all_contacts() - contacts = simple_mapdata.contact_extractor.extract_basal_contacts([ - "unit1", - "unit2", - ]) - assert list(contacts["basal_unit"]) == ["unit1"] - assert simple_mapdata.basal_contacts is not None +def test_extract_basal_contacts(simple_geology): + extractor = ContactExtractor() + contacts = extractor.extract_all_contacts(simple_geology) + basal = extractor.extract_basal_contacts(contacts, ["unit1", "unit2"]) + assert list(basal["basal_unit"]) == ["unit1"] + + From 44c3b28d6f5f1a7adafa1d821724537ccaa567b1 Mon Sep 17 00:00:00 2001 From: rabii-chaarani Date: Mon, 30 Jun 2025 15:50:14 +0930 Subject: [PATCH 3/6] refactor: remove dependency to MapData and add tests --- map2loop/contact_extractor.py | 26 +++++++----- tests/contacts/test_contact_extractor.py | 51 ++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 10 deletions(-) create mode 100644 tests/contacts/test_contact_extractor.py diff --git a/map2loop/contact_extractor.py b/map2loop/contact_extractor.py index 9c6cd4d6..a0d36633 100644 --- a/map2loop/contact_extractor.py +++ b/map2loop/contact_extractor.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import List +from typing import List, Optional import geopandas import pandas @@ -16,19 +16,24 @@ class ContactExtractor: """Encapsulates contact extraction logic.""" - def __init__(self) -> None: - pass + def __init__(self, + geology_data: geopandas.GeoDataFrame, + fault_data: Optional[geopandas.GeoDataFrame] = None) -> None: + + self.geology_data = geology_data + self.fault_data = fault_data + self.contacts = None # ------------------------------------------------------------------ - def extract_all_contacts( - self, - geology: geopandas.GeoDataFrame, - faults: geopandas.GeoDataFrame | None = None, - ) -> geopandas.GeoDataFrame: - """Extract all contacts between units in ``geology``.""" + def extract_all_contacts(self) -> geopandas.GeoDataFrame: + """Extract all contacts between units in ``geology_data``.""" logger.info("Extracting contacts") - geology = geology.copy() + if self.fault_data is not None: + faults = self.fault_data.copy() + else: + faults = None + geology = self.geology_data.copy() geology = geology.dissolve(by="UNITNAME", as_index=False) # Remove intrusions @@ -63,6 +68,7 @@ def extract_all_contacts( contacts = pandas.concat([contacts, end], ignore_index=True) contacts["length"] = [row.length for row in contacts["geometry"]] + self.contacts = contacts return contacts diff --git a/tests/contacts/test_contact_extractor.py b/tests/contacts/test_contact_extractor.py new file mode 100644 index 00000000..7fc2e6a0 --- /dev/null +++ b/tests/contacts/test_contact_extractor.py @@ -0,0 +1,51 @@ +import types +import importlib.util +import sys +import logging +import pathlib +import geopandas as gpd +from shapely.geometry import Polygon + +def load_contact_extractor(): + base = pathlib.Path(__file__).resolve().parents[2] / "map2loop" + pkg = types.ModuleType("map2loop") + pkg.loggers = {} + pkg.ch = logging.StreamHandler() + sys.modules["map2loop"] = pkg + for name in ["logging", "m2l_enums", "contacts"]: + spec = importlib.util.spec_from_file_location(f"map2loop.{name}", base / f"{name}.py") + mod = importlib.util.module_from_spec(spec) + mod.__package__ = "map2loop" + sys.modules[f"map2loop.{name}"] = mod + spec.loader.exec_module(mod) + return sys.modules["map2loop.contacts"].ContactExtractor + +ContactExtractor = load_contact_extractor() + +def simple_geology(): + poly1 = Polygon([(0, 0), (2, 0), (2, 2), (0, 2)]) + poly2 = Polygon([(2, 0), (4, 0), (4, 2), (2, 2)]) + return gpd.GeoDataFrame( + { + "UNITNAME": ["A", "B"], + "INTRUSIVE": [False, False], + "SILL": [False, False], + "geometry": [poly1, poly2], + }, + crs="EPSG:3857", + ) + +def test_extract_all_contacts_basic(): + ce = ContactExtractor() + gdf = simple_geology() + contacts = ce.extract_all_contacts(gdf) + assert {"UNITNAME_1", "UNITNAME_2", "geometry", "length"} <= set(contacts.columns) + assert len(contacts) > 0 + +def test_extract_basal_contacts_basic(): + ce = ContactExtractor() + gdf = simple_geology() + contacts = ce.extract_all_contacts(gdf) + allc, basal = ce.extract_basal_contacts(contacts, ["A", "B"]) + assert len(allc) >= len(basal) + assert "basal_unit" in allc.columns \ No newline at end of file From 5024e59ad1048d8b37eb0601278b5a24afb5816b Mon Sep 17 00:00:00 2001 From: rabii-chaarani Date: Mon, 30 Jun 2025 15:57:19 +0930 Subject: [PATCH 4/6] refactor: update ContactExtractor instantiation in basal contacts test --- tests/contacts/test_contact_extractor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/contacts/test_contact_extractor.py b/tests/contacts/test_contact_extractor.py index 7fc2e6a0..54c2e572 100644 --- a/tests/contacts/test_contact_extractor.py +++ b/tests/contacts/test_contact_extractor.py @@ -43,7 +43,7 @@ def test_extract_all_contacts_basic(): assert len(contacts) > 0 def test_extract_basal_contacts_basic(): - ce = ContactExtractor() + ce = ContactExtractor(simple_geology()) gdf = simple_geology() contacts = ce.extract_all_contacts(gdf) allc, basal = ce.extract_basal_contacts(contacts, ["A", "B"]) From 3832d86416f8bc047b3b6070e7f64c5e556862e5 Mon Sep 17 00:00:00 2001 From: rabii-chaarani Date: Mon, 30 Jun 2025 16:01:00 +0930 Subject: [PATCH 5/6] refactor: update tests --- map2loop/contact_extractor.py | 8 +-- .../test_contact_extractor.py | 10 ++-- tests/contacts/test_contact_extractor.py | 51 ------------------- 3 files changed, 9 insertions(+), 60 deletions(-) delete mode 100644 tests/contacts/test_contact_extractor.py diff --git a/map2loop/contact_extractor.py b/map2loop/contact_extractor.py index a0d36633..14226948 100644 --- a/map2loop/contact_extractor.py +++ b/map2loop/contact_extractor.py @@ -23,8 +23,9 @@ def __init__(self, self.geology_data = geology_data self.fault_data = fault_data self.contacts = None - - # ------------------------------------------------------------------ + self.basal_contacts = None + + def extract_all_contacts(self) -> geopandas.GeoDataFrame: """Extract all contacts between units in ``geology_data``.""" @@ -75,7 +76,6 @@ def extract_all_contacts(self) -> geopandas.GeoDataFrame: # ------------------------------------------------------------------ def extract_basal_contacts( self, - contacts: geopandas.GeoDataFrame, stratigraphic_column: List[str], ) -> geopandas.GeoDataFrame: """Identify the basal unit of ``contacts`` based on ``stratigraphic_column``.""" @@ -83,7 +83,7 @@ def extract_basal_contacts( logger.info("Extracting basal contacts") units = stratigraphic_column - basal_contacts = contacts.copy() + basal_contacts = self.contacts.copy() # verify units exist in the geology dataset if any(unit not in units for unit in basal_contacts["UNITNAME_1"].unique()): diff --git a/tests/contact_extractor/test_contact_extractor.py b/tests/contact_extractor/test_contact_extractor.py index 40ba7f33..67c75d7a 100644 --- a/tests/contact_extractor/test_contact_extractor.py +++ b/tests/contact_extractor/test_contact_extractor.py @@ -47,15 +47,15 @@ def simple_geology(): def test_extract_all_contacts(simple_geology): - extractor = ContactExtractor() - result = extractor.extract_all_contacts(simple_geology) + extractor = ContactExtractor(simple_geology) + result = extractor.extract_all_contacts() assert len(result) == 1 def test_extract_basal_contacts(simple_geology): - extractor = ContactExtractor() - contacts = extractor.extract_all_contacts(simple_geology) - basal = extractor.extract_basal_contacts(contacts, ["unit1", "unit2"]) + extractor = ContactExtractor(simple_geology) + contacts = extractor.extract_all_contacts() + basal = extractor.extract_basal_contacts(["unit1", "unit2"]) assert list(basal["basal_unit"]) == ["unit1"] diff --git a/tests/contacts/test_contact_extractor.py b/tests/contacts/test_contact_extractor.py deleted file mode 100644 index 54c2e572..00000000 --- a/tests/contacts/test_contact_extractor.py +++ /dev/null @@ -1,51 +0,0 @@ -import types -import importlib.util -import sys -import logging -import pathlib -import geopandas as gpd -from shapely.geometry import Polygon - -def load_contact_extractor(): - base = pathlib.Path(__file__).resolve().parents[2] / "map2loop" - pkg = types.ModuleType("map2loop") - pkg.loggers = {} - pkg.ch = logging.StreamHandler() - sys.modules["map2loop"] = pkg - for name in ["logging", "m2l_enums", "contacts"]: - spec = importlib.util.spec_from_file_location(f"map2loop.{name}", base / f"{name}.py") - mod = importlib.util.module_from_spec(spec) - mod.__package__ = "map2loop" - sys.modules[f"map2loop.{name}"] = mod - spec.loader.exec_module(mod) - return sys.modules["map2loop.contacts"].ContactExtractor - -ContactExtractor = load_contact_extractor() - -def simple_geology(): - poly1 = Polygon([(0, 0), (2, 0), (2, 2), (0, 2)]) - poly2 = Polygon([(2, 0), (4, 0), (4, 2), (2, 2)]) - return gpd.GeoDataFrame( - { - "UNITNAME": ["A", "B"], - "INTRUSIVE": [False, False], - "SILL": [False, False], - "geometry": [poly1, poly2], - }, - crs="EPSG:3857", - ) - -def test_extract_all_contacts_basic(): - ce = ContactExtractor() - gdf = simple_geology() - contacts = ce.extract_all_contacts(gdf) - assert {"UNITNAME_1", "UNITNAME_2", "geometry", "length"} <= set(contacts.columns) - assert len(contacts) > 0 - -def test_extract_basal_contacts_basic(): - ce = ContactExtractor(simple_geology()) - gdf = simple_geology() - contacts = ce.extract_all_contacts(gdf) - allc, basal = ce.extract_basal_contacts(contacts, ["A", "B"]) - assert len(allc) >= len(basal) - assert "basal_unit" in allc.columns \ No newline at end of file From 4455520985341cd4b4172e85e69348cceab2dbff Mon Sep 17 00:00:00 2001 From: rabii-chaarani Date: Mon, 30 Jun 2025 16:02:19 +0930 Subject: [PATCH 6/6] refactor: remove unnecessary variable assignment in basal contacts test --- tests/contact_extractor/test_contact_extractor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/contact_extractor/test_contact_extractor.py b/tests/contact_extractor/test_contact_extractor.py index 67c75d7a..eb7d511c 100644 --- a/tests/contact_extractor/test_contact_extractor.py +++ b/tests/contact_extractor/test_contact_extractor.py @@ -54,7 +54,7 @@ def test_extract_all_contacts(simple_geology): def test_extract_basal_contacts(simple_geology): extractor = ContactExtractor(simple_geology) - contacts = extractor.extract_all_contacts() + extractor.extract_all_contacts() basal = extractor.extract_basal_contacts(["unit1", "unit2"]) assert list(basal["basal_unit"]) == ["unit1"]