Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions tensorflow_datasets/core/features/image_feature_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,18 @@ def make_none_fail():


def _unsupported_images_for_pil(
np_dtype: Type[np.generic], failing_lib: Optional[LibWithImportError]
np_dtype: Type[np.generic],
failing_lib: Optional[LibWithImportError],
channels: int,
) -> bool:
# PIL is used when OpenCV is not installed, but PIL doesn't support 16-bit
# images.
return np_dtype == np.uint16 and failing_lib == LibWithImportError.CV2
if np_dtype == np.uint16 and failing_lib == LibWithImportError.CV2:
return True
# PIL also doesn't support multi-channel 16-bit images for thumbnails.
if np_dtype == np.uint16 and channels > 1:
return True
return False


class ImageFeatureTest(
Expand All @@ -91,7 +98,7 @@ class ImageFeatureTest(
def test_images(self, make_lib_fail, dtypes, channels):
dtype, np_dtype = dtypes
with make_lib_fail() as failing_lib:
if _unsupported_images_for_pil(np_dtype, failing_lib):
if _unsupported_images_for_pil(np_dtype, failing_lib, channels):
return
img = randint(256, size=(128, 100, channels), dtype=np_dtype)
img_other_shape = randint(256, size=(64, 200, channels), dtype=np_dtype)
Expand Down Expand Up @@ -198,7 +205,7 @@ def test_images(self, make_lib_fail, dtypes, channels):
def test_images_with_invalid_shape(self, make_lib_fail, dtypes, channels):
dtype, np_dtype = dtypes
with make_lib_fail() as failing_lib:
if _unsupported_images_for_pil(np_dtype, failing_lib):
if _unsupported_images_for_pil(np_dtype, failing_lib, channels):
return
invalid_number_of_dimensions = testing.FeatureExpectationItem(
value=randint(256, size=(128, 128), dtype=np_dtype),
Expand Down
Loading