From 01a2f2e0bfbba3da521be19659dfa362797af990 Mon Sep 17 00:00:00 2001 From: Jaap de Ruyter Date: Wed, 18 Feb 2026 16:17:11 +0100 Subject: [PATCH 1/6] update TorchvisionDetectorAdaptor: register as usable model --- .../pose_estimation_pytorch/models/detectors/__init__.py | 3 +++ .../models/detectors/torchvision.py | 7 ++++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/dlclive/pose_estimation_pytorch/models/detectors/__init__.py b/dlclive/pose_estimation_pytorch/models/detectors/__init__.py index e9a99a6e..c89902ae 100644 --- a/dlclive/pose_estimation_pytorch/models/detectors/__init__.py +++ b/dlclive/pose_estimation_pytorch/models/detectors/__init__.py @@ -12,5 +12,8 @@ DETECTORS, BaseDetector, ) +from dlclive.pose_estimation_pytorch.models.detectors.torchvision import ( + TorchvisionDetectorAdaptor, +) from dlclive.pose_estimation_pytorch.models.detectors.fasterRCNN import FasterRCNN from dlclive.pose_estimation_pytorch.models.detectors.ssd import SSDLite diff --git a/dlclive/pose_estimation_pytorch/models/detectors/torchvision.py b/dlclive/pose_estimation_pytorch/models/detectors/torchvision.py index 72dd54b8..85418702 100644 --- a/dlclive/pose_estimation_pytorch/models/detectors/torchvision.py +++ b/dlclive/pose_estimation_pytorch/models/detectors/torchvision.py @@ -14,9 +14,10 @@ import torch import torchvision.models.detection as detection -from dlclive.pose_estimation_pytorch.models.detectors.base import BaseDetector +from dlclive.pose_estimation_pytorch.models.detectors.base import DETECTORS, BaseDetector +@DETECTORS.register_module class TorchvisionDetectorAdaptor(BaseDetector): """An adaptor for torchvision detectors @@ -26,8 +27,8 @@ class TorchvisionDetectorAdaptor(BaseDetector): - fasterrcnn_mobilenet_v3_large_fpn - fasterrcnn_resnet50_fpn_v2 - This class should not be used out-of-the-box. Subclasses (such as FasterRCNN or - SSDLite) should be used instead. + This class can be used directly (e.g. with pre-trained COCO weights) or through its + subclasses (FasterRCNN or SSDLite) which adapt the model for DLC's 2-class detection. The torchvision implementation does not allow to get both predictions and losses with a single forward pass. Therefore, during evaluation only bounding box metrics From 333f714fe581c674ad383360fa0e0acc423a5359 Mon Sep 17 00:00:00 2001 From: Jaap de Ruyter Date: Wed, 18 Feb 2026 16:17:53 +0100 Subject: [PATCH 2/6] update runner: consider pretrained detectors (no weights in raw_data) --- dlclive/pose_estimation_pytorch/runner.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/dlclive/pose_estimation_pytorch/runner.py b/dlclive/pose_estimation_pytorch/runner.py index 2c59605f..a22506e1 100644 --- a/dlclive/pose_estimation_pytorch/runner.py +++ b/dlclive/pose_estimation_pytorch/runner.py @@ -268,10 +268,24 @@ def load_model(self) -> None: self.model = self.model.half() self.detector = None - if self.dynamic is None and raw_data.get("detector") is not None: + detector_cfg = self.cfg.get("detector") + has_detector_weights = raw_data.get("detector") is not None + if detector_cfg is not None: + detector_model_cfg = detector_cfg["model"] + uses_pretrained = ( + detector_model_cfg.get("pretrained", False) + or detector_model_cfg.get("weights") is not None + ) + else: + uses_pretrained = False + + if self.dynamic is None and (has_detector_weights or uses_pretrained): self.detector = models.DETECTORS.build(self.cfg["detector"]["model"]) self.detector.to(self.device) - self.detector.load_state_dict(raw_data["detector"]) + + if has_detector_weights: + self.detector.load_state_dict(raw_data["detector"]) + self.detector.eval() if self.precision == "FP16": self.detector = self.detector.half() @@ -281,7 +295,8 @@ def load_model(self) -> None: self.top_down_config.read_config(self.cfg) detector_transforms = [v2.ToDtype(torch.float32, scale=True)] - if self.cfg["detector"]["data"]["inference"].get("normalize_images", False): + detector_data_cfg = detector_cfg.get("data", {}).get("inference", {}) + if detector_data_cfg.get("normalize_images", False): detector_transforms.append(v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) self.detector_transform = v2.Compose(detector_transforms) From 3a791aaa5b3df12f2cddd2472df57a054a44be08 Mon Sep 17 00:00:00 2001 From: Jaap de Ruyter Date: Wed, 18 Feb 2026 16:43:19 +0100 Subject: [PATCH 3/6] Add specific export config for torchvision detectors --- dlclive/modelzoo/pytorch_model_zoo_export.py | 4 +++- dlclive/modelzoo/utils.py | 22 ++++++++++++++----- .../models/detectors/torchvision.py | 2 ++ 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/dlclive/modelzoo/pytorch_model_zoo_export.py b/dlclive/modelzoo/pytorch_model_zoo_export.py index 616857d1..b3554c1e 100644 --- a/dlclive/modelzoo/pytorch_model_zoo_export.py +++ b/dlclive/modelzoo/pytorch_model_zoo_export.py @@ -32,10 +32,12 @@ def _load_model_weights(model_name: str, super_animal: str = super_animal) -> Or checkpoint: Path = download_super_animal_snapshot(dataset=super_animal, model_name=model_name) return torch.load(checkpoint, map_location="cpu", weights_only=True)["model"] + # Skip downloading the detector weights for humanbody models, as they are not on huggingface + skip_detector_download = (detector_name is None) or (super_animal == "superanimal_humanbody") export_dict = { "config": model_cfg, "pose": _load_model_weights(model_name), - "detector": _load_model_weights(detector_name) if detector_name is not None else None, + "detector": None if skip_detector_download else _load_model_weights(detector_name), } torch.save(export_dict, export_path) diff --git a/dlclive/modelzoo/utils.py b/dlclive/modelzoo/utils.py index 3857d141..3376fe13 100644 --- a/dlclive/modelzoo/utils.py +++ b/dlclive/modelzoo/utils.py @@ -12,6 +12,7 @@ from ruamel.yaml import YAML from dlclive.modelzoo.resolve_config import update_config +from dlclive.pose_estimation_pytorch.models.detectors.torchvision import SUPPORTED_TORCHVISION_DETECTORS _MODELZOO_PATH = Path(__file__).parent @@ -131,12 +132,21 @@ def load_super_animal_config( model_config["method"] = "BU" else: model_config["method"] = "TD" - if super_animal != "superanimal_humanbody": - detector_cfg_path = get_super_animal_model_config_path( - model_name=detector_name - ) - detector_cfg = read_config_as_dict(detector_cfg_path) - model_config["detector"] = detector_cfg + detector_cfg_path = get_super_animal_model_config_path( + model_name=detector_name + ) + detector_cfg = read_config_as_dict(detector_cfg_path) + model_config["detector"] = detector_cfg + if super_animal == "superanimal_humanbody": + # Apply specific updates required to run the torchvision detector with pretrained weights + assert detector_name in SUPPORTED_TORCHVISION_DETECTORS + model_config["detector"]['model']= { + "type": "TorchvisionDetectorAdaptor", + "model": detector_name, + "weights": "COCO_V1", + "num_classes": None, + "box_score_thresh": 0.6, + } return model_config diff --git a/dlclive/pose_estimation_pytorch/models/detectors/torchvision.py b/dlclive/pose_estimation_pytorch/models/detectors/torchvision.py index 85418702..8790a1d7 100644 --- a/dlclive/pose_estimation_pytorch/models/detectors/torchvision.py +++ b/dlclive/pose_estimation_pytorch/models/detectors/torchvision.py @@ -16,6 +16,8 @@ from dlclive.pose_estimation_pytorch.models.detectors.base import DETECTORS, BaseDetector +SUPPORTED_TORCHVISION_DETECTORS = ["fasterrcnn_mobilenet_v3_large_fpn"] + @DETECTORS.register_module class TorchvisionDetectorAdaptor(BaseDetector): From 9ee97685b1ff5f2adf247597e1d27acd78b94f6b Mon Sep 17 00:00:00 2001 From: Jaap de Ruyter van Steveninck <32810691+deruyter92@users.noreply.github.com> Date: Tue, 10 Mar 2026 10:53:52 +0100 Subject: [PATCH 4/6] refactor modelzoo export: more explicit handling of torchvision detector case --- dlclive/modelzoo/pytorch_model_zoo_export.py | 2 +- dlclive/modelzoo/utils.py | 32 +++++++++++++------- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/dlclive/modelzoo/pytorch_model_zoo_export.py b/dlclive/modelzoo/pytorch_model_zoo_export.py index 4f97f0da..d5e7fab7 100644 --- a/dlclive/modelzoo/pytorch_model_zoo_export.py +++ b/dlclive/modelzoo/pytorch_model_zoo_export.py @@ -24,7 +24,7 @@ def export_modelzoo_model( export_path: Arbitrary destination path for the exported .pt file. super_animal: Super animal dataset name (e.g. "superanimal_quadruped"). model_name: Pose model architecture name (e.g. "resnet_50"). - detector_name: Optional detector model name. If provided, detector + detector_name: Detector model name for top-down models. If provided, detector weights are included in the export. """ Path(export_path).parent.mkdir(parents=True, exist_ok=True) diff --git a/dlclive/modelzoo/utils.py b/dlclive/modelzoo/utils.py index 22cef5a0..93985619 100644 --- a/dlclive/modelzoo/utils.py +++ b/dlclive/modelzoo/utils.py @@ -97,6 +97,21 @@ def add_metadata( return config +def _get_torchvision_detector_config(detector_name: str) -> dict: + """Get a torchvision detector configuration for the superanimal humanbody model""" + if detector_name is None: + raise ValueError(f"Detector name is required for superanimal humanbody models. Must be one of {SUPPORTED_TORCHVISION_DETECTORS}.") + if detector_name not in SUPPORTED_TORCHVISION_DETECTORS: + raise ValueError(f"Unsupported humanbody detector {detector_name}. Should be one of {SUPPORTED_TORCHVISION_DETECTORS}") + return { + "type": "TorchvisionDetectorAdaptor", + "model": detector_name, + "weights": "COCO_V1", + "num_classes": None, + "box_score_thresh": 0.6, + } + + # NOTE - DUPLICATED @deruyter92 2026-01-23: Copied from the original DeepLabCut codebase # from deeplabcut/pose_estimation_pytorch/modelzoo/utils.py def load_super_animal_config( @@ -126,7 +141,7 @@ def load_super_animal_config( model_config = add_metadata(project_config, model_config) model_config = update_config(model_config, max_individuals, device) - if detector_name is None and super_animal != "superanimal_humanbody": + if detector_name is None: model_config["method"] = "BU" else: model_config["method"] = "TD" @@ -135,16 +150,11 @@ def load_super_animal_config( ) detector_cfg = read_config_as_dict(detector_cfg_path) model_config["detector"] = detector_cfg - if super_animal == "superanimal_humanbody": - # Apply specific updates required to run the torchvision detector with pretrained weights - assert detector_name in SUPPORTED_TORCHVISION_DETECTORS - model_config["detector"]['model']= { - "type": "TorchvisionDetectorAdaptor", - "model": detector_name, - "weights": "COCO_V1", - "num_classes": None, - "box_score_thresh": 0.6, - } + + if super_animal == "superanimal_humanbody": + # Raises ValueError if Detector name is not one of SUPPORTED_TORCHVISION_DETECTORS + torchvision_detector_config = _get_torchvision_detector_config(detector_name) + model_config["detector"]["model"] = torchvision_detector_config return model_config From 0639363d6d78f9ad136ff7ac387026b554f3ccac Mon Sep 17 00:00:00 2001 From: Jaap de Ruyter van Steveninck <32810691+deruyter92@users.noreply.github.com> Date: Tue, 10 Mar 2026 11:11:55 +0100 Subject: [PATCH 5/6] add basic tests for modelzoo configuration export with superanimal_humanbody --- tests/test_modelzoo.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/test_modelzoo.py b/tests/test_modelzoo.py index c2a0d701..0cee1299 100644 --- a/tests/test_modelzoo.py +++ b/tests/test_modelzoo.py @@ -30,6 +30,35 @@ def test_get_config_model_paths(super_animal, model_name, detector_name): assert "detector" in model_config +def test_humanbody_requires_detector_name(): + with pytest.raises(ValueError): + modelzoo.load_super_animal_config( + super_animal="superanimal_humanbody", + model_name="hrnet_w32", + detector_name=None, + ) + + +def test_humanbody_rejects_unsupported_detector(): + with pytest.raises(ValueError): + modelzoo.load_super_animal_config( + super_animal="superanimal_humanbody", + model_name="hrnet_w32", + detector_name="fasterrcnn_resnet50_fpn_v2", + ) + + +def test_humanbody_uses_torchvision_detector_config(): + model_config = modelzoo.load_super_animal_config( + super_animal="superanimal_humanbody", + model_name="hrnet_w32", + detector_name="fasterrcnn_mobilenet_v3_large_fpn", + ) + detector_model_cfg = model_config["detector"]["model"] + assert model_config["method"].lower() == "td" + assert detector_model_cfg["type"] == "TorchvisionDetectorAdaptor" + + def test_download_huggingface_model(tmp_path_factory, model="full_cat"): folder = tmp_path_factory.mktemp("temp") dlclibrary.download_huggingface_model(model, str(folder)) From 3ca1d4ed42e30af452b3dc4b8d3d009c2042d7f2 Mon Sep 17 00:00:00 2001 From: Jaap de Ruyter van Steveninck <32810691+deruyter92@users.noreply.github.com> Date: Tue, 10 Mar 2026 11:25:14 +0100 Subject: [PATCH 6/6] apply pre-commit hooks (linting) --- dlclive/modelzoo/__init__.py | 2 +- .../fasterrcnn_resnet50_fpn_v2.yaml | 2 +- dlclive/modelzoo/model_configs/ssdlite.yaml | 2 +- .../superanimal_humanbody.yaml | 2 +- dlclive/modelzoo/pytorch_model_zoo_export.py | 35 +++++++++++++------ dlclive/modelzoo/utils.py | 22 +++++++----- .../models/detectors/torchvision.py | 6 +++- 7 files changed, 48 insertions(+), 23 deletions(-) diff --git a/dlclive/modelzoo/__init__.py b/dlclive/modelzoo/__init__.py index 1f4f0182..a70c5750 100644 --- a/dlclive/modelzoo/__init__.py +++ b/dlclive/modelzoo/__init__.py @@ -6,4 +6,4 @@ load_super_animal_config, download_super_animal_snapshot, ) -from dlclive.modelzoo.pytorch_model_zoo_export import export_modelzoo_model \ No newline at end of file +from dlclive.modelzoo.pytorch_model_zoo_export import export_modelzoo_model diff --git a/dlclive/modelzoo/model_configs/fasterrcnn_resnet50_fpn_v2.yaml b/dlclive/modelzoo/model_configs/fasterrcnn_resnet50_fpn_v2.yaml index 27d147e3..a78d93eb 100644 --- a/dlclive/modelzoo/model_configs/fasterrcnn_resnet50_fpn_v2.yaml +++ b/dlclive/modelzoo/model_configs/fasterrcnn_resnet50_fpn_v2.yaml @@ -48,4 +48,4 @@ train_settings: dataloader_workers: 0 dataloader_pin_memory: false display_iters: 500 - epochs: 250 \ No newline at end of file + epochs: 250 diff --git a/dlclive/modelzoo/model_configs/ssdlite.yaml b/dlclive/modelzoo/model_configs/ssdlite.yaml index 04e694fa..307bf92e 100644 --- a/dlclive/modelzoo/model_configs/ssdlite.yaml +++ b/dlclive/modelzoo/model_configs/ssdlite.yaml @@ -47,4 +47,4 @@ train_settings: dataloader_workers: 0 dataloader_pin_memory: false display_iters: 500 - epochs: 250 \ No newline at end of file + epochs: 250 diff --git a/dlclive/modelzoo/project_configs/superanimal_humanbody.yaml b/dlclive/modelzoo/project_configs/superanimal_humanbody.yaml index d1e665c1..e4d891b2 100644 --- a/dlclive/modelzoo/project_configs/superanimal_humanbody.yaml +++ b/dlclive/modelzoo/project_configs/superanimal_humanbody.yaml @@ -87,4 +87,4 @@ corner2move2: move2corner: # Conversion tables to fine-tune SuperAnimal weights -SuperAnimalConversionTables: \ No newline at end of file +SuperAnimalConversionTables: diff --git a/dlclive/modelzoo/pytorch_model_zoo_export.py b/dlclive/modelzoo/pytorch_model_zoo_export.py index d5e7fab7..9c8de46c 100644 --- a/dlclive/modelzoo/pytorch_model_zoo_export.py +++ b/dlclive/modelzoo/pytorch_model_zoo_export.py @@ -4,7 +4,10 @@ import torch -from dlclive.modelzoo.utils import load_super_animal_config, download_super_animal_snapshot +from dlclive.modelzoo.utils import ( + load_super_animal_config, + download_super_animal_snapshot, +) def export_modelzoo_model( @@ -29,7 +32,9 @@ def export_modelzoo_model( """ Path(export_path).parent.mkdir(parents=True, exist_ok=True) if Path(export_path).exists(): - warnings.warn(f"Export path {export_path} already exists, skipping export", UserWarning) + warnings.warn( + f"Export path {export_path} already exists, skipping export", UserWarning + ) return model_cfg = load_super_animal_config( @@ -38,30 +43,40 @@ def export_modelzoo_model( detector_name=detector_name, ) - def _load_model_weights(model_name: str, super_animal: str = super_animal) -> OrderedDict: + def _load_model_weights( + model_name: str, super_animal: str = super_animal + ) -> OrderedDict: """Download the model weights from huggingface and load them in torch state dict""" - checkpoint: Path = download_super_animal_snapshot(dataset=super_animal, model_name=model_name) + checkpoint: Path = download_super_animal_snapshot( + dataset=super_animal, model_name=model_name + ) return torch.load(checkpoint, map_location="cpu", weights_only=True)["model"] - + # Skip downloading the detector weights for humanbody models, as they are not on huggingface - skip_detector_download = (detector_name is None) or (super_animal == "superanimal_humanbody") + skip_detector_download = (detector_name is None) or ( + super_animal == "superanimal_humanbody" + ) export_dict = { "config": model_cfg, "pose": _load_model_weights(model_name), - "detector": None if skip_detector_download else _load_model_weights(detector_name), + "detector": None + if skip_detector_download + else _load_model_weights(detector_name), } torch.save(export_dict, export_path) if __name__ == "__main__": - """Example usage""" + """Example usage""" from utils import _MODELZOO_PATH - + model_name = "resnet_50" super_animal = "superanimal_quadruped" export_modelzoo_model( - export_path=_MODELZOO_PATH / 'exported_models' / f'exported_{super_animal}_{model_name}.pt', + export_path=_MODELZOO_PATH + / "exported_models" + / f"exported_{super_animal}_{model_name}.pt", super_animal=super_animal, model_name=model_name, ) diff --git a/dlclive/modelzoo/utils.py b/dlclive/modelzoo/utils.py index 93985619..6123142f 100644 --- a/dlclive/modelzoo/utils.py +++ b/dlclive/modelzoo/utils.py @@ -9,11 +9,15 @@ from pathlib import Path from dlclibrary.dlcmodelzoo.modelzoo_download import download_huggingface_model -from dlclibrary.dlcmodelzoo.modelzoo_download import _load_model_names as huggingface_model_paths +from dlclibrary.dlcmodelzoo.modelzoo_download import ( + _load_model_names as huggingface_model_paths, +) from ruamel.yaml import YAML from dlclive.modelzoo.resolve_config import update_config -from dlclive.pose_estimation_pytorch.models.detectors.torchvision import SUPPORTED_TORCHVISION_DETECTORS +from dlclive.pose_estimation_pytorch.models.detectors.torchvision import ( + SUPPORTED_TORCHVISION_DETECTORS, +) _MODELZOO_PATH = Path(__file__).parent @@ -100,9 +104,13 @@ def add_metadata( def _get_torchvision_detector_config(detector_name: str) -> dict: """Get a torchvision detector configuration for the superanimal humanbody model""" if detector_name is None: - raise ValueError(f"Detector name is required for superanimal humanbody models. Must be one of {SUPPORTED_TORCHVISION_DETECTORS}.") + raise ValueError( + f"Detector name is required for superanimal humanbody models. Must be one of {SUPPORTED_TORCHVISION_DETECTORS}." + ) if detector_name not in SUPPORTED_TORCHVISION_DETECTORS: - raise ValueError(f"Unsupported humanbody detector {detector_name}. Should be one of {SUPPORTED_TORCHVISION_DETECTORS}") + raise ValueError( + f"Unsupported humanbody detector {detector_name}. Should be one of {SUPPORTED_TORCHVISION_DETECTORS}" + ) return { "type": "TorchvisionDetectorAdaptor", "model": detector_name, @@ -145,12 +153,10 @@ def load_super_animal_config( model_config["method"] = "BU" else: model_config["method"] = "TD" - detector_cfg_path = get_super_animal_model_config_path( - model_name=detector_name - ) + detector_cfg_path = get_super_animal_model_config_path(model_name=detector_name) detector_cfg = read_config_as_dict(detector_cfg_path) model_config["detector"] = detector_cfg - + if super_animal == "superanimal_humanbody": # Raises ValueError if Detector name is not one of SUPPORTED_TORCHVISION_DETECTORS torchvision_detector_config = _get_torchvision_detector_config(detector_name) diff --git a/dlclive/pose_estimation_pytorch/models/detectors/torchvision.py b/dlclive/pose_estimation_pytorch/models/detectors/torchvision.py index 8790a1d7..6c2d291f 100644 --- a/dlclive/pose_estimation_pytorch/models/detectors/torchvision.py +++ b/dlclive/pose_estimation_pytorch/models/detectors/torchvision.py @@ -9,12 +9,16 @@ # Licensed under GNU Lesser General Public License v3.0 # """Module to adapt torchvision detectors for DeepLabCut""" + from __future__ import annotations import torch import torchvision.models.detection as detection -from dlclive.pose_estimation_pytorch.models.detectors.base import DETECTORS, BaseDetector +from dlclive.pose_estimation_pytorch.models.detectors.base import ( + DETECTORS, + BaseDetector, +) SUPPORTED_TORCHVISION_DETECTORS = ["fasterrcnn_mobilenet_v3_large_fpn"]