Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions map2loop/contact_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
"""Utility class for extracting geological contacts."""

from __future__ import annotations

from typing import List, Optional

import geopandas
import pandas
import shapely

from .logging import getLogger

logger = getLogger(__name__)


class ContactExtractor:
"""Encapsulates contact extraction logic."""

def __init__(self,
geology_data: geopandas.GeoDataFrame,
fault_data: Optional[geopandas.GeoDataFrame] = None) -> None:

self.geology_data = geology_data
self.fault_data = fault_data
self.contacts = None
self.basal_contacts = None


def extract_all_contacts(self) -> geopandas.GeoDataFrame:
"""Extract all contacts between units in ``geology_data``."""

logger.info("Extracting contacts")
if self.fault_data is not None:
faults = self.fault_data.copy()
else:
faults = None
geology = self.geology_data.copy()
geology = geology.dissolve(by="UNITNAME", as_index=False)

# Remove intrusions
geology = geology[~geology["INTRUSIVE"]]
geology = geology[~geology["SILL"]]

# Remove faults from contact geometry
if faults is not None:
faults = faults.copy()
faults["geometry"] = faults.buffer(50)
geology = geopandas.overlay(geology, faults, how="difference", keep_geom_type=False)

units = geology["UNITNAME"].unique()
column_names = ["UNITNAME_1", "UNITNAME_2", "geometry"]
contacts = geopandas.GeoDataFrame(crs=geology.crs, columns=column_names, data=None)

while len(units) > 1:
unit1 = units[0]
units = units[1:]
for unit2 in units:
if unit1 != unit2:
join = geopandas.overlay(
geology[geology["UNITNAME"] == unit1],
geology[geology["UNITNAME"] == unit2],
keep_geom_type=False,
)[column_names]
join["geometry"] = join.buffer(1)
buffered = geology[geology["UNITNAME"] == unit2][["geometry"]].copy()
buffered["geometry"] = buffered.boundary
end = geopandas.overlay(buffered, join, keep_geom_type=False)
if len(end):
contacts = pandas.concat([contacts, end], ignore_index=True)

contacts["length"] = [row.length for row in contacts["geometry"]]
self.contacts = contacts

return contacts

# ------------------------------------------------------------------
def extract_basal_contacts(
self,
stratigraphic_column: List[str],
) -> geopandas.GeoDataFrame:
"""Identify the basal unit of ``contacts`` based on ``stratigraphic_column``."""

logger.info("Extracting basal contacts")

units = stratigraphic_column
basal_contacts = self.contacts.copy()

# verify units exist in the geology dataset
if any(unit not in units for unit in basal_contacts["UNITNAME_1"].unique()):
missing_units = (
basal_contacts[~basal_contacts["UNITNAME_1"].isin(units)]["UNITNAME_1"]
.unique()
.tolist()
)
logger.error(
"There are units in the Geology dataset, but not in the stratigraphic column: "
+ ", ".join(missing_units)
+ ". Please readjust the stratigraphic column if this is a user defined column."
)
raise ValueError(
"There are units in stratigraphic column, but not in the Geology dataset: "
+ ", ".join(missing_units)
+ ". Please readjust the stratigraphic column if this is a user defined column."
)

# apply minimum lithological id between the two units
basal_contacts["ID"] = basal_contacts.apply(
lambda row: min(units.index(row["UNITNAME_1"]), units.index(row["UNITNAME_2"])), axis=1
)

# match the name of the unit with the minimum id
basal_contacts["basal_unit"] = basal_contacts.apply(lambda row: units[row["ID"]], axis=1)

# how many units apart are the two units?
basal_contacts["stratigraphic_distance"] = basal_contacts.apply(
lambda row: abs(units.index(row["UNITNAME_1"]) - units.index(row["UNITNAME_2"])),
axis=1,
)

# if the units are more than 1 unit apart, the contact is abnormal
basal_contacts["type"] = basal_contacts.apply(
lambda row: "ABNORMAL" if abs(row["stratigraphic_distance"]) > 1 else "BASAL",
axis=1,
)

basal_contacts = basal_contacts[["ID", "basal_unit", "type", "geometry"]]

