Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 156 additions & 83 deletions src/osekit/core_api/audio_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from __future__ import annotations

from math import ceil
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Self

import numpy as np
import soundfile as sf
Expand All @@ -30,11 +30,14 @@ class AudioData(BaseData[AudioItem, AudioFile]):
The data is accessed via an AudioItem object per AudioFile.
"""

item_cls = AudioItem

def __init__(
self,
items: list[AudioItem] | None = None,
begin: Timestamp | None = None,
end: Timestamp | None = None,
name: str | None = None,
sample_rate: int | None = None,
instrument: Instrument | None = None,
normalization: Normalization = Normalization.RAW,
Expand All @@ -54,14 +57,16 @@ def __init__(
end: Timestamp | None
Only effective if items is None.
Set the end of the empty data.
name: str | None
Name of the exported files.
instrument: Instrument | None
Instrument that might be used to obtain acoustic pressure from
the wav audio data.
normalization: Normalization
The type of normalization to apply to the audio data.

"""
super().__init__(items=items, begin=begin, end=end)
super().__init__(items=items, begin=begin, end=end, name=name)
self._set_sample_rate(sample_rate=sample_rate)
self.instrument = instrument
self.normalization = normalization
Expand Down Expand Up @@ -115,7 +120,61 @@ def normalization_values(self, value: dict | None) -> None:
}
)

@classmethod
def _make_item(
cls,
file: AudioFile | None = None,
begin: Timestamp | None = None,
end: Timestamp | None = None,
) -> AudioItem:
"""Make an AudioItem for a given AudioFile between begin and end timestamps.

Parameters
----------
file: AudioFile
AudioFile of the item.
begin: Timestamp
Begin of the item.
end:
End of the item.

Returns
-------
An AudioItem for the AudioFile file, between the begin and end timestamps.

"""
return AudioItem(file=file, begin=begin, end=end)

@classmethod
def _make_file(cls, path: Path, begin: Timestamp) -> AudioFile:
"""Make an AudioFile from a path and a begin timestamp.

Parameters
----------
path: Path
Path to the file.
begin: Timestamp
Begin of the file.

Returns
-------
AudioFile:
The audio file.

"""
return AudioFile(path=path, begin=begin)

