diff --git a/src/osekit/core_api/audio_data.py b/src/osekit/core_api/audio_data.py index 1dfa098e..a2d54333 100644 --- a/src/osekit/core_api/audio_data.py +++ b/src/osekit/core_api/audio_data.py @@ -7,7 +7,7 @@ from __future__ import annotations from math import ceil -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Self import numpy as np import soundfile as sf @@ -30,11 +30,14 @@ class AudioData(BaseData[AudioItem, AudioFile]): The data is accessed via an AudioItem object per AudioFile. """ + item_cls = AudioItem + def __init__( self, items: list[AudioItem] | None = None, begin: Timestamp | None = None, end: Timestamp | None = None, + name: str | None = None, sample_rate: int | None = None, instrument: Instrument | None = None, normalization: Normalization = Normalization.RAW, @@ -54,6 +57,8 @@ def __init__( end: Timestamp | None Only effective if items is None. Set the end of the empty data. + name: str | None + Name of the exported files. instrument: Instrument | None Instrument that might be used to obtain acoustic pressure from the wav audio data. @@ -61,7 +66,7 @@ def __init__( The type of normalization to apply to the audio data. """ - super().__init__(items=items, begin=begin, end=end) + super().__init__(items=items, begin=begin, end=end, name=name) self._set_sample_rate(sample_rate=sample_rate) self.instrument = instrument self.normalization = normalization @@ -115,7 +120,61 @@ def normalization_values(self, value: dict | None) -> None: } ) + @classmethod + def _make_item( + cls, + file: AudioFile | None = None, + begin: Timestamp | None = None, + end: Timestamp | None = None, + ) -> AudioItem: + """Make an AudioItem for a given AudioFile between begin and end timestamps. + + Parameters + ---------- + file: AudioFile + AudioFile of the item. + begin: Timestamp + Begin of the item. + end: + End of the item. + + Returns + ------- + An AudioItem for the AudioFile file, between the begin and end timestamps. + + """ + return AudioItem(file=file, begin=begin, end=end) + + @classmethod + def _make_file(cls, path: Path, begin: Timestamp) -> AudioFile: + """Make an AudioFile from a path and a begin timestamp. + + Parameters + ---------- + path: Path + Path to the file. + begin: Timestamp + Begin of the file. + + Returns + ------- + AudioFile: + The audio file. + + """ + return AudioFile(path=path, begin=begin) + def get_normalization_values(self) -> dict: + """Return the values used for normalizing the audio data. + + Returns + ------- + dict: + "mean": mean value to substract to center values on 0. + "peak": peak value for PEAK normalization + "std": standard deviation used for z-score normalization + + """ values = np.array(self.get_raw_value()) return { "mean": values.mean(), @@ -205,6 +264,7 @@ def write( self, folder: Path, subtype: str | None = None, + *, link: bool = False, ) -> None: """Write the audio data to file. @@ -270,7 +330,7 @@ def split( nb_subdata: int = 2, *, pass_normalization: bool = True, - ) -> list[AudioData]: + ) -> list[Self]: """Split the audio data object in the specified number of audio subdata. Parameters @@ -296,16 +356,47 @@ def split( if any(self.normalization_values.values()) else self.get_normalization_values() ) - return [ - AudioData.from_base_data( - data=base_data, - sample_rate=self.sample_rate, - instrument=self.instrument, - normalization=self.normalization, - normalization_values=normalization_values, - ) - for base_data in super().split(nb_subdata) - ] + return super().split( + nb_subdata=nb_subdata, + normalization_values=normalization_values, + ) + + def _make_split_data( + self, + files: list[AudioFile], + begin: Timestamp, + end: Timestamp, + **kwargs: tuple[float, float, float], + ) -> AudioData: + """Return an AudioData object after an AudioData.split() call. + + Parameters + ---------- + files: list[AudioFile] + The AudioFiles of the original AudioData. + begin: Timestamp + The begin timestamp of the split AudioData. + end: Timestamp + The end timestamp of the split AudioData. + kwargs: + normalization_values: tuple[float, float, float] + Values used for normalizing the split AudioData. + + Returns + ------- + AudioData: + The AudioData instance. + + """ + return AudioData.from_files( + files=files, + begin=begin, + end=end, + sample_rate=self.sample_rate, + instrument=self.instrument, + normalization=self.normalization, + normalization_values=kwargs["normalization_values"], + ) def split_frames( self, @@ -335,9 +426,11 @@ def split_frames( """ if start_frame < 0: - raise ValueError("Start_frame must be greater than or equal to 0.") + msg = "Start_frame must be greater than or equal to 0." + raise ValueError(msg) if stop_frame < -1 or stop_frame > self.length: - raise ValueError("Stop_frame must be lower than the length of the data.") + msg = "Stop_frame must be lower than the length of the data." + raise ValueError(msg) start_timestamp = self.begin + Timedelta( seconds=ceil(start_frame / self.sample_rate * 1e9) / 1e9, @@ -390,13 +483,31 @@ def to_dict(self) -> dict: ) @classmethod - def from_dict(cls, dictionary: dict) -> AudioData: - """Deserialize an AudioData from a dictionary. + def _from_base_dict( + cls, + dictionary: dict, + files: list[AudioFile], + begin: Timestamp, + end: Timestamp, + **kwargs, # noqa: ANN003 + ) -> AudioData: + """Deserialize the AudioData-specific parts of a Data dictionary. + + This method is called within the BaseData.from_dict() method, which + deserializes the base files, begin and end parameters. Parameters ---------- dictionary: dict The serialized dictionary representing the AudioData. + files: list[AudioFile] + The list of deserialized AudioFiles. + begin: Timestamp + The deserialized begin timestamp. + end: Timestamp + The deserialized end timestamp. + kwargs: + None. Returns ------- @@ -404,32 +515,31 @@ def from_dict(cls, dictionary: dict) -> AudioData: The deserialized AudioData. """ - base_data = BaseData.from_dict(dictionary) instrument = ( None if dictionary["instrument"] is None else Instrument.from_dict(dictionary["instrument"]) ) - return cls.from_base_data( - data=base_data, + return cls.from_files( + files=files, + begin=begin, + end=end, + instrument=instrument, sample_rate=dictionary["sample_rate"], normalization=Normalization(dictionary["normalization"]), normalization_values=dictionary["normalization_values"], - instrument=instrument, ) @classmethod def from_files( cls, - files: list[AudioFile], + files: list[AudioFile], # The method is redefined just to specify the type begin: Timestamp | None = None, end: Timestamp | None = None, - sample_rate: float | None = None, - instrument: Instrument | None = None, - normalization: Normalization = Normalization.RAW, - normalization_values: dict | None = None, + name: str | None = None, + **kwargs, # noqa: ANN003 ) -> AudioData: - """Return an AudioData object from a list of AudioFiles. + """Return a, AudioData object from a list of AudioFiles. Parameters ---------- @@ -441,65 +551,28 @@ def from_files( end: Timestamp | None End of the data object. Defaulted to the end of the last file. - sample_rate: float | None - Sample rate of the AudioData. - instrument: Instrument | None - Instrument that might be used to obtain acoustic pressure from - the wav audio data. - normalization: Normalization - The type of normalization to apply to the audio data. - normalization_values: dict|None - Mean, peak and std values with which to normalize the data. - - Returns - ------- - AudioData: - The AudioData object. - - """ - return cls.from_base_data( - data=BaseData.from_files(files, begin, end), - sample_rate=sample_rate, - instrument=instrument, - normalization=normalization, - normalization_values=normalization_values, - ) - - @classmethod - def from_base_data( - cls, - data: BaseData, - sample_rate: float | None = None, - instrument: Instrument | None = None, - normalization: Normalization = Normalization.RAW, - normalization_values: dict | None = None, - ) -> AudioData: - """Return an AudioData object from a BaseData object. - - Parameters - ---------- - data: BaseData - BaseData object to convert to AudioData. - sample_rate: float | None - Sample rate of the AudioData. - instrument: Instrument | None - Instrument that might be used to obtain acoustic pressure from - the wav audio data. - normalization: Normalization - The type of normalization to apply to the audio data. - normalization_values: dict|None - Mean, peak and std values with which to normalize the data. + name: str | None + Name of the exported files. + kwargs + Keyword arguments that are passed to the cls constructor. + sample_rate: int + The sample rate of the audio data. + instrument: Instrument | None + Instrument that might be used to obtain acoustic pressure from + the wav audio data. + normalization: Normalization + The type of normalization to apply to the audio data. Returns ------- - AudioData: - The AudioData object. + Self: + The AudioData object. """ - return cls( - items=[AudioItem.from_base_item(item) for item in data.items], - sample_rate=sample_rate, - instrument=instrument, - normalization=normalization, - normalization_values=normalization_values, + return super().from_files( + files=files, # This way, this static error doesn't appear to the user + begin=begin, + end=end, + name=name, + **kwargs, ) diff --git a/src/osekit/core_api/audio_dataset.py b/src/osekit/core_api/audio_dataset.py index 398f38ed..6c50f362 100644 --- a/src/osekit/core_api/audio_dataset.py +++ b/src/osekit/core_api/audio_dataset.py @@ -7,8 +7,7 @@ from __future__ import annotations import logging -from pathlib import Path -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Literal, Self from osekit.core_api.audio_data import AudioData from osekit.core_api.audio_file import AudioFile @@ -18,6 +17,8 @@ from osekit.utils.multiprocess_utils import multiprocess if TYPE_CHECKING: + from pathlib import Path + import pytz from pandas import Timedelta, Timestamp @@ -32,6 +33,8 @@ class AudioDataset(BaseDataset[AudioData, AudioFile]): """ + file_cls = AudioFile + def __init__( self, data: list[AudioData], @@ -145,26 +148,21 @@ def write( ) @classmethod - def from_dict(cls, dictionary: dict) -> AudioDataset: - """Deserialize an AudioDataset from a dictionary. + def _data_from_dict(cls, dictionary: dict) -> list[AudioData]: + """Return the list of AudioData objects from the serialized dictionary. Parameters ---------- dictionary: dict - The serialized dictionary representing the AudioDataset. + Dictionary representing the serialized AudioDataset. Returns ------- - AudioDataset - The deserialized AudioDataset. + list[AudioData]: + The list of deserialized AudioData objects. """ - return cls( - [AudioData.from_dict(d) for d in dictionary["data"].values()], - name=dictionary["name"], - suffix=dictionary["suffix"], - folder=Path(dictionary["folder"]), - ) + return [AudioData.from_dict(data) for data in dictionary.values()] @classmethod def from_folder( # noqa: PLR0913 @@ -181,8 +179,8 @@ def from_folder( # noqa: PLR0913 name: str | None = None, instrument: Instrument | None = None, normalization: Normalization = Normalization.RAW, - **kwargs: any, - ) -> AudioDataset: + **kwargs, # noqa: ANN003 + ) -> Self: """Return an AudioDataset from a folder containing the audio files. Parameters @@ -241,28 +239,18 @@ def from_folder( # noqa: PLR0913 The audio dataset. """ - kwargs.update( - { - "file_class": AudioFile, - "supported_file_extensions": [".wav", ".flac", ".mp3"], - }, - ) - base_dataset = BaseDataset.from_folder( + return super().from_folder( folder=folder, strptime_format=strptime_format, begin=begin, end=end, - timezone=timezone, mode=mode, + timezone=timezone, overlap=overlap, data_duration=data_duration, - **kwargs, - ) - return cls.from_base_dataset( - base_dataset=base_dataset, + sample_rate=sample_rate, name=name, instrument=instrument, - sample_rate=sample_rate, normalization=normalization, ) @@ -272,11 +260,11 @@ def from_files( # noqa: PLR0913 files: list[AudioFile], begin: Timestamp | None = None, end: Timestamp | None = None, + name: str | None = None, mode: Literal["files", "timedelta_total", "timedelta_file"] = "timedelta_total", overlap: float = 0.0, data_duration: Timedelta | None = None, sample_rate: float | None = None, - name: str | None = None, instrument: Instrument | None = None, normalization: Normalization = Normalization.RAW, ) -> AudioDataset: @@ -320,51 +308,67 @@ def from_files( # noqa: PLR0913 Returns ------- - BaseDataset[TItem, TFile]: - The DataBase object. + AudioDataset: + The AudioDataset object. """ - base = BaseDataset.from_files( + return super().from_files( files=files, begin=begin, end=end, - mode=mode, - overlap=overlap, - data_duration=data_duration, - ) - return cls.from_base_dataset( - base, name=name, - sample_rate=sample_rate, instrument=instrument, normalization=normalization, + sample_rate=sample_rate, + mode=mode, + overlap=overlap, + data_duration=data_duration, ) @classmethod - def from_base_dataset( + def _data_from_files( cls, - base_dataset: BaseDataset, - sample_rate: float | None = None, + files: list[AudioFile], + begin: Timestamp | None = None, + end: Timestamp | None = None, name: str | None = None, - instrument: Instrument | None = None, - normalization: Normalization = Normalization.RAW, - ) -> AudioDataset: - """Return an AudioDataset object from a BaseDataset object.""" - return cls( - [ - AudioData.from_base_data( - data=data, - sample_rate=sample_rate, - normalization=normalization, - ) - for data in base_dataset.data - ], + **kwargs, # noqa: ANN003 + ) -> AudioData: + """Return an AudioData object from a list of AudioFiles. + + The AudioData starts at the begin and ends at end. + + Parameters + ---------- + files: list[AudioFile] + List of AudioFiles contained in the AudioData. + begin: Timestamp | None + Begin of the AudioData. + Defaulted to the begin of the first AudioFile. + end: Timestamp | None + End of the AudioData. + Defaulted to the end of the last AudioFile. + name: str|None + Name of the AudioData. + kwargs: + Keyword arguments to pass to the AudioData.from_files() method. + + Returns + ------- + AudioData: + The AudioData object. + + """ + return AudioData.from_files( + files=files, + begin=begin, + end=end, name=name, - instrument=instrument, + **kwargs, ) @classmethod - def from_json(cls, file: Path) -> AudioDataset: + def from_json(cls, file: Path) -> Self: """Deserialize an AudioDataset from a JSON file. Parameters diff --git a/src/osekit/core_api/audio_file.py b/src/osekit/core_api/audio_file.py index 77508b11..5781f356 100644 --- a/src/osekit/core_api/audio_file.py +++ b/src/osekit/core_api/audio_file.py @@ -2,6 +2,7 @@ from __future__ import annotations +import typing from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -21,6 +22,8 @@ class AudioFile(BaseFile): """Audio file associated with timestamps.""" + supported_extensions: typing.ClassVar = [".wav", ".flac", ".mp3"] + def __init__( self, path: PathLike | str, @@ -116,11 +119,6 @@ def frames_indexes(self, start: Timestamp, stop: Timestamp) -> tuple[int, int]: stop_sample = round(((stop - self.begin) * self.sample_rate).total_seconds()) return start_sample, stop_sample - @classmethod - def from_base_file(cls, file: BaseFile) -> AudioFile: - """Return an AudioFile object from a BaseFile object.""" - return cls(path=file.path, begin=file.begin) - def move(self, folder: Path) -> None: """Move the file to the target folder. diff --git a/src/osekit/core_api/audio_item.py b/src/osekit/core_api/audio_item.py index 170c28fd..9207eabc 100644 --- a/src/osekit/core_api/audio_item.py +++ b/src/osekit/core_api/audio_item.py @@ -7,7 +7,6 @@ import numpy as np from osekit.core_api.audio_file import AudioFile -from osekit.core_api.base_file import BaseFile from osekit.core_api.base_item import BaseItem if TYPE_CHECKING: @@ -37,7 +36,7 @@ def __init__( It is defaulted to the AudioFile end. """ - super().__init__(file, begin, end) + super().__init__(file=file, begin=begin, end=end) @property def sample_rate(self) -> float: @@ -65,17 +64,3 @@ def get_value(self) -> np.ndarray: if self.is_empty: return np.zeros((1, self.nb_channels)) return super().get_value() - - @classmethod - def from_base_item(cls, item: BaseItem) -> AudioItem: - """Return an AudioItem object from a BaseItem object.""" - file = item.file - if not file or isinstance(file, AudioFile): - return cls(file=file, begin=item.begin, end=item.end) - if isinstance(file, BaseFile): - return cls( - file=AudioFile.from_base_file(file), - begin=item.begin, - end=item.end, - ) - raise TypeError diff --git a/src/osekit/core_api/base_data.py b/src/osekit/core_api/base_data.py index c1067c9b..2880f8ec 100644 --- a/src/osekit/core_api/base_data.py +++ b/src/osekit/core_api/base_data.py @@ -6,8 +6,10 @@ from __future__ import annotations +import itertools +from abc import ABC, abstractmethod from pathlib import Path -from typing import Generic, TypeVar +from typing import Self, TypeVar import numpy as np from pandas import Timestamp, date_range @@ -27,13 +29,16 @@ TFile = TypeVar("TFile", bound=BaseFile) -class BaseData(Generic[TItem, TFile], Event): +class BaseData[TItem: BaseItem, TFile: BaseFile](Event, ABC): """Base class for the Data objects. Data corresponds to data scattered through different Files. The data is accessed via an Item object per File. """ + item_cls: type[TItem] + file_cls: type[TFile] + def __init__( self, items: list[TItem] | None = None, @@ -59,11 +64,12 @@ def __init__( """ if not items: - items = [BaseItem(begin=begin, end=end)] + items = [self._make_item(begin=begin, end=end)] self.items = items self._begin = min(item.begin for item in self.items) self._end = max(item.end for item in self.items) self._name = name + super().__init__(self._begin, self._end) def __eq__(self, other: BaseData) -> bool: """Override __eq__.""" @@ -135,9 +141,11 @@ def create_directories(path: Path) -> None: """ path.mkdir(parents=True, exist_ok=True, mode=DPDEFAULT) - def write(self, folder: Path, link: bool = False) -> None: + @abstractmethod + def write(self, folder: Path, *, link: bool = False) -> None: """Abstract method for writing data to file.""" + @abstractmethod def link(self, folder: Path) -> None: """Abstract method for linking data to a file in a given folder. @@ -169,13 +177,19 @@ def to_dict(self) -> dict: } @classmethod - def from_dict(cls, dictionary: dict) -> BaseData: + def from_dict( + cls, + dictionary: dict, + **kwargs, # noqa: ANN003 + ) -> BaseData: """Deserialize a BaseData from a dictionary. Parameters ---------- dictionary: dict The serialized dictionary representing the BaseData. + kwargs: + Keyword arguments that are passed to cls.from_base_dict(). Returns ------- @@ -184,35 +198,69 @@ def from_dict(cls, dictionary: dict) -> BaseData: """ files = [ - BaseFile( - Path(file["path"]), + cls._make_file( + path=Path(file["path"]), begin=strptime_from_text( file["begin"], datetime_template=TIMESTAMP_FORMATS_EXPORTED_FILES, ), - end=strptime_from_text( - file["end"], - datetime_template=TIMESTAMP_FORMATS_EXPORTED_FILES, - ), ) for file in dictionary["files"].values() ] begin = Timestamp(dictionary["begin"]) end = Timestamp(dictionary["end"]) - return cls.from_files(files, begin, end) + return cls._from_base_dict( + dictionary=dictionary, + files=files, + begin=begin, + end=end, + **kwargs, + ) + + @classmethod + @abstractmethod + def _make_file(cls, path: Path, begin: Timestamp) -> type[TFile]: + """Make a File from a path and a begin timestamp.""" + ... + + @classmethod + @abstractmethod + def _make_item( + cls, + file: TFile | None = None, + begin: Timestamp | None = None, + end: Timestamp | None = None, + ) -> TItem: ... + + @classmethod + @abstractmethod + def _from_base_dict( + cls, + dictionary: dict, + files: list[TFile], + begin: Timestamp, + end: Timestamp, + **kwargs, # noqa: ANN003 + ) -> Self: ... @property def files(self) -> set[TFile]: """All files referred to by the Data.""" return {item.file for item in self.items if item.file is not None} - def split(self, nb_subdata: int = 2) -> list[BaseData]: + def split( + self, + nb_subdata: int = 2, + **kwargs, # noqa: ANN003 + ) -> list[BaseData]: """Split the data object in the specified number of subdata. Parameters ---------- nb_subdata: int Number of subdata in which to split the data. + **kwargs: + Keyword arguments that are passed to self.make_split_data(). Returns ------- @@ -221,12 +269,23 @@ def split(self, nb_subdata: int = 2) -> list[BaseData]: """ dates = date_range(self.begin, self.end, periods=nb_subdata + 1) - subdata_dates = zip(dates, dates[1:], strict=False) + subdata_dates = itertools.pairwise(dates) return [ - BaseData.from_files(files=list(self.files), begin=b, end=e) + self._make_split_data(files=list(self.files), begin=b, end=e, **kwargs) for b, e in subdata_dates ] + @abstractmethod + def _make_split_data( + self, + files: list[TFile], + begin: Timestamp, + end: Timestamp, + **kwargs, # noqa: ANN003 + ) -> Self: + """Make a Data object after a .split() call.""" + ... + @classmethod def from_files( cls, @@ -234,8 +293,9 @@ def from_files( begin: Timestamp | None = None, end: Timestamp | None = None, name: str | None = None, - ) -> BaseData[TItem, TFile]: - """Return a base DataBase object from a list of Files. + **kwargs, # noqa: ANN003 + ) -> Self: + """Return a Data object from a list of Files. Parameters ---------- @@ -249,15 +309,17 @@ def from_files( Defaulted to the end of the last file. name: str | None Name of the exported files. + kwargs: + Keyword arguments that are passed to the cls constructor. Returns ------- - BaseData[TItem, TFile]: - The BaseData object. + Self: + The Data object. """ items = cls.items_from_files(files=files, begin=begin, end=end) - return cls(items=items, name=name) + return cls(items=items, name=name, **kwargs) @classmethod def items_from_files( @@ -265,7 +327,7 @@ def items_from_files( files: list[TFile], begin: Timestamp | None = None, end: Timestamp | None = None, - ) -> list[BaseItem]: + ) -> list[TItem]: """Return a list of Items from a list of Files and timestamps. The Items range from begin to end. @@ -295,12 +357,14 @@ def items_from_files( file for file in files if file.overlaps(Event(begin=begin, end=end)) ] - items = [BaseItem(file, begin, end) for file in included_files] + items = [ + cls._make_item(file=file, begin=begin, end=end) for file in included_files + ] if not items: - items.append(BaseItem(begin=begin, end=end)) + items.append(cls._make_item(begin=begin, end=end)) if (first_item := sorted(items, key=lambda item: item.begin)[0]).begin > begin: - items.append(BaseItem(begin=begin, end=first_item.begin)) + items.append(cls._make_item(begin=begin, end=first_item.begin)) if (last_item := sorted(items, key=lambda item: item.end)[-1]).end < end: - items.append(BaseItem(begin=last_item.end, end=end)) + items.append(cls._make_item(begin=last_item.end, end=end)) items = Event.remove_overlaps(items) - return Event.fill_gaps(items, BaseItem) + return Event.fill_gaps(items, cls.item_cls) diff --git a/src/osekit/core_api/base_dataset.py b/src/osekit/core_api/base_dataset.py index 222f6e62..b67fe6eb 100644 --- a/src/osekit/core_api/base_dataset.py +++ b/src/osekit/core_api/base_dataset.py @@ -7,9 +7,10 @@ from __future__ import annotations import os +from abc import ABC, abstractmethod from bisect import bisect from pathlib import Path -from typing import TYPE_CHECKING, Generic, Literal, TypeVar +from typing import TYPE_CHECKING, Literal, Self, TypeVar from pandas import Timedelta, Timestamp, date_range from soundfile import LibsndfileError @@ -30,13 +31,15 @@ TFile = TypeVar("TFile", bound=BaseFile) -class BaseDataset(Generic[TData, TFile], Event): +class BaseDataset[TData: BaseData, TFile: BaseFile](Event, ABC): """Base class for Dataset objects. Datasets are collections of Data, with methods that simplify repeated operations on the data. """ + file_cls: type[TFile] + def __init__( self, data: list[TData], @@ -157,7 +160,11 @@ def move_files(self, folder: Path) -> None: @property def data_duration(self) -> Timedelta: - """Return the most frequent duration among durations of the data of this dataset, rounded to the nearest second.""" + """Return the most frequent duration among the data of this dataset. + + The duration is rounded to the nearest second. + + """ data_durations = [ Timedelta(data.duration).round(freq="1s") for data in self.data ] @@ -166,9 +173,10 @@ def data_duration(self) -> Timedelta: def write( self, folder: Path, - link: bool = False, first: int = 0, last: int | None = None, + *, + link: bool = False, ) -> None: """Write all data objects in the specified folder. @@ -210,7 +218,7 @@ def to_dict(self) -> dict: } @classmethod - def from_dict(cls, dictionary: dict) -> BaseDataset: + def from_dict(cls, dictionary: dict) -> Self: """Deserialize a BaseDataset from a dictionary. Parameters @@ -224,19 +232,24 @@ def from_dict(cls, dictionary: dict) -> BaseDataset: The deserialized BaseDataset. """ - return cls( - [BaseData.from_dict(d) for d in dictionary["data"].values()], - name=dictionary["name"], - suffix=dictionary["suffix"], - folder=Path(dictionary["folder"]), - ) + data = cls._data_from_dict(dictionary["data"]) + name = dictionary["name"] + suffix = dictionary["suffix"] + folder = Path(dictionary["folder"]) + return cls(data=data, name=name, suffix=suffix, folder=folder) + + @classmethod + @abstractmethod + def _data_from_dict(cls, dictionary: dict) -> list[TData]: + """Return a list of Data from a serialized dictionnary.""" + ... def write_json(self, folder: Path) -> None: """Write a serialized BaseDataset to a JSON file.""" serialize_json(folder / f"{self.name}.json", self.to_dict()) @classmethod - def from_json(cls, file: Path) -> BaseDataset: + def from_json(cls, file: Path) -> Self: """Deserialize a BaseDataset from a JSON file. Parameters @@ -262,8 +275,9 @@ def from_files( # noqa: PLR0913 data_duration: Timedelta | None = None, overlap: float = 0.0, name: str | None = None, - ) -> BaseDataset: - """Return a base BaseDataset object from a list of Files. + **kwargs, # noqa: ANN003 + ) -> Self: + """Return a base Dataset object from a list of Files. Parameters ---------- @@ -281,8 +295,9 @@ def from_files( # noqa: PLR0913 "timedelta_total": data objects of duration equal to data_duration will be created from the begin timestamp to the end timestamp. "timedelta_file": data objects of duration equal to data_duration will - be created from the beginning of the first file that the begin timestamp is into, until it would resume - in a data beginning between two files. Then, the next data object will be created from the + be created from the beginning of the first file that the begin timestamp + is into, until it would resume in a data beginning between two files. + Then, the next data object will be created from the beginning of the next original file and so on. data_duration: Timedelta | None Duration of the data objects. @@ -293,17 +308,19 @@ def from_files( # noqa: PLR0913 Overlap percentage between consecutive data. name: str|None Name of the dataset. + kwargs: + Keyword arguments to pass to the cls.data_from_files() method. Returns ------- - BaseDataset[TItem, TFile]: - The DataBase object. + Self: + The Dataset object. """ if mode == "files": - data_base = [BaseData.from_files([f]) for f in files] - data_base = BaseData.remove_overlaps(data_base) - return cls(data=data_base, name=name) + data = [cls._data_from_files([f], **kwargs) for f in files] + data = BaseData.remove_overlaps(data) + return cls(data=data, name=name) if not begin: begin = min(file.begin for file in files) @@ -311,35 +328,53 @@ def from_files( # noqa: PLR0913 end = max(file.end for file in files) if data_duration: data_base = ( - cls._get_base_data_from_files_timedelta_total( + cls._get_data_from_files_timedelta_total( begin=begin, end=end, data_duration=data_duration, files=files, overlap=overlap, + **kwargs, ) if mode == "timedelta_total" - else cls._get_base_data_from_files_timedelta_file( + else cls._get_data_from_files_timedelta_file( begin=begin, end=end, data_duration=data_duration, files=files, overlap=overlap, + **kwargs, ) ) else: - data_base = [BaseData.from_files(files, begin=begin, end=end)] + data_base = [ + cls._data_from_files(files=files, begin=begin, end=end, **kwargs), + ] return cls(data_base, name=name) @classmethod - def _get_base_data_from_files_timedelta_total( + @abstractmethod + def _data_from_files( + cls, + files: list[TFile], + begin: Timestamp | None = None, + end: Timestamp | None = None, + name: str | None = None, + **kwargs, # noqa: ANN003 + ) -> TData: + """Return a base Dataset object between two timestamps from a list of Files.""" + ... + + @classmethod + def _get_data_from_files_timedelta_total( cls, begin: Timestamp, end: Timestamp, data_duration: Timedelta, files: list[TFile], overlap: float = 0, - ) -> list[BaseData]: + **kwargs, # noqa: ANN003 + ) -> list[TData]: if not 0 <= overlap < 1: msg = f"Overlap ({overlap}) must be between 0 and 1." raise ValueError(msg) @@ -366,24 +401,26 @@ def _get_base_data_from_files_timedelta_total( ): last_active_file_index += 1 output.append( - BaseData.from_files( + cls._data_from_files( files[active_file_index:last_active_file_index], data_begin, data_end, + **kwargs, ), ) return output @classmethod - def _get_base_data_from_files_timedelta_file( + def _get_data_from_files_timedelta_file( cls, begin: Timestamp, end: Timestamp, data_duration: Timedelta, files: list[TFile], overlap: float = 0, - ) -> list[BaseData]: + **kwargs, + ) -> list[TData]: if not 0 <= overlap < 1: msg = f"Overlap ({overlap}) must be between 0 and 1." raise ValueError(msg) @@ -416,7 +453,12 @@ def _get_base_data_from_files_timedelta_file( files_chunk.append(next_file) output.extend( - BaseData.from_files(files, data_begin, data_begin + data_duration) + cls._data_from_files( + files, + data_begin, + data_begin + data_duration, + **kwargs, + ) for data_begin in date_range( file.begin, files_chunk[-1].end, @@ -429,11 +471,9 @@ def _get_base_data_from_files_timedelta_file( @classmethod def from_folder( # noqa: PLR0913 - cls, + cls: type[Self], folder: Path, strptime_format: str | None, - file_class: type[TFile] = BaseFile, - supported_file_extensions: list[str] | None = None, begin: Timestamp | None = None, end: Timestamp | None = None, timezone: str | pytz.timezone | None = None, @@ -442,8 +482,9 @@ def from_folder( # noqa: PLR0913 data_duration: Timedelta | None = None, first_file_begin: Timestamp | None = None, name: str | None = None, - ) -> BaseDataset: - """Return a BaseDataset from a folder containing the base files. + **kwargs, # noqa: ANN003 + ) -> Self: + """Return a Dataset from a folder containing the base files. Parameters ---------- @@ -455,10 +496,6 @@ def from_folder( # noqa: PLR0913 If None, the first audio file of the folder will start at first_file_begin, and each following file will start at the end of the previous one. - file_class: type[Tfile] - Derived type of BaseFile used to instantiate the dataset. - supported_file_extensions: list[str] - List of supported file extensions for parsing TFiles. begin: Timestamp | None The begin of the dataset. Defaulted to the begin of the first file. @@ -492,15 +529,15 @@ def from_folder( # noqa: PLR0913 Will be ignored if striptime_format is specified. name: str|None Name of the dataset. + kwargs: + Keyword arguments to pass to the cls.from_files() method. Returns ------- - Basedataset: - The base dataset. + Self: + The dataset. """ - if supported_file_extensions is None: - supported_file_extensions = [] valid_files = [] rejected_files = [] first_file_begin = first_file_begin or Timestamp("2020-01-01 00:00:00") @@ -508,10 +545,8 @@ def from_folder( # noqa: PLR0913 sorted(folder.iterdir()), disable=os.getenv("DISABLE_TQDM", "False").lower() in ("true", "1", "t"), ): - is_file_ok = _parse_file( + is_file_ok = cls._parse_file( file=file, - file_class=file_class, - supported_file_extensions=supported_file_extensions, strptime_format=strptime_format, timezone=timezone, begin_timestamp=first_file_begin, @@ -528,9 +563,10 @@ def from_folder( # noqa: PLR0913 ) if not valid_files: - raise FileNotFoundError(f"No valid file found in {folder}.") + msg = f"No valid file found in {folder}" + raise FileNotFoundError(msg) - return BaseDataset.from_files( + return cls.from_files( files=valid_files, begin=begin, end=end, @@ -538,29 +574,33 @@ def from_folder( # noqa: PLR0913 overlap=overlap, data_duration=data_duration, name=name, + **kwargs, ) - -def _parse_file( - file: Path, - file_class: type, - supported_file_extensions: list[str], - strptime_format: str, - timezone: str | pytz.timezone | None, - begin_timestamp: Timestamp, - valid_files: list[BaseFile], - rejected_files: list[Path], -) -> bool: - if file.suffix.lower() not in supported_file_extensions: - return False - try: - if strptime_format is None: - f = file_class(file, begin=begin_timestamp, timezone=timezone) + @classmethod + def _parse_file( + cls: type[Self], + file: Path, + strptime_format: str, + timezone: str | pytz.timezone | None, + begin_timestamp: Timestamp, + valid_files: list[TFile], + rejected_files: list[Path], + ) -> bool: + if file.suffix.lower() not in cls.file_cls.supported_extensions: + return False + try: + if strptime_format is None: + f = cls.file_cls(file, begin=begin_timestamp, timezone=timezone) + else: + f = cls.file_cls( + file, + strptime_format=strptime_format, + timezone=timezone, + ) + valid_files.append(f) + except (ValueError, LibsndfileError): + rejected_files.append(file) + return False else: - f = file_class(file, strptime_format=strptime_format, timezone=timezone) - valid_files.append(f) - except (ValueError, LibsndfileError): - rejected_files.append(file) - return False - else: - return True + return True diff --git a/src/osekit/core_api/base_file.py b/src/osekit/core_api/base_file.py index 370124b3..69e73967 100644 --- a/src/osekit/core_api/base_file.py +++ b/src/osekit/core_api/base_file.py @@ -5,7 +5,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING +import typing +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Self from osekit.config import ( TIMESTAMP_FORMAT_EXPORTED_FILES_LOCALIZED, @@ -28,12 +30,14 @@ from osekit.utils.timestamp_utils import strptime_from_text -class BaseFile(Event): +class BaseFile(Event, ABC): """Base class for the File objects. A File object associates file-written data to timestamps. """ + supported_extensions: typing.ClassVar = [] + def __init__( self, path: PathLike | str, @@ -73,9 +77,10 @@ def __init__( self.path = Path(path) if begin is None and strptime_format is None: - raise ValueError("Either begin or strptime_format must be specified") + msg = "Either begin or strptime_format must be specified" + raise ValueError(msg) - self.begin = ( + begin = ( begin if begin is not None else strptime_from_text( @@ -85,10 +90,12 @@ def __init__( ) if timezone: - self.begin = localize_timestamp(self.begin, timezone) + begin = localize_timestamp(begin, timezone) - self.end = end if end is not None else (self.begin + Timedelta(seconds=1)) + end = end if end is not None else (begin + Timedelta(seconds=1)) + super().__init__(begin=begin, end=end) + @abstractmethod def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: """Return the data that is between start and stop from the file. @@ -106,6 +113,7 @@ def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: The data between start and stop. """ + ... def to_dict(self) -> dict: """Serialize a BaseFile to a dictionary. @@ -123,7 +131,7 @@ def to_dict(self) -> dict: } @classmethod - def from_dict(cls, serialized: dict) -> BaseFile: + def from_dict(cls: type[Self], serialized: dict) -> type[Self]: """Return a BaseFile object from a dictionary. Parameters @@ -151,7 +159,7 @@ def __str__(self) -> str: """Overwrite __str__.""" return self.begin.strftime(TIMESTAMP_FORMAT_EXPORTED_FILES_LOCALIZED) - def __eq__(self, other: BaseFile): + def __eq__(self, other: BaseFile) -> bool: """Override __eq__.""" if not isinstance(other, BaseFile): return False diff --git a/src/osekit/core_api/base_item.py b/src/osekit/core_api/base_item.py index 1c5e7408..8fc35538 100644 --- a/src/osekit/core_api/base_item.py +++ b/src/osekit/core_api/base_item.py @@ -5,7 +5,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Generic, TypeVar +from abc import ABC +from typing import TYPE_CHECKING, TypeVar import numpy as np @@ -18,7 +19,7 @@ TFile = TypeVar("TFile", bound=BaseFile) -class BaseItem(Generic[TFile], Event): +class BaseItem[TFile: BaseFile](Event, ABC): """Base class for the Item objects. An Item correspond to a portion of a File object. @@ -51,10 +52,10 @@ def __init__( self.end = end return - self.begin = ( - max(begin, self.file.begin) if begin is not None else self.file.begin - ) - self.end = min(end, self.file.end) if end is not None else self.file.end + begin = max(begin, self.file.begin) if begin is not None else self.file.begin + end = min(end, self.file.end) if end is not None else self.file.end + + super().__init__(begin=begin, end=end) def get_value(self) -> np.ndarray: """Get the values from the File between the begin and stop timestamps. diff --git a/src/osekit/core_api/spectro_data.py b/src/osekit/core_api/spectro_data.py index f4e3c1f3..bc8e6d0e 100644 --- a/src/osekit/core_api/spectro_data.py +++ b/src/osekit/core_api/spectro_data.py @@ -8,7 +8,7 @@ import gc import itertools -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Literal, Self import matplotlib.pyplot as plt import numpy as np @@ -20,7 +20,7 @@ TIMESTAMP_FORMATS_EXPORTED_FILES, ) from osekit.core_api.audio_data import AudioData -from osekit.core_api.base_data import BaseData +from osekit.core_api.base_data import BaseData, TFile from osekit.core_api.spectro_file import SpectroFile from osekit.core_api.spectro_item import SpectroItem @@ -39,12 +39,15 @@ class SpectroData(BaseData[SpectroItem, SpectroFile]): The data is accessed via a SpectroItem object per SpectroFile. """ + item_cls = SpectroItem + def __init__( self, items: list[SpectroItem] | None = None, audio_data: AudioData = None, begin: Timestamp | None = None, end: Timestamp | None = None, + name: str | None = None, fft: ShortTimeFFT | None = None, db_ref: float | None = None, v_lim: tuple[float, float] | None = None, @@ -64,6 +67,8 @@ def __init__( end: Timestamp | None Only effective if items is None. Set the end of the empty data. + name: str | None + Name of the exported files. fft: ShortTimeFFT The short time FFT used for computing the spectrogram. db_ref: float | None @@ -75,7 +80,7 @@ def __init__( Colormap to use for plotting the spectrogram. """ - super().__init__(items=items, begin=begin, end=end) + super().__init__(items=items, begin=begin, end=end, name=name) self.audio_data = audio_data self.fft = fft self._sx_dtype = complex @@ -174,9 +179,10 @@ def sx_dtype(self) -> type[complex]: return self._sx_dtype @sx_dtype.setter - def sx_dtype(self, dtype: type[complex]) -> [complex, float]: + def sx_dtype(self, dtype: type[complex]) -> None: if dtype not in (complex, float): - raise ValueError("dtype must be complex or float.") + msg = "dtype must be complex or float." + raise ValueError(msg) self._sx_dtype = dtype @property @@ -241,7 +247,8 @@ def get_value(self) -> np.ndarray: if not all(item.is_empty for item in self.items): return self._get_value_from_items(self.items) if not self.audio_data or not self.fft: - raise ValueError("SpectroData should have either items or audio_data.") + msg = "SpectroData should have either items or audio_data." + raise ValueError(msg) sx = self.fft.stft( x=self.audio_data.get_value_calibrated()[ @@ -275,9 +282,10 @@ def get_welch( self, nperseg: int | None = None, detrend: str | callable | False = "constant", - return_onesided: bool = True, scaling: Literal["density", "spectrum"] = "density", average: Literal["mean", "median"] = "mean", + *, + return_onesided: bool = True, ) -> np.ndarray: """Estimate power spectral density of the SpectroData using Welch's method. @@ -334,9 +342,10 @@ def write_welch( px: np.ndarray | None = None, nperseg: int | None = None, detrend: str | callable | False = "constant", - return_onesided: bool = True, scaling: Literal["density", "spectrum"] = "density", average: Literal["mean", "median"] = "mean", + *, + return_onesided: bool = True, ) -> None: """Write the psd (welch) of the SpectroData to a npz file. @@ -476,6 +485,7 @@ def write( self, folder: Path, sx: np.ndarray | None = None, + *, link: bool = False, ) -> None: """Write the Spectro data to file. @@ -551,20 +561,29 @@ def link_audio_data(self, audio_data: AudioData) -> None: """ if self.begin != audio_data.begin: - raise ValueError("The begin of the audio data doesn't match.") + msg = "The begin of the audio data doesn't match." + raise ValueError(msg) if self.end != audio_data.end: - raise ValueError("The end of the audio data doesn't match.") + msg = "The end of the audio data doesn't match." + raise ValueError(msg) if self.fft.fs != audio_data.sample_rate: - raise ValueError("The sample rate of the audio data doesn't match.") + msg = "The sample rate of the audio data doesn't match." + raise ValueError(msg) self.audio_data = audio_data - def split(self, nb_subdata: int = 2) -> list[SpectroData]: + def split( + self, + nb_subdata: int = 2, + **kwargs, # noqa: ANN003 + ) -> list[SpectroData]: """Split the spectro data object in the specified number of spectro subdata. Parameters ---------- nb_subdata: int Number of subdata in which to split the data. + kwargs: + None Returns ------- @@ -606,10 +625,12 @@ def _get_value_from_items(self, items: list[SpectroItem]) -> np.ndarray: for i in items[1:] if not i.is_empty ): - raise ValueError("Items don't have the same frequency bins.") + msg = "Items don't have the same frequency bins." + raise ValueError(msg) if len({i.file.get_fft().delta_t for i in items if not i.is_empty}) > 1: - raise ValueError("Items don't have the same time resolution.") + msg = "Items don't have the same time resolution." + raise ValueError(msg) return np.hstack( [item.get_value(fft=self.fft, sx_dtype=self.sx_dtype) for item in items], @@ -658,56 +679,128 @@ def get_overlapped_bins(cls, sd1: SpectroData, sd2: SpectroData) -> np.ndarray: return sd_part1[:, -p1_le:] + sd_part2[:, :p1_le] @classmethod - def from_files( + def _make_file(cls, path: Path, begin: Timestamp) -> SpectroFile: + """Make a SpectroFile from a path and a begin timestamp. + + Parameters + ---------- + path: Path + Path to the file. + begin: Timestamp + Begin of the file. + + Returns + ------- + SpectroFile: + The spectro file. + + """ + return SpectroFile(path=path, begin=begin) + + @classmethod + def _make_item( cls, - files: list[SpectroFile], + file: TFile | None = None, begin: Timestamp | None = None, end: Timestamp | None = None, + ) -> SpectroItem: + """Make a SpectroItem for a given SpectroFile between begin and end timestamps. + + Parameters + ---------- + file: SpectroFile + SpectroFile of the item. + begin: Timestamp + Begin of the item. + end: + End of the item. + + Returns + ------- + A SpectroItem for the SpectroFile file, between the begin and end timestamps. + + """ + return SpectroItem( + file=file, + begin=begin, + end=end, + ) + + @classmethod + def _from_base_dict( + cls, + dictionary: dict, + files: list[SpectroFile], + begin: Timestamp, + end: Timestamp, + **kwargs, # noqa: ANN003 ) -> SpectroData: - """Return a SpectroData object from a list of SpectroFiles. + """Deserialize the SpectroData-specific parts of a Data dictionary. + + This method is called within the BaseData.from_dict() method, which + deserializes the base files, begin and end parameters. Parameters ---------- + dictionary: dict + The serialized dictionary representing the SpectroData. files: list[SpectroFile] - List of SpectroFiles containing the data. - begin: Timestamp | None - Begin of the data object. - Defaulted to the begin of the first file. - end: Timestamp | None - End of the data object. - Defaulted to the end of the last file. + The list of deserialized SpectroFiles. + begin: Timestamp + The deserialized begin timestamp. + end: Timestamp + The deserialized end timestamp. + kwargs: + None. Returns ------- - SpectroData: - The SpectroData object. + SpectroData + The deserialized SpectroData. """ - instance = cls.from_base_data( - BaseData.from_files(files, begin, end), - fft=files[0].get_fft(), + return cls.from_files( + files=files, + begin=begin, + end=end, + colormap=dictionary["colormap"], ) - if not any(file.sx_dtype is complex for file in files): - instance.sx_dtype = float - return instance + + def _make_split_data( + self, + files: list[TFile], + begin: Timestamp, + end: Timestamp, + **kwargs, # noqa: ANN003 + ) -> SpectroData: ... @classmethod - def from_base_data( + def from_files( cls, - data: BaseData, - fft: ShortTimeFFT, - colormap: str | None = None, + files: list[SpectroFile], + begin: Timestamp | None = None, + end: Timestamp | None = None, + name: str | None = None, + **kwargs, ) -> SpectroData: - """Return an SpectroData object from a BaseData object. + """Return a SpectroData object from a list of SpectroFiles. Parameters ---------- - data: BaseData - BaseData object to convert to SpectroData. - fft: ShortTimeFFT - The ShortTimeFFT used to compute the spectrogram. - colormap: str - The colormap used to plot the spectrogram. + files: list[SpectroFile] + List of SpectroFiles containing the data. + begin: Timestamp | None + Begin of the data object. + Defaulted to the begin of the first file. + end: Timestamp | None + End of the data object. + Defaulted to the end of the last file. + name: str | None + Name of the exported files. + kwargs + Keyword arguments that are passed to the cls constructor. + colormap: str + Colormap to use for plotting the spectrogram. Returns ------- @@ -715,16 +808,22 @@ def from_base_data( The SpectroData object. """ - items = [SpectroItem.from_base_item(item) for item in data.items] - db_ref = next((f.file.db_ref for f in items if f.file.db_ref is not None), None) - v_lim = next((f.file.v_lim for f in items if f.file.v_lim is not None), None) - return cls( - [SpectroItem.from_base_item(item) for item in data.items], + fft = files[0].get_fft() + db_ref = next((f.db_ref for f in files if f.db_ref is not None), None) + v_lim = next((f.v_lim for f in files if f.v_lim is not None), None) + instance = super().from_files( + files=files, # This way, this static error doesn't appear to the user + begin=begin, + end=end, + name=name, fft=fft, db_ref=db_ref, v_lim=v_lim, - colormap=colormap, + **kwargs, ) + if not any(file.sx_dtype is complex for file in files): + instance.sx_dtype = float + return instance @classmethod def from_audio_data( @@ -815,7 +914,7 @@ def from_dict( cls, dictionary: dict, sft: ShortTimeFFT | None = None, - ) -> SpectroData: + ) -> Self: """Deserialize a SpectroData from a dictionary. Parameters @@ -832,20 +931,19 @@ def from_dict( The deserialized SpectroData. """ + if dictionary["audio_data"] is None: + return super().from_dict( + dictionary=dictionary, + colormap=dictionary["colormap"], + ) + if sft is None and dictionary["sft"] is None: - raise ValueError("Missing sft") + msg = "Missing SFT" + raise ValueError(msg) if sft is None: dictionary["sft"]["win"] = np.array(dictionary["sft"]["win"]) sft = ShortTimeFFT(**dictionary["sft"]) - if dictionary["audio_data"] is None: - base_data = BaseData.from_dict(dictionary) - return cls.from_base_data( - data=base_data, - fft=sft, - colormap=dictionary["colormap"], - ) - audio_data = AudioData.from_dict(dictionary["audio_data"]) v_lim = ( None if type(dictionary["v_lim"]) is object else tuple(dictionary["v_lim"]) diff --git a/src/osekit/core_api/spectro_dataset.py b/src/osekit/core_api/spectro_dataset.py index b2e96add..b8289f83 100644 --- a/src/osekit/core_api/spectro_dataset.py +++ b/src/osekit/core_api/spectro_dataset.py @@ -7,14 +7,14 @@ from __future__ import annotations from pathlib import Path -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Literal, Self import numpy as np from pandas import DataFrame from scipy.signal import ShortTimeFFT from osekit.config import DPDEFAULT -from osekit.core_api.base_dataset import BaseDataset +from osekit.core_api.base_dataset import BaseDataset, TFile from osekit.core_api.frequency_scale import Scale from osekit.core_api.json_serializer import deserialize_json from osekit.core_api.spectro_data import SpectroData @@ -40,6 +40,7 @@ class SpectroDataset(BaseDataset[SpectroData, SpectroFile]): sentinel_value = object() _bypass_multiprocessing_on_dataset = False data_cls = SpectroData + file_cls = SpectroFile def __init__( self, @@ -285,6 +286,7 @@ def _save_all_( data: SpectroData, matrix_folder: Path, spectrogram_folder: Path, + *, link: bool, ) -> SpectroData: """Save the data matrix and spectrogram to disk.""" @@ -297,9 +299,10 @@ def save_all( self, matrix_folder: Path, spectrogram_folder: Path, - link: bool = False, first: int = 0, last: int | None = None, + *, + link: bool = False, ) -> None: """Export both Sx matrices as npz files and spectrograms for each data. @@ -341,12 +344,18 @@ def link_audio_dataset( ---------- audio_dataset: AudioDataset The AudioDataset which data will be linked to the SpectroDataset data. + first: int + Index of the first SpectroData and AudioData to link. + last: int + Index of the last SpectroData and AudioData to link. """ if len(audio_dataset.data) != len(self.data): - raise ValueError( - "The audio dataset doesn't contain the same number of data as the spectro dataset.", + msg = ( + "The audio dataset doesn't contain the same number of data" + " as the spectro dataset." ) + raise ValueError(msg) last = len(self.data) if last is None else last @@ -485,9 +494,8 @@ def from_folder( # noqa: PLR0913 overlap: float = 0.0, data_duration: Timedelta | None = None, name: str | None = None, - v_lim: tuple[float, float] | None | object = sentinel_value, - **kwargs: any, - ) -> SpectroDataset: + **kwargs, # noqa: ANN003 + ) -> Self: """Return a SpectroDataset from a folder containing the spectro files. Parameters @@ -526,10 +534,8 @@ def from_folder( # noqa: PLR0913 Else, one data object will cover the whole time period. name: str|None Name of the dataset. - v_lim: tuple[float, float] | None - Limits (in dB) of the colormap used for plotting the spectrogram. - kwargs: any - Keyword arguments passed to the BaseDataset.from_folder classmethod. + kwargs: + None. Returns ------- @@ -537,10 +543,7 @@ def from_folder( # noqa: PLR0913 The audio dataset. """ - kwargs.update( - {"file_class": SpectroFile, "supported_file_extensions": [".npz"]}, - ) - base_dataset = BaseDataset.from_folder( + return super().from_folder( folder=folder, strptime_format=strptime_format, begin=begin, @@ -549,35 +552,134 @@ def from_folder( # noqa: PLR0913 mode=mode, overlap=overlap, data_duration=data_duration, - **kwargs, + name=name, ) - sft = next(iter(base_dataset.files)).get_fft() - return cls.from_base_dataset( - base_dataset=base_dataset, - fft=sft, + + @classmethod + def from_files( # noqa: PLR0913 + cls, + files: list[SpectroFile], + begin: Timestamp | None = None, + end: Timestamp | None = None, + name: str | None = None, + mode: Literal["files", "timedelta_total", "timedelta_file"] = "timedelta_total", + overlap: float = 0.0, + data_duration: Timedelta | None = None, + **kwargs, # noqa: ANN003 + ) -> AudioDataset: + """Return an SpectroDataset object from a list of SpectroFiles. + + Parameters + ---------- + files: list[SpectroFile] + The list of files contained in the Dataset. + begin: Timestamp | None + Begin of the first data object. + Defaulted to the begin of the first file. + end: Timestamp | None + End of the last data object. + Defaulted to the end of the last file. + mode: Literal["files", "timedelta_total", "timedelta_file"] + Mode of creation of the dataset data from the original files. + "files": one data will be created for each file. + "timedelta_total": data objects of duration equal to data_duration will + be created from the begin timestamp to the end timestamp. + "timedelta_file": data objects of duration equal to data_duration will + be created from the beginning of the first file that the begin timestamp is into, until it would resume + in a data beginning between two files. Then, the next data object will be created from the + beginning of the next original file and so on. + overlap: float + Overlap percentage between consecutive data. + data_duration: Timedelta | None + Duration of the data objects. + If mode is set to "files", this parameter has no effect. + If provided, data will be evenly distributed between begin and end. + Else, one data object will cover the whole time period. + sample_rate: float | None + Sample rate of the audio data objects. + name: str|None + Name of the dataset. + instrument: Instrument | None + Instrument that might be used to obtain acoustic pressure from + the wav audio data. + normalization: Normalization + The type of normalization to apply to the audio data. + kwargs: + None. + + Returns + ------- + SpectroDataset: + The SpectroDataset object. + + """ + return super().from_files( + files=files, + begin=begin, + end=end, name=name, - v_lim=v_lim, + mode=mode, + overlap=overlap, + data_duration=data_duration, ) @classmethod - def from_base_dataset( + def _data_from_dict(cls, dictionary: dict) -> list[SpectroData]: + """Return the list of SpectroData objects from the serialized dictionary. + + Parameters + ---------- + dictionary: dict + Dictionary representing the serialized SpectroDataset. + + Returns + ------- + list[SpectroData]: + The list of deserialized SpectroData objects. + + """ + return [SpectroData.from_dict(data) for data in dictionary.values()] + + @classmethod + def _data_from_files( cls, - base_dataset: BaseDataset, - fft: ShortTimeFFT, + files: list[TFile], + begin: Timestamp | None = None, + end: Timestamp | None = None, name: str | None = None, - colormap: str | None = None, - scale: Scale | None = None, - v_lim: tuple[float, float] | None | object = sentinel_value, - ) -> SpectroDataset: - """Return a SpectroDataset object from a BaseDataset object.""" - return cls( - [ - SpectroData.from_base_data(data=data, fft=fft, colormap=colormap) - for data in base_dataset.data - ], + **kwargs, # noqa: ANN003 + ) -> SpectroData: + """Return a SpectroData object from a list of SpectroFiles. + + The SpectroData starts at the begin and ends at end. + + Parameters + ---------- + files: list[SpectroFile] + List of SpectroFiles contained in the SpectroData. + begin: Timestamp | None + Begin of the SpectroData. + Defaulted to the begin of the first SpectroFile. + end: Timestamp | None + End of the SpectroData. + Defaulted to the end of the last SpectroFile. + name: str|None + Name of the SpectroData. + kwargs: + Keyword arguments to pass to the SpectroData.from_files() method. + + Returns + ------- + SpectroData: + The SpectroData object. + + """ + return SpectroData.from_files( + files=files, + begin=begin, + end=end, name=name, - scale=scale, - v_lim=v_lim, + **kwargs, ) @classmethod diff --git a/src/osekit/core_api/spectro_file.py b/src/osekit/core_api/spectro_file.py index 5654fa90..b846d3b7 100644 --- a/src/osekit/core_api/spectro_file.py +++ b/src/osekit/core_api/spectro_file.py @@ -6,6 +6,7 @@ from __future__ import annotations +import typing from typing import TYPE_CHECKING import numpy as np @@ -27,6 +28,8 @@ class SpectroFile(BaseFile): Metadata (time_resolution) are stored as separate arrays. """ + supported_extensions: typing.ClassVar = [".npz"] + def __init__( self, path: PathLike | str, @@ -164,8 +167,3 @@ def get_fft(self) -> ShortTimeFFT: fs=self.sample_rate, mfft=self.mfft, ) - - @classmethod - def from_base_file(cls, file: BaseFile) -> SpectroFile: - """Return a SpectroFile object from a BaseFile object.""" - return cls(path=file.path, begin=file.begin) diff --git a/src/osekit/core_api/spectro_item.py b/src/osekit/core_api/spectro_item.py index 4183ed16..31cc869a 100644 --- a/src/osekit/core_api/spectro_item.py +++ b/src/osekit/core_api/spectro_item.py @@ -6,7 +6,6 @@ import numpy as np -from osekit.core_api.base_file import BaseFile from osekit.core_api.base_item import BaseItem from osekit.core_api.spectro_file import SpectroFile @@ -38,27 +37,13 @@ def __init__( It is defaulted to the SpectroFile end. """ - super().__init__(file, begin, end) + super().__init__(file=file, begin=begin, end=end) @property def time_resolution(self) -> Timedelta: """Time resolution of the associated SpectroFile.""" return None if self.is_empty else self.file.time_resolution - @classmethod - def from_base_item(cls, item: BaseItem) -> SpectroItem: - """Return a SpectroItem object from a BaseItem object.""" - file = item.file - if not file or isinstance(file, SpectroFile): - return cls(file=file, begin=item.begin, end=item.end) - if isinstance(file, BaseFile): - return cls( - file=SpectroFile.from_base_file(file), - begin=item.begin, - end=item.end, - ) - raise TypeError - def get_value( self, fft: ShortTimeFFT | None = None, @@ -75,10 +60,11 @@ def get_value( if sx_dtype is float: sx = abs(sx) ** 2 if sx_dtype is complex: - raise TypeError( - "Cannot convert absolute npz values to complex sx values." - "Change the SpectroData dtype to absolute.", + msg = ( + "Cannot convert absolute npz values to complex sx values. " + "Change the SpectroData dtype to absolute." ) + raise TypeError(msg) return sx diff --git a/tests/conftest.py b/tests/conftest.py index 522a0bd8..29b47339 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,8 +21,6 @@ from osekit.core_api import AudioFileManager from osekit.core_api.audio_data import AudioData from osekit.core_api.audio_file import AudioFile -from osekit.core_api.base_dataset import BaseDataset -from osekit.core_api.base_file import BaseFile from osekit.utils.audio_utils import generate_sample_audio @@ -169,24 +167,6 @@ def mock_open(self: AudioFileManager, path: Path) -> None: return opened_files -@pytest.fixture -def base_dataset(tmp_path: Path) -> BaseDataset: - files = [tmp_path / f"file_{i}.txt" for i in range(5)] - for file in files: - file.touch() - timestamps = pd.date_range( - start=pd.Timestamp("2000-01-01 00:00:00"), - freq="1s", - periods=5, - ) - - bfs = [ - BaseFile(path=file, begin=timestamp, end=timestamp + pd.Timedelta(seconds=1)) - for file, timestamp in zip(files, timestamps, strict=False) - ] - return BaseDataset.from_files(files=bfs, mode="files") - - @pytest.fixture def patch_audio_data(monkeypatch: pytest.MonkeyPatch) -> None: original_init = AudioData.__init__ diff --git a/tests/test_core_api_base.py b/tests/test_core_api_base.py index ea17dfe5..f7f7df20 100644 --- a/tests/test_core_api_base.py +++ b/tests/test_core_api_base.py @@ -1,25 +1,140 @@ from __future__ import annotations +import typing from pathlib import Path -from typing import Literal +from typing import Literal, Self import numpy as np +import pandas as pd import pytest from pandas import Timedelta, Timestamp from osekit.config import TIMESTAMP_FORMATS_EXPORTED_FILES -from osekit.core_api.base_data import BaseData -from osekit.core_api.base_dataset import BaseDataset +from osekit.core_api.base_data import BaseData, TFile +from osekit.core_api.base_dataset import BaseDataset, TData from osekit.core_api.base_file import BaseFile +from osekit.core_api.base_item import BaseItem from osekit.core_api.event import Event +class DummyFile(BaseFile): + supported_extensions: typing.ClassVar = [".wav"] + + def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: ... + + +class DummyItem(BaseItem[DummyFile]): ... + + +class DummyData(BaseData[DummyItem, DummyFile]): + item_cls = DummyItem + + def write(self, folder: Path, link: bool = False) -> None: ... + + def link(self, folder: Path) -> None: ... + + def _make_split_data( + self, + files: list[DummyFile], + begin: Timestamp, + end: Timestamp, + **kwargs, + ) -> Self: + return DummyData(files, begin, end, **kwargs) + + @classmethod + def _make_file(cls, path: Path, begin: Timestamp) -> DummyFile: + return DummyFile(path=path, begin=begin) + + @classmethod + def _make_item( + cls, + file: TFile | None = None, + begin: Timestamp | None = None, + end: Timestamp | None = None, + ) -> DummyItem: + return DummyItem(file=file, begin=begin, end=end) + + @classmethod + def _from_base_dict( + cls, + dictionary: dict, + files: list[TFile], + begin: Timestamp, + end: Timestamp, + ) -> Self: + return cls.from_files( + files=files, + begin=begin, + end=end, + ) + + @classmethod + def from_files( + cls, + files: list[DummyFile], + begin: Timestamp | None = None, + end: Timestamp | None = None, + name: str | None = None, + **kwargs, + ) -> Self: + return super().from_files( + files=files, + begin=begin, + end=end, + name=name, + **kwargs, + ) + + +class DummyDataset(BaseDataset[DummyData, DummyFile]): + @classmethod + def _data_from_dict(cls, dictionary: dict) -> list[TData]: + return [DummyData.from_dict(data) for data in dictionary.values()] + + @classmethod + def _data_from_files( + cls, + files: list[DummyFile], + begin: Timestamp | None = None, + end: Timestamp | None = None, + name: str | None = None, + **kwargs, + ) -> TData: + return DummyData.from_files( + files=files, + begin=begin, + end=end, + name=name, + ) + + file_cls = DummyFile + + +@pytest.fixture +def dummy_dataset(tmp_path: Path) -> DummyDataset: + files = [tmp_path / f"file_{i}.txt" for i in range(5)] + for file in files: + file.touch() + timestamps = pd.date_range( + start=pd.Timestamp("2000-01-01 00:00:00"), + freq="1s", + periods=5, + ) + + dfs = [ + DummyFile(path=file, begin=timestamp, end=timestamp + pd.Timedelta(seconds=1)) + for file, timestamp in zip(files, timestamps, strict=False) + ] + return DummyDataset.from_files(files=dfs, mode="files") + + @pytest.mark.parametrize( ("base_files", "begin", "end", "duration", "expected_data_events"), [ pytest.param( [ - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:00:00"), end=Timestamp("2016-02-05 01:00:00"), @@ -38,7 +153,7 @@ ), pytest.param( [ - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:00:00"), end=Timestamp("2016-02-05 01:00:00"), @@ -57,7 +172,7 @@ ), pytest.param( [ - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:00:00"), end=Timestamp("2016-02-05 01:00:00"), @@ -76,7 +191,7 @@ ), pytest.param( [ - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:00:00"), end=Timestamp("2016-02-05 01:00:00"), @@ -95,7 +210,7 @@ ), pytest.param( [ - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:00:00"), end=Timestamp("2016-02-05 01:00:00"), @@ -114,7 +229,7 @@ ), pytest.param( [ - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:00:00"), end=Timestamp("2016-02-05 01:00:00"), @@ -133,7 +248,7 @@ ), pytest.param( [ - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:00:00"), end=Timestamp("2016-02-05 01:00:00"), @@ -160,7 +275,7 @@ ), pytest.param( [ - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:00:00"), end=Timestamp("2016-02-05 00:50:00"), @@ -187,27 +302,27 @@ ), pytest.param( [ - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:00:00"), end=Timestamp("2016-02-05 00:10:00"), ), - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:10:00"), end=Timestamp("2016-02-05 00:20:00"), ), - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:20:00"), end=Timestamp("2016-02-05 00:30:00"), ), - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:30:00"), end=Timestamp("2016-02-05 00:40:00"), ), - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:40:00"), end=Timestamp("2016-02-05 00:50:00"), @@ -234,27 +349,27 @@ ), pytest.param( [ - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:00:00"), end=Timestamp("2016-02-05 00:10:00"), ), - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:10:00"), end=Timestamp("2016-02-05 00:20:00"), ), - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:20:00"), end=Timestamp("2016-02-05 00:30:00"), ), - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:30:00"), end=Timestamp("2016-02-05 00:40:00"), ), - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:40:00"), end=Timestamp("2016-02-05 00:50:00"), @@ -286,13 +401,13 @@ ], ) def test_base_dataset_from_files( - base_files: list[BaseFile], + base_files: list[DummyFile], begin: Timestamp | None, end: Timestamp | None, duration: Timedelta | None, expected_data_events: list[Event], ) -> None: - ads = BaseDataset.from_files( + ads = DummyDataset.from_files( base_files, begin=begin, end=end, @@ -342,9 +457,9 @@ def test_base_dataset_from_files_overlap_errors(overlap: float, mode: str) -> No match=rf"Overlap \({overlap}\) must be between 0 and 1.", ) as e: assert ( - BaseDataset.from_files( + DummyDataset.from_files( [ - BaseFile( + DummyFile( path=Path("foo"), begin=Timestamp("2016-02-05 00:00:00"), end=Timestamp("2016-02-05 00:10:00"), @@ -361,7 +476,6 @@ def test_base_dataset_from_files_overlap_errors(overlap: float, mode: str) -> No @pytest.mark.parametrize( ( "strptime_format", - "supported_file_extensions", "begin", "end", "timezone", @@ -376,7 +490,6 @@ def test_base_dataset_from_files_overlap_errors(overlap: float, mode: str) -> No [ pytest.param( "%y%m%d%H%M%S", - [".wav"], None, None, None, @@ -399,7 +512,6 @@ def test_base_dataset_from_files_overlap_errors(overlap: float, mode: str) -> No ), pytest.param( None, - [".wav"], None, None, None, @@ -422,7 +534,6 @@ def test_base_dataset_from_files_overlap_errors(overlap: float, mode: str) -> No ), pytest.param( None, - [".wav"], None, None, None, @@ -452,7 +563,6 @@ def test_base_dataset_from_files_overlap_errors(overlap: float, mode: str) -> No ), pytest.param( None, - [".wav", ".mp3"], None, None, None, @@ -461,7 +571,7 @@ def test_base_dataset_from_files_overlap_errors(overlap: float, mode: str) -> No None, Timestamp("2023-12-01 00:00:00"), None, - [Path(r"cool.wav"), Path(r"fun.mp3"), Path("boring.flac")], + [Path(r"cool.wav"), Path("boring.shenanigan")], [ ( Event( @@ -470,19 +580,11 @@ def test_base_dataset_from_files_overlap_errors(overlap: float, mode: str) -> No ), [Path(r"cool.wav")], ), - ( - Event( - begin=Timestamp("2023-12-01 00:00:01"), - end=Timestamp("2023-12-01 00:00:02"), - ), - [Path(r"fun.mp3")], - ), ], id="only_specified_formats_are_kept", ), pytest.param( None, - [".wav"], None, None, None, @@ -526,7 +628,6 @@ def test_base_dataset_from_files_overlap_errors(overlap: float, mode: str) -> No ), pytest.param( None, - [".wav"], None, None, None, @@ -570,7 +671,6 @@ def test_base_dataset_from_files_overlap_errors(overlap: float, mode: str) -> No ), pytest.param( None, - [".wav"], Timestamp("2023-12-01 00:00:00.5"), Timestamp("2023-12-01 00:00:01.5"), None, @@ -600,7 +700,6 @@ def test_base_dataset_from_files_overlap_errors(overlap: float, mode: str) -> No ), pytest.param( "%y%m%d%H%M%S", - [".wav"], Timestamp("2023-12-01 00:00:00.5"), Timestamp("2023-12-01 00:00:01.5"), None, @@ -630,7 +729,6 @@ def test_base_dataset_from_files_overlap_errors(overlap: float, mode: str) -> No ), pytest.param( "%y%m%d%H%M%S", - [".wav"], Timestamp("2023-12-01 00:00:00.5+01:00"), Timestamp("2023-12-01 00:00:01.5+01:00"), "Europe/Warsaw", @@ -660,7 +758,6 @@ def test_base_dataset_from_files_overlap_errors(overlap: float, mode: str) -> No ), pytest.param( "%y%m%d%H%M%S%z", - [".wav"], Timestamp("2023-12-01 01:00:00.5+01:00"), Timestamp("2023-12-01 01:00:01.5+01:00"), "Europe/Warsaw", @@ -693,7 +790,6 @@ def test_base_dataset_from_files_overlap_errors(overlap: float, mode: str) -> No def test_base_dataset_from_folder( monkeypatch: pytest.monkeypatch, strptime_format: str | None, - supported_file_extensions: list[str], begin: Timestamp | None, end: Timestamp | None, timezone: str | None, @@ -707,10 +803,9 @@ def test_base_dataset_from_folder( ) -> None: monkeypatch.setattr(Path, "iterdir", lambda x: files) - bds = BaseDataset.from_folder( + bds = DummyDataset.from_folder( folder=Path("foo"), strptime_format=strptime_format, - supported_file_extensions=supported_file_extensions, begin=begin, end=end, timezone=timezone, @@ -752,7 +847,7 @@ def test_move_file( ) -> None: filename = "cool.txt" (tmp_path / filename).touch(mode=0o666, exist_ok=True) - bf = BaseFile( + bf = DummyFile( tmp_path / filename, begin=Timestamp("2022-04-22 12:12:12"), end=Timestamp("2022-04-22 12:13:12"), @@ -768,27 +863,27 @@ def test_move_file( def test_dataset_move( tmp_path: Path, - base_dataset: BaseDataset, + dummy_dataset: DummyDataset, ) -> None: - origin_files = [Path(str(file.path)) for file in base_dataset.files] + origin_files = [Path(str(file.path)) for file in dummy_dataset.files] # The starting folder of the dataset is the folder where the files are located - assert base_dataset.folder == tmp_path + assert dummy_dataset.folder == tmp_path destination = tmp_path / "destination" - base_dataset.folder = destination + dummy_dataset.folder = destination # Setting the folder shouldn't move the files - assert all(file.path.parent == tmp_path for file in base_dataset.files) + assert all(file.path.parent == tmp_path for file in dummy_dataset.files) assert all(file.exists for file in origin_files) # Folder should be changed when dataset is moved new_destination = tmp_path / "new_destination" - base_dataset.move_files(new_destination) + dummy_dataset.move_files(new_destination) assert new_destination.exists() assert new_destination.is_dir() - assert base_dataset.folder == new_destination + assert dummy_dataset.folder == new_destination assert all((new_destination / file.name).exists() for file in origin_files) assert not any(file.exists() for file in origin_files) @@ -798,7 +893,7 @@ def test_dataset_move( [ pytest.param( [ - BaseFile( + DummyFile( path=Path("cool"), begin=Timestamp("2022-04-22 12:12:12"), end=Timestamp("2022-04-22 12:12:13"), @@ -819,7 +914,7 @@ def test_dataset_move( ), pytest.param( [ - BaseFile( + DummyFile( path=Path("cool"), begin=Timestamp("2022-04-22 12:12:12"), end=Timestamp("2022-04-22 12:12:13"), @@ -847,12 +942,12 @@ def test_dataset_move( ), pytest.param( [ - BaseFile( + DummyFile( path=Path("cool"), begin=Timestamp("2022-04-22 12:12:12"), end=Timestamp("2022-04-22 12:12:13"), ), - BaseFile( + DummyFile( path=Path("fun"), begin=Timestamp("2022-04-22 12:12:13"), end=Timestamp("2022-04-22 12:12:14"), @@ -894,7 +989,7 @@ def test_dataset_move( ), pytest.param( [ - BaseFile( + DummyFile( path=Path("cool"), begin=Timestamp("2022-04-22 12:12:12"), end=Timestamp("2022-04-22 12:12:13"), @@ -915,12 +1010,12 @@ def test_dataset_move( ), pytest.param( [ - BaseFile( + DummyFile( path=Path("cool"), begin=Timestamp("2022-04-22 12:12:12"), end=Timestamp("2022-04-22 12:12:13"), ), - BaseFile( + DummyFile( path=Path("fun"), begin=Timestamp("2022-04-22 12:12:13"), end=Timestamp("2022-04-22 12:12:14"), @@ -948,12 +1043,12 @@ def test_dataset_move( ), pytest.param( [ - BaseFile( + DummyFile( path=Path("cool"), begin=Timestamp("2022-04-22 12:12:12"), end=Timestamp("2022-04-22 12:12:13"), ), - BaseFile( + DummyFile( path=Path("fun"), begin=Timestamp("2022-04-22 12:12:13"), end=Timestamp("2022-04-22 12:12:14"), @@ -983,12 +1078,12 @@ def test_dataset_move( ) def test_base_dataset_file_mode( tmp_path: pytest.fixture, - files: list[BaseFile], + files: list[DummyFile], mode: Literal["files", "timedelta_total"], data_duration: Timedelta | None, expected_data: list[tuple[Event, str]], ) -> None: - ds = BaseDataset.from_files( + ds = DummyDataset.from_files( files=files, mode=mode, data_duration=data_duration, @@ -1007,7 +1102,7 @@ def test_base_dataset_file_mode( [ pytest.param( [ - BaseFile( + DummyFile( path=Path("cool"), begin=Timestamp("2022-04-22 12:12:12"), end=Timestamp("2022-04-22 12:12:13"), @@ -1023,7 +1118,7 @@ def test_base_dataset_file_mode( ), pytest.param( [ - BaseFile( + DummyFile( path=Path("cool"), begin=Timestamp("2022-04-22 12:12:12"), end=Timestamp("2022-04-22 12:12:13"), @@ -1039,7 +1134,7 @@ def test_base_dataset_file_mode( ), pytest.param( [ - BaseFile( + DummyFile( path=Path("cool"), begin=Timestamp("2022-04-22 12:12:12"), end=Timestamp("2022-04-22 12:12:13"), @@ -1055,7 +1150,7 @@ def test_base_dataset_file_mode( ), pytest.param( [ - BaseFile( + DummyFile( path=Path("cool"), begin=Timestamp("2022-04-22 12:12:12"), end=Timestamp("2022-04-22 12:12:13"), @@ -1071,7 +1166,7 @@ def test_base_dataset_file_mode( ), pytest.param( [ - BaseFile( + DummyFile( path=Path("cool"), begin=Timestamp("2022-04-22 12:12:12"), end=Timestamp("2022-04-22 12:12:13"), @@ -1087,12 +1182,12 @@ def test_base_dataset_file_mode( ), pytest.param( [ - BaseFile( + DummyFile( path=Path("cool"), begin=Timestamp("2022-04-22 12:12:12"), end=Timestamp("2022-04-22 12:12:13"), ), - BaseFile( + DummyFile( path=Path("fun"), begin=Timestamp("2022-04-22 12:12:14"), end=Timestamp("2022-04-22 12:12:15"), @@ -1108,12 +1203,12 @@ def test_base_dataset_file_mode( ), pytest.param( [ - BaseFile( + DummyFile( path=Path("cool"), begin=Timestamp("2022-04-22 12:12:12"), end=Timestamp("2022-04-22 12:12:13"), ), - BaseFile( + DummyFile( path=Path("fun"), begin=Timestamp("2022-04-22 12:12:14"), end=Timestamp("2022-04-22 12:12:15"), @@ -1129,12 +1224,12 @@ def test_base_dataset_file_mode( ), pytest.param( [ - BaseFile( + DummyFile( path=Path("cool"), begin=Timestamp("2022-04-22 12:12:12"), end=Timestamp("2022-04-22 12:12:13"), ), - BaseFile( + DummyFile( path=Path("fun"), begin=Timestamp("2022-04-22 12:12:14"), end=Timestamp("2022-04-22 12:12:15"), @@ -1152,12 +1247,12 @@ def test_base_dataset_file_mode( ) def test_base_data_boundaries( monkeypatch: pytest.fixture, - files: list[BaseFile], + files: list[DummyFile], begin: Timestamp, end: Timestamp, expected_data: Event, ) -> None: - data = BaseData.from_files(files=files) + data = DummyData.from_files(files=files) if begin: data.begin = begin if end: @@ -1165,14 +1260,14 @@ def test_base_data_boundaries( assert data.begin == expected_data.begin assert data.end == expected_data.end - def mocked_get_value(self: BaseData) -> None: + def mocked_get_value(self: DummyData) -> None: for item in data.items: if item.is_empty: continue assert item.file.begin <= item.begin assert item.file.end >= item.end - monkeypatch.setattr(BaseData, "get_value", mocked_get_value) + monkeypatch.setattr(DummyData, "get_value", mocked_get_value) data.get_value() @@ -1181,9 +1276,9 @@ def mocked_get_value(self: BaseData) -> None: ("data1", "data2", "expected"), [ pytest.param( - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1192,9 +1287,9 @@ def mocked_get_value(self: BaseData) -> None: begin=None, end=None, ), - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1207,9 +1302,9 @@ def mocked_get_value(self: BaseData) -> None: id="same_one_full_file", ), pytest.param( - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1218,9 +1313,9 @@ def mocked_get_value(self: BaseData) -> None: begin=None, end=None, ), - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1233,9 +1328,9 @@ def mocked_get_value(self: BaseData) -> None: id="same_one_full_file_explicit_timestamps", ), pytest.param( - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1244,9 +1339,9 @@ def mocked_get_value(self: BaseData) -> None: begin=None, end=None, ), - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1259,9 +1354,9 @@ def mocked_get_value(self: BaseData) -> None: id="different_begin", ), pytest.param( - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1270,9 +1365,9 @@ def mocked_get_value(self: BaseData) -> None: begin=None, end=None, ), - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1285,9 +1380,9 @@ def mocked_get_value(self: BaseData) -> None: id="different_end", ), pytest.param( - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1296,9 +1391,9 @@ def mocked_get_value(self: BaseData) -> None: begin=None, end=None, ), - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1311,9 +1406,9 @@ def mocked_get_value(self: BaseData) -> None: id="different_begin_and_end", ), pytest.param( - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "beach", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1322,9 +1417,9 @@ def mocked_get_value(self: BaseData) -> None: begin=None, end=None, ), - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "house", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1337,14 +1432,14 @@ def mocked_get_value(self: BaseData) -> None: id="different_file", ), pytest.param( - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "beach", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), ), - BaseFile( + DummyFile( "house", begin=Timestamp("2015-08-28 12:12:14"), end=Timestamp("2015-08-28 12:13:15"), @@ -1353,14 +1448,14 @@ def mocked_get_value(self: BaseData) -> None: begin=None, end=None, ), - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "beach", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), ), - BaseFile( + DummyFile( "house", begin=Timestamp("2015-08-28 12:12:14"), end=Timestamp("2015-08-28 12:13:15"), @@ -1373,14 +1468,14 @@ def mocked_get_value(self: BaseData) -> None: id="same_two_files", ), pytest.param( - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "beach", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), ), - BaseFile( + DummyFile( "house", begin=Timestamp("2015-08-28 12:12:14"), end=Timestamp("2015-08-28 12:13:15"), @@ -1389,14 +1484,14 @@ def mocked_get_value(self: BaseData) -> None: begin=None, end=None, ), - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), ), - BaseFile( + DummyFile( "house", begin=Timestamp("2015-08-28 12:12:14"), end=Timestamp("2015-08-28 12:13:15"), @@ -1410,7 +1505,7 @@ def mocked_get_value(self: BaseData) -> None: ), ], ) -def test_base_data_equality(data1: BaseData, data2: BaseData, expected: bool) -> None: +def test_base_data_equality(data1: DummyData, data2: DummyData, expected: bool) -> None: assert (data1 == data2) == expected @@ -1418,9 +1513,9 @@ def test_base_data_equality(data1: BaseData, data2: BaseData, expected: bool) -> ("data", "name", "expected"), [ pytest.param( - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1434,14 +1529,14 @@ def test_base_data_equality(data1: BaseData, data2: BaseData, expected: bool) -> id="default_to_data_begin", ), pytest.param( - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "beach", begin=Timestamp("2015-08-28 12:13:12"), end=Timestamp("2015-08-28 12:14:12"), ), - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1457,9 +1552,9 @@ def test_base_data_equality(data1: BaseData, data2: BaseData, expected: bool) -> id="default_to_data_begin_with_unordered_files", ), pytest.param( - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1471,9 +1566,9 @@ def test_base_data_equality(data1: BaseData, data2: BaseData, expected: bool) -> id="given_name", ), pytest.param( - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1486,9 +1581,9 @@ def test_base_data_equality(data1: BaseData, data2: BaseData, expected: bool) -> id="given_name_over_existing_name", ), pytest.param( - BaseData.from_files( + DummyData.from_files( [ - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1504,7 +1599,7 @@ def test_base_data_equality(data1: BaseData, data2: BaseData, expected: bool) -> ), ], ) -def test_data_name(data: BaseData, name: str | None, expected: str) -> None: +def test_data_name(data: DummyData, name: str | None, expected: str) -> None: data.name = name assert data.name == expected assert str(data) == expected @@ -1515,7 +1610,7 @@ def test_data_name(data: BaseData, name: str | None, expected: str) -> None: [ pytest.param( [ - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1550,12 +1645,12 @@ def test_data_name(data: BaseData, name: str | None, expected: str) -> None: ), pytest.param( [ - BaseFile( + DummyFile( "depression", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), ), - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:13:12"), end=Timestamp("2015-08-28 12:14:12"), @@ -1590,12 +1685,12 @@ def test_data_name(data: BaseData, name: str | None, expected: str) -> None: ), pytest.param( [ - BaseFile( + DummyFile( "depression", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), ), - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:13:12"), end=Timestamp("2015-08-28 12:14:12"), @@ -1653,12 +1748,12 @@ def test_data_name(data: BaseData, name: str | None, expected: str) -> None: ), pytest.param( [ - BaseFile( + DummyFile( "depression", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), ), - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:13:12"), end=Timestamp("2015-08-28 12:14:12"), @@ -1716,7 +1811,7 @@ def test_data_name(data: BaseData, name: str | None, expected: str) -> None: ), pytest.param( [ - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), @@ -1751,12 +1846,12 @@ def test_data_name(data: BaseData, name: str | None, expected: str) -> None: ), pytest.param( [ - BaseFile( + DummyFile( "depression", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), ), - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:13:12"), end=Timestamp("2015-08-28 12:14:12"), @@ -1791,12 +1886,12 @@ def test_data_name(data: BaseData, name: str | None, expected: str) -> None: ), pytest.param( [ - BaseFile( + DummyFile( "depression", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:12"), ), - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:13:12"), end=Timestamp("2015-08-28 12:14:12"), @@ -1854,12 +1949,12 @@ def test_data_name(data: BaseData, name: str | None, expected: str) -> None: ), pytest.param( [ - BaseFile( + DummyFile( "depression", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:22"), ), - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:13:32"), end=Timestamp("2015-08-28 12:14:12"), @@ -1908,12 +2003,12 @@ def test_data_name(data: BaseData, name: str | None, expected: str) -> None: ), pytest.param( [ - BaseFile( + DummyFile( "depression", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:22"), ), - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:13:32"), end=Timestamp("2015-08-28 12:14:12"), @@ -1962,12 +2057,12 @@ def test_data_name(data: BaseData, name: str | None, expected: str) -> None: ), pytest.param( [ - BaseFile( + DummyFile( "depression", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:02"), ), - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:13:22"), end=Timestamp("2015-08-28 12:14:12"), @@ -2016,12 +2111,12 @@ def test_data_name(data: BaseData, name: str | None, expected: str) -> None: ), pytest.param( [ - BaseFile( + DummyFile( "depression", begin=Timestamp("2015-08-28 12:12:12"), end=Timestamp("2015-08-28 12:13:02"), ), - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:13:22"), end=Timestamp("2015-08-28 12:14:12"), @@ -2070,12 +2165,12 @@ def test_data_name(data: BaseData, name: str | None, expected: str) -> None: ), pytest.param( [ - BaseFile( + DummyFile( "depression", begin=Timestamp("2015-08-28 12:12:00"), end=Timestamp("2015-08-28 12:13:02"), ), - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:13:22"), end=Timestamp("2015-08-28 12:14:12"), @@ -2124,12 +2219,12 @@ def test_data_name(data: BaseData, name: str | None, expected: str) -> None: ), pytest.param( [ - BaseFile( + DummyFile( "depression", begin=Timestamp("2015-08-28 12:12:00"), end=Timestamp("2015-08-28 12:13:02"), ), - BaseFile( + DummyFile( "cherry", begin=Timestamp("2015-08-28 12:13:22"), end=Timestamp("2015-08-28 12:14:12"), @@ -2195,7 +2290,7 @@ def test_data_name(data: BaseData, name: str | None, expected: str) -> None: ], ) def test_get_base_data_from_files( - files: list[BaseFile], + files: list[DummyFile], begin: Timestamp, end: Timestamp, data_duration: Timedelta, @@ -2203,7 +2298,7 @@ def test_get_base_data_from_files( overlap: float, expected: list[list[tuple[Event, str | None]]], ) -> None: - data = BaseDataset.from_files( + data = DummyDataset.from_files( files=files, begin=begin, end=end, diff --git a/tests/test_files.py b/tests/test_files.py index ee8072ca..c88da641 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -6,8 +6,7 @@ import pytz from pandas import Timestamp -from osekit.core_api.base_dataset import BaseDataset -from osekit.core_api.base_file import BaseFile +from tests.test_core_api_base import DummyDataset, DummyFile @pytest.mark.parametrize( @@ -82,7 +81,7 @@ def test_file_localization( timezone: str | pytz.timezone | None, expected_begin: Timestamp, ) -> None: - file = BaseFile( + file = DummyFile( path=Path(file_name), strptime_format=strptime_format, timezone=timezone, @@ -169,13 +168,12 @@ def test_dataset_localization( expected_begins: list[Timestamp], ) -> None: for file in file_names: - (tmp_path / f"{file}.foo").touch() + (tmp_path / f"{file}.wav").touch() - dataset = BaseDataset.from_folder( + dataset = DummyDataset.from_folder( tmp_path, strptime_format=strptime_format, timezone=timezone, - supported_file_extensions=[".foo"], ) assert all(