basal_contacts["geometry"] = [
shapely.line_merge(shapely.snap(geo, geo, 1)) for geo in basal_contacts["geometry"]
]

return basal_contacts

110 changes: 20 additions & 90 deletions map2loop/mapdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
from .config import Config
from .aus_state_urls import AustraliaStateUrls
from .utils import generate_random_hex_colors, calculate_minimum_fault_length
from .data_checks import check_geology_fields_validity, check_structure_fields_validity, check_fault_fields_validity, check_fold_fields_validity
from .data_checks import (
check_geology_fields_validity,
check_structure_fields_validity,
check_fault_fields_validity,
check_fold_fields_validity,
)
from .contact_extractor import ContactExtractor

# external imports
import geopandas
Expand Down Expand Up @@ -91,6 +97,7 @@ def __init__(self, verbose_level: VerboseLevel = VerboseLevel.ALL):
self.colour_filename = None
self.verbose_level = verbose_level
self.config = Config()
self.contact_extractor = ContactExtractor()

@property
@beartype.beartype
Expand Down Expand Up @@ -1508,107 +1515,30 @@ def get_value_from_raster_df(self, datatype: Datatype, df: pandas.DataFrame):

@beartype.beartype
def extract_all_contacts(self, save_contacts=True):
"""
Extract the contacts between units in the geology GeoDataFrame
"""
logger.info("Extracting contacts")
geology = self.get_map_data(Datatype.GEOLOGY).copy()
geology = geology.dissolve(by="UNITNAME", as_index=False)
# Remove intrusions
geology = geology[~geology["INTRUSIVE"]]
geology = geology[~geology["SILL"]]
# Remove faults from contact geomety
if self.get_map_data(Datatype.FAULT) is not None:
faults = self.get_map_data(Datatype.FAULT).copy()
faults["geometry"] = faults.buffer(50)
geology = geopandas.overlay(geology, faults, how="difference", keep_geom_type=False)
units = geology["UNITNAME"].unique()
column_names = ["UNITNAME_1", "UNITNAME_2", "geometry"]
contacts = geopandas.GeoDataFrame(crs=geology.crs, columns=column_names, data=None)
while len(units) > 1:
unit1 = units[0]
units = units[1:]
for unit2 in units:
if unit1 != unit2:
# print(f'contact: {unit1} and {unit2}')
join = geopandas.overlay(
geology[geology["UNITNAME"] == unit1],
geology[geology["UNITNAME"] == unit2],
keep_geom_type=False,
)[column_names]
join["geometry"] = join.buffer(1)
buffered = geology[geology["UNITNAME"] == unit2][["geometry"]].copy()
buffered["geometry"] = buffered.boundary
end = geopandas.overlay(buffered, join, keep_geom_type=False)
if len(end):
contacts = pandas.concat([contacts, end], ignore_index=True)
# contacts["TYPE"] = "UNKNOWN"
contacts["length"] = [row.length for row in contacts["geometry"]]
# print('finished extracting contacts')
"""Extract contacts from the loaded geology and fault data."""

geology = self.get_map_data(Datatype.GEOLOGY)
faults = self.get_map_data(Datatype.FAULT)

contacts = self.contact_extractor.extract_all_contacts(geology, faults)

if save_contacts:
self.contacts = contacts
return contacts

@beartype.beartype
def extract_basal_contacts(self, stratigraphic_column: list, save_contacts=True):
"""
Identify the basal unit of the contacts based on the stratigraphic column
"""Identify basal contacts using the loaded contacts."""

Args:
stratigraphic_column (list):
The stratigraphic column to use
"""
logger.info("Extracting basal contacts")

units = stratigraphic_column
basal_contacts = self.contacts.copy()

