diff --git a/src/supervision/detection/vlm.py b/src/supervision/detection/vlm.py index 1aa0fb5182..e287dcc1c7 100644 --- a/src/supervision/detection/vlm.py +++ b/src/supervision/detection/vlm.py @@ -229,7 +229,7 @@ def from_paligemma( matches = np.array(matches) if matches else np.empty((0, 5)) if matches.shape[0] == 0: - return np.empty((0, 4)), None, np.empty(0, dtype=str) + return np.empty((0, 4)), np.empty((0,), dtype=int), np.empty(0, dtype=str) xyxy, class_name = matches[:, [1, 0, 3, 2]], matches[:, 4] xyxy = xyxy.astype(int) / 1024 * np.array([w, h, w, h]) @@ -626,7 +626,7 @@ def from_google_gemini_2_0( try: data = json.loads(result) except json.JSONDecodeError: - return np.empty((0, 4)), None, np.empty((0,), dtype=str) + return np.empty((0, 4)), np.empty((0,), dtype=int), np.empty((0,), dtype=str) labels = [] xyxy = [] @@ -640,7 +640,7 @@ def from_google_gemini_2_0( xyxy.append([box[1], box[0], box[3], box[2]]) if len(xyxy) == 0: - return np.empty((0, 4)), None, np.empty((0,), dtype=str) + return np.empty((0, 4)), np.empty((0,), dtype=int), np.empty((0,), dtype=str) xyxy = denormalize_boxes( np.array(xyxy, dtype=np.float64), diff --git a/tests/detection/test_vlm.py b/tests/detection/test_vlm.py index c21679b820..b5d035b065 100644 --- a/tests/detection/test_vlm.py +++ b/tests/detection/test_vlm.py @@ -27,49 +27,49 @@ "", (1000, 1000), None, - (np.empty((0, 4)), None, np.empty(0).astype(str)), + (np.empty((0, 4)), np.empty((0,), dtype=int), np.empty(0).astype(str)), ), # empty text ( does_not_raise(), "", (1000, 1000), ["cat", "dog"], - (np.empty((0, 4)), None, np.empty(0).astype(str)), + (np.empty((0, 4)), np.empty((0,), dtype=int), np.empty(0).astype(str)), ), # empty text, classes ( does_not_raise(), "\n", (1000, 1000), None, - (np.empty((0, 4)), None, np.empty(0).astype(str)), + (np.empty((0, 4)), np.empty((0,), dtype=int), np.empty(0).astype(str)), ), # newline only ( does_not_raise(), "the quick brown fox jumps over the lazy dog.", (1000, 1000), None, - (np.empty((0, 4)), None, np.empty(0).astype(str)), + (np.empty((0, 4)), np.empty((0,), dtype=int), np.empty(0).astype(str)), ), # random text, no location ( does_not_raise(), " cat", (1000, 1000), None, - (np.empty((0, 4)), None, np.empty(0).astype(str)), + (np.empty((0, 4)), np.empty((0,), dtype=int), np.empty(0).astype(str)), ), # partial location ( does_not_raise(), " cat", (1000, 1000), None, - (np.empty((0, 4)), None, np.empty(0).astype(str)), + (np.empty((0, 4)), np.empty((0,), dtype=int), np.empty(0).astype(str)), ), # extra loc ( does_not_raise(), "", (1000, 1000), None, - (np.empty((0, 4)), None, np.empty(0).astype(str)), + (np.empty((0, 4)), np.empty((0,), dtype=int), np.empty(0).astype(str)), ), # no class ( does_not_raise(), @@ -436,21 +436,21 @@ def test_from_qwen_2_5_vl( "random text", (1000, 1000), None, - (np.empty((0, 4)), None, np.empty(0, dtype=str)), + (np.empty((0, 4)), np.empty((0,), dtype=int), np.empty(0, dtype=str)), ), # random text without JSON format ( does_not_raise(), "```json\ninvalid json\n```", (1000, 1000), None, - (np.empty((0, 4)), None, np.empty(0, dtype=str)), + (np.empty((0, 4)), np.empty((0,), dtype=int), np.empty(0, dtype=str)), ), # invalid JSON within code blocks ( does_not_raise(), "```json\n[]\n```", (1000, 1000), None, - (np.empty((0, 4)), None, np.empty(0, dtype=str)), + (np.empty((0, 4)), np.empty((0,), dtype=int), np.empty(0, dtype=str)), ), # empty JSON array ( does_not_raise(),