-
Notifications
You must be signed in to change notification settings - Fork 54
add torchvision detector functionality #164
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
01a2f2e
333f714
3a791aa
15b265c
9ee9768
0639363
3ca1d4e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -48,4 +48,4 @@ train_settings: | |
| dataloader_workers: 0 | ||
| dataloader_pin_memory: false | ||
| display_iters: 500 | ||
| epochs: 250 | ||
| epochs: 250 | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -47,4 +47,4 @@ train_settings: | |
| dataloader_workers: 0 | ||
| dataloader_pin_memory: false | ||
| display_iters: 500 | ||
| epochs: 250 | ||
| epochs: 250 | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -9,10 +9,15 @@ | |||||||||||||||||||||||||
| from pathlib import Path | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| from dlclibrary.dlcmodelzoo.modelzoo_download import download_huggingface_model | ||||||||||||||||||||||||||
| from dlclibrary.dlcmodelzoo.modelzoo_download import _load_model_names as huggingface_model_paths | ||||||||||||||||||||||||||
| from dlclibrary.dlcmodelzoo.modelzoo_download import ( | ||||||||||||||||||||||||||
| _load_model_names as huggingface_model_paths, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| from ruamel.yaml import YAML | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| from dlclive.modelzoo.resolve_config import update_config | ||||||||||||||||||||||||||
| from dlclive.pose_estimation_pytorch.models.detectors.torchvision import ( | ||||||||||||||||||||||||||
| SUPPORTED_TORCHVISION_DETECTORS, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
|
Comment on lines
+18
to
+20
|
||||||||||||||||||||||||||
| from dlclive.pose_estimation_pytorch.models.detectors.torchvision import ( | |
| SUPPORTED_TORCHVISION_DETECTORS, | |
| ) | |
| try: | |
| from dlclive.pose_estimation_pytorch.models.detectors.torchvision import ( | |
| SUPPORTED_TORCHVISION_DETECTORS, | |
| ) | |
| except ImportError: # torchvision (and its dependencies) is optional | |
| # Fallback to an empty collection so this module can be imported even when | |
| # torchvision is not installed. Code that relies on detectors should | |
| # handle the absence of supported detectors appropriately. | |
| SUPPORTED_TORCHVISION_DETECTORS = () |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bad suggestion IMO. utils modules are junk drawers. _get_torchvision_detector_config is a torch-specific function, and it should exist in a torch-specific subpackage or module. don't listen to the LLMs which will fallback your code into a verbose oblivion - just put the code in the place it should be.
Copilot
AI
Mar 10, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_get_torchvision_detector_config is annotated as taking detector_name: str, but it explicitly handles None and is called with detector_name that may be None. Update the type signature to accept str | None (and/or adjust the call sites) so type checkers and IDEs reflect the actual contract.
| def _get_torchvision_detector_config(detector_name: str) -> dict: | |
| def _get_torchvision_detector_config(detector_name: str | None) -> dict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand why this function exists - we're just putting a string in a hardcoded dictionary? the main thing it does is throw exceptions, but the exception should be thrown in whatever consumes this dict. this is way way far on the LBYL side of look before you leap vs. easier to ask forgiveness - we are checking if a value is valid in some utility method very distant from where it's used. if this is needed, it sort of screams that whatever super animal is needs a proper data model - if super animal requires a detector name, there should be a data model that requires a detector name, not in some random utility function somewhere!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
soooo remember what we were saying about how copying and pasting code is bad because then you have to ensure that it never changes or you always need to remember to update in both places? this is one of those times. this function should have been located in some shared root package (dlclibrary!!!!) and imported in both in the first place, but if we are sprouting divergent functionality, need to make that clear: "I do this all the time in normal deeplabcut, how come it raises an error here!"
this is also strangely conflicting with the other place where we silence rather than raise an error - https://github.com/DeepLabCut/DeepLabCut-live/pull/164/changes#r2913689689
so we are handling the same thing twice in opposing directions.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,14 +9,21 @@ | |
| # Licensed under GNU Lesser General Public License v3.0 | ||
| # | ||
| """Module to adapt torchvision detectors for DeepLabCut""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import torch | ||
| import torchvision.models.detection as detection | ||
|
|
||
| from dlclive.pose_estimation_pytorch.models.detectors.base import BaseDetector | ||
| from dlclive.pose_estimation_pytorch.models.detectors.base import ( | ||
| DETECTORS, | ||
| BaseDetector, | ||
| ) | ||
|
|
||
| SUPPORTED_TORCHVISION_DETECTORS = ["fasterrcnn_mobilenet_v3_large_fpn"] | ||
|
||
|
|
||
|
|
||
| @DETECTORS.register_module | ||
| class TorchvisionDetectorAdaptor(BaseDetector): | ||
|
||
| """An adaptor for torchvision detectors | ||
|
|
||
|
|
@@ -26,8 +33,8 @@ class TorchvisionDetectorAdaptor(BaseDetector): | |
| - fasterrcnn_mobilenet_v3_large_fpn | ||
| - fasterrcnn_resnet50_fpn_v2 | ||
|
Comment on lines
33
to
34
|
||
|
|
||
| This class should not be used out-of-the-box. Subclasses (such as FasterRCNN or | ||
| SSDLite) should be used instead. | ||
| This class can be used directly (e.g. with pre-trained COCO weights) or through its | ||
| subclasses (FasterRCNN or SSDLite) which adapt the model for DLC's 2-class detection. | ||
|
Comment on lines
+36
to
+37
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure why! usually base classes that shouldn't be used have some abstract methods and only serve as the external public API contract, so it doesn't seem like this was serving as an ABC before, but usually mixing levels like this is a bad call, having a "pre trained" or whatever subclass might be a good idea to discourage modifying the base class to accommodate any specific needs for this use that doesn't apply to the other subclasses and warps the contract made by the ABC. |
||
|
|
||
| The torchvision implementation does not allow to get both predictions and losses | ||
| with a single forward pass. Therefore, during evaluation only bounding box metrics | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -268,10 +268,24 @@ def load_model(self) -> None: | |
| self.model = self.model.half() | ||
|
|
||
| self.detector = None | ||
| if self.dynamic is None and raw_data.get("detector") is not None: | ||
| detector_cfg = self.cfg.get("detector") | ||
| has_detector_weights = raw_data.get("detector") is not None | ||
| if detector_cfg is not None: | ||
| detector_model_cfg = detector_cfg["model"] | ||
| uses_pretrained = ( | ||
| detector_model_cfg.get("pretrained", False) | ||
| or detector_model_cfg.get("weights") is not None | ||
| ) | ||
|
Comment on lines
+271
to
+278
|
||
| else: | ||
| uses_pretrained = False | ||
|
|
||
| if self.dynamic is None and (has_detector_weights or uses_pretrained): | ||
| self.detector = models.DETECTORS.build(self.cfg["detector"]["model"]) | ||
| self.detector.to(self.device) | ||
| self.detector.load_state_dict(raw_data["detector"]) | ||
|
|
||
| if has_detector_weights: | ||
| self.detector.load_state_dict(raw_data["detector"]) | ||
|
Comment on lines
+282
to
+287
|
||
|
|
||
| self.detector.eval() | ||
| if self.precision == "FP16": | ||
| self.detector = self.detector.half() | ||
|
|
@@ -281,7 +295,8 @@ def load_model(self) -> None: | |
| self.top_down_config.read_config(self.cfg) | ||
|
|
||
| detector_transforms = [v2.ToDtype(torch.float32, scale=True)] | ||
| if self.cfg["detector"]["data"]["inference"].get("normalize_images", False): | ||
| detector_data_cfg = detector_cfg.get("data", {}).get("inference", {}) | ||
| if detector_data_cfg.get("normalize_images", False): | ||
| detector_transforms.append(v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) | ||
| self.detector_transform = v2.Compose(detector_transforms) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems not great to hardcode. idk what if it eventually is on huggingface? what if there are other detector names that can't be reached by the load model weights function? also silently changing explicitly specified functionality is bad to do - i told you to use
detector_name, but there is some hidden interaction with my model choice, and how would I even know?just throw an exception in
_load_model_weightsif the model weights can't be loaded