# check if the units in the strati colum are in the geology dataset, so that basal contacts can be built
# if not, stop the project
if any(unit not in units for unit in basal_contacts["UNITNAME_1"].unique()):
missing_units = (
basal_contacts[~basal_contacts["UNITNAME_1"].isin(units)]["UNITNAME_1"]
.unique()
.tolist()
)
logger.error(
"There are units in the Geology dataset, but not in the stratigraphic column: "
+ ", ".join(missing_units)
+ ". Please readjust the stratigraphic column if this is a user defined column."
)
raise ValueError(
"There are units in stratigraphic column, but not in the Geology dataset: "
+ ", ".join(missing_units)
+ ". Please readjust the stratigraphic column if this is a user defined column."
)
if self.contacts is None:
raise ValueError("Contacts must be extracted before extracting basal contacts")

# apply minimum lithological id between the two units
basal_contacts["ID"] = basal_contacts.apply(
lambda row: min(units.index(row["UNITNAME_1"]), units.index(row["UNITNAME_2"])), axis=1
basal_contacts = self.contact_extractor.extract_basal_contacts(
self.contacts, stratigraphic_column
)
# match the name of the unit with the minimum id
basal_contacts["basal_unit"] = basal_contacts.apply(lambda row: units[row["ID"]], axis=1)
# how many units apart are the two units?
basal_contacts["stratigraphic_distance"] = basal_contacts.apply(
lambda row: abs(units.index(row["UNITNAME_1"]) - units.index(row["UNITNAME_2"])), axis=1
)
# if the units are more than 1 unit apart, the contact is abnormal (meaning that there is one (or more) unit(s) missing in between the two)
basal_contacts["type"] = basal_contacts.apply(
lambda row: "ABNORMAL" if abs(row["stratigraphic_distance"]) > 1 else "BASAL", axis=1
)

basal_contacts = basal_contacts[["ID", "basal_unit", "type", "geometry"]]

# added code to make sure that multi-line that touch each other are snapped and merged.
# necessary for the reconstruction based on featureId
basal_contacts["geometry"] = [
shapely.line_merge(shapely.snap(geo, geo, 1)) for geo in basal_contacts["geometry"]
]

if save_contacts:
# keep abnormal contacts as all_basal_contacts
self.all_basal_contacts = basal_contacts
# remove the abnormal contacts from basal contacts
self.basal_contacts = basal_contacts[basal_contacts["type"] == "BASAL"]

return basal_contacts
Expand Down
61 changes: 61 additions & 0 deletions tests/contact_extractor/test_contact_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import geopandas as gpd
import shapely.geometry
import pytest
import importlib.util
import pathlib
import types
import sys

ROOT = pathlib.Path(__file__).resolve().parents[2]
PACKAGE_NAME = "map2loop"

if PACKAGE_NAME not in sys.modules:
pkg = types.ModuleType(PACKAGE_NAME)
pkg.__path__ = [str(ROOT / PACKAGE_NAME)]
import logging
pkg.loggers = {}
pkg.ch = logging.StreamHandler()
pkg.ch.setLevel(logging.WARNING)
sys.modules[PACKAGE_NAME] = pkg

spec = importlib.util.spec_from_file_location(
f"{PACKAGE_NAME}.contact_extractor",
ROOT / PACKAGE_NAME / "contact_extractor.py",
)
module = importlib.util.module_from_spec(spec)
sys.modules[f"{PACKAGE_NAME}.contact_extractor"] = module
spec.loader.exec_module(module)
ContactExtractor = module.ContactExtractor


@pytest.fixture
def simple_geology():
"""Create a minimal geology dataset for testing."""

poly1 = shapely.geometry.Polygon([(0, 0), (2, 0), (2, 2), (0, 2)])
poly2 = shapely.geometry.Polygon([(2, 0), (4, 0), (4, 2), (2, 2)])

return gpd.GeoDataFrame(
{
"UNITNAME": ["unit1", "unit2"],
"INTRUSIVE": [False, False],
"SILL": [False, False],
"geometry": [poly1, poly2],
},
crs="EPSG:4326",
)


def test_extract_all_contacts(simple_geology):
extractor = ContactExtractor(simple_geology)
result = extractor.extract_all_contacts()
assert len(result) == 1


def test_extract_basal_contacts(simple_geology):
extractor = ContactExtractor(simple_geology)
extractor.extract_all_contacts()
basal = extractor.extract_basal_contacts(["unit1", "unit2"])
assert list(basal["basal_unit"]) == ["unit1"]


Loading