diff --git a/tensorflow_datasets/core/features/image_feature_test.py b/tensorflow_datasets/core/features/image_feature_test.py index 40dd092be87..accbec34175 100644 --- a/tensorflow_datasets/core/features/image_feature_test.py +++ b/tensorflow_datasets/core/features/image_feature_test.py @@ -67,11 +67,18 @@ def make_none_fail(): def _unsupported_images_for_pil( - np_dtype: Type[np.generic], failing_lib: Optional[LibWithImportError] + np_dtype: Type[np.generic], + failing_lib: Optional[LibWithImportError], + channels: int, ) -> bool: # PIL is used when OpenCV is not installed, but PIL doesn't support 16-bit # images. - return np_dtype == np.uint16 and failing_lib == LibWithImportError.CV2 + if np_dtype == np.uint16 and failing_lib == LibWithImportError.CV2: + return True + # PIL also doesn't support multi-channel 16-bit images for thumbnails. + if np_dtype == np.uint16 and channels > 1: + return True + return False class ImageFeatureTest( @@ -91,7 +98,7 @@ class ImageFeatureTest( def test_images(self, make_lib_fail, dtypes, channels): dtype, np_dtype = dtypes with make_lib_fail() as failing_lib: - if _unsupported_images_for_pil(np_dtype, failing_lib): + if _unsupported_images_for_pil(np_dtype, failing_lib, channels): return img = randint(256, size=(128, 100, channels), dtype=np_dtype) img_other_shape = randint(256, size=(64, 200, channels), dtype=np_dtype) @@ -198,7 +205,7 @@ def test_images(self, make_lib_fail, dtypes, channels): def test_images_with_invalid_shape(self, make_lib_fail, dtypes, channels): dtype, np_dtype = dtypes with make_lib_fail() as failing_lib: - if _unsupported_images_for_pil(np_dtype, failing_lib): + if _unsupported_images_for_pil(np_dtype, failing_lib, channels): return invalid_number_of_dimensions = testing.FeatureExpectationItem( value=randint(256, size=(128, 128), dtype=np_dtype),