def get_normalization_values(self) -> dict:
"""Return the values used for normalizing the audio data.

Returns
-------
dict:
"mean": mean value to substract to center values on 0.
"peak": peak value for PEAK normalization
"std": standard deviation used for z-score normalization

"""
values = np.array(self.get_raw_value())
return {
"mean": values.mean(),
Expand Down Expand Up @@ -205,6 +264,7 @@ def write(
self,
folder: Path,
subtype: str | None = None,
*,
link: bool = False,
) -> None:
"""Write the audio data to file.
Expand Down Expand Up @@ -270,7 +330,7 @@ def split(
nb_subdata: int = 2,
*,
pass_normalization: bool = True,
) -> list[AudioData]:
) -> list[Self]:
"""Split the audio data object in the specified number of audio subdata.

Parameters
Expand All @@ -296,16 +356,47 @@ def split(
if any(self.normalization_values.values())
else self.get_normalization_values()
)
return [
AudioData.from_base_data(
data=base_data,
sample_rate=self.sample_rate,
instrument=self.instrument,
normalization=self.normalization,
normalization_values=normalization_values,
)
for base_data in super().split(nb_subdata)
]
return super().split(
nb_subdata=nb_subdata,
normalization_values=normalization_values,
)

def _make_split_data(
self,
files: list[AudioFile],
begin: Timestamp,
end: Timestamp,
**kwargs: tuple[float, float, float],
) -> AudioData:
"""Return an AudioData object after an AudioData.split() call.

Parameters
----------
files: list[AudioFile]
The AudioFiles of the original AudioData.
begin: Timestamp
The begin timestamp of the split AudioData.
end: Timestamp
The end timestamp of the split AudioData.
kwargs:
normalization_values: tuple[float, float, float]
Values used for normalizing the split AudioData.

Returns
-------
AudioData:
The AudioData instance.

"""
return AudioData.from_files(
files=files,
begin=begin,
end=end,
sample_rate=self.sample_rate,
instrument=self.instrument,
normalization=self.normalization,
normalization_values=kwargs["normalization_values"],
)

def split_frames(
self,
Expand Down Expand Up @@ -335,9 +426,11 @@ def split_frames(

"""
if start_frame < 0:
raise ValueError("Start_frame must be greater than or equal to 0.")
msg = "Start_frame must be greater than or equal to 0."
raise ValueError(msg)
if stop_frame < -1 or stop_frame > self.length:
raise ValueError("Stop_frame must be lower than the length of the data.")
msg = "Stop_frame must be lower than the length of the data."
raise ValueError(msg)

start_timestamp = self.begin + Timedelta(
seconds=ceil(start_frame / self.sample_rate * 1e9) / 1e9,
Expand Down Expand Up @@ -390,46 +483,63 @@ def to_dict(self) -> dict:
)

@classmethod
def from_dict(cls, dictionary: dict) -> AudioData:
"""Deserialize an AudioData from a dictionary.
def _from_base_dict(
cls,
dictionary: dict,
files: list[AudioFile],
begin: Timestamp,
end: Timestamp,
**kwargs, # noqa: ANN003
) -> AudioData:
"""Deserialize the AudioData-specific parts of a Data dictionary.

This method is called within the BaseData.from_dict() method, which
deserializes the base files, begin and end parameters.

Parameters
----------
dictionary: dict
The serialized dictionary representing the AudioData.
files: list[AudioFile]
The list of deserialized AudioFiles.
begin: Timestamp
The deserialized begin timestamp.
end: Timestamp
The deserialized end timestamp.
kwargs:
None.

Returns
-------
AudioData
The deserialized AudioData.

"""
base_data = BaseData.from_dict(dictionary)
instrument = (
None
if dictionary["instrument"] is None
else Instrument.from_dict(dictionary["instrument"])
)
return cls.from_base_data(
data=base_data,
return cls.from_files(
files=files,
begin=begin,
end=end,
instrument=instrument,
sample_rate=dictionary["sample_rate"],
normalization=Normalization(dictionary["normalization"]),
normalization_values=dictionary["normalization_values"],
instrument=instrument,
)

@classmethod
def from_files(
cls,
files: list[AudioFile],
files: list[AudioFile], # The method is redefined just to specify the type
begin: Timestamp | None = None,
end: Timestamp | None = None,
sample_rate: float | None = None,
instrument: Instrument | None = None,
normalization: Normalization = Normalization.RAW,
normalization_values: dict | None = None,
name: str | None = None,
**kwargs, # noqa: ANN003
) -> AudioData:
"""Return an AudioData object from a list of AudioFiles.
"""Return a, AudioData object from a list of AudioFiles.

Parameters
----------
Expand All @@ -441,65 +551,28 @@ def from_files(
end: Timestamp | None
End of the data object.
Defaulted to the end of the last file.
sample_rate: float | None
Sample rate of the AudioData.
instrument: Instrument | None
Instrument that might be used to obtain acoustic pressure from
the wav audio data.
normalization: Normalization
The type of normalization to apply to the audio data.
normalization_values: dict|None
Mean, peak and std values with which to normalize the data.

Returns
-------
AudioData:
The AudioData object.

"""
return cls.from_base_data(
data=BaseData.from_files(files, begin, end),
sample_rate=sample_rate,
instrument=instrument,
normalization=normalization,
normalization_values=normalization_values,
)

@classmethod
def from_base_data(
cls,
data: BaseData,
sample_rate: float | None = None,
instrument: Instrument | None = None,
normalization: Normalization = Normalization.RAW,
normalization_values: dict | None = None,
) -> AudioData:
"""Return an AudioData object from a BaseData object.

Parameters
----------
data: BaseData
BaseData object to convert to AudioData.
sample_rate: float | None
Sample rate of the AudioData.
instrument: Instrument | None
Instrument that might be used to obtain acoustic pressure from
the wav audio data.
normalization: Normalization
The type of normalization to apply to the audio data.
normalization_values: dict|None
Mean, peak and std values with which to normalize the data.
name: str | None
Name of the exported files.
kwargs
Keyword arguments that are passed to the cls constructor.
sample_rate: int
The sample rate of the audio data.
instrument: Instrument | None
Instrument that might be used to obtain acoustic pressure from
the wav audio data.
normalization: Normalization
The type of normalization to apply to the audio data.

Returns
-------
AudioData:
The AudioData object.
Self:
The AudioData object.

"""
return cls(
items=[AudioItem.from_base_item(item) for item in data.items],
sample_rate=sample_rate,
instrument=instrument,
normalization=normalization,
normalization_values=normalization_values,
return super().from_files(
files=files, # This way, this static error doesn't appear to the user
begin=begin,
end=end,
name=name,
**kwargs,
)
Loading