Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions dlclive/pose_estimation_pytorch/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Literal
import warnings

import numpy as np
import torch
Expand Down Expand Up @@ -131,15 +132,25 @@ def __init__(
path: str | Path,
device: str = "auto",
precision: Literal["FP16", "FP32"] = "FP32",
single_animal: bool = True,
single_animal: bool | None = None,
dynamic: dict | dynamic_cropping.DynamicCropper | None = None,
top_down_config: dict | TopDownConfig | None = None,
) -> None:
super().__init__(path)
self.device = _parse_device(device)
self.precision = precision
if single_animal is not None:
warnings.warn(
"The `single_animal` parameter is deprecated and will be removed "
"in a future version. The number of individuals will be automaticalliy inferred "
"from the model configuration. Remove argument `single_animal` or set "
"`single_animal=None` to accept the inferred value and silence this warning.",
DeprecationWarning,
stacklevel=2,
)
self.single_animal = single_animal

self.n_individuals = None
self.n_bodyparts = None
self.cfg = None
self.detector = None
self.model = None
Expand Down Expand Up @@ -191,9 +202,14 @@ def get_pose(self, frame: np.ndarray) -> np.ndarray:

frame_batch, offsets_and_scales = self._prepare_top_down(tensor, detections)
if len(frame_batch) == 0:
offsets_and_scales = [(0, 0), 1]
else:
tensor = frame_batch # still CHW, batched
zero_pose = (
np.zeros((self.n_bodyparts, 3))
if self.n_individuals < 2 else
np.zeros((self.n_individuals, self.n_bodyparts, 3))
)
return zero_pose

tensor = frame_batch # still CHW, batched

if self.dynamic is not None:
tensor = self.dynamic.crop(tensor)
Expand Down Expand Up @@ -260,6 +276,15 @@ def load_model(self) -> None:
raw_data = torch.load(self.path, map_location="cpu", weights_only=True)

self.cfg = raw_data["config"]

# Infer single animal mode and n_bodyparts from model configuration
individuals = self.cfg.get("metadata", {}).get("individuals", ['idv1'])
bodyparts = self.cfg.get("metadata", {}).get("bodyparts", [])
self.n_individuals = len(individuals)
self.n_bodyparts = len(bodyparts)
if self.single_animal is None:
self.single_animal = self.n_individuals == 1

self.model = models.PoseModel.build(self.cfg["model"])
self.model.load_state_dict(raw_data["pose"])
self.model = self.model.to(self.device)
Expand Down