From b5052c05cd01483a1294f60ecd8c4680d34508bb Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Wed, 14 Jan 2026 16:08:36 +0100 Subject: [PATCH 01/17] make base_file ABC --- src/osekit/core_api/audio_file.py | 8 +++----- src/osekit/core_api/base_file.py | 14 ++++++++++---- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/osekit/core_api/audio_file.py b/src/osekit/core_api/audio_file.py index 77508b11..5781f356 100644 --- a/src/osekit/core_api/audio_file.py +++ b/src/osekit/core_api/audio_file.py @@ -2,6 +2,7 @@ from __future__ import annotations +import typing from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -21,6 +22,8 @@ class AudioFile(BaseFile): """Audio file associated with timestamps.""" + supported_extensions: typing.ClassVar = [".wav", ".flac", ".mp3"] + def __init__( self, path: PathLike | str, @@ -116,11 +119,6 @@ def frames_indexes(self, start: Timestamp, stop: Timestamp) -> tuple[int, int]: stop_sample = round(((stop - self.begin) * self.sample_rate).total_seconds()) return start_sample, stop_sample - @classmethod - def from_base_file(cls, file: BaseFile) -> AudioFile: - """Return an AudioFile object from a BaseFile object.""" - return cls(path=file.path, begin=file.begin) - def move(self, folder: Path) -> None: """Move the file to the target folder. diff --git a/src/osekit/core_api/base_file.py b/src/osekit/core_api/base_file.py index 370124b3..1566a4fc 100644 --- a/src/osekit/core_api/base_file.py +++ b/src/osekit/core_api/base_file.py @@ -5,7 +5,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING +import typing +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Self from osekit.config import ( TIMESTAMP_FORMAT_EXPORTED_FILES_LOCALIZED, @@ -28,12 +30,14 @@ from osekit.utils.timestamp_utils import strptime_from_text -class BaseFile(Event): +class BaseFile(Event, ABC): """Base class for the File objects. A File object associates file-written data to timestamps. """ + supported_extensions: typing.ClassVar = [] + def __init__( self, path: PathLike | str, @@ -89,6 +93,7 @@ def __init__( self.end = end if end is not None else (self.begin + Timedelta(seconds=1)) + @abstractmethod def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: """Return the data that is between start and stop from the file. @@ -106,6 +111,7 @@ def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: The data between start and stop. """ + ... def to_dict(self) -> dict: """Serialize a BaseFile to a dictionary. @@ -123,7 +129,7 @@ def to_dict(self) -> dict: } @classmethod - def from_dict(cls, serialized: dict) -> BaseFile: + def from_dict(cls: type[Self], serialized: dict) -> type[Self]: """Return a BaseFile object from a dictionary. Parameters @@ -151,7 +157,7 @@ def __str__(self) -> str: """Overwrite __str__.""" return self.begin.strftime(TIMESTAMP_FORMAT_EXPORTED_FILES_LOCALIZED) - def __eq__(self, other: BaseFile): + def __eq__(self, other: BaseFile) -> bool: """Override __eq__.""" if not isinstance(other, BaseFile): return False From bf9833600f40da0c4874fceac6c517648e6a7746 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Wed, 14 Jan 2026 17:21:35 +0100 Subject: [PATCH 02/17] make BaseData and BaseItem ABC --- src/osekit/core_api/audio_data.py | 126 ++++++++++-------------------- src/osekit/core_api/audio_item.py | 15 ---- src/osekit/core_api/base_data.py | 63 ++++++++++----- src/osekit/core_api/base_item.py | 3 +- 4 files changed, 88 insertions(+), 119 deletions(-) diff --git a/src/osekit/core_api/audio_data.py b/src/osekit/core_api/audio_data.py index 1dfa098e..ab2f8f48 100644 --- a/src/osekit/core_api/audio_data.py +++ b/src/osekit/core_api/audio_data.py @@ -30,11 +30,14 @@ class AudioData(BaseData[AudioItem, AudioFile]): The data is accessed via an AudioItem object per AudioFile. """ + item_cls = AudioItem + def __init__( self, items: list[AudioItem] | None = None, begin: Timestamp | None = None, end: Timestamp | None = None, + name: str | None = None, sample_rate: int | None = None, instrument: Instrument | None = None, normalization: Normalization = Normalization.RAW, @@ -54,6 +57,8 @@ def __init__( end: Timestamp | None Only effective if items is None. Set the end of the empty data. + name: str | None + Name of the exported files. instrument: Instrument | None Instrument that might be used to obtain acoustic pressure from the wav audio data. @@ -61,7 +66,7 @@ def __init__( The type of normalization to apply to the audio data. """ - super().__init__(items=items, begin=begin, end=end) + super().__init__(items=items, begin=begin, end=end, name=name) self._set_sample_rate(sample_rate=sample_rate) self.instrument = instrument self.normalization = normalization @@ -115,6 +120,19 @@ def normalization_values(self, value: dict | None) -> None: } ) + @classmethod + def make_item( + cls, + file: AudioFile | None = None, + begin: Timestamp | None = None, + end: Timestamp | None = None, + ) -> AudioItem: + return AudioItem(file=file, begin=begin, end=end) + + @classmethod + def make_file(cls, path: Path, begin: Timestamp) -> AudioFile: + return AudioFile(path=path, begin=begin) + def get_normalization_values(self) -> dict: values = np.array(self.get_raw_value()) return { @@ -390,7 +408,13 @@ def to_dict(self) -> dict: ) @classmethod - def from_dict(cls, dictionary: dict) -> AudioData: + def from_base_dict( + cls, + dictionary: dict, + files: list[AudioFile], + begin: Timestamp, + end: Timestamp, + ) -> AudioData: """Deserialize an AudioData from a dictionary. Parameters @@ -404,102 +428,34 @@ def from_dict(cls, dictionary: dict) -> AudioData: The deserialized AudioData. """ - base_data = BaseData.from_dict(dictionary) instrument = ( None if dictionary["instrument"] is None else Instrument.from_dict(dictionary["instrument"]) ) - return cls.from_base_data( - data=base_data, + return cls.from_files( + files=files, + begin=begin, + end=end, + instrument=instrument, sample_rate=dictionary["sample_rate"], normalization=Normalization(dictionary["normalization"]), normalization_values=dictionary["normalization_values"], - instrument=instrument, ) @classmethod - def from_files( + def from_files( # noqa: D102 cls, - files: list[AudioFile], + files: list[AudioFile], # The method is redefined just to specify the type begin: Timestamp | None = None, end: Timestamp | None = None, - sample_rate: float | None = None, - instrument: Instrument | None = None, - normalization: Normalization = Normalization.RAW, - normalization_values: dict | None = None, + name: str | None = None, + **kwargs, ) -> AudioData: - """Return an AudioData object from a list of AudioFiles. - - Parameters - ---------- - files: list[AudioFile] - List of AudioFiles containing the data. - begin: Timestamp | None - Begin of the data object. - Defaulted to the begin of the first file. - end: Timestamp | None - End of the data object. - Defaulted to the end of the last file. - sample_rate: float | None - Sample rate of the AudioData. - instrument: Instrument | None - Instrument that might be used to obtain acoustic pressure from - the wav audio data. - normalization: Normalization - The type of normalization to apply to the audio data. - normalization_values: dict|None - Mean, peak and std values with which to normalize the data. - - Returns - ------- - AudioData: - The AudioData object. - - """ - return cls.from_base_data( - data=BaseData.from_files(files, begin, end), - sample_rate=sample_rate, - instrument=instrument, - normalization=normalization, - normalization_values=normalization_values, - ) - - @classmethod - def from_base_data( - cls, - data: BaseData, - sample_rate: float | None = None, - instrument: Instrument | None = None, - normalization: Normalization = Normalization.RAW, - normalization_values: dict | None = None, - ) -> AudioData: - """Return an AudioData object from a BaseData object. - - Parameters - ---------- - data: BaseData - BaseData object to convert to AudioData. - sample_rate: float | None - Sample rate of the AudioData. - instrument: Instrument | None - Instrument that might be used to obtain acoustic pressure from - the wav audio data. - normalization: Normalization - The type of normalization to apply to the audio data. - normalization_values: dict|None - Mean, peak and std values with which to normalize the data. - - Returns - ------- - AudioData: - The AudioData object. - - """ - return cls( - items=[AudioItem.from_base_item(item) for item in data.items], - sample_rate=sample_rate, - instrument=instrument, - normalization=normalization, - normalization_values=normalization_values, + return super().from_files( + files=files, # This way, this static error doesn't appear to the user + begin=begin, + end=end, + name=name, + **kwargs, ) diff --git a/src/osekit/core_api/audio_item.py b/src/osekit/core_api/audio_item.py index 170c28fd..c32442ce 100644 --- a/src/osekit/core_api/audio_item.py +++ b/src/osekit/core_api/audio_item.py @@ -7,7 +7,6 @@ import numpy as np from osekit.core_api.audio_file import AudioFile -from osekit.core_api.base_file import BaseFile from osekit.core_api.base_item import BaseItem if TYPE_CHECKING: @@ -65,17 +64,3 @@ def get_value(self) -> np.ndarray: if self.is_empty: return np.zeros((1, self.nb_channels)) return super().get_value() - - @classmethod - def from_base_item(cls, item: BaseItem) -> AudioItem: - """Return an AudioItem object from a BaseItem object.""" - file = item.file - if not file or isinstance(file, AudioFile): - return cls(file=file, begin=item.begin, end=item.end) - if isinstance(file, BaseFile): - return cls( - file=AudioFile.from_base_file(file), - begin=item.begin, - end=item.end, - ) - raise TypeError diff --git a/src/osekit/core_api/base_data.py b/src/osekit/core_api/base_data.py index c1067c9b..6ca68e57 100644 --- a/src/osekit/core_api/base_data.py +++ b/src/osekit/core_api/base_data.py @@ -6,8 +6,9 @@ from __future__ import annotations +from abc import ABC, abstractmethod from pathlib import Path -from typing import Generic, TypeVar +from typing import Any, Generic, Self, TypeVar import numpy as np from pandas import Timestamp, date_range @@ -27,13 +28,16 @@ TFile = TypeVar("TFile", bound=BaseFile) -class BaseData(Generic[TItem, TFile], Event): +class BaseData(Generic[TItem, TFile], Event, ABC): """Base class for the Data objects. Data corresponds to data scattered through different Files. The data is accessed via an Item object per File. """ + item_cls: type[TItem] + file_cls: type[TFile] + def __init__( self, items: list[TItem] | None = None, @@ -59,7 +63,7 @@ def __init__( """ if not items: - items = [BaseItem(begin=begin, end=end)] + items = [self.make_empty_item(begin=begin, end=end)] self.items = items self._begin = min(item.begin for item in self.items) self._end = max(item.end for item in self.items) @@ -135,9 +139,11 @@ def create_directories(path: Path) -> None: """ path.mkdir(parents=True, exist_ok=True, mode=DPDEFAULT) + @abstractmethod def write(self, folder: Path, link: bool = False) -> None: """Abstract method for writing data to file.""" + @abstractmethod def link(self, folder: Path) -> None: """Abstract method for linking data to a file in a given folder. @@ -184,22 +190,40 @@ def from_dict(cls, dictionary: dict) -> BaseData: """ files = [ - BaseFile( - Path(file["path"]), + cls.make_file( + path=Path(file["path"]), begin=strptime_from_text( file["begin"], datetime_template=TIMESTAMP_FORMATS_EXPORTED_FILES, ), - end=strptime_from_text( - file["end"], - datetime_template=TIMESTAMP_FORMATS_EXPORTED_FILES, - ), ) for file in dictionary["files"].values() ] begin = Timestamp(dictionary["begin"]) end = Timestamp(dictionary["end"]) - return cls.from_files(files, begin, end) + return cls.from_base_dict(dictionary, files, begin, end) + + @classmethod + def make_file(cls, path: Path, begin: Timestamp) -> type[TFile]: ... + + @classmethod + @abstractmethod + def make_item( + cls, + file: TFile | None = None, + begin: Timestamp | None = None, + end: Timestamp | None = None, + ) -> TItem: ... + + @classmethod + @abstractmethod + def from_base_dict( + cls, + dictionary: dict, + files: list[TFile], + begin: Timestamp, + end: Timestamp, + ) -> Self: ... @property def files(self) -> set[TFile]: @@ -234,7 +258,8 @@ def from_files( begin: Timestamp | None = None, end: Timestamp | None = None, name: str | None = None, - ) -> BaseData[TItem, TFile]: + **kwargs: Any, + ) -> Self: """Return a base DataBase object from a list of Files. Parameters @@ -257,7 +282,7 @@ def from_files( """ items = cls.items_from_files(files=files, begin=begin, end=end) - return cls(items=items, name=name) + return cls(items=items, name=name, **kwargs) @classmethod def items_from_files( @@ -265,7 +290,7 @@ def items_from_files( files: list[TFile], begin: Timestamp | None = None, end: Timestamp | None = None, - ) -> list[BaseItem]: + ) -> list[TItem]: """Return a list of Items from a list of Files and timestamps. The Items range from begin to end. @@ -295,12 +320,14 @@ def items_from_files( file for file in files if file.overlaps(Event(begin=begin, end=end)) ] - items = [BaseItem(file, begin, end) for file in included_files] + items = [ + cls.make_item(file=file, begin=begin, end=end) for file in included_files + ] if not items: - items.append(BaseItem(begin=begin, end=end)) + items.append(cls.make_item(begin=begin, end=end)) if (first_item := sorted(items, key=lambda item: item.begin)[0]).begin > begin: - items.append(BaseItem(begin=begin, end=first_item.begin)) + items.append(cls.make_item(begin=begin, end=first_item.begin)) if (last_item := sorted(items, key=lambda item: item.end)[-1]).end < end: - items.append(BaseItem(begin=last_item.end, end=end)) + items.append(cls.make_item(begin=last_item.end, end=end)) items = Event.remove_overlaps(items) - return Event.fill_gaps(items, BaseItem) + return Event.fill_gaps(items, cls.item_cls) diff --git a/src/osekit/core_api/base_item.py b/src/osekit/core_api/base_item.py index 1c5e7408..7d8c89a4 100644 --- a/src/osekit/core_api/base_item.py +++ b/src/osekit/core_api/base_item.py @@ -5,6 +5,7 @@ from __future__ import annotations +from abc import ABC from typing import TYPE_CHECKING, Generic, TypeVar import numpy as np @@ -18,7 +19,7 @@ TFile = TypeVar("TFile", bound=BaseFile) -class BaseItem(Generic[TFile], Event): +class BaseItem(Generic[TFile], Event, ABC): """Base class for the Item objects. An Item correspond to a portion of a File object. From bbd9843c14632f3429642d51a0519ff981b6bdd4 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Wed, 14 Jan 2026 18:06:24 +0100 Subject: [PATCH 03/17] make BaseDataset ABC --- src/osekit/core_api/audio_dataset.py | 68 +++++------- src/osekit/core_api/base_dataset.py | 150 ++++++++++++++++----------- 2 files changed, 116 insertions(+), 102 deletions(-) diff --git a/src/osekit/core_api/audio_dataset.py b/src/osekit/core_api/audio_dataset.py index 398f38ed..1eb523ea 100644 --- a/src/osekit/core_api/audio_dataset.py +++ b/src/osekit/core_api/audio_dataset.py @@ -32,6 +32,8 @@ class AudioDataset(BaseDataset[AudioData, AudioFile]): """ + file_cls = AudioFile + def __init__( self, data: list[AudioData], @@ -145,26 +147,8 @@ def write( ) @classmethod - def from_dict(cls, dictionary: dict) -> AudioDataset: - """Deserialize an AudioDataset from a dictionary. - - Parameters - ---------- - dictionary: dict - The serialized dictionary representing the AudioDataset. - - Returns - ------- - AudioDataset - The deserialized AudioDataset. - - """ - return cls( - [AudioData.from_dict(d) for d in dictionary["data"].values()], - name=dictionary["name"], - suffix=dictionary["suffix"], - folder=Path(dictionary["folder"]), - ) + def data_from_dict(cls, dictionary: dict) -> list[AudioData]: + return [AudioData.from_dict(data) for data in dictionary.values()] @classmethod def from_folder( # noqa: PLR0913 @@ -181,7 +165,7 @@ def from_folder( # noqa: PLR0913 name: str | None = None, instrument: Instrument | None = None, normalization: Normalization = Normalization.RAW, - **kwargs: any, + **kwargs, ) -> AudioDataset: """Return an AudioDataset from a folder containing the audio files. @@ -241,13 +225,7 @@ def from_folder( # noqa: PLR0913 The audio dataset. """ - kwargs.update( - { - "file_class": AudioFile, - "supported_file_extensions": [".wav", ".flac", ".mp3"], - }, - ) - base_dataset = BaseDataset.from_folder( + return super().from_folder( folder=folder, strptime_format=strptime_format, begin=begin, @@ -256,13 +234,9 @@ def from_folder( # noqa: PLR0913 mode=mode, overlap=overlap, data_duration=data_duration, - **kwargs, - ) - return cls.from_base_dataset( - base_dataset=base_dataset, + sample_rate=sample_rate, name=name, instrument=instrument, - sample_rate=sample_rate, normalization=normalization, ) @@ -272,11 +246,11 @@ def from_files( # noqa: PLR0913 files: list[AudioFile], begin: Timestamp | None = None, end: Timestamp | None = None, + name: str | None = None, mode: Literal["files", "timedelta_total", "timedelta_file"] = "timedelta_total", overlap: float = 0.0, data_duration: Timedelta | None = None, sample_rate: float | None = None, - name: str | None = None, instrument: Instrument | None = None, normalization: Normalization = Normalization.RAW, ) -> AudioDataset: @@ -324,20 +298,34 @@ def from_files( # noqa: PLR0913 The DataBase object. """ - base = BaseDataset.from_files( + return super().from_files( files=files, begin=begin, end=end, + name=name, + instrument=instrument, + normalization=normalization, + sample_rate=sample_rate, mode=mode, overlap=overlap, data_duration=data_duration, ) - return cls.from_base_dataset( - base, + + @classmethod + def data_from_files( + cls, + files: list[AudioFile], + begin: Timestamp | None = None, + end: Timestamp | None = None, + name: str | None = None, + **kwargs, + ) -> AudioData: + return AudioData.from_files( + files=files, + begin=begin, + end=end, name=name, - sample_rate=sample_rate, - instrument=instrument, - normalization=normalization, + **kwargs, ) @classmethod diff --git a/src/osekit/core_api/base_dataset.py b/src/osekit/core_api/base_dataset.py index 222f6e62..2d679df5 100644 --- a/src/osekit/core_api/base_dataset.py +++ b/src/osekit/core_api/base_dataset.py @@ -7,9 +7,10 @@ from __future__ import annotations import os +from abc import ABC, abstractmethod from bisect import bisect from pathlib import Path -from typing import TYPE_CHECKING, Generic, Literal, TypeVar +from typing import TYPE_CHECKING, Generic, Literal, Self, TypeVar from pandas import Timedelta, Timestamp, date_range from soundfile import LibsndfileError @@ -30,13 +31,15 @@ TFile = TypeVar("TFile", bound=BaseFile) -class BaseDataset(Generic[TData, TFile], Event): +class BaseDataset(Generic[TData, TFile], Event, ABC): """Base class for Dataset objects. Datasets are collections of Data, with methods that simplify repeated operations on the data. """ + file_cls: type[TFile] + def __init__( self, data: list[TData], @@ -210,7 +213,7 @@ def to_dict(self) -> dict: } @classmethod - def from_dict(cls, dictionary: dict) -> BaseDataset: + def from_dict(cls, dictionary: dict) -> Self: """Deserialize a BaseDataset from a dictionary. Parameters @@ -224,19 +227,22 @@ def from_dict(cls, dictionary: dict) -> BaseDataset: The deserialized BaseDataset. """ - return cls( - [BaseData.from_dict(d) for d in dictionary["data"].values()], - name=dictionary["name"], - suffix=dictionary["suffix"], - folder=Path(dictionary["folder"]), - ) + data = cls.data_from_dict(dictionary["data"]) + name = dictionary["name"] + suffix = dictionary["suffix"] + folder = Path(dictionary["folder"]) + return cls(data=data, name=name, suffix=suffix, folder=folder) + + @classmethod + @abstractmethod + def data_from_dict(cls, dictionary: dict) -> list[TData]: ... def write_json(self, folder: Path) -> None: """Write a serialized BaseDataset to a JSON file.""" serialize_json(folder / f"{self.name}.json", self.to_dict()) @classmethod - def from_json(cls, file: Path) -> BaseDataset: + def from_json(cls, file: Path) -> Self: """Deserialize a BaseDataset from a JSON file. Parameters @@ -262,7 +268,8 @@ def from_files( # noqa: PLR0913 data_duration: Timedelta | None = None, overlap: float = 0.0, name: str | None = None, - ) -> BaseDataset: + **kwargs, + ) -> Self: """Return a base BaseDataset object from a list of Files. Parameters @@ -301,9 +308,9 @@ def from_files( # noqa: PLR0913 """ if mode == "files": - data_base = [BaseData.from_files([f]) for f in files] - data_base = BaseData.remove_overlaps(data_base) - return cls(data=data_base, name=name) + data = [cls.data_from_files([f], **kwargs) for f in files] + data = BaseData.remove_overlaps(data) + return cls(data=data, name=name) if not begin: begin = min(file.begin for file in files) @@ -311,35 +318,51 @@ def from_files( # noqa: PLR0913 end = max(file.end for file in files) if data_duration: data_base = ( - cls._get_base_data_from_files_timedelta_total( + cls._get_data_from_files_timedelta_total( begin=begin, end=end, data_duration=data_duration, files=files, overlap=overlap, + **kwargs, ) if mode == "timedelta_total" - else cls._get_base_data_from_files_timedelta_file( + else cls._get_data_from_files_timedelta_file( begin=begin, end=end, data_duration=data_duration, files=files, overlap=overlap, + **kwargs, ) ) else: - data_base = [BaseData.from_files(files, begin=begin, end=end)] + data_base = [ + cls.data_from_files(files=files, begin=begin, end=end, **kwargs), + ] return cls(data_base, name=name) @classmethod - def _get_base_data_from_files_timedelta_total( + @abstractmethod + def data_from_files( + cls, + files: list[TFile], + begin: Timestamp | None = None, + end: Timestamp | None = None, + name: str | None = None, + **kwargs, + ) -> TData: ... + + @classmethod + def _get_data_from_files_timedelta_total( cls, begin: Timestamp, end: Timestamp, data_duration: Timedelta, files: list[TFile], overlap: float = 0, - ) -> list[BaseData]: + **kwargs, + ) -> list[TData]: if not 0 <= overlap < 1: msg = f"Overlap ({overlap}) must be between 0 and 1." raise ValueError(msg) @@ -366,24 +389,26 @@ def _get_base_data_from_files_timedelta_total( ): last_active_file_index += 1 output.append( - BaseData.from_files( + cls.data_from_files( files[active_file_index:last_active_file_index], data_begin, data_end, + **kwargs, ), ) return output @classmethod - def _get_base_data_from_files_timedelta_file( + def _get_data_from_files_timedelta_file( cls, begin: Timestamp, end: Timestamp, data_duration: Timedelta, files: list[TFile], overlap: float = 0, - ) -> list[BaseData]: + **kwargs, + ) -> list[TData]: if not 0 <= overlap < 1: msg = f"Overlap ({overlap}) must be between 0 and 1." raise ValueError(msg) @@ -416,7 +441,12 @@ def _get_base_data_from_files_timedelta_file( files_chunk.append(next_file) output.extend( - BaseData.from_files(files, data_begin, data_begin + data_duration) + cls.data_from_files( + files, + data_begin, + data_begin + data_duration, + **kwargs, + ) for data_begin in date_range( file.begin, files_chunk[-1].end, @@ -429,11 +459,9 @@ def _get_base_data_from_files_timedelta_file( @classmethod def from_folder( # noqa: PLR0913 - cls, + cls: type[Self], folder: Path, strptime_format: str | None, - file_class: type[TFile] = BaseFile, - supported_file_extensions: list[str] | None = None, begin: Timestamp | None = None, end: Timestamp | None = None, timezone: str | pytz.timezone | None = None, @@ -442,7 +470,8 @@ def from_folder( # noqa: PLR0913 data_duration: Timedelta | None = None, first_file_begin: Timestamp | None = None, name: str | None = None, - ) -> BaseDataset: + **kwargs, + ) -> Self: """Return a BaseDataset from a folder containing the base files. Parameters @@ -455,10 +484,6 @@ def from_folder( # noqa: PLR0913 If None, the first audio file of the folder will start at first_file_begin, and each following file will start at the end of the previous one. - file_class: type[Tfile] - Derived type of BaseFile used to instantiate the dataset. - supported_file_extensions: list[str] - List of supported file extensions for parsing TFiles. begin: Timestamp | None The begin of the dataset. Defaulted to the begin of the first file. @@ -495,12 +520,10 @@ def from_folder( # noqa: PLR0913 Returns ------- - Basedataset: + Self: The base dataset. """ - if supported_file_extensions is None: - supported_file_extensions = [] valid_files = [] rejected_files = [] first_file_begin = first_file_begin or Timestamp("2020-01-01 00:00:00") @@ -508,10 +531,8 @@ def from_folder( # noqa: PLR0913 sorted(folder.iterdir()), disable=os.getenv("DISABLE_TQDM", "False").lower() in ("true", "1", "t"), ): - is_file_ok = _parse_file( + is_file_ok = cls._parse_file( file=file, - file_class=file_class, - supported_file_extensions=supported_file_extensions, strptime_format=strptime_format, timezone=timezone, begin_timestamp=first_file_begin, @@ -528,9 +549,10 @@ def from_folder( # noqa: PLR0913 ) if not valid_files: - raise FileNotFoundError(f"No valid file found in {folder}.") + msg = f"No valid file found in {folder}" + raise FileNotFoundError(msg) - return BaseDataset.from_files( + return cls.from_files( files=valid_files, begin=begin, end=end, @@ -538,29 +560,33 @@ def from_folder( # noqa: PLR0913 overlap=overlap, data_duration=data_duration, name=name, + **kwargs, ) - -def _parse_file( - file: Path, - file_class: type, - supported_file_extensions: list[str], - strptime_format: str, - timezone: str | pytz.timezone | None, - begin_timestamp: Timestamp, - valid_files: list[BaseFile], - rejected_files: list[Path], -) -> bool: - if file.suffix.lower() not in supported_file_extensions: - return False - try: - if strptime_format is None: - f = file_class(file, begin=begin_timestamp, timezone=timezone) + @classmethod + def _parse_file( + cls: type[Self], + file: Path, + strptime_format: str, + timezone: str | pytz.timezone | None, + begin_timestamp: Timestamp, + valid_files: list[TFile], + rejected_files: list[Path], + ) -> bool: + if file.suffix.lower() not in cls.file_cls.supported_extensions: + return False + try: + if strptime_format is None: + f = cls.file_cls(file, begin=begin_timestamp, timezone=timezone) + else: + f = cls.file_cls( + file, + strptime_format=strptime_format, + timezone=timezone, + ) + valid_files.append(f) + except (ValueError, LibsndfileError): + rejected_files.append(file) + return False else: - f = file_class(file, strptime_format=strptime_format, timezone=timezone) - valid_files.append(f) - except (ValueError, LibsndfileError): - rejected_files.append(file) - return False - else: - return True + return True From c37a882dbe07537e9108c3f9e1a89eaeb9c3aee1 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Thu, 15 Jan 2026 10:32:20 +0100 Subject: [PATCH 04/17] fix core API tests --- src/osekit/core_api/audio_dataset.py | 31 +-- src/osekit/core_api/base_data.py | 8 +- src/osekit/core_api/base_dataset.py | 10 +- tests/test_core_api_base.py | 319 ++++++++++++++++----------- 4 files changed, 202 insertions(+), 166 deletions(-) diff --git a/src/osekit/core_api/audio_dataset.py b/src/osekit/core_api/audio_dataset.py index 1eb523ea..d480c29c 100644 --- a/src/osekit/core_api/audio_dataset.py +++ b/src/osekit/core_api/audio_dataset.py @@ -8,7 +8,7 @@ import logging from pathlib import Path -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Literal, Self from osekit.core_api.audio_data import AudioData from osekit.core_api.audio_file import AudioFile @@ -166,7 +166,7 @@ def from_folder( # noqa: PLR0913 instrument: Instrument | None = None, normalization: Normalization = Normalization.RAW, **kwargs, - ) -> AudioDataset: + ) -> Self: """Return an AudioDataset from a folder containing the audio files. Parameters @@ -230,8 +230,8 @@ def from_folder( # noqa: PLR0913 strptime_format=strptime_format, begin=begin, end=end, - timezone=timezone, mode=mode, + timezone=timezone, overlap=overlap, data_duration=data_duration, sample_rate=sample_rate, @@ -329,30 +329,7 @@ def data_from_files( ) @classmethod - def from_base_dataset( - cls, - base_dataset: BaseDataset, - sample_rate: float | None = None, - name: str | None = None, - instrument: Instrument | None = None, - normalization: Normalization = Normalization.RAW, - ) -> AudioDataset: - """Return an AudioDataset object from a BaseDataset object.""" - return cls( - [ - AudioData.from_base_data( - data=data, - sample_rate=sample_rate, - normalization=normalization, - ) - for data in base_dataset.data - ], - name=name, - instrument=instrument, - ) - - @classmethod - def from_json(cls, file: Path) -> AudioDataset: + def from_json(cls, file: Path) -> Self: """Deserialize an AudioDataset from a JSON file. Parameters diff --git a/src/osekit/core_api/base_data.py b/src/osekit/core_api/base_data.py index 6ca68e57..6a4d13a1 100644 --- a/src/osekit/core_api/base_data.py +++ b/src/osekit/core_api/base_data.py @@ -63,7 +63,7 @@ def __init__( """ if not items: - items = [self.make_empty_item(begin=begin, end=end)] + items = [self.make_item(begin=begin, end=end)] self.items = items self._begin = min(item.begin for item in self.items) self._end = max(item.end for item in self.items) @@ -260,7 +260,7 @@ def from_files( name: str | None = None, **kwargs: Any, ) -> Self: - """Return a base DataBase object from a list of Files. + """Return a Data object from a list of Files. Parameters ---------- @@ -277,8 +277,8 @@ def from_files( Returns ------- - BaseData[TItem, TFile]: - The BaseData object. + Self: + The Data object. """ items = cls.items_from_files(files=files, begin=begin, end=end) diff --git a/src/osekit/core_api/base_dataset.py b/src/osekit/core_api/base_dataset.py index 2d679df5..7aaa8baa 100644 --- a/src/osekit/core_api/base_dataset.py +++ b/src/osekit/core_api/base_dataset.py @@ -270,7 +270,7 @@ def from_files( # noqa: PLR0913 name: str | None = None, **kwargs, ) -> Self: - """Return a base BaseDataset object from a list of Files. + """Return a base Dataset object from a list of Files. Parameters ---------- @@ -303,8 +303,8 @@ def from_files( # noqa: PLR0913 Returns ------- - BaseDataset[TItem, TFile]: - The DataBase object. + Self: + The Dataset object. """ if mode == "files": @@ -472,7 +472,7 @@ def from_folder( # noqa: PLR0913 name: str | None = None, **kwargs, ) -> Self: - """Return a BaseDataset from a folder containing the base files. + """Return a Dataset from a folder containing the base files. Parameters ---------- @@ -521,7 +521,7 @@ def from_folder( # noqa: PLR0913 Returns ------- Self: - The base dataset. + The dataset. """ valid_files = [] diff --git a/tests/test_core_api_base.py b/tests/test_core_api_base.py index ea17dfe5..3acde140 100644 --- a/tests/test_core_api_base.py +++ b/tests/test_core_api_base.py @@ -1,25 +1,85 @@ from __future__ import annotations +import typing from pathlib import Path -from typing import Literal +from typing import Literal, Self import numpy as np import pytest from pandas import Timedelta, Timestamp from osekit.config import TIMESTAMP_FORMATS_EXPORTED_FILES -from osekit.core_api.base_data import BaseData -from osekit.core_api.base_dataset import BaseDataset +from osekit.core_api.base_data import BaseData, TFile +from osekit.core_api.base_dataset import BaseDataset, TData from osekit.core_api.base_file import BaseFile +from osekit.core_api.base_item import BaseItem from osekit.core_api.event import Event +class DummyFile(BaseFile): + supported_extensions: typing.ClassVar = [".wav"] + + def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: + pass + + +class DummyItem(BaseItem[DummyFile]): ... + + +class DummyData(BaseData[DummyItem, DummyFile]): + item_cls = DummyItem + + def write(self, folder: Path, link: bool = False) -> None: + pass + + def link(self, folder: Path) -> None: + pass + + @classmethod + def make_item( + cls, + file: TFile | None = None, + begin: Timestamp | None = None, + end: Timestamp | None = None, + ) -> DummyItem: + return DummyItem(file=file, begin=begin, end=end) + + @classmethod + def from_base_dict( + cls, + dictionary: dict, + files: list[TFile], + begin: Timestamp, + end: Timestamp, + ) -> Self: + pass + + +class DummyDataset(BaseDataset[DummyData, DummyFile]): + @classmethod + def data_from_dict(cls, dictionary: dict) -> list[TData]: + pass + + @classmethod + def data_from_files( + cls, + files: list[TFile], + begin: Timestamp | None = None, + end: Timestamp | None = None, + name: str | None = None, + **kwargs, + ) -> TData: + pass + + file_cls = DummyFile + + @pytest.mark.parametrize( ("base_files", "begin", "end", "duration", "expected_data_events"), [ pytest.param( [ - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:00:00"), end=Timestamp("2016-02-05 01:00:00"), @@ -38,7 +98,7 @@ ), pytest.param( [ - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:00:00"), end=Timestamp("2016-02-05 01:00:00"), @@ -57,7 +117,7 @@ ), pytest.param( [ - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:00:00"), end=Timestamp("2016-02-05 01:00:00"), @@ -76,7 +136,7 @@ ), pytest.param( [ - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:00:00"), end=Timestamp("2016-02-05 01:00:00"), @@ -95,7 +155,7 @@ ), pytest.param( [ - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:00:00"), end=Timestamp("2016-02-05 01:00:00"), @@ -114,7 +174,7 @@ ), pytest.param( [ - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:00:00"), end=Timestamp("2016-02-05 01:00:00"), @@ -133,7 +193,7 @@ ), pytest.param( [ - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:00:00"), end=Timestamp("2016-02-05 01:00:00"), @@ -160,7 +220,7 @@ ), pytest.param( [ - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:00:00"), end=Timestamp("2016-02-05 00:50:00"), @@ -187,27 +247,27 @@ ), pytest.param( [ - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:00:00"), end=Timestamp("2016-02-05 00:10:00"), ), - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:10:00"), end=Timestamp("2016-02-05 00:20:00"), ), - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:20:00"), end=Timestamp("2016-02-05 00:30:00"), ), - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:30:00"), end=Timestamp("2016-02-05 00:40:00"), ), - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:40:00"), end=Timestamp("2016-02-05 00:50:00"), @@ -234,27 +294,27 @@ ), pytest.param( [ - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:00:00"), end=Timestamp("2016-02-05 00:10:00"), ), - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:10:00"), end=Timestamp("2016-02-05 00:20:00"), ), - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:20:00"), end=Timestamp("2016-02-05 00:30:00"), ), - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:30:00"), end=Timestamp("2016-02-05 00:40:00"), ), - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:40:00"), end=Timestamp("2016-02-05 00:50:00"), @@ -286,13 +346,13 @@ ], ) def test_base_dataset_from_files( - base_files: list[BaseFile], + base_files: list[DummyFile], begin: Timestamp | None, end: Timestamp | None, duration: Timedelta | None, expected_data_events: list[Event], ) -> None: - ads = BaseDataset.from_files( + ads = DummyDataset.from_files( base_files, begin=begin, end=end, @@ -342,9 +402,9 @@ def test_base_dataset_from_files_overlap_errors(overlap: float, mode: str) -> No match=rf"Overlap \({overlap}\) must be between 0 and 1.", ) as e: assert ( - BaseDataset.from_files( + DummyDataset.from_files( [ - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:00:00"), end=Timestamp("2016-02-05 00:10:00"), @@ -707,10 +767,9 @@ def test_base_dataset_from_folder( ) -> None: monkeypatch.setattr(Path, "iterdir", lambda x: files) - bds = BaseDataset.from_folder( + bds = DummyDataset.from_folder( folder=Path("foo"), strptime_format=strptime_format, - supported_file_extensions=supported_file_extensions, begin=begin, end=end, timezone=timezone, @@ -752,7 +811,7 @@ def test_move_file( ) -> None: filename = "cool.txt" (tmp_path / filename).touch(mode=0o666, exist_ok=True) - bf = BaseFile( + bf = DummyFile( tmp_path / filename, begin=Timestamp("2022-04-22 12:12:12"), end=Timestamp("2022-04-22 12:13:12"), @@ -768,7 +827,7 @@ def test_move_file( def test_dataset_move( tmp_path: Path, - base_dataset: BaseDataset, + base_dataset: DummyDataset, ) -> None: origin_files = [Path(str(file.path)) for file in base_dataset.files] @@ -798,7 +857,7 @@ def test_dataset_move( [ pytest.param( [ - BaseFile( + DummyFile( path=Path("cool"), begin=Timestamp("2022-04-22 12:12:12"), end=Timestamp("2022-04-22 12:12:13"), @@ -819,7 +878,7 @@ def test_dataset_move( ), pytest.param( [ - BaseFile( + DummyFile( path=Path("cool"), begin=Timestamp("2022-04-22 12:12:12"), end=Timestamp("2022-04-22 12:12:13"), @@ -847,12 +906,12 @@ def test_dataset_move( ), pytest.param( [ - BaseFile( + DummyFile( path=Path("cool"), begin=Timestamp("2022-04-22 12:12:12"), end=Timestamp("2022-04-22 12:12:13"), ), - BaseFile( + DummyFile( path=Path("fun"), begin=Timestamp("2022-04-22 12:12:13"), end=Timestamp("2022-04-22 12:12:14"), @@ -894,7 +953,7 @@ def test_dataset_move( ), pytest.param( [ - BaseFile( + DummyFile( path=Path("cool"), begin=Timestamp("2022-04-22 12:12:12"), end=Timestamp("2022-04-22 12:12:13"), @@ -915,12 +974,12 @@ def test_dataset_move( ), pytest.param( [ - BaseFile( + DummyFile( path=Path("cool"), begin=Timestamp("2022-04-22 12:12:12"), end=Timestamp("2022-04-22 12:12:13"), ), - BaseFile( + DummyFile( path=Path("fun"), begin=Timestamp("2022-04-22 12:12:13"), end=Timestamp("2022-04-22 12:12:14"), @@ -948,12 +1007,12 @@ def test_dataset_move( ), pytest.param( [ - BaseFile( + DummyFile( path=Path("cool"), begin=Timestamp("2022-04-22 12:12:12"), end=Timestamp("2022-04-22 12:12:13"), ), - BaseFile( + DummyFile( path=Path("fun"), begin=Timestamp("2022-04-22 12:12:13"), end=Timestamp("2022-04-22 12:12:14"), @@ -983,12 +1042,12 @@ def test_dataset_move( ) def test_base_dataset_file_mode( tmp_path: pytest.fixture, - files: list[BaseFile], + files: list[DummyFile], mode: Literal["files", "timedelta_total"], data_duration: Timedelta | None, expected_data: list[tuple[Event, str]], ) -> None: - ds = BaseDataset.from_files( + ds = DummyDataset.from_files( files=files, mode=mode, data_duration=data_duration, @@ -1007,7 +1066,7 @@ def test_base_dataset_file_mode( [ pytest.param( [ - BaseFile( + DummyFile( path=Path("cool"), begin=Timestamp("2022-04-22 12:12:12"), end=Timestamp("2022-04-22 12:12:13"), @@ -1023,7 +1082,7 @@ def test_base_dataset_file_mode( ), pytest.param( [ - BaseFile( + DummyFile( path=Path("cool"), begin=Timestamp("2022-04-22 12:12:12"), end=Timestamp("2022-04-22 12:12:13"), @@ -1039,7 +1098,7 @@ def test_base_dataset_file_mode( ), pytest.param( [ - BaseFile( + DummyFile( path=Path("cool"), begin=Timestamp("2022-04-22 12:12:12"), end=Timestamp("2022-04-22 12:12:13"), @@ -1055,7 +1114,7 @@ def test_base_dataset_file_mode( ), pytest.param( [ - BaseFile( + DummyFile( path=Path("cool"), begin=Timestamp("2022-04-22 12:12:12"), end=Timestamp("2022-04-22 12:12:13"), @@ -1071,7 +1130,7 @@ def test_base_dataset_file_mode( ), pytest.param( [ - BaseFile( + DummyFile( path=Path("cool"), begin=Timestamp("2022-04-22 12:12:12"), end=Timestamp("2022-04-22 12:12:13"), @@ -1087,12 +1146,12 @@ def test_base_dataset_file_mode( ), pytest.param( [ - BaseFile( + DummyFile( path=Path("cool"), begin=Timestamp("2022-04-22 12:12:12"), end=Timestamp("2022-04-22 12:12:13"), ), - BaseFile( + DummyFile( path=Path("fun"), begin=Timestamp("2022-04-22 12:12:14"), end=Timestamp("2022-04-22 12:12:15"), @@ -1108,12 +1167,12 @@ def test_base_dataset_file_mode( ), pytest.param( [ - BaseFile( + DummyFile( path=Path("cool"), begin=Timestamp("2022-04-22 12:12:12"), end=Timestamp("2022-04-22 12:12:13"), ), - BaseFile( + DummyFile( path=Path("fun"), begin=Timestamp("2022-04-22 12:12:14"), end=Timestamp("2022-04-22 12:12:15"), @@ -1129,12 +1188,12 @@ def test_base_dataset_file_mode( ), pytest.param( [ - BaseFile( + DummyFile( path=Path("cool"), begin=Timestamp("2022-04-22 12:12:12"), end=Timestamp("2022-04-22 12:12:13"), ), - BaseFile( + DummyFile( path=Path("fun"), begin=Timestamp("2022-04-22 12:12:14"), end=Timestamp("2022-04-22 12:12:15"), @@ -1152,12 +1211,12 @@ def test_base_dataset_file_mode( ) def test_base_data_boundaries( monkeypatch: pytest.fixture, - files: list[BaseFile], + files: list[DummyFile], begin: Timestamp, end: Timestamp, expected_data: Event, ) -> None: - data = BaseData.from_files(files=files) + data = DummyData.from_files(files=files) if begin: data.begin = begin if end: @@ -1165,14 +1224,14 @@ def test_base_data_boundaries( assert data.begin == expected_data.begin assert data.end == expected_data.end - def mocked_get_value(self: BaseData) -> None: + def mocked_get_value(self: DummyData) -> None: for item in data.items: if item.is_empty: continue assert item.file.begin <= item.begin assert item.file.end >= item.end - monkeypatch.setattr(BaseData, "get_value", mocked_get_value) + monkeypatch.setattr(DummyData, "get_value", mocked_get_value) data.get_value() @@ -1181,9 +1240,9 @@ def mocked_get_value(self: BaseData) -> None: ("data1", "data2", "expected"), [ pytest.param( - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1192,9 +1251,9 @@ def mocked_get_value(self: BaseData) -> None: begin=None, end=None, ), - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1207,9 +1266,9 @@ def mocked_get_value(self: BaseData) -> None: id="same_one_full_file", ), pytest.param( - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1218,9 +1277,9 @@ def mocked_get_value(self: BaseData) -> None: begin=None, end=None, ), - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1233,9 +1292,9 @@ def mocked_get_value(self: BaseData) -> None: id="same_one_full_file_explicit_timestamps", ), pytest.param( - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1244,9 +1303,9 @@ def mocked_get_value(self: BaseData) -> None: begin=None, end=None, ), - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1259,9 +1318,9 @@ def mocked_get_value(self: BaseData) -> None: id="different_begin", ), pytest.param( - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1270,9 +1329,9 @@ def mocked_get_value(self: BaseData) -> None: begin=None, end=None, ), - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1285,9 +1344,9 @@ def mocked_get_value(self: BaseData) -> None: id="different_end", ), pytest.param( - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1296,9 +1355,9 @@ def mocked_get_value(self: BaseData) -> None: begin=None, end=None, ), - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1311,9 +1370,9 @@ def mocked_get_value(self: BaseData) -> None: id="different_begin_and_end", ), pytest.param( - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "beach", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1322,9 +1381,9 @@ def mocked_get_value(self: BaseData) -> None: begin=None, end=None, ), - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "house", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1337,14 +1396,14 @@ def mocked_get_value(self: BaseData) -> None: id="different_file", ), pytest.param( - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "beach", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), ), - BaseFile( + DummyFile( "house", begin=Timestamp("2015-08-28 12:12:14"), end=Timestamp("2015-08-28 12:13:15"), @@ -1353,14 +1412,14 @@ def mocked_get_value(self: BaseData) -> None: begin=None, end=None, ), - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "beach", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), ), - BaseFile( + DummyFile( "house", begin=Timestamp("2015-08-28 12:12:14"), end=Timestamp("2015-08-28 12:13:15"), @@ -1373,14 +1432,14 @@ def mocked_get_value(self: BaseData) -> None: id="same_two_files", ), pytest.param( - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "beach", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), ), - BaseFile( + DummyFile( "house", begin=Timestamp("2015-08-28 12:12:14"), end=Timestamp("2015-08-28 12:13:15"), @@ -1389,14 +1448,14 @@ def mocked_get_value(self: BaseData) -> None: begin=None, end=None, ), - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), ), - BaseFile( + DummyFile( "house", begin=Timestamp("2015-08-28 12:12:14"), end=Timestamp("2015-08-28 12:13:15"), @@ -1410,7 +1469,7 @@ def mocked_get_value(self: BaseData) -> None: ), ], ) -def test_base_data_equality(data1: BaseData, data2: BaseData, expected: bool) -> None: +def test_base_data_equality(data1: DummyData, data2: DummyData, expected: bool) -> None: assert (data1 == data2) == expected @@ -1418,9 +1477,9 @@ def test_base_data_equality(data1: BaseData, data2: BaseData, expected: bool) -> ("data", "name", "expected"), [ pytest.param( - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1434,14 +1493,14 @@ def test_base_data_equality(data1: BaseData, data2: BaseData, expected: bool) -> id="default_to_data_begin", ), pytest.param( - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "beach", begin=Timestamp("2015-08-28 12:13:12"), end=Timestamp("2015-08-28 12:14:12"), ), - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1457,9 +1516,9 @@ def test_base_data_equality(data1: BaseData, data2: BaseData, expected: bool) -> id="default_to_data_begin_with_unordered_files", ), pytest.param( - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1471,9 +1530,9 @@ def test_base_data_equality(data1: BaseData, data2: BaseData, expected: bool) -> id="given_name", ), pytest.param( - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1486,9 +1545,9 @@ def test_base_data_equality(data1: BaseData, data2: BaseData, expected: bool) -> id="given_name_over_existing_name", ), pytest.param( - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1504,7 +1563,7 @@ def test_base_data_equality(data1: BaseData, data2: BaseData, expected: bool) -> ), ], ) -def test_data_name(data: BaseData, name: str | None, expected: str) -> None: +def test_data_name(data: DummyData, name: str | None, expected: str) -> None: data.name = name assert data.name == expected assert str(data) == expected @@ -1515,7 +1574,7 @@ def test_data_name(data: BaseData, name: str | None, expected: str) -> None: [ pytest.param( [ - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1550,12 +1609,12 @@ def test_data_name(data: BaseData, name: str | None, expected: str) -> None: ), pytest.param( [ - BaseFile( + DummyFile( "depression", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), ), - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:13:12"), end=Timestamp("2015-08-28 12:14:12"), @@ -1590,12 +1649,12 @@ def test_data_name(data: BaseData, name: str | None, expected: str) -> None: ), pytest.param( [ - BaseFile( + DummyFile( "depression", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), ), - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:13:12"), end=Timestamp("2015-08-28 12:14:12"), @@ -1653,12 +1712,12 @@ def test_data_name(data: BaseData, name: str | None, expected: str) -> None: ), pytest.param( [ - BaseFile( + DummyFile( "depression", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), ), - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:13:12"), end=Timestamp("2015-08-28 12:14:12"), @@ -1716,7 +1775,7 @@ def test_data_name(data: BaseData, name: str | None, expected: str) -> None: ), pytest.param( [ - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1751,12 +1810,12 @@ def test_data_name(data: BaseData, name: str | None, expected: str) -> None: ), pytest.param( [ - BaseFile( + DummyFile( "depression", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), ), - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:13:12"), end=Timestamp("2015-08-28 12:14:12"), @@ -1791,12 +1850,12 @@ def test_data_name(data: BaseData, name: str | None, expected: str) -> None: ), pytest.param( [ - BaseFile( + DummyFile( "depression", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), ), - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:13:12"), end=Timestamp("2015-08-28 12:14:12"), @@ -1854,12 +1913,12 @@ def test_data_name(data: BaseData, name: str | None, expected: str) -> None: ), pytest.param( [ - BaseFile( + DummyFile( "depression", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:22"), ), - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:13:32"), end=Timestamp("2015-08-28 12:14:12"), @@ -1908,12 +1967,12 @@ def test_data_name(data: BaseData, name: str | None, expected: str) -> None: ), pytest.param( [ - BaseFile( + DummyFile( "depression", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:22"), ), - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:13:32"), end=Timestamp("2015-08-28 12:14:12"), @@ -1962,12 +2021,12 @@ def test_data_name(data: BaseData, name: str | None, expected: str) -> None: ), pytest.param( [ - BaseFile( + DummyFile( "depression", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:02"), ), - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:13:22"), end=Timestamp("2015-08-28 12:14:12"), @@ -2016,12 +2075,12 @@ def test_data_name(data: BaseData, name: str | None, expected: str) -> None: ), pytest.param( [ - BaseFile( + DummyFile( "depression", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:02"), ), - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:13:22"), end=Timestamp("2015-08-28 12:14:12"), @@ -2070,12 +2129,12 @@ def test_data_name(data: BaseData, name: str | None, expected: str) -> None: ), pytest.param( [ - BaseFile( + DummyFile( "depression", begin=Timestamp("2015-08-28 12:12:00"), end=Timestamp("2015-08-28 12:13:02"), ), - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:13:22"), end=Timestamp("2015-08-28 12:14:12"), @@ -2124,12 +2183,12 @@ def test_data_name(data: BaseData, name: str | None, expected: str) -> None: ), pytest.param( [ - BaseFile( + DummyFile( "depression", begin=Timestamp("2015-08-28 12:12:00"), end=Timestamp("2015-08-28 12:13:02"), ), - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:13:22"), end=Timestamp("2015-08-28 12:14:12"), @@ -2195,7 +2254,7 @@ def test_data_name(data: BaseData, name: str | None, expected: str) -> None: ], ) def test_get_base_data_from_files( - files: list[BaseFile], + files: list[DummyFile], begin: Timestamp, end: Timestamp, data_duration: Timedelta, @@ -2203,7 +2262,7 @@ def test_get_base_data_from_files( overlap: float, expected: list[list[tuple[Event, str | None]]], ) -> None: - data = BaseDataset.from_files( + data = DummyDataset.from_files( files=files, begin=begin, end=end, From 4138f32eb2a0e24c317eaa55feb608e2953890b1 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Thu, 15 Jan 2026 11:34:49 +0100 Subject: [PATCH 05/17] adapt core base tests --- src/osekit/core_api/audio_data.py | 43 ++++++++++++++++++---------- src/osekit/core_api/base_data.py | 16 +++++++++-- tests/test_core_api_base.py | 47 +++++++++++++++---------------- 3 files changed, 64 insertions(+), 42 deletions(-) diff --git a/src/osekit/core_api/audio_data.py b/src/osekit/core_api/audio_data.py index ab2f8f48..84bcbe1a 100644 --- a/src/osekit/core_api/audio_data.py +++ b/src/osekit/core_api/audio_data.py @@ -7,7 +7,7 @@ from __future__ import annotations from math import ceil -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Self import numpy as np import soundfile as sf @@ -15,7 +15,7 @@ from osekit.core_api.audio_file import AudioFile from osekit.core_api.audio_item import AudioItem -from osekit.core_api.base_data import BaseData +from osekit.core_api.base_data import BaseData, TFile from osekit.core_api.instrument import Instrument from osekit.utils.audio_utils import Normalization, normalize, resample @@ -288,7 +288,7 @@ def split( nb_subdata: int = 2, *, pass_normalization: bool = True, - ) -> list[AudioData]: + ) -> list[Self]: """Split the audio data object in the specified number of audio subdata. Parameters @@ -314,16 +314,27 @@ def split( if any(self.normalization_values.values()) else self.get_normalization_values() ) - return [ - AudioData.from_base_data( - data=base_data, - sample_rate=self.sample_rate, - instrument=self.instrument, - normalization=self.normalization, - normalization_values=normalization_values, - ) - for base_data in super().split(nb_subdata) - ] + return super().split( + nb_subdata=nb_subdata, + normalization_values=normalization_values, + ) + + def make_split_data( + self, + files: list[TFile], + begin: Timestamp, + end: Timestamp, + normalization_values: dict, + ) -> Self: + return AudioData.from_files( + files=files, + begin=begin, + end=end, + sample_rate=self.sample_rate, + instrument=self.instrument, + normalization=self.normalization, + normalization_values=normalization_values, + ) def split_frames( self, @@ -353,9 +364,11 @@ def split_frames( """ if start_frame < 0: - raise ValueError("Start_frame must be greater than or equal to 0.") + msg = "Start_frame must be greater than or equal to 0." + raise ValueError(msg) if stop_frame < -1 or stop_frame > self.length: - raise ValueError("Stop_frame must be lower than the length of the data.") + msg = "Stop_frame must be lower than the length of the data." + raise ValueError(msg) start_timestamp = self.begin + Timedelta( seconds=ceil(start_frame / self.sample_rate * 1e9) / 1e9, diff --git a/src/osekit/core_api/base_data.py b/src/osekit/core_api/base_data.py index 6a4d13a1..bbc5fd8e 100644 --- a/src/osekit/core_api/base_data.py +++ b/src/osekit/core_api/base_data.py @@ -6,6 +6,7 @@ from __future__ import annotations +import itertools from abc import ABC, abstractmethod from pathlib import Path from typing import Any, Generic, Self, TypeVar @@ -230,7 +231,7 @@ def files(self) -> set[TFile]: """All files referred to by the Data.""" return {item.file for item in self.items if item.file is not None} - def split(self, nb_subdata: int = 2) -> list[BaseData]: + def split(self, nb_subdata: int = 2, **kwargs) -> list[BaseData]: """Split the data object in the specified number of subdata. Parameters @@ -245,12 +246,21 @@ def split(self, nb_subdata: int = 2) -> list[BaseData]: """ dates = date_range(self.begin, self.end, periods=nb_subdata + 1) - subdata_dates = zip(dates, dates[1:], strict=False) + subdata_dates = itertools.pairwise(dates) return [ - BaseData.from_files(files=list(self.files), begin=b, end=e) + self.make_split_data(files=list(self.files), begin=b, end=e, **kwargs) for b, e in subdata_dates ] + @abstractmethod + def make_split_data( + self, + files: list[TFile], + begin: Timestamp, + end: Timestamp, + **kwargs, + ) -> Self: ... + @classmethod def from_files( cls, diff --git a/tests/test_core_api_base.py b/tests/test_core_api_base.py index 3acde140..c668eccf 100644 --- a/tests/test_core_api_base.py +++ b/tests/test_core_api_base.py @@ -35,6 +35,15 @@ def write(self, folder: Path, link: bool = False) -> None: def link(self, folder: Path) -> None: pass + def make_split_data( + self, + files: list[DummyFile], + begin: Timestamp, + end: Timestamp, + **kwargs, + ) -> Self: + return DummyData(files, begin, end, **kwargs) + @classmethod def make_item( cls, @@ -52,24 +61,33 @@ def from_base_dict( begin: Timestamp, end: Timestamp, ) -> Self: - pass + return cls.from_files( + files=files, + begin=begin, + end=end, + ) class DummyDataset(BaseDataset[DummyData, DummyFile]): @classmethod def data_from_dict(cls, dictionary: dict) -> list[TData]: - pass + return [DummyData.from_dict(data) for data in dictionary.values()] @classmethod def data_from_files( cls, - files: list[TFile], + files: list[DummyFile], begin: Timestamp | None = None, end: Timestamp | None = None, name: str | None = None, **kwargs, ) -> TData: - pass + return DummyData.from_files( + files=files, + begin=begin, + end=end, + name=name, + ) file_cls = DummyFile @@ -421,7 +439,6 @@ def test_base_dataset_from_files_overlap_errors(overlap: float, mode: str) -> No @pytest.mark.parametrize( ( "strptime_format", - "supported_file_extensions", "begin", "end", "timezone", @@ -436,7 +453,6 @@ def test_base_dataset_from_files_overlap_errors(overlap: float, mode: str) -> No [ pytest.param( "%y%m%d%H%M%S", - [".wav"], None, None, None, @@ -459,7 +475,6 @@ def test_base_dataset_from_files_overlap_errors(overlap: float, mode: str) -> No ), pytest.param( None, - [".wav"], None, None, None, @@ -482,7 +497,6 @@ def test_base_dataset_from_files_overlap_errors(overlap: float, mode: str) -> No ), pytest.param( None, - [".wav"], None, None, None, @@ -512,7 +526,6 @@ def test_base_dataset_from_files_overlap_errors(overlap: float, mode: str) -> No ), pytest.param( None, - [".wav", ".mp3"], None, None, None, @@ -521,7 +534,7 @@ def test_base_dataset_from_files_overlap_errors(overlap: float, mode: str) -> No None, Timestamp("2023-12-01 00:00:00"), None, - [Path(r"cool.wav"), Path(r"fun.mp3"), Path("boring.flac")], + [Path(r"cool.wav"), Path("boring.shenanigan")], [ ( Event( @@ -530,19 +543,11 @@ def test_base_dataset_from_files_overlap_errors(overlap: float, mode: str) -> No ), [Path(r"cool.wav")], ), - ( - Event( - begin=Timestamp("2023-12-01 00:00:01"), - end=Timestamp("2023-12-01 00:00:02"), - ), - [Path(r"fun.mp3")], - ), ], id="only_specified_formats_are_kept", ), pytest.param( None, - [".wav"], None, None, None, @@ -586,7 +591,6 @@ def test_base_dataset_from_files_overlap_errors(overlap: float, mode: str) -> No ), pytest.param( None, - [".wav"], None, None, None, @@ -630,7 +634,6 @@ def test_base_dataset_from_files_overlap_errors(overlap: float, mode: str) -> No ), pytest.param( None, - [".wav"], Timestamp("2023-12-01 00:00:00.5"), Timestamp("2023-12-01 00:00:01.5"), None, @@ -660,7 +663,6 @@ def test_base_dataset_from_files_overlap_errors(overlap: float, mode: str) -> No ), pytest.param( "%y%m%d%H%M%S", - [".wav"], Timestamp("2023-12-01 00:00:00.5"), Timestamp("2023-12-01 00:00:01.5"), None, @@ -690,7 +692,6 @@ def test_base_dataset_from_files_overlap_errors(overlap: float, mode: str) -> No ), pytest.param( "%y%m%d%H%M%S", - [".wav"], Timestamp("2023-12-01 00:00:00.5+01:00"), Timestamp("2023-12-01 00:00:01.5+01:00"), "Europe/Warsaw", @@ -720,7 +721,6 @@ def test_base_dataset_from_files_overlap_errors(overlap: float, mode: str) -> No ), pytest.param( "%y%m%d%H%M%S%z", - [".wav"], Timestamp("2023-12-01 01:00:00.5+01:00"), Timestamp("2023-12-01 01:00:01.5+01:00"), "Europe/Warsaw", @@ -753,7 +753,6 @@ def test_base_dataset_from_files_overlap_errors(overlap: float, mode: str) -> No def test_base_dataset_from_folder( monkeypatch: pytest.monkeypatch, strptime_format: str | None, - supported_file_extensions: list[str], begin: Timestamp | None, end: Timestamp | None, timezone: str | None, From f6c01ad09a5d03eb8262877b74669c1ab00afa54 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Thu, 15 Jan 2026 11:59:45 +0100 Subject: [PATCH 06/17] use Dummy classes in file tests --- tests/conftest.py | 20 ------------- tests/test_core_api_base.py | 59 +++++++++++++++++++++++++++++-------- tests/test_files.py | 10 +++---- 3 files changed, 50 insertions(+), 39 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 522a0bd8..29b47339 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,8 +21,6 @@ from osekit.core_api import AudioFileManager from osekit.core_api.audio_data import AudioData from osekit.core_api.audio_file import AudioFile -from osekit.core_api.base_dataset import BaseDataset -from osekit.core_api.base_file import BaseFile from osekit.utils.audio_utils import generate_sample_audio @@ -169,24 +167,6 @@ def mock_open(self: AudioFileManager, path: Path) -> None: return opened_files -@pytest.fixture -def base_dataset(tmp_path: Path) -> BaseDataset: - files = [tmp_path / f"file_{i}.txt" for i in range(5)] - for file in files: - file.touch() - timestamps = pd.date_range( - start=pd.Timestamp("2000-01-01 00:00:00"), - freq="1s", - periods=5, - ) - - bfs = [ - BaseFile(path=file, begin=timestamp, end=timestamp + pd.Timedelta(seconds=1)) - for file, timestamp in zip(files, timestamps, strict=False) - ] - return BaseDataset.from_files(files=bfs, mode="files") - - @pytest.fixture def patch_audio_data(monkeypatch: pytest.MonkeyPatch) -> None: original_init = AudioData.__init__ diff --git a/tests/test_core_api_base.py b/tests/test_core_api_base.py index c668eccf..777f697e 100644 --- a/tests/test_core_api_base.py +++ b/tests/test_core_api_base.py @@ -5,6 +5,7 @@ from typing import Literal, Self import numpy as np +import pandas as pd import pytest from pandas import Timedelta, Timestamp @@ -19,8 +20,7 @@ class DummyFile(BaseFile): supported_extensions: typing.ClassVar = [".wav"] - def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: - pass + def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: ... class DummyItem(BaseItem[DummyFile]): ... @@ -29,11 +29,9 @@ class DummyItem(BaseItem[DummyFile]): ... class DummyData(BaseData[DummyItem, DummyFile]): item_cls = DummyItem - def write(self, folder: Path, link: bool = False) -> None: - pass + def write(self, folder: Path, link: bool = False) -> None: ... - def link(self, folder: Path) -> None: - pass + def link(self, folder: Path) -> None: ... def make_split_data( self, @@ -67,6 +65,23 @@ def from_base_dict( end=end, ) + @classmethod + def from_files( + cls, + files: list[DummyFile], + begin: Timestamp | None = None, + end: Timestamp | None = None, + name: str | None = None, + **kwargs, + ) -> Self: + return super().from_files( + files=files, + begin=begin, + end=end, + name=name, + **kwargs, + ) + class DummyDataset(BaseDataset[DummyData, DummyFile]): @classmethod @@ -92,6 +107,24 @@ def data_from_files( file_cls = DummyFile +@pytest.fixture +def dummy_dataset(tmp_path: Path) -> DummyDataset: + files = [tmp_path / f"file_{i}.txt" for i in range(5)] + for file in files: + file.touch() + timestamps = pd.date_range( + start=pd.Timestamp("2000-01-01 00:00:00"), + freq="1s", + periods=5, + ) + + dfs = [ + DummyFile(path=file, begin=timestamp, end=timestamp + pd.Timedelta(seconds=1)) + for file, timestamp in zip(files, timestamps, strict=False) + ] + return DummyDataset.from_files(files=dfs, mode="files") + + @pytest.mark.parametrize( ("base_files", "begin", "end", "duration", "expected_data_events"), [ @@ -826,27 +859,27 @@ def test_move_file( def test_dataset_move( tmp_path: Path, - base_dataset: DummyDataset, + dummy_dataset: DummyDataset, ) -> None: - origin_files = [Path(str(file.path)) for file in base_dataset.files] + origin_files = [Path(str(file.path)) for file in dummy_dataset.files] # The starting folder of the dataset is the folder where the files are located - assert base_dataset.folder == tmp_path + assert dummy_dataset.folder == tmp_path destination = tmp_path / "destination" - base_dataset.folder = destination + dummy_dataset.folder = destination # Setting the folder shouldn't move the files - assert all(file.path.parent == tmp_path for file in base_dataset.files) + assert all(file.path.parent == tmp_path for file in dummy_dataset.files) assert all(file.exists for file in origin_files) # Folder should be changed when dataset is moved new_destination = tmp_path / "new_destination" - base_dataset.move_files(new_destination) + dummy_dataset.move_files(new_destination) assert new_destination.exists() assert new_destination.is_dir() - assert base_dataset.folder == new_destination + assert dummy_dataset.folder == new_destination assert all((new_destination / file.name).exists() for file in origin_files) assert not any(file.exists() for file in origin_files) diff --git a/tests/test_files.py b/tests/test_files.py index ee8072ca..c88da641 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -6,8 +6,7 @@ import pytz from pandas import Timestamp -from osekit.core_api.base_dataset import BaseDataset -from osekit.core_api.base_file import BaseFile +from tests.test_core_api_base import DummyDataset, DummyFile @pytest.mark.parametrize( @@ -82,7 +81,7 @@ def test_file_localization( timezone: str | pytz.timezone | None, expected_begin: Timestamp, ) -> None: - file = BaseFile( + file = DummyFile( path=Path(file_name), strptime_format=strptime_format, timezone=timezone, @@ -169,13 +168,12 @@ def test_dataset_localization( expected_begins: list[Timestamp], ) -> None: for file in file_names: - (tmp_path / f"{file}.foo").touch() + (tmp_path / f"{file}.wav").touch() - dataset = BaseDataset.from_folder( + dataset = DummyDataset.from_folder( tmp_path, strptime_format=strptime_format, timezone=timezone, - supported_file_extensions=[".foo"], ) assert all( From 8af2fc5cd9b6117be1cc2f7098bd95403f05f2ac Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Thu, 15 Jan 2026 16:56:08 +0100 Subject: [PATCH 07/17] adapt SpectroData serialization --- src/osekit/core_api/audio_data.py | 5 +- src/osekit/core_api/base_data.py | 11 +- src/osekit/core_api/spectro_data.py | 152 ++++++++++++++++------------ src/osekit/core_api/spectro_file.py | 5 - src/osekit/core_api/spectro_item.py | 22 +--- tests/test_core_api_base.py | 4 + 6 files changed, 110 insertions(+), 89 deletions(-) diff --git a/src/osekit/core_api/audio_data.py b/src/osekit/core_api/audio_data.py index 84bcbe1a..c8ec428f 100644 --- a/src/osekit/core_api/audio_data.py +++ b/src/osekit/core_api/audio_data.py @@ -324,7 +324,7 @@ def make_split_data( files: list[TFile], begin: Timestamp, end: Timestamp, - normalization_values: dict, + **kwargs, ) -> Self: return AudioData.from_files( files=files, @@ -333,7 +333,7 @@ def make_split_data( sample_rate=self.sample_rate, instrument=self.instrument, normalization=self.normalization, - normalization_values=normalization_values, + normalization_values=kwargs["normalization_values"], ) def split_frames( @@ -427,6 +427,7 @@ def from_base_dict( files: list[AudioFile], begin: Timestamp, end: Timestamp, + **kwargs, ) -> AudioData: """Deserialize an AudioData from a dictionary. diff --git a/src/osekit/core_api/base_data.py b/src/osekit/core_api/base_data.py index bbc5fd8e..6c5d9ac9 100644 --- a/src/osekit/core_api/base_data.py +++ b/src/osekit/core_api/base_data.py @@ -176,7 +176,7 @@ def to_dict(self) -> dict: } @classmethod - def from_dict(cls, dictionary: dict) -> BaseData: + def from_dict(cls, dictionary: dict, **kwargs) -> BaseData: """Deserialize a BaseData from a dictionary. Parameters @@ -202,9 +202,16 @@ def from_dict(cls, dictionary: dict) -> BaseData: ] begin = Timestamp(dictionary["begin"]) end = Timestamp(dictionary["end"]) - return cls.from_base_dict(dictionary, files, begin, end) + return cls.from_base_dict( + dictionary=dictionary, + files=files, + begin=begin, + end=end, + **kwargs, + ) @classmethod + @abstractmethod def make_file(cls, path: Path, begin: Timestamp) -> type[TFile]: ... @classmethod diff --git a/src/osekit/core_api/spectro_data.py b/src/osekit/core_api/spectro_data.py index f4e3c1f3..ef81d3e2 100644 --- a/src/osekit/core_api/spectro_data.py +++ b/src/osekit/core_api/spectro_data.py @@ -8,7 +8,7 @@ import gc import itertools -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Literal, Self import matplotlib.pyplot as plt import numpy as np @@ -20,7 +20,7 @@ TIMESTAMP_FORMATS_EXPORTED_FILES, ) from osekit.core_api.audio_data import AudioData -from osekit.core_api.base_data import BaseData +from osekit.core_api.base_data import BaseData, TFile from osekit.core_api.spectro_file import SpectroFile from osekit.core_api.spectro_item import SpectroItem @@ -39,12 +39,15 @@ class SpectroData(BaseData[SpectroItem, SpectroFile]): The data is accessed via a SpectroItem object per SpectroFile. """ + item_cls = SpectroItem + def __init__( self, items: list[SpectroItem] | None = None, audio_data: AudioData = None, begin: Timestamp | None = None, end: Timestamp | None = None, + name: str | None = None, fft: ShortTimeFFT | None = None, db_ref: float | None = None, v_lim: tuple[float, float] | None = None, @@ -64,6 +67,8 @@ def __init__( end: Timestamp | None Only effective if items is None. Set the end of the empty data. + name: str | None + Name of the exported files. fft: ShortTimeFFT The short time FFT used for computing the spectrogram. db_ref: float | None @@ -75,7 +80,7 @@ def __init__( Colormap to use for plotting the spectrogram. """ - super().__init__(items=items, begin=begin, end=end) + super().__init__(items=items, begin=begin, end=end, name=name) self.audio_data = audio_data self.fft = fft self._sx_dtype = complex @@ -174,9 +179,10 @@ def sx_dtype(self) -> type[complex]: return self._sx_dtype @sx_dtype.setter - def sx_dtype(self, dtype: type[complex]) -> [complex, float]: + def sx_dtype(self, dtype: type[complex]) -> None: if dtype not in (complex, float): - raise ValueError("dtype must be complex or float.") + msg = "dtype must be complex or float." + raise ValueError(msg) self._sx_dtype = dtype @property @@ -241,7 +247,8 @@ def get_value(self) -> np.ndarray: if not all(item.is_empty for item in self.items): return self._get_value_from_items(self.items) if not self.audio_data or not self.fft: - raise ValueError("SpectroData should have either items or audio_data.") + msg = "SpectroData should have either items or audio_data." + raise ValueError(msg) sx = self.fft.stft( x=self.audio_data.get_value_calibrated()[ @@ -275,9 +282,10 @@ def get_welch( self, nperseg: int | None = None, detrend: str | callable | False = "constant", - return_onesided: bool = True, scaling: Literal["density", "spectrum"] = "density", average: Literal["mean", "median"] = "mean", + *, + return_onesided: bool = True, ) -> np.ndarray: """Estimate power spectral density of the SpectroData using Welch's method. @@ -334,9 +342,10 @@ def write_welch( px: np.ndarray | None = None, nperseg: int | None = None, detrend: str | callable | False = "constant", - return_onesided: bool = True, scaling: Literal["density", "spectrum"] = "density", average: Literal["mean", "median"] = "mean", + *, + return_onesided: bool = True, ) -> None: """Write the psd (welch) of the SpectroData to a npz file. @@ -551,20 +560,25 @@ def link_audio_data(self, audio_data: AudioData) -> None: """ if self.begin != audio_data.begin: - raise ValueError("The begin of the audio data doesn't match.") + msg = "The begin of the audio data doesn't match." + raise ValueError(msg) if self.end != audio_data.end: - raise ValueError("The end of the audio data doesn't match.") + msg = "The end of the audio data doesn't match." + raise ValueError(msg) if self.fft.fs != audio_data.sample_rate: - raise ValueError("The sample rate of the audio data doesn't match.") + msg = "The sample rate of the audio data doesn't match." + raise ValueError(msg) self.audio_data = audio_data - def split(self, nb_subdata: int = 2) -> list[SpectroData]: + def split(self, nb_subdata: int = 2, **kwargs) -> list[SpectroData]: """Split the spectro data object in the specified number of spectro subdata. Parameters ---------- nb_subdata: int Number of subdata in which to split the data. + kwargs: dict + Additionnal keyword arguments. Returns ------- @@ -606,10 +620,12 @@ def _get_value_from_items(self, items: list[SpectroItem]) -> np.ndarray: for i in items[1:] if not i.is_empty ): - raise ValueError("Items don't have the same frequency bins.") + msg = "Items don't have the same frequency bins." + raise ValueError(msg) if len({i.file.get_fft().delta_t for i in items if not i.is_empty}) > 1: - raise ValueError("Items don't have the same time resolution.") + msg = "Items don't have the same time resolution." + raise ValueError(msg) return np.hstack( [item.get_value(fft=self.fft, sx_dtype=self.sx_dtype) for item in items], @@ -657,12 +673,55 @@ def get_overlapped_bins(cls, sd1: SpectroData, sd2: SpectroData) -> np.ndarray: p1_le = fft.lower_border_end[1] - fft.p_min return sd_part1[:, -p1_le:] + sd_part2[:, :p1_le] + @classmethod + def make_file(cls, path: Path, begin: Timestamp) -> SpectroFile: + return SpectroFile(path=path, begin=begin) + + @classmethod + def make_item( + cls, + file: TFile | None = None, + begin: Timestamp | None = None, + end: Timestamp | None = None, + ) -> SpectroItem: + return SpectroItem( + file=file, + begin=begin, + end=end, + ) + + @classmethod + def from_base_dict( + cls, + dictionary: dict, + files: list[TFile], + begin: Timestamp, + end: Timestamp, + **kwargs, + ) -> Self: + return cls.from_files( + files=files, + begin=begin, + end=end, + colormap=dictionary["colormap"], + ) + + def make_split_data( + self, + files: list[TFile], + begin: Timestamp, + end: Timestamp, + **kwargs, + ) -> Self: ... + @classmethod def from_files( cls, files: list[SpectroFile], begin: Timestamp | None = None, end: Timestamp | None = None, + name: str | None = None, + **kwargs, ) -> SpectroData: """Return a SpectroData object from a list of SpectroFiles. @@ -683,49 +742,19 @@ def from_files( The SpectroData object. """ - instance = cls.from_base_data( - BaseData.from_files(files, begin, end), - fft=files[0].get_fft(), + fft = files[0].get_fft() + instance = super().from_files( + files=files, # This way, this static error doesn't appear to the user + begin=begin, + end=end, + name=name, + fft=fft, + **kwargs, ) if not any(file.sx_dtype is complex for file in files): instance.sx_dtype = float return instance - @classmethod - def from_base_data( - cls, - data: BaseData, - fft: ShortTimeFFT, - colormap: str | None = None, - ) -> SpectroData: - """Return an SpectroData object from a BaseData object. - - Parameters - ---------- - data: BaseData - BaseData object to convert to SpectroData. - fft: ShortTimeFFT - The ShortTimeFFT used to compute the spectrogram. - colormap: str - The colormap used to plot the spectrogram. - - Returns - ------- - SpectroData: - The SpectroData object. - - """ - items = [SpectroItem.from_base_item(item) for item in data.items] - db_ref = next((f.file.db_ref for f in items if f.file.db_ref is not None), None) - v_lim = next((f.file.v_lim for f in items if f.file.v_lim is not None), None) - return cls( - [SpectroItem.from_base_item(item) for item in data.items], - fft=fft, - db_ref=db_ref, - v_lim=v_lim, - colormap=colormap, - ) - @classmethod def from_audio_data( cls, @@ -815,7 +844,7 @@ def from_dict( cls, dictionary: dict, sft: ShortTimeFFT | None = None, - ) -> SpectroData: + ) -> Self: """Deserialize a SpectroData from a dictionary. Parameters @@ -832,20 +861,19 @@ def from_dict( The deserialized SpectroData. """ + if dictionary["audio_data"] is None: + return super().from_dict( + dictionary=dictionary, + colormap=dictionary["colormap"], + ) + if sft is None and dictionary["sft"] is None: - raise ValueError("Missing sft") + msg = "Missing SFT" + raise ValueError(msg) if sft is None: dictionary["sft"]["win"] = np.array(dictionary["sft"]["win"]) sft = ShortTimeFFT(**dictionary["sft"]) - if dictionary["audio_data"] is None: - base_data = BaseData.from_dict(dictionary) - return cls.from_base_data( - data=base_data, - fft=sft, - colormap=dictionary["colormap"], - ) - audio_data = AudioData.from_dict(dictionary["audio_data"]) v_lim = ( None if type(dictionary["v_lim"]) is object else tuple(dictionary["v_lim"]) diff --git a/src/osekit/core_api/spectro_file.py b/src/osekit/core_api/spectro_file.py index 5654fa90..d5b84b12 100644 --- a/src/osekit/core_api/spectro_file.py +++ b/src/osekit/core_api/spectro_file.py @@ -164,8 +164,3 @@ def get_fft(self) -> ShortTimeFFT: fs=self.sample_rate, mfft=self.mfft, ) - - @classmethod - def from_base_file(cls, file: BaseFile) -> SpectroFile: - """Return a SpectroFile object from a BaseFile object.""" - return cls(path=file.path, begin=file.begin) diff --git a/src/osekit/core_api/spectro_item.py b/src/osekit/core_api/spectro_item.py index 4183ed16..b0faba75 100644 --- a/src/osekit/core_api/spectro_item.py +++ b/src/osekit/core_api/spectro_item.py @@ -6,7 +6,6 @@ import numpy as np -from osekit.core_api.base_file import BaseFile from osekit.core_api.base_item import BaseItem from osekit.core_api.spectro_file import SpectroFile @@ -45,20 +44,6 @@ def time_resolution(self) -> Timedelta: """Time resolution of the associated SpectroFile.""" return None if self.is_empty else self.file.time_resolution - @classmethod - def from_base_item(cls, item: BaseItem) -> SpectroItem: - """Return a SpectroItem object from a BaseItem object.""" - file = item.file - if not file or isinstance(file, SpectroFile): - return cls(file=file, begin=item.begin, end=item.end) - if isinstance(file, BaseFile): - return cls( - file=SpectroFile.from_base_file(file), - begin=item.begin, - end=item.end, - ) - raise TypeError - def get_value( self, fft: ShortTimeFFT | None = None, @@ -75,10 +60,11 @@ def get_value( if sx_dtype is float: sx = abs(sx) ** 2 if sx_dtype is complex: - raise TypeError( - "Cannot convert absolute npz values to complex sx values." - "Change the SpectroData dtype to absolute.", + msg = ( + "Cannot convert absolute npz values to complex sx values. " + "Change the SpectroData dtype to absolute." ) + raise TypeError(msg) return sx diff --git a/tests/test_core_api_base.py b/tests/test_core_api_base.py index 777f697e..b58cae22 100644 --- a/tests/test_core_api_base.py +++ b/tests/test_core_api_base.py @@ -42,6 +42,10 @@ def make_split_data( ) -> Self: return DummyData(files, begin, end, **kwargs) + @classmethod + def make_file(cls, path: Path, begin: Timestamp) -> DummyFile: + return DummyFile(path=path, begin=begin) + @classmethod def make_item( cls, From 8ff3119ec9cd55546fc8bfba1dbd0a786202974e Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Thu, 15 Jan 2026 17:18:00 +0100 Subject: [PATCH 08/17] fix SpectroData.from_files v_lim and db_ref parsing --- src/osekit/core_api/spectro_data.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/osekit/core_api/spectro_data.py b/src/osekit/core_api/spectro_data.py index ef81d3e2..4d1fdb86 100644 --- a/src/osekit/core_api/spectro_data.py +++ b/src/osekit/core_api/spectro_data.py @@ -743,12 +743,16 @@ def from_files( """ fft = files[0].get_fft() + db_ref = next((f.db_ref for f in files if f.db_ref is not None), None) + v_lim = next((f.v_lim for f in files if f.v_lim is not None), None) instance = super().from_files( files=files, # This way, this static error doesn't appear to the user begin=begin, end=end, name=name, fft=fft, + db_ref=db_ref, + v_lim=v_lim, **kwargs, ) if not any(file.sx_dtype is complex for file in files): From 47ba13523e67a61abf8c994e47eda3377b85767f Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Thu, 15 Jan 2026 18:04:47 +0100 Subject: [PATCH 09/17] adapt SpectroDataset to ABC base --- src/osekit/core_api/audio_dataset.py | 7 +- src/osekit/core_api/spectro_dataset.py | 121 ++++++++++++++++++------- src/osekit/core_api/spectro_file.py | 3 + 3 files changed, 94 insertions(+), 37 deletions(-) diff --git a/src/osekit/core_api/audio_dataset.py b/src/osekit/core_api/audio_dataset.py index d480c29c..f3be47f1 100644 --- a/src/osekit/core_api/audio_dataset.py +++ b/src/osekit/core_api/audio_dataset.py @@ -7,7 +7,6 @@ from __future__ import annotations import logging -from pathlib import Path from typing import TYPE_CHECKING, Literal, Self from osekit.core_api.audio_data import AudioData @@ -18,6 +17,8 @@ from osekit.utils.multiprocess_utils import multiprocess if TYPE_CHECKING: + from pathlib import Path + import pytz from pandas import Timedelta, Timestamp @@ -294,8 +295,8 @@ def from_files( # noqa: PLR0913 Returns ------- - BaseDataset[TItem, TFile]: - The DataBase object. + AudioDataset: + The AudioDataset object. """ return super().from_files( diff --git a/src/osekit/core_api/spectro_dataset.py b/src/osekit/core_api/spectro_dataset.py index b2e96add..c8f1552a 100644 --- a/src/osekit/core_api/spectro_dataset.py +++ b/src/osekit/core_api/spectro_dataset.py @@ -7,14 +7,14 @@ from __future__ import annotations from pathlib import Path -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Literal, Self import numpy as np from pandas import DataFrame from scipy.signal import ShortTimeFFT from osekit.config import DPDEFAULT -from osekit.core_api.base_dataset import BaseDataset +from osekit.core_api.base_dataset import BaseDataset, TFile from osekit.core_api.frequency_scale import Scale from osekit.core_api.json_serializer import deserialize_json from osekit.core_api.spectro_data import SpectroData @@ -40,6 +40,7 @@ class SpectroDataset(BaseDataset[SpectroData, SpectroFile]): sentinel_value = object() _bypass_multiprocessing_on_dataset = False data_cls = SpectroData + file_cls = SpectroFile def __init__( self, @@ -344,9 +345,8 @@ def link_audio_dataset( """ if len(audio_dataset.data) != len(self.data): - raise ValueError( - "The audio dataset doesn't contain the same number of data as the spectro dataset.", - ) + msg = "The audio dataset doesn't contain the same number of data as the spectro dataset." + raise ValueError(msg) last = len(self.data) if last is None else last @@ -485,9 +485,8 @@ def from_folder( # noqa: PLR0913 overlap: float = 0.0, data_duration: Timedelta | None = None, name: str | None = None, - v_lim: tuple[float, float] | None | object = sentinel_value, **kwargs: any, - ) -> SpectroDataset: + ) -> Self: """Return a SpectroDataset from a folder containing the spectro files. Parameters @@ -526,8 +525,6 @@ def from_folder( # noqa: PLR0913 Else, one data object will cover the whole time period. name: str|None Name of the dataset. - v_lim: tuple[float, float] | None - Limits (in dB) of the colormap used for plotting the spectrogram. kwargs: any Keyword arguments passed to the BaseDataset.from_folder classmethod. @@ -537,10 +534,7 @@ def from_folder( # noqa: PLR0913 The audio dataset. """ - kwargs.update( - {"file_class": SpectroFile, "supported_file_extensions": [".npz"]}, - ) - base_dataset = BaseDataset.from_folder( + return super().from_folder( folder=folder, strptime_format=strptime_format, begin=begin, @@ -549,35 +543,94 @@ def from_folder( # noqa: PLR0913 mode=mode, overlap=overlap, data_duration=data_duration, - **kwargs, + name=name, ) - sft = next(iter(base_dataset.files)).get_fft() - return cls.from_base_dataset( - base_dataset=base_dataset, - fft=sft, + + @classmethod + def from_files( # noqa: PLR0913 + cls, + files: list[SpectroFile], + begin: Timestamp | None = None, + end: Timestamp | None = None, + name: str | None = None, + mode: Literal["files", "timedelta_total", "timedelta_file"] = "timedelta_total", + overlap: float = 0.0, + data_duration: Timedelta | None = None, + **kwargs, + ) -> AudioDataset: + """Return an SpectroDataset object from a list of SpectroFiles. + + Parameters + ---------- + files: list[SpectroFile] + The list of files contained in the Dataset. + begin: Timestamp | None + Begin of the first data object. + Defaulted to the begin of the first file. + end: Timestamp | None + End of the last data object. + Defaulted to the end of the last file. + mode: Literal["files", "timedelta_total", "timedelta_file"] + Mode of creation of the dataset data from the original files. + "files": one data will be created for each file. + "timedelta_total": data objects of duration equal to data_duration will + be created from the begin timestamp to the end timestamp. + "timedelta_file": data objects of duration equal to data_duration will + be created from the beginning of the first file that the begin timestamp is into, until it would resume + in a data beginning between two files. Then, the next data object will be created from the + beginning of the next original file and so on. + overlap: float + Overlap percentage between consecutive data. + data_duration: Timedelta | None + Duration of the data objects. + If mode is set to "files", this parameter has no effect. + If provided, data will be evenly distributed between begin and end. + Else, one data object will cover the whole time period. + sample_rate: float | None + Sample rate of the audio data objects. + name: str|None + Name of the dataset. + instrument: Instrument | None + Instrument that might be used to obtain acoustic pressure from + the wav audio data. + normalization: Normalization + The type of normalization to apply to the audio data. + + Returns + ------- + SpectroDataset: + The SpectroDataset object. + + """ + return super().from_files( + files=files, + begin=begin, + end=end, name=name, - v_lim=v_lim, + mode=mode, + overlap=overlap, + data_duration=data_duration, ) @classmethod - def from_base_dataset( + def data_from_dict(cls, dictionary: dict) -> list[SpectroData]: + return [SpectroData.from_dict(data) for data in dictionary.values()] + + @classmethod + def data_from_files( cls, - base_dataset: BaseDataset, - fft: ShortTimeFFT, + files: list[TFile], + begin: Timestamp | None = None, + end: Timestamp | None = None, name: str | None = None, - colormap: str | None = None, - scale: Scale | None = None, - v_lim: tuple[float, float] | None | object = sentinel_value, - ) -> SpectroDataset: - """Return a SpectroDataset object from a BaseDataset object.""" - return cls( - [ - SpectroData.from_base_data(data=data, fft=fft, colormap=colormap) - for data in base_dataset.data - ], + **kwargs, + ) -> SpectroData: + return SpectroData.from_files( + files=files, + begin=begin, + end=end, name=name, - scale=scale, - v_lim=v_lim, + **kwargs, ) @classmethod diff --git a/src/osekit/core_api/spectro_file.py b/src/osekit/core_api/spectro_file.py index d5b84b12..b846d3b7 100644 --- a/src/osekit/core_api/spectro_file.py +++ b/src/osekit/core_api/spectro_file.py @@ -6,6 +6,7 @@ from __future__ import annotations +import typing from typing import TYPE_CHECKING import numpy as np @@ -27,6 +28,8 @@ class SpectroFile(BaseFile): Metadata (time_resolution) are stored as separate arrays. """ + supported_extensions: typing.ClassVar = [".npz"] + def __init__( self, path: PathLike | str, From e38552fa9a70ee0ca2001098e7cf6bde991ed465 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Fri, 16 Jan 2026 09:44:38 +0100 Subject: [PATCH 10/17] call super().__init__ in BaseFile init --- src/osekit/core_api/base_file.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/osekit/core_api/base_file.py b/src/osekit/core_api/base_file.py index 1566a4fc..93ebb53a 100644 --- a/src/osekit/core_api/base_file.py +++ b/src/osekit/core_api/base_file.py @@ -79,7 +79,7 @@ def __init__( if begin is None and strptime_format is None: raise ValueError("Either begin or strptime_format must be specified") - self.begin = ( + begin = ( begin if begin is not None else strptime_from_text( @@ -91,7 +91,8 @@ def __init__( if timezone: self.begin = localize_timestamp(self.begin, timezone) - self.end = end if end is not None else (self.begin + Timedelta(seconds=1)) + end = end if end is not None else (self.begin + Timedelta(seconds=1)) + super().__init__(begin=begin, end=end) @abstractmethod def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: From 44ec9716e8a2e284cae4b15837635668495382b2 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Fri, 16 Jan 2026 09:49:12 +0100 Subject: [PATCH 11/17] lint Items --- src/osekit/core_api/audio_item.py | 2 +- src/osekit/core_api/base_item.py | 12 ++++++------ src/osekit/core_api/spectro_item.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/osekit/core_api/audio_item.py b/src/osekit/core_api/audio_item.py index c32442ce..9207eabc 100644 --- a/src/osekit/core_api/audio_item.py +++ b/src/osekit/core_api/audio_item.py @@ -36,7 +36,7 @@ def __init__( It is defaulted to the AudioFile end. """ - super().__init__(file, begin, end) + super().__init__(file=file, begin=begin, end=end) @property def sample_rate(self) -> float: diff --git a/src/osekit/core_api/base_item.py b/src/osekit/core_api/base_item.py index 7d8c89a4..8fc35538 100644 --- a/src/osekit/core_api/base_item.py +++ b/src/osekit/core_api/base_item.py @@ -6,7 +6,7 @@ from __future__ import annotations from abc import ABC -from typing import TYPE_CHECKING, Generic, TypeVar +from typing import TYPE_CHECKING, TypeVar import numpy as np @@ -19,7 +19,7 @@ TFile = TypeVar("TFile", bound=BaseFile) -class BaseItem(Generic[TFile], Event, ABC): +class BaseItem[TFile: BaseFile](Event, ABC): """Base class for the Item objects. An Item correspond to a portion of a File object. @@ -52,10 +52,10 @@ def __init__( self.end = end return - self.begin = ( - max(begin, self.file.begin) if begin is not None else self.file.begin - ) - self.end = min(end, self.file.end) if end is not None else self.file.end + begin = max(begin, self.file.begin) if begin is not None else self.file.begin + end = min(end, self.file.end) if end is not None else self.file.end + + super().__init__(begin=begin, end=end) def get_value(self) -> np.ndarray: """Get the values from the File between the begin and stop timestamps. diff --git a/src/osekit/core_api/spectro_item.py b/src/osekit/core_api/spectro_item.py index b0faba75..31cc869a 100644 --- a/src/osekit/core_api/spectro_item.py +++ b/src/osekit/core_api/spectro_item.py @@ -37,7 +37,7 @@ def __init__( It is defaulted to the SpectroFile end. """ - super().__init__(file, begin, end) + super().__init__(file=file, begin=begin, end=end) @property def time_resolution(self) -> Timedelta: From 6ebd629a08b06ad329d31eb58a92b5e49897ddbb Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Fri, 16 Jan 2026 09:52:49 +0100 Subject: [PATCH 12/17] fix BaseFile __init__ begin calls --- src/osekit/core_api/base_data.py | 1 + src/osekit/core_api/base_file.py | 7 ++++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/osekit/core_api/base_data.py b/src/osekit/core_api/base_data.py index 6c5d9ac9..0dd968bd 100644 --- a/src/osekit/core_api/base_data.py +++ b/src/osekit/core_api/base_data.py @@ -69,6 +69,7 @@ def __init__( self._begin = min(item.begin for item in self.items) self._end = max(item.end for item in self.items) self._name = name + super().__init__(self._begin, self._end) def __eq__(self, other: BaseData) -> bool: """Override __eq__.""" diff --git a/src/osekit/core_api/base_file.py b/src/osekit/core_api/base_file.py index 93ebb53a..69e73967 100644 --- a/src/osekit/core_api/base_file.py +++ b/src/osekit/core_api/base_file.py @@ -77,7 +77,8 @@ def __init__( self.path = Path(path) if begin is None and strptime_format is None: - raise ValueError("Either begin or strptime_format must be specified") + msg = "Either begin or strptime_format must be specified" + raise ValueError(msg) begin = ( begin @@ -89,9 +90,9 @@ def __init__( ) if timezone: - self.begin = localize_timestamp(self.begin, timezone) + begin = localize_timestamp(begin, timezone) - end = end if end is not None else (self.begin + Timedelta(seconds=1)) + end = end if end is not None else (begin + Timedelta(seconds=1)) super().__init__(begin=begin, end=end) @abstractmethod From e6ab4ba390e596a1315164b250a8de07200065ef Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Fri, 16 Jan 2026 09:59:26 +0100 Subject: [PATCH 13/17] set boolean args as keyword only --- src/osekit/core_api/audio_data.py | 1 + src/osekit/core_api/base_data.py | 5 ++++- src/osekit/core_api/spectro_data.py | 1 + 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/osekit/core_api/audio_data.py b/src/osekit/core_api/audio_data.py index c8ec428f..c0cf0e9c 100644 --- a/src/osekit/core_api/audio_data.py +++ b/src/osekit/core_api/audio_data.py @@ -223,6 +223,7 @@ def write( self, folder: Path, subtype: str | None = None, + *, link: bool = False, ) -> None: """Write the audio data to file. diff --git a/src/osekit/core_api/base_data.py b/src/osekit/core_api/base_data.py index 0dd968bd..5ae7df7d 100644 --- a/src/osekit/core_api/base_data.py +++ b/src/osekit/core_api/base_data.py @@ -142,7 +142,7 @@ def create_directories(path: Path) -> None: path.mkdir(parents=True, exist_ok=True, mode=DPDEFAULT) @abstractmethod - def write(self, folder: Path, link: bool = False) -> None: + def write(self, folder: Path, *, link: bool = False) -> None: """Abstract method for writing data to file.""" @abstractmethod @@ -184,6 +184,8 @@ def from_dict(cls, dictionary: dict, **kwargs) -> BaseData: ---------- dictionary: dict The serialized dictionary representing the BaseData. + kwargs: + Keyword arguments that are passed to cls.from_base_dict(). Returns ------- @@ -232,6 +234,7 @@ def from_base_dict( files: list[TFile], begin: Timestamp, end: Timestamp, + **kwargs, ) -> Self: ... @property diff --git a/src/osekit/core_api/spectro_data.py b/src/osekit/core_api/spectro_data.py index 4d1fdb86..1e08adf8 100644 --- a/src/osekit/core_api/spectro_data.py +++ b/src/osekit/core_api/spectro_data.py @@ -485,6 +485,7 @@ def write( self, folder: Path, sx: np.ndarray | None = None, + *, link: bool = False, ) -> None: """Write the Spectro data to file. From 4caefd0f9d3a2acb500d8538d2e478b39749f3b4 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Fri, 16 Jan 2026 10:49:11 +0100 Subject: [PATCH 14/17] lint AudioData --- src/osekit/core_api/audio_data.py | 126 +++++++++++++++++++++++++--- src/osekit/core_api/base_data.py | 56 ++++++++----- src/osekit/core_api/spectro_data.py | 39 ++++++++- tests/test_core_api_base.py | 8 +- 4 files changed, 189 insertions(+), 40 deletions(-) diff --git a/src/osekit/core_api/audio_data.py b/src/osekit/core_api/audio_data.py index c0cf0e9c..a2d54333 100644 --- a/src/osekit/core_api/audio_data.py +++ b/src/osekit/core_api/audio_data.py @@ -15,7 +15,7 @@ from osekit.core_api.audio_file import AudioFile from osekit.core_api.audio_item import AudioItem -from osekit.core_api.base_data import BaseData, TFile +from osekit.core_api.base_data import BaseData from osekit.core_api.instrument import Instrument from osekit.utils.audio_utils import Normalization, normalize, resample @@ -121,19 +121,60 @@ def normalization_values(self, value: dict | None) -> None: ) @classmethod - def make_item( + def _make_item( cls, file: AudioFile | None = None, begin: Timestamp | None = None, end: Timestamp | None = None, ) -> AudioItem: + """Make an AudioItem for a given AudioFile between begin and end timestamps. + + Parameters + ---------- + file: AudioFile + AudioFile of the item. + begin: Timestamp + Begin of the item. + end: + End of the item. + + Returns + ------- + An AudioItem for the AudioFile file, between the begin and end timestamps. + + """ return AudioItem(file=file, begin=begin, end=end) @classmethod - def make_file(cls, path: Path, begin: Timestamp) -> AudioFile: + def _make_file(cls, path: Path, begin: Timestamp) -> AudioFile: + """Make an AudioFile from a path and a begin timestamp. + + Parameters + ---------- + path: Path + Path to the file. + begin: Timestamp + Begin of the file. + + Returns + ------- + AudioFile: + The audio file. + + """ return AudioFile(path=path, begin=begin) def get_normalization_values(self) -> dict: + """Return the values used for normalizing the audio data. + + Returns + ------- + dict: + "mean": mean value to substract to center values on 0. + "peak": peak value for PEAK normalization + "std": standard deviation used for z-score normalization + + """ values = np.array(self.get_raw_value()) return { "mean": values.mean(), @@ -320,13 +361,33 @@ def split( normalization_values=normalization_values, ) - def make_split_data( + def _make_split_data( self, - files: list[TFile], + files: list[AudioFile], begin: Timestamp, end: Timestamp, - **kwargs, - ) -> Self: + **kwargs: tuple[float, float, float], + ) -> AudioData: + """Return an AudioData object after an AudioData.split() call. + + Parameters + ---------- + files: list[AudioFile] + The AudioFiles of the original AudioData. + begin: Timestamp + The begin timestamp of the split AudioData. + end: Timestamp + The end timestamp of the split AudioData. + kwargs: + normalization_values: tuple[float, float, float] + Values used for normalizing the split AudioData. + + Returns + ------- + AudioData: + The AudioData instance. + + """ return AudioData.from_files( files=files, begin=begin, @@ -422,20 +483,31 @@ def to_dict(self) -> dict: ) @classmethod - def from_base_dict( + def _from_base_dict( cls, dictionary: dict, files: list[AudioFile], begin: Timestamp, end: Timestamp, - **kwargs, + **kwargs, # noqa: ANN003 ) -> AudioData: - """Deserialize an AudioData from a dictionary. + """Deserialize the AudioData-specific parts of a Data dictionary. + + This method is called within the BaseData.from_dict() method, which + deserializes the base files, begin and end parameters. Parameters ---------- dictionary: dict The serialized dictionary representing the AudioData. + files: list[AudioFile] + The list of deserialized AudioFiles. + begin: Timestamp + The deserialized begin timestamp. + end: Timestamp + The deserialized end timestamp. + kwargs: + None. Returns ------- @@ -459,14 +531,44 @@ def from_base_dict( ) @classmethod - def from_files( # noqa: D102 + def from_files( cls, files: list[AudioFile], # The method is redefined just to specify the type begin: Timestamp | None = None, end: Timestamp | None = None, name: str | None = None, - **kwargs, + **kwargs, # noqa: ANN003 ) -> AudioData: + """Return a, AudioData object from a list of AudioFiles. + + Parameters + ---------- + files: list[AudioFile] + List of AudioFiles containing the data. + begin: Timestamp | None + Begin of the data object. + Defaulted to the begin of the first file. + end: Timestamp | None + End of the data object. + Defaulted to the end of the last file. + name: str | None + Name of the exported files. + kwargs + Keyword arguments that are passed to the cls constructor. + sample_rate: int + The sample rate of the audio data. + instrument: Instrument | None + Instrument that might be used to obtain acoustic pressure from + the wav audio data. + normalization: Normalization + The type of normalization to apply to the audio data. + + Returns + ------- + Self: + The AudioData object. + + """ return super().from_files( files=files, # This way, this static error doesn't appear to the user begin=begin, diff --git a/src/osekit/core_api/base_data.py b/src/osekit/core_api/base_data.py index 5ae7df7d..2880f8ec 100644 --- a/src/osekit/core_api/base_data.py +++ b/src/osekit/core_api/base_data.py @@ -9,7 +9,7 @@ import itertools from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Generic, Self, TypeVar +from typing import Self, TypeVar import numpy as np from pandas import Timestamp, date_range @@ -29,7 +29,7 @@ TFile = TypeVar("TFile", bound=BaseFile) -class BaseData(Generic[TItem, TFile], Event, ABC): +class BaseData[TItem: BaseItem, TFile: BaseFile](Event, ABC): """Base class for the Data objects. Data corresponds to data scattered through different Files. @@ -64,7 +64,7 @@ def __init__( """ if not items: - items = [self.make_item(begin=begin, end=end)] + items = [self._make_item(begin=begin, end=end)] self.items = items self._begin = min(item.begin for item in self.items) self._end = max(item.end for item in self.items) @@ -177,7 +177,11 @@ def to_dict(self) -> dict: } @classmethod - def from_dict(cls, dictionary: dict, **kwargs) -> BaseData: + def from_dict( + cls, + dictionary: dict, + **kwargs, # noqa: ANN003 + ) -> BaseData: """Deserialize a BaseData from a dictionary. Parameters @@ -194,7 +198,7 @@ def from_dict(cls, dictionary: dict, **kwargs) -> BaseData: """ files = [ - cls.make_file( + cls._make_file( path=Path(file["path"]), begin=strptime_from_text( file["begin"], @@ -205,7 +209,7 @@ def from_dict(cls, dictionary: dict, **kwargs) -> BaseData: ] begin = Timestamp(dictionary["begin"]) end = Timestamp(dictionary["end"]) - return cls.from_base_dict( + return cls._from_base_dict( dictionary=dictionary, files=files, begin=begin, @@ -215,11 +219,13 @@ def from_dict(cls, dictionary: dict, **kwargs) -> BaseData: @classmethod @abstractmethod - def make_file(cls, path: Path, begin: Timestamp) -> type[TFile]: ... + def _make_file(cls, path: Path, begin: Timestamp) -> type[TFile]: + """Make a File from a path and a begin timestamp.""" + ... @classmethod @abstractmethod - def make_item( + def _make_item( cls, file: TFile | None = None, begin: Timestamp | None = None, @@ -228,13 +234,13 @@ def make_item( @classmethod @abstractmethod - def from_base_dict( + def _from_base_dict( cls, dictionary: dict, files: list[TFile], begin: Timestamp, end: Timestamp, - **kwargs, + **kwargs, # noqa: ANN003 ) -> Self: ... @property @@ -242,13 +248,19 @@ def files(self) -> set[TFile]: """All files referred to by the Data.""" return {item.file for item in self.items if item.file is not None} - def split(self, nb_subdata: int = 2, **kwargs) -> list[BaseData]: + def split( + self, + nb_subdata: int = 2, + **kwargs, # noqa: ANN003 + ) -> list[BaseData]: """Split the data object in the specified number of subdata. Parameters ---------- nb_subdata: int Number of subdata in which to split the data. + **kwargs: + Keyword arguments that are passed to self.make_split_data(). Returns ------- @@ -259,18 +271,20 @@ def split(self, nb_subdata: int = 2, **kwargs) -> list[BaseData]: dates = date_range(self.begin, self.end, periods=nb_subdata + 1) subdata_dates = itertools.pairwise(dates) return [ - self.make_split_data(files=list(self.files), begin=b, end=e, **kwargs) + self._make_split_data(files=list(self.files), begin=b, end=e, **kwargs) for b, e in subdata_dates ] @abstractmethod - def make_split_data( + def _make_split_data( self, files: list[TFile], begin: Timestamp, end: Timestamp, - **kwargs, - ) -> Self: ... + **kwargs, # noqa: ANN003 + ) -> Self: + """Make a Data object after a .split() call.""" + ... @classmethod def from_files( @@ -279,7 +293,7 @@ def from_files( begin: Timestamp | None = None, end: Timestamp | None = None, name: str | None = None, - **kwargs: Any, + **kwargs, # noqa: ANN003 ) -> Self: """Return a Data object from a list of Files. @@ -295,6 +309,8 @@ def from_files( Defaulted to the end of the last file. name: str | None Name of the exported files. + kwargs: + Keyword arguments that are passed to the cls constructor. Returns ------- @@ -342,13 +358,13 @@ def items_from_files( ] items = [ - cls.make_item(file=file, begin=begin, end=end) for file in included_files + cls._make_item(file=file, begin=begin, end=end) for file in included_files ] if not items: - items.append(cls.make_item(begin=begin, end=end)) + items.append(cls._make_item(begin=begin, end=end)) if (first_item := sorted(items, key=lambda item: item.begin)[0]).begin > begin: - items.append(cls.make_item(begin=begin, end=first_item.begin)) + items.append(cls._make_item(begin=begin, end=first_item.begin)) if (last_item := sorted(items, key=lambda item: item.end)[-1]).end < end: - items.append(cls.make_item(begin=last_item.end, end=end)) + items.append(cls._make_item(begin=last_item.end, end=end)) items = Event.remove_overlaps(items) return Event.fill_gaps(items, cls.item_cls) diff --git a/src/osekit/core_api/spectro_data.py b/src/osekit/core_api/spectro_data.py index 1e08adf8..1e186880 100644 --- a/src/osekit/core_api/spectro_data.py +++ b/src/osekit/core_api/spectro_data.py @@ -675,16 +675,47 @@ def get_overlapped_bins(cls, sd1: SpectroData, sd2: SpectroData) -> np.ndarray: return sd_part1[:, -p1_le:] + sd_part2[:, :p1_le] @classmethod - def make_file(cls, path: Path, begin: Timestamp) -> SpectroFile: + def _make_file(cls, path: Path, begin: Timestamp) -> SpectroFile: + """Make a SpectroFile from a path and a begin timestamp. + + Parameters + ---------- + path: Path + Path to the file. + begin: Timestamp + Begin of the file. + + Returns + ------- + SpectroFile: + The spectro file. + + """ return SpectroFile(path=path, begin=begin) @classmethod - def make_item( + def _make_item( cls, file: TFile | None = None, begin: Timestamp | None = None, end: Timestamp | None = None, ) -> SpectroItem: + """Make a SpectroItem for a given SpectroFile between begin andend timestamps. + + Parameters + ---------- + file: SpectroFile + SpectroFile of the item. + begin: Timestamp + Begin of the item. + end: + End of the item. + + Returns + ------- + A SpectroItem for the SpectroFile file, between the begin and end timestamps. + + """ return SpectroItem( file=file, begin=begin, @@ -692,7 +723,7 @@ def make_item( ) @classmethod - def from_base_dict( + def _from_base_dict( cls, dictionary: dict, files: list[TFile], @@ -707,7 +738,7 @@ def from_base_dict( colormap=dictionary["colormap"], ) - def make_split_data( + def _make_split_data( self, files: list[TFile], begin: Timestamp, diff --git a/tests/test_core_api_base.py b/tests/test_core_api_base.py index b58cae22..64e138a9 100644 --- a/tests/test_core_api_base.py +++ b/tests/test_core_api_base.py @@ -33,7 +33,7 @@ def write(self, folder: Path, link: bool = False) -> None: ... def link(self, folder: Path) -> None: ... - def make_split_data( + def _make_split_data( self, files: list[DummyFile], begin: Timestamp, @@ -43,11 +43,11 @@ def make_split_data( return DummyData(files, begin, end, **kwargs) @classmethod - def make_file(cls, path: Path, begin: Timestamp) -> DummyFile: + def _make_file(cls, path: Path, begin: Timestamp) -> DummyFile: return DummyFile(path=path, begin=begin) @classmethod - def make_item( + def _make_item( cls, file: TFile | None = None, begin: Timestamp | None = None, @@ -56,7 +56,7 @@ def make_item( return DummyItem(file=file, begin=begin, end=end) @classmethod - def from_base_dict( + def _from_base_dict( cls, dictionary: dict, files: list[TFile], From 4f2afe19660a0794f90610d1fa48b5f8ba87832c Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Fri, 16 Jan 2026 10:57:02 +0100 Subject: [PATCH 15/17] lint SpectroData --- src/osekit/core_api/spectro_data.py | 52 ++++++++++++++++++++++++----- 1 file changed, 43 insertions(+), 9 deletions(-) diff --git a/src/osekit/core_api/spectro_data.py b/src/osekit/core_api/spectro_data.py index 1e186880..bc8e6d0e 100644 --- a/src/osekit/core_api/spectro_data.py +++ b/src/osekit/core_api/spectro_data.py @@ -571,15 +571,19 @@ def link_audio_data(self, audio_data: AudioData) -> None: raise ValueError(msg) self.audio_data = audio_data - def split(self, nb_subdata: int = 2, **kwargs) -> list[SpectroData]: + def split( + self, + nb_subdata: int = 2, + **kwargs, # noqa: ANN003 + ) -> list[SpectroData]: """Split the spectro data object in the specified number of spectro subdata. Parameters ---------- nb_subdata: int Number of subdata in which to split the data. - kwargs: dict - Additionnal keyword arguments. + kwargs: + None Returns ------- @@ -700,7 +704,7 @@ def _make_item( begin: Timestamp | None = None, end: Timestamp | None = None, ) -> SpectroItem: - """Make a SpectroItem for a given SpectroFile between begin andend timestamps. + """Make a SpectroItem for a given SpectroFile between begin and end timestamps. Parameters ---------- @@ -726,11 +730,35 @@ def _make_item( def _from_base_dict( cls, dictionary: dict, - files: list[TFile], + files: list[SpectroFile], begin: Timestamp, end: Timestamp, - **kwargs, - ) -> Self: + **kwargs, # noqa: ANN003 + ) -> SpectroData: + """Deserialize the SpectroData-specific parts of a Data dictionary. + + This method is called within the BaseData.from_dict() method, which + deserializes the base files, begin and end parameters. + + Parameters + ---------- + dictionary: dict + The serialized dictionary representing the SpectroData. + files: list[SpectroFile] + The list of deserialized SpectroFiles. + begin: Timestamp + The deserialized begin timestamp. + end: Timestamp + The deserialized end timestamp. + kwargs: + None. + + Returns + ------- + SpectroData + The deserialized SpectroData. + + """ return cls.from_files( files=files, begin=begin, @@ -743,8 +771,8 @@ def _make_split_data( files: list[TFile], begin: Timestamp, end: Timestamp, - **kwargs, - ) -> Self: ... + **kwargs, # noqa: ANN003 + ) -> SpectroData: ... @classmethod def from_files( @@ -767,6 +795,12 @@ def from_files( end: Timestamp | None End of the data object. Defaulted to the end of the last file. + name: str | None + Name of the exported files. + kwargs + Keyword arguments that are passed to the cls constructor. + colormap: str + Colormap to use for plotting the spectrogram. Returns ------- From 27d7a18e0a3f1574e8d332a2a6de0ac7ab613caf Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Fri, 16 Jan 2026 11:13:44 +0100 Subject: [PATCH 16/17] lint AudioDataset --- src/osekit/core_api/audio_dataset.py | 31 ++++++++++++++-- src/osekit/core_api/base_dataset.py | 50 ++++++++++++++++---------- src/osekit/core_api/spectro_dataset.py | 4 +-- tests/test_core_api_base.py | 4 +-- 4 files changed, 64 insertions(+), 25 deletions(-) diff --git a/src/osekit/core_api/audio_dataset.py b/src/osekit/core_api/audio_dataset.py index f3be47f1..570ce3ba 100644 --- a/src/osekit/core_api/audio_dataset.py +++ b/src/osekit/core_api/audio_dataset.py @@ -148,7 +148,7 @@ def write( ) @classmethod - def data_from_dict(cls, dictionary: dict) -> list[AudioData]: + def _data_from_dict(cls, dictionary: dict) -> list[AudioData]: return [AudioData.from_dict(data) for data in dictionary.values()] @classmethod @@ -313,14 +313,39 @@ def from_files( # noqa: PLR0913 ) @classmethod - def data_from_files( + def _data_from_files( cls, files: list[AudioFile], begin: Timestamp | None = None, end: Timestamp | None = None, name: str | None = None, - **kwargs, + **kwargs, # noqa: ANN003 ) -> AudioData: + """Return an AudioData object from a list of AudioFiles. + + The AudioData starts at the begin and ends at end. + + Parameters + ---------- + files: list[AudioFile] + List of AudioFiles contained in the AudioData. + begin: Timestamp | None + Begin of the AudioData. + Defaulted to the begin of the first AudioFile. + end: Timestamp | None + End of the AudioData. + Defaulted to the end of the last AudioFile. + name: str|None + Name of the AudioData. + kwargs: + Keyword arguments to pass to the AudioData.from_files() method. + + Returns + ------- + AudioData: + The AudioData object. + + """ return AudioData.from_files( files=files, begin=begin, diff --git a/src/osekit/core_api/base_dataset.py b/src/osekit/core_api/base_dataset.py index 7aaa8baa..b67fe6eb 100644 --- a/src/osekit/core_api/base_dataset.py +++ b/src/osekit/core_api/base_dataset.py @@ -10,7 +10,7 @@ from abc import ABC, abstractmethod from bisect import bisect from pathlib import Path -from typing import TYPE_CHECKING, Generic, Literal, Self, TypeVar +from typing import TYPE_CHECKING, Literal, Self, TypeVar from pandas import Timedelta, Timestamp, date_range from soundfile import LibsndfileError @@ -31,7 +31,7 @@ TFile = TypeVar("TFile", bound=BaseFile) -class BaseDataset(Generic[TData, TFile], Event, ABC): +class BaseDataset[TData: BaseData, TFile: BaseFile](Event, ABC): """Base class for Dataset objects. Datasets are collections of Data, with methods @@ -160,7 +160,11 @@ def move_files(self, folder: Path) -> None: @property def data_duration(self) -> Timedelta: - """Return the most frequent duration among durations of the data of this dataset, rounded to the nearest second.""" + """Return the most frequent duration among the data of this dataset. + + The duration is rounded to the nearest second. + + """ data_durations = [ Timedelta(data.duration).round(freq="1s") for data in self.data ] @@ -169,9 +173,10 @@ def data_duration(self) -> Timedelta: def write( self, folder: Path, - link: bool = False, first: int = 0, last: int | None = None, + *, + link: bool = False, ) -> None: """Write all data objects in the specified folder. @@ -227,7 +232,7 @@ def from_dict(cls, dictionary: dict) -> Self: The deserialized BaseDataset. """ - data = cls.data_from_dict(dictionary["data"]) + data = cls._data_from_dict(dictionary["data"]) name = dictionary["name"] suffix = dictionary["suffix"] folder = Path(dictionary["folder"]) @@ -235,7 +240,9 @@ def from_dict(cls, dictionary: dict) -> Self: @classmethod @abstractmethod - def data_from_dict(cls, dictionary: dict) -> list[TData]: ... + def _data_from_dict(cls, dictionary: dict) -> list[TData]: + """Return a list of Data from a serialized dictionnary.""" + ... def write_json(self, folder: Path) -> None: """Write a serialized BaseDataset to a JSON file.""" @@ -268,7 +275,7 @@ def from_files( # noqa: PLR0913 data_duration: Timedelta | None = None, overlap: float = 0.0, name: str | None = None, - **kwargs, + **kwargs, # noqa: ANN003 ) -> Self: """Return a base Dataset object from a list of Files. @@ -288,8 +295,9 @@ def from_files( # noqa: PLR0913 "timedelta_total": data objects of duration equal to data_duration will be created from the begin timestamp to the end timestamp. "timedelta_file": data objects of duration equal to data_duration will - be created from the beginning of the first file that the begin timestamp is into, until it would resume - in a data beginning between two files. Then, the next data object will be created from the + be created from the beginning of the first file that the begin timestamp + is into, until it would resume in a data beginning between two files. + Then, the next data object will be created from the beginning of the next original file and so on. data_duration: Timedelta | None Duration of the data objects. @@ -300,6 +308,8 @@ def from_files( # noqa: PLR0913 Overlap percentage between consecutive data. name: str|None Name of the dataset. + kwargs: + Keyword arguments to pass to the cls.data_from_files() method. Returns ------- @@ -308,7 +318,7 @@ def from_files( # noqa: PLR0913 """ if mode == "files": - data = [cls.data_from_files([f], **kwargs) for f in files] + data = [cls._data_from_files([f], **kwargs) for f in files] data = BaseData.remove_overlaps(data) return cls(data=data, name=name) @@ -338,20 +348,22 @@ def from_files( # noqa: PLR0913 ) else: data_base = [ - cls.data_from_files(files=files, begin=begin, end=end, **kwargs), + cls._data_from_files(files=files, begin=begin, end=end, **kwargs), ] return cls(data_base, name=name) @classmethod @abstractmethod - def data_from_files( + def _data_from_files( cls, files: list[TFile], begin: Timestamp | None = None, end: Timestamp | None = None, name: str | None = None, - **kwargs, - ) -> TData: ... + **kwargs, # noqa: ANN003 + ) -> TData: + """Return a base Dataset object between two timestamps from a list of Files.""" + ... @classmethod def _get_data_from_files_timedelta_total( @@ -361,7 +373,7 @@ def _get_data_from_files_timedelta_total( data_duration: Timedelta, files: list[TFile], overlap: float = 0, - **kwargs, + **kwargs, # noqa: ANN003 ) -> list[TData]: if not 0 <= overlap < 1: msg = f"Overlap ({overlap}) must be between 0 and 1." @@ -389,7 +401,7 @@ def _get_data_from_files_timedelta_total( ): last_active_file_index += 1 output.append( - cls.data_from_files( + cls._data_from_files( files[active_file_index:last_active_file_index], data_begin, data_end, @@ -441,7 +453,7 @@ def _get_data_from_files_timedelta_file( files_chunk.append(next_file) output.extend( - cls.data_from_files( + cls._data_from_files( files, data_begin, data_begin + data_duration, @@ -470,7 +482,7 @@ def from_folder( # noqa: PLR0913 data_duration: Timedelta | None = None, first_file_begin: Timestamp | None = None, name: str | None = None, - **kwargs, + **kwargs, # noqa: ANN003 ) -> Self: """Return a Dataset from a folder containing the base files. @@ -517,6 +529,8 @@ def from_folder( # noqa: PLR0913 Will be ignored if striptime_format is specified. name: str|None Name of the dataset. + kwargs: + Keyword arguments to pass to the cls.from_files() method. Returns ------- diff --git a/src/osekit/core_api/spectro_dataset.py b/src/osekit/core_api/spectro_dataset.py index c8f1552a..d03c64a0 100644 --- a/src/osekit/core_api/spectro_dataset.py +++ b/src/osekit/core_api/spectro_dataset.py @@ -613,11 +613,11 @@ def from_files( # noqa: PLR0913 ) @classmethod - def data_from_dict(cls, dictionary: dict) -> list[SpectroData]: + def _data_from_dict(cls, dictionary: dict) -> list[SpectroData]: return [SpectroData.from_dict(data) for data in dictionary.values()] @classmethod - def data_from_files( + def _data_from_files( cls, files: list[TFile], begin: Timestamp | None = None, diff --git a/tests/test_core_api_base.py b/tests/test_core_api_base.py index 64e138a9..f7f7df20 100644 --- a/tests/test_core_api_base.py +++ b/tests/test_core_api_base.py @@ -89,11 +89,11 @@ def from_files( class DummyDataset(BaseDataset[DummyData, DummyFile]): @classmethod - def data_from_dict(cls, dictionary: dict) -> list[TData]: + def _data_from_dict(cls, dictionary: dict) -> list[TData]: return [DummyData.from_dict(data) for data in dictionary.values()] @classmethod - def data_from_files( + def _data_from_files( cls, files: list[DummyFile], begin: Timestamp | None = None, From f6c72409c08a95131ec9c9b34a93afc68b7c325d Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Fri, 16 Jan 2026 11:22:34 +0100 Subject: [PATCH 17/17] lint SpectroDataset --- src/osekit/core_api/audio_dataset.py | 15 +++++- src/osekit/core_api/spectro_dataset.py | 63 +++++++++++++++++++++++--- 2 files changed, 70 insertions(+), 8 deletions(-) diff --git a/src/osekit/core_api/audio_dataset.py b/src/osekit/core_api/audio_dataset.py index 570ce3ba..6c50f362 100644 --- a/src/osekit/core_api/audio_dataset.py +++ b/src/osekit/core_api/audio_dataset.py @@ -149,6 +149,19 @@ def write( @classmethod def _data_from_dict(cls, dictionary: dict) -> list[AudioData]: + """Return the list of AudioData objects from the serialized dictionary. + + Parameters + ---------- + dictionary: dict + Dictionary representing the serialized AudioDataset. + + Returns + ------- + list[AudioData]: + The list of deserialized AudioData objects. + + """ return [AudioData.from_dict(data) for data in dictionary.values()] @classmethod @@ -166,7 +179,7 @@ def from_folder( # noqa: PLR0913 name: str | None = None, instrument: Instrument | None = None, normalization: Normalization = Normalization.RAW, - **kwargs, + **kwargs, # noqa: ANN003 ) -> Self: """Return an AudioDataset from a folder containing the audio files. diff --git a/src/osekit/core_api/spectro_dataset.py b/src/osekit/core_api/spectro_dataset.py index d03c64a0..b8289f83 100644 --- a/src/osekit/core_api/spectro_dataset.py +++ b/src/osekit/core_api/spectro_dataset.py @@ -286,6 +286,7 @@ def _save_all_( data: SpectroData, matrix_folder: Path, spectrogram_folder: Path, + *, link: bool, ) -> SpectroData: """Save the data matrix and spectrogram to disk.""" @@ -298,9 +299,10 @@ def save_all( self, matrix_folder: Path, spectrogram_folder: Path, - link: bool = False, first: int = 0, last: int | None = None, + *, + link: bool = False, ) -> None: """Export both Sx matrices as npz files and spectrograms for each data. @@ -342,10 +344,17 @@ def link_audio_dataset( ---------- audio_dataset: AudioDataset The AudioDataset which data will be linked to the SpectroDataset data. + first: int + Index of the first SpectroData and AudioData to link. + last: int + Index of the last SpectroData and AudioData to link. """ if len(audio_dataset.data) != len(self.data): - msg = "The audio dataset doesn't contain the same number of data as the spectro dataset." + msg = ( + "The audio dataset doesn't contain the same number of data" + " as the spectro dataset." + ) raise ValueError(msg) last = len(self.data) if last is None else last @@ -485,7 +494,7 @@ def from_folder( # noqa: PLR0913 overlap: float = 0.0, data_duration: Timedelta | None = None, name: str | None = None, - **kwargs: any, + **kwargs, # noqa: ANN003 ) -> Self: """Return a SpectroDataset from a folder containing the spectro files. @@ -525,8 +534,8 @@ def from_folder( # noqa: PLR0913 Else, one data object will cover the whole time period. name: str|None Name of the dataset. - kwargs: any - Keyword arguments passed to the BaseDataset.from_folder classmethod. + kwargs: + None. Returns ------- @@ -556,7 +565,7 @@ def from_files( # noqa: PLR0913 mode: Literal["files", "timedelta_total", "timedelta_file"] = "timedelta_total", overlap: float = 0.0, data_duration: Timedelta | None = None, - **kwargs, + **kwargs, # noqa: ANN003 ) -> AudioDataset: """Return an SpectroDataset object from a list of SpectroFiles. @@ -595,6 +604,8 @@ def from_files( # noqa: PLR0913 the wav audio data. normalization: Normalization The type of normalization to apply to the audio data. + kwargs: + None. Returns ------- @@ -614,6 +625,19 @@ def from_files( # noqa: PLR0913 @classmethod def _data_from_dict(cls, dictionary: dict) -> list[SpectroData]: + """Return the list of SpectroData objects from the serialized dictionary. + + Parameters + ---------- + dictionary: dict + Dictionary representing the serialized SpectroDataset. + + Returns + ------- + list[SpectroData]: + The list of deserialized SpectroData objects. + + """ return [SpectroData.from_dict(data) for data in dictionary.values()] @classmethod @@ -623,8 +647,33 @@ def _data_from_files( begin: Timestamp | None = None, end: Timestamp | None = None, name: str | None = None, - **kwargs, + **kwargs, # noqa: ANN003 ) -> SpectroData: + """Return a SpectroData object from a list of SpectroFiles. + + The SpectroData starts at the begin and ends at end. + + Parameters + ---------- + files: list[SpectroFile] + List of SpectroFiles contained in the SpectroData. + begin: Timestamp | None + Begin of the SpectroData. + Defaulted to the begin of the first SpectroFile. + end: Timestamp | None + End of the SpectroData. + Defaulted to the end of the last SpectroFile. + name: str|None + Name of the SpectroData. + kwargs: + Keyword arguments to pass to the SpectroData.from_files() method. + + Returns + ------- + SpectroData: + The SpectroData object. + + """ return SpectroData.from_files( files=files, begin=begin,