diff --git a/dlclive/modelzoo/__init__.py b/dlclive/modelzoo/__init__.py index 1f4f018..a70c575 100644 --- a/dlclive/modelzoo/__init__.py +++ b/dlclive/modelzoo/__init__.py @@ -6,4 +6,4 @@ load_super_animal_config, download_super_animal_snapshot, ) -from dlclive.modelzoo.pytorch_model_zoo_export import export_modelzoo_model \ No newline at end of file +from dlclive.modelzoo.pytorch_model_zoo_export import export_modelzoo_model diff --git a/dlclive/modelzoo/model_configs/fasterrcnn_resnet50_fpn_v2.yaml b/dlclive/modelzoo/model_configs/fasterrcnn_resnet50_fpn_v2.yaml index 27d147e..a78d93e 100644 --- a/dlclive/modelzoo/model_configs/fasterrcnn_resnet50_fpn_v2.yaml +++ b/dlclive/modelzoo/model_configs/fasterrcnn_resnet50_fpn_v2.yaml @@ -48,4 +48,4 @@ train_settings: dataloader_workers: 0 dataloader_pin_memory: false display_iters: 500 - epochs: 250 \ No newline at end of file + epochs: 250 diff --git a/dlclive/modelzoo/model_configs/ssdlite.yaml b/dlclive/modelzoo/model_configs/ssdlite.yaml index 04e694f..307bf92 100644 --- a/dlclive/modelzoo/model_configs/ssdlite.yaml +++ b/dlclive/modelzoo/model_configs/ssdlite.yaml @@ -47,4 +47,4 @@ train_settings: dataloader_workers: 0 dataloader_pin_memory: false display_iters: 500 - epochs: 250 \ No newline at end of file + epochs: 250 diff --git a/dlclive/modelzoo/project_configs/superanimal_humanbody.yaml b/dlclive/modelzoo/project_configs/superanimal_humanbody.yaml index d1e665c..e4d891b 100644 --- a/dlclive/modelzoo/project_configs/superanimal_humanbody.yaml +++ b/dlclive/modelzoo/project_configs/superanimal_humanbody.yaml @@ -87,4 +87,4 @@ corner2move2: move2corner: # Conversion tables to fine-tune SuperAnimal weights -SuperAnimalConversionTables: \ No newline at end of file +SuperAnimalConversionTables: diff --git a/dlclive/modelzoo/pytorch_model_zoo_export.py b/dlclive/modelzoo/pytorch_model_zoo_export.py index f5cc39a..9c8de46 100644 --- a/dlclive/modelzoo/pytorch_model_zoo_export.py +++ b/dlclive/modelzoo/pytorch_model_zoo_export.py @@ -4,7 +4,10 @@ import torch -from dlclive.modelzoo.utils import load_super_animal_config, download_super_animal_snapshot +from dlclive.modelzoo.utils import ( + load_super_animal_config, + download_super_animal_snapshot, +) def export_modelzoo_model( @@ -24,12 +27,14 @@ def export_modelzoo_model( export_path: Arbitrary destination path for the exported .pt file. super_animal: Super animal dataset name (e.g. "superanimal_quadruped"). model_name: Pose model architecture name (e.g. "resnet_50"). - detector_name: Optional detector model name. If provided, detector + detector_name: Detector model name for top-down models. If provided, detector weights are included in the export. """ Path(export_path).parent.mkdir(parents=True, exist_ok=True) if Path(export_path).exists(): - warnings.warn(f"Export path {export_path} already exists, skipping export", UserWarning) + warnings.warn( + f"Export path {export_path} already exists, skipping export", UserWarning + ) return model_cfg = load_super_animal_config( @@ -38,28 +43,40 @@ def export_modelzoo_model( detector_name=detector_name, ) - def _load_model_weights(model_name: str, super_animal: str = super_animal) -> OrderedDict: + def _load_model_weights( + model_name: str, super_animal: str = super_animal + ) -> OrderedDict: """Download the model weights from huggingface and load them in torch state dict""" - checkpoint: Path = download_super_animal_snapshot(dataset=super_animal, model_name=model_name) + checkpoint: Path = download_super_animal_snapshot( + dataset=super_animal, model_name=model_name + ) return torch.load(checkpoint, map_location="cpu", weights_only=True)["model"] - + + # Skip downloading the detector weights for humanbody models, as they are not on huggingface + skip_detector_download = (detector_name is None) or ( + super_animal == "superanimal_humanbody" + ) export_dict = { "config": model_cfg, "pose": _load_model_weights(model_name), - "detector": _load_model_weights(detector_name) if detector_name is not None else None, + "detector": None + if skip_detector_download + else _load_model_weights(detector_name), } torch.save(export_dict, export_path) if __name__ == "__main__": - """Example usage""" + """Example usage""" from utils import _MODELZOO_PATH - + model_name = "resnet_50" super_animal = "superanimal_quadruped" export_modelzoo_model( - export_path=_MODELZOO_PATH / 'exported_models' / f'exported_{super_animal}_{model_name}.pt', + export_path=_MODELZOO_PATH + / "exported_models" + / f"exported_{super_animal}_{model_name}.pt", super_animal=super_animal, model_name=model_name, ) diff --git a/dlclive/modelzoo/utils.py b/dlclive/modelzoo/utils.py index f9bf2f7..6123142 100644 --- a/dlclive/modelzoo/utils.py +++ b/dlclive/modelzoo/utils.py @@ -9,10 +9,15 @@ from pathlib import Path from dlclibrary.dlcmodelzoo.modelzoo_download import download_huggingface_model -from dlclibrary.dlcmodelzoo.modelzoo_download import _load_model_names as huggingface_model_paths +from dlclibrary.dlcmodelzoo.modelzoo_download import ( + _load_model_names as huggingface_model_paths, +) from ruamel.yaml import YAML from dlclive.modelzoo.resolve_config import update_config +from dlclive.pose_estimation_pytorch.models.detectors.torchvision import ( + SUPPORTED_TORCHVISION_DETECTORS, +) _MODELZOO_PATH = Path(__file__).parent @@ -96,6 +101,25 @@ def add_metadata( return config +def _get_torchvision_detector_config(detector_name: str) -> dict: + """Get a torchvision detector configuration for the superanimal humanbody model""" + if detector_name is None: + raise ValueError( + f"Detector name is required for superanimal humanbody models. Must be one of {SUPPORTED_TORCHVISION_DETECTORS}." + ) + if detector_name not in SUPPORTED_TORCHVISION_DETECTORS: + raise ValueError( + f"Unsupported humanbody detector {detector_name}. Should be one of {SUPPORTED_TORCHVISION_DETECTORS}" + ) + return { + "type": "TorchvisionDetectorAdaptor", + "model": detector_name, + "weights": "COCO_V1", + "num_classes": None, + "box_score_thresh": 0.6, + } + + # NOTE - DUPLICATED @deruyter92 2026-01-23: Copied from the original DeepLabCut codebase # from deeplabcut/pose_estimation_pytorch/modelzoo/utils.py def load_super_animal_config( @@ -125,16 +149,18 @@ def load_super_animal_config( model_config = add_metadata(project_config, model_config) model_config = update_config(model_config, max_individuals, device) - if detector_name is None and super_animal != "superanimal_humanbody": + if detector_name is None: model_config["method"] = "BU" else: model_config["method"] = "TD" - if super_animal != "superanimal_humanbody": - detector_cfg_path = get_super_animal_model_config_path( - model_name=detector_name - ) - detector_cfg = read_config_as_dict(detector_cfg_path) - model_config["detector"] = detector_cfg + detector_cfg_path = get_super_animal_model_config_path(model_name=detector_name) + detector_cfg = read_config_as_dict(detector_cfg_path) + model_config["detector"] = detector_cfg + + if super_animal == "superanimal_humanbody": + # Raises ValueError if Detector name is not one of SUPPORTED_TORCHVISION_DETECTORS + torchvision_detector_config = _get_torchvision_detector_config(detector_name) + model_config["detector"]["model"] = torchvision_detector_config return model_config diff --git a/dlclive/pose_estimation_pytorch/models/detectors/__init__.py b/dlclive/pose_estimation_pytorch/models/detectors/__init__.py index e9a99a6..c89902a 100644 --- a/dlclive/pose_estimation_pytorch/models/detectors/__init__.py +++ b/dlclive/pose_estimation_pytorch/models/detectors/__init__.py @@ -12,5 +12,8 @@ DETECTORS, BaseDetector, ) +from dlclive.pose_estimation_pytorch.models.detectors.torchvision import ( + TorchvisionDetectorAdaptor, +) from dlclive.pose_estimation_pytorch.models.detectors.fasterRCNN import FasterRCNN from dlclive.pose_estimation_pytorch.models.detectors.ssd import SSDLite diff --git a/dlclive/pose_estimation_pytorch/models/detectors/torchvision.py b/dlclive/pose_estimation_pytorch/models/detectors/torchvision.py index 72dd54b..6c2d291 100644 --- a/dlclive/pose_estimation_pytorch/models/detectors/torchvision.py +++ b/dlclive/pose_estimation_pytorch/models/detectors/torchvision.py @@ -9,14 +9,21 @@ # Licensed under GNU Lesser General Public License v3.0 # """Module to adapt torchvision detectors for DeepLabCut""" + from __future__ import annotations import torch import torchvision.models.detection as detection -from dlclive.pose_estimation_pytorch.models.detectors.base import BaseDetector +from dlclive.pose_estimation_pytorch.models.detectors.base import ( + DETECTORS, + BaseDetector, +) + +SUPPORTED_TORCHVISION_DETECTORS = ["fasterrcnn_mobilenet_v3_large_fpn"] +@DETECTORS.register_module class TorchvisionDetectorAdaptor(BaseDetector): """An adaptor for torchvision detectors @@ -26,8 +33,8 @@ class TorchvisionDetectorAdaptor(BaseDetector): - fasterrcnn_mobilenet_v3_large_fpn - fasterrcnn_resnet50_fpn_v2 - This class should not be used out-of-the-box. Subclasses (such as FasterRCNN or - SSDLite) should be used instead. + This class can be used directly (e.g. with pre-trained COCO weights) or through its + subclasses (FasterRCNN or SSDLite) which adapt the model for DLC's 2-class detection. The torchvision implementation does not allow to get both predictions and losses with a single forward pass. Therefore, during evaluation only bounding box metrics diff --git a/dlclive/pose_estimation_pytorch/runner.py b/dlclive/pose_estimation_pytorch/runner.py index 2c59605..a22506e 100644 --- a/dlclive/pose_estimation_pytorch/runner.py +++ b/dlclive/pose_estimation_pytorch/runner.py @@ -268,10 +268,24 @@ def load_model(self) -> None: self.model = self.model.half() self.detector = None - if self.dynamic is None and raw_data.get("detector") is not None: + detector_cfg = self.cfg.get("detector") + has_detector_weights = raw_data.get("detector") is not None + if detector_cfg is not None: + detector_model_cfg = detector_cfg["model"] + uses_pretrained = ( + detector_model_cfg.get("pretrained", False) + or detector_model_cfg.get("weights") is not None + ) + else: + uses_pretrained = False + + if self.dynamic is None and (has_detector_weights or uses_pretrained): self.detector = models.DETECTORS.build(self.cfg["detector"]["model"]) self.detector.to(self.device) - self.detector.load_state_dict(raw_data["detector"]) + + if has_detector_weights: + self.detector.load_state_dict(raw_data["detector"]) + self.detector.eval() if self.precision == "FP16": self.detector = self.detector.half() @@ -281,7 +295,8 @@ def load_model(self) -> None: self.top_down_config.read_config(self.cfg) detector_transforms = [v2.ToDtype(torch.float32, scale=True)] - if self.cfg["detector"]["data"]["inference"].get("normalize_images", False): + detector_data_cfg = detector_cfg.get("data", {}).get("inference", {}) + if detector_data_cfg.get("normalize_images", False): detector_transforms.append(v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) self.detector_transform = v2.Compose(detector_transforms) diff --git a/tests/test_modelzoo.py b/tests/test_modelzoo.py index c2a0d70..0cee129 100644 --- a/tests/test_modelzoo.py +++ b/tests/test_modelzoo.py @@ -30,6 +30,35 @@ def test_get_config_model_paths(super_animal, model_name, detector_name): assert "detector" in model_config +def test_humanbody_requires_detector_name(): + with pytest.raises(ValueError): + modelzoo.load_super_animal_config( + super_animal="superanimal_humanbody", + model_name="hrnet_w32", + detector_name=None, + ) + + +def test_humanbody_rejects_unsupported_detector(): + with pytest.raises(ValueError): + modelzoo.load_super_animal_config( + super_animal="superanimal_humanbody", + model_name="hrnet_w32", + detector_name="fasterrcnn_resnet50_fpn_v2", + ) + + +def test_humanbody_uses_torchvision_detector_config(): + model_config = modelzoo.load_super_animal_config( + super_animal="superanimal_humanbody", + model_name="hrnet_w32", + detector_name="fasterrcnn_mobilenet_v3_large_fpn", + ) + detector_model_cfg = model_config["detector"]["model"] + assert model_config["method"].lower() == "td" + assert detector_model_cfg["type"] == "TorchvisionDetectorAdaptor" + + def test_download_huggingface_model(tmp_path_factory, model="full_cat"): folder = tmp_path_factory.mktemp("temp") dlclibrary.download_huggingface_model(model, str(folder))