Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dlclive/modelzoo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
load_super_animal_config,
download_super_animal_snapshot,
)
from dlclive.modelzoo.pytorch_model_zoo_export import export_modelzoo_model
from dlclive.modelzoo.pytorch_model_zoo_export import export_modelzoo_model
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ train_settings:
dataloader_workers: 0
dataloader_pin_memory: false
display_iters: 500
epochs: 250
epochs: 250
2 changes: 1 addition & 1 deletion dlclive/modelzoo/model_configs/ssdlite.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,4 @@ train_settings:
dataloader_workers: 0
dataloader_pin_memory: false
display_iters: 500
epochs: 250
epochs: 250
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,4 @@ corner2move2:
move2corner:

# Conversion tables to fine-tune SuperAnimal weights
SuperAnimalConversionTables:
SuperAnimalConversionTables:
37 changes: 27 additions & 10 deletions dlclive/modelzoo/pytorch_model_zoo_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

import torch

from dlclive.modelzoo.utils import load_super_animal_config, download_super_animal_snapshot
from dlclive.modelzoo.utils import (
load_super_animal_config,
download_super_animal_snapshot,
)


def export_modelzoo_model(
Expand All @@ -24,12 +27,14 @@ def export_modelzoo_model(
export_path: Arbitrary destination path for the exported .pt file.
super_animal: Super animal dataset name (e.g. "superanimal_quadruped").
model_name: Pose model architecture name (e.g. "resnet_50").
detector_name: Optional detector model name. If provided, detector
detector_name: Detector model name for top-down models. If provided, detector
weights are included in the export.
"""
Path(export_path).parent.mkdir(parents=True, exist_ok=True)
if Path(export_path).exists():
warnings.warn(f"Export path {export_path} already exists, skipping export", UserWarning)
warnings.warn(
f"Export path {export_path} already exists, skipping export", UserWarning
)
return

model_cfg = load_super_animal_config(
Expand All @@ -38,28 +43,40 @@ def export_modelzoo_model(
detector_name=detector_name,
)

def _load_model_weights(model_name: str, super_animal: str = super_animal) -> OrderedDict:
def _load_model_weights(
model_name: str, super_animal: str = super_animal
) -> OrderedDict:
"""Download the model weights from huggingface and load them in torch state dict"""
checkpoint: Path = download_super_animal_snapshot(dataset=super_animal, model_name=model_name)
checkpoint: Path = download_super_animal_snapshot(
dataset=super_animal, model_name=model_name
)
return torch.load(checkpoint, map_location="cpu", weights_only=True)["model"]


# Skip downloading the detector weights for humanbody models, as they are not on huggingface
skip_detector_download = (detector_name is None) or (
super_animal == "superanimal_humanbody"
)
Comment on lines +54 to +58
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems not great to hardcode. idk what if it eventually is on huggingface? what if there are other detector names that can't be reached by the load model weights function? also silently changing explicitly specified functionality is bad to do - i told you to use detector_name, but there is some hidden interaction with my model choice, and how would I even know?

just throw an exception in _load_model_weights if the model weights can't be loaded

export_dict = {
"config": model_cfg,
"pose": _load_model_weights(model_name),
"detector": _load_model_weights(detector_name) if detector_name is not None else None,
"detector": None
if skip_detector_download
else _load_model_weights(detector_name),
}
torch.save(export_dict, export_path)


if __name__ == "__main__":
"""Example usage"""
"""Example usage"""
from utils import _MODELZOO_PATH

model_name = "resnet_50"
super_animal = "superanimal_quadruped"

export_modelzoo_model(
export_path=_MODELZOO_PATH / 'exported_models' / f'exported_{super_animal}_{model_name}.pt',
export_path=_MODELZOO_PATH
/ "exported_models"
/ f"exported_{super_animal}_{model_name}.pt",
super_animal=super_animal,
model_name=model_name,
)
42 changes: 34 additions & 8 deletions dlclive/modelzoo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,15 @@
from pathlib import Path

from dlclibrary.dlcmodelzoo.modelzoo_download import download_huggingface_model
from dlclibrary.dlcmodelzoo.modelzoo_download import _load_model_names as huggingface_model_paths
from dlclibrary.dlcmodelzoo.modelzoo_download import (
_load_model_names as huggingface_model_paths,
)
from ruamel.yaml import YAML

from dlclive.modelzoo.resolve_config import update_config
from dlclive.pose_estimation_pytorch.models.detectors.torchvision import (
SUPPORTED_TORCHVISION_DETECTORS,
)
Comment on lines +18 to +20
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Importing SUPPORTED_TORCHVISION_DETECTORS from the torchvision detector module makes dlclive.modelzoo.utils import-time dependent on torch/torchvision. If torchvision is an optional dependency, this can cause config/modelzoo utilities to fail to import even when the user isn’t using detectors. Consider moving the allowlist to a lightweight location (e.g., modelzoo module constants) or doing a local/lazy import inside _get_torchvision_detector_config with a clear error if torchvision isn’t available.

Suggested change
from dlclive.pose_estimation_pytorch.models.detectors.torchvision import (
SUPPORTED_TORCHVISION_DETECTORS,
)
try:
from dlclive.pose_estimation_pytorch.models.detectors.torchvision import (
SUPPORTED_TORCHVISION_DETECTORS,
)
except ImportError: # torchvision (and its dependencies) is optional
# Fallback to an empty collection so this module can be imported even when
# torchvision is not installed. Code that relies on detectors should
# handle the absence of supported detectors appropriately.
SUPPORTED_TORCHVISION_DETECTORS = ()

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bad suggestion IMO. utils modules are junk drawers. _get_torchvision_detector_config is a torch-specific function, and it should exist in a torch-specific subpackage or module. don't listen to the LLMs which will fallback your code into a verbose oblivion - just put the code in the place it should be.


_MODELZOO_PATH = Path(__file__).parent

Expand Down Expand Up @@ -96,6 +101,25 @@ def add_metadata(
return config


def _get_torchvision_detector_config(detector_name: str) -> dict:
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_get_torchvision_detector_config is annotated as taking detector_name: str, but it explicitly handles None and is called with detector_name that may be None. Update the type signature to accept str | None (and/or adjust the call sites) so type checkers and IDEs reflect the actual contract.

Suggested change
def _get_torchvision_detector_config(detector_name: str) -> dict:
def _get_torchvision_detector_config(detector_name: str | None) -> dict:

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why this function exists - we're just putting a string in a hardcoded dictionary? the main thing it does is throw exceptions, but the exception should be thrown in whatever consumes this dict. this is way way far on the LBYL side of look before you leap vs. easier to ask forgiveness - we are checking if a value is valid in some utility method very distant from where it's used. if this is needed, it sort of screams that whatever super animal is needs a proper data model - if super animal requires a detector name, there should be a data model that requires a detector name, not in some random utility function somewhere!

"""Get a torchvision detector configuration for the superanimal humanbody model"""
if detector_name is None:
raise ValueError(
f"Detector name is required for superanimal humanbody models. Must be one of {SUPPORTED_TORCHVISION_DETECTORS}."
)
if detector_name not in SUPPORTED_TORCHVISION_DETECTORS:
raise ValueError(
f"Unsupported humanbody detector {detector_name}. Should be one of {SUPPORTED_TORCHVISION_DETECTORS}"
)
return {
"type": "TorchvisionDetectorAdaptor",
"model": detector_name,
"weights": "COCO_V1",
"num_classes": None,
"box_score_thresh": 0.6,
}


# NOTE - DUPLICATED @deruyter92 2026-01-23: Copied from the original DeepLabCut codebase
# from deeplabcut/pose_estimation_pytorch/modelzoo/utils.py
def load_super_animal_config(
Expand Down Expand Up @@ -125,16 +149,18 @@ def load_super_animal_config(
model_config = add_metadata(project_config, model_config)
model_config = update_config(model_config, max_individuals, device)

if detector_name is None and super_animal != "superanimal_humanbody":
if detector_name is None:
model_config["method"] = "BU"
else:
model_config["method"] = "TD"
if super_animal != "superanimal_humanbody":
detector_cfg_path = get_super_animal_model_config_path(
model_name=detector_name
)
detector_cfg = read_config_as_dict(detector_cfg_path)
model_config["detector"] = detector_cfg
detector_cfg_path = get_super_animal_model_config_path(model_name=detector_name)
detector_cfg = read_config_as_dict(detector_cfg_path)
model_config["detector"] = detector_cfg

if super_animal == "superanimal_humanbody":
# Raises ValueError if Detector name is not one of SUPPORTED_TORCHVISION_DETECTORS
torchvision_detector_config = _get_torchvision_detector_config(detector_name)
model_config["detector"]["model"] = torchvision_detector_config
Comment on lines +152 to +163
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

soooo remember what we were saying about how copying and pasting code is bad because then you have to ensure that it never changes or you always need to remember to update in both places? this is one of those times. this function should have been located in some shared root package (dlclibrary!!!!) and imported in both in the first place, but if we are sprouting divergent functionality, need to make that clear: "I do this all the time in normal deeplabcut, how come it raises an error here!"

this is also strangely conflicting with the other place where we silence rather than raise an error - https://github.com/DeepLabCut/DeepLabCut-live/pull/164/changes#r2913689689

so we are handling the same thing twice in opposing directions.

return model_config


Expand Down
3 changes: 3 additions & 0 deletions dlclive/pose_estimation_pytorch/models/detectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,8 @@
DETECTORS,
BaseDetector,
)
from dlclive.pose_estimation_pytorch.models.detectors.torchvision import (
TorchvisionDetectorAdaptor,
)
from dlclive.pose_estimation_pytorch.models.detectors.fasterRCNN import FasterRCNN
from dlclive.pose_estimation_pytorch.models.detectors.ssd import SSDLite
13 changes: 10 additions & 3 deletions dlclive/pose_estimation_pytorch/models/detectors/torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,21 @@
# Licensed under GNU Lesser General Public License v3.0
#
"""Module to adapt torchvision detectors for DeepLabCut"""

from __future__ import annotations

import torch
import torchvision.models.detection as detection

from dlclive.pose_estimation_pytorch.models.detectors.base import BaseDetector
from dlclive.pose_estimation_pytorch.models.detectors.base import (
DETECTORS,
BaseDetector,
)

SUPPORTED_TORCHVISION_DETECTORS = ["fasterrcnn_mobilenet_v3_large_fpn"]
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The class docstring lists fasterrcnn_resnet50_fpn_v2 as supported, but SUPPORTED_TORCHVISION_DETECTORS only includes fasterrcnn_mobilenet_v3_large_fpn. This inconsistency will confuse users (and currently the modelzoo validation rejects the resnet50 variant). Either update the docstring to match the allowlist, or expand SUPPORTED_TORCHVISION_DETECTORS and ensure the humanbody path supports that detector.

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or, better suggestion, don't list the models in the docstring and just say "one of {variable name}"



@DETECTORS.register_module
class TorchvisionDetectorAdaptor(BaseDetector):
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The class docstring lists fasterrcnn_resnet50_fpn_v2 as supported, but SUPPORTED_TORCHVISION_DETECTORS only includes fasterrcnn_mobilenet_v3_large_fpn. This inconsistency will confuse users (and currently the modelzoo validation rejects the resnet50 variant). Either update the docstring to match the allowlist, or expand SUPPORTED_TORCHVISION_DETECTORS and ensure the humanbody path supports that detector.

Copilot uses AI. Check for mistakes.
"""An adaptor for torchvision detectors

Expand All @@ -26,8 +33,8 @@ class TorchvisionDetectorAdaptor(BaseDetector):
- fasterrcnn_mobilenet_v3_large_fpn
- fasterrcnn_resnet50_fpn_v2
Comment on lines 33 to 34
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The class docstring lists fasterrcnn_resnet50_fpn_v2 as supported, but SUPPORTED_TORCHVISION_DETECTORS only includes fasterrcnn_mobilenet_v3_large_fpn. This inconsistency will confuse users (and currently the modelzoo validation rejects the resnet50 variant). Either update the docstring to match the allowlist, or expand SUPPORTED_TORCHVISION_DETECTORS and ensure the humanbody path supports that detector.

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how many times can copilot make the same bad suggestion lmao


This class should not be used out-of-the-box. Subclasses (such as FasterRCNN or
SSDLite) should be used instead.
This class can be used directly (e.g. with pre-trained COCO weights) or through its
subclasses (FasterRCNN or SSDLite) which adapt the model for DLC's 2-class detection.
Comment on lines +36 to +37
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure why! usually base classes that shouldn't be used have some abstract methods and only serve as the external public API contract, so it doesn't seem like this was serving as an ABC before, but usually mixing levels like this is a bad call, having a "pre trained" or whatever subclass might be a good idea to discourage modifying the base class to accommodate any specific needs for this use that doesn't apply to the other subclasses and warps the contract made by the ABC.


The torchvision implementation does not allow to get both predictions and losses
with a single forward pass. Therefore, during evaluation only bounding box metrics
Expand Down
21 changes: 18 additions & 3 deletions dlclive/pose_estimation_pytorch/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,10 +268,24 @@ def load_model(self) -> None:
self.model = self.model.half()

self.detector = None
if self.dynamic is None and raw_data.get("detector") is not None:
detector_cfg = self.cfg.get("detector")
has_detector_weights = raw_data.get("detector") is not None
if detector_cfg is not None:
detector_model_cfg = detector_cfg["model"]
uses_pretrained = (
detector_model_cfg.get("pretrained", False)
or detector_model_cfg.get("weights") is not None
)
Comment on lines +271 to +278
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The runner now has new behavior: it should build a detector when weights/pretrained is set even if raw_data["detector"] is None, and it should skip load_state_dict in that case. There are modelzoo tests added, but no tests appear to cover this updated runner logic. Consider adding a focused unit test that mocks models.DETECTORS.build and verifies (1) build is invoked when weights is set, (2) load_state_dict is not invoked when no detector weights are present, and (3) load_state_dict is invoked when detector weights are present.

Copilot uses AI. Check for mistakes.
else:
uses_pretrained = False

if self.dynamic is None and (has_detector_weights or uses_pretrained):
self.detector = models.DETECTORS.build(self.cfg["detector"]["model"])
self.detector.to(self.device)
self.detector.load_state_dict(raw_data["detector"])

if has_detector_weights:
self.detector.load_state_dict(raw_data["detector"])
Comment on lines +282 to +287
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The runner now has new behavior: it should build a detector when weights/pretrained is set even if raw_data["detector"] is None, and it should skip load_state_dict in that case. There are modelzoo tests added, but no tests appear to cover this updated runner logic. Consider adding a focused unit test that mocks models.DETECTORS.build and verifies (1) build is invoked when weights is set, (2) load_state_dict is not invoked when no detector weights are present, and (3) load_state_dict is invoked when detector weights are present.

Copilot uses AI. Check for mistakes.

self.detector.eval()
if self.precision == "FP16":
self.detector = self.detector.half()
Expand All @@ -281,7 +295,8 @@ def load_model(self) -> None:
self.top_down_config.read_config(self.cfg)

detector_transforms = [v2.ToDtype(torch.float32, scale=True)]
if self.cfg["detector"]["data"]["inference"].get("normalize_images", False):
detector_data_cfg = detector_cfg.get("data", {}).get("inference", {})
if detector_data_cfg.get("normalize_images", False):
detector_transforms.append(v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
self.detector_transform = v2.Compose(detector_transforms)

Expand Down
29 changes: 29 additions & 0 deletions tests/test_modelzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,35 @@ def test_get_config_model_paths(super_animal, model_name, detector_name):
assert "detector" in model_config


def test_humanbody_requires_detector_name():
with pytest.raises(ValueError):
modelzoo.load_super_animal_config(
super_animal="superanimal_humanbody",
model_name="hrnet_w32",
detector_name=None,
)


def test_humanbody_rejects_unsupported_detector():
with pytest.raises(ValueError):
modelzoo.load_super_animal_config(
super_animal="superanimal_humanbody",
model_name="hrnet_w32",
detector_name="fasterrcnn_resnet50_fpn_v2",
)


def test_humanbody_uses_torchvision_detector_config():
model_config = modelzoo.load_super_animal_config(
super_animal="superanimal_humanbody",
model_name="hrnet_w32",
detector_name="fasterrcnn_mobilenet_v3_large_fpn",
)
detector_model_cfg = model_config["detector"]["model"]
assert model_config["method"].lower() == "td"
assert detector_model_cfg["type"] == "TorchvisionDetectorAdaptor"


def test_download_huggingface_model(tmp_path_factory, model="full_cat"):
folder = tmp_path_factory.mktemp("temp")
dlclibrary.download_huggingface_model(model, str(folder))
Expand Down