forked from DeepLabCut/DeepLabCut
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_superanimal_humanbody.py
More file actions
71 lines (59 loc) · 2.91 KB
/
test_superanimal_humanbody.py
File metadata and controls
71 lines (59 loc) · 2.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#!/usr/bin/env python3
"""
Test script for superanimal_humanbody with torchvision detector
"""
import torch
import torchvision.models.detection as detection
from deeplabcut.pose_estimation_pytorch.modelzoo import load_super_animal_config
def test_torchvision_detector():
"""Test that the torchvision detector works with superanimal_humanbody"""
# Load the superanimal_humanbody config
config = load_super_animal_config(
super_animal="superanimal_humanbody",
model_name="rtmpose_x",
detector_name="fasterrcnn_mobilenet_v3_large_fpn",
)
print("Config loaded successfully!")
print(f"Model method: {config['method']}")
print(f"Detector variant: {config['detector']['model']['variant']}")
# Check if the detector is configured to use torchvision
detector_config = config['detector']['model']
print(f"Detector config: {detector_config}")
# Test loading the torchvision detector directly
print("\nTesting torchvision detector loading...")
weights = detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT
detector = detection.fasterrcnn_mobilenet_v3_large_fpn(
weights=weights, box_score_thresh=0.6,
)
detector.eval()
print("Torchvision detector loaded successfully!")
# Test that the detector config matches what we expect for torchvision
print("\nTesting detector config compatibility...")
expected_variant = "fasterrcnn_mobilenet_v3_large_fpn"
actual_variant = detector_config.get("variant", "")
if actual_variant == expected_variant:
print(f"✅ Detector variant matches expected: {expected_variant}")
else:
print(f"❌ Detector variant mismatch. Expected: {expected_variant}, Got: {actual_variant}")
return False
# Test that the config has the correct structure for torchvision detector
if "type" in detector_config and detector_config["type"] == "FasterRCNN":
print("✅ Detector type is correctly set to FasterRCNN")
else:
print("❌ Detector type is not correctly set")
return False
# Test that the config allows for torchvision weights (no pretrained field or pretrained=False)
if "pretrained" not in detector_config or detector_config.get("pretrained") is False:
print("✅ Detector config allows torchvision weights")
else:
print("❌ Detector config has pretrained=True, which may conflict with torchvision weights")
return False
print("\n✅ All tests passed! The torchvision detector integration is working correctly.")
return True
if __name__ == "__main__":
print("Testing superanimal_humanbody with torchvision detector...")
success = test_torchvision_detector()
if success:
print("\n✅ Test passed! The torchvision detector works with superanimal_humanbody")
else:
print("\n❌ Test failed! There's an issue with the torchvision detector integration")