diff --git a/gigl/common/types/uri/http_uri.py b/gigl/common/types/uri/http_uri.py index b94d4c0ac..41ae2740c 100644 --- a/gigl/common/types/uri/http_uri.py +++ b/gigl/common/types/uri/http_uri.py @@ -10,8 +10,7 @@ class HttpUri(Uri): """Represents an HTTP URI.""" def __init__(self, uri: Union[str, Path, HttpUri]) -> None: - self._has_valid_prefix(uri=uri) - self._has_no_backslash(uri=uri) + self.is_valid(uri=self._token_to_string(uri), raise_exception=True) super().__init__(uri=uri) @classmethod diff --git a/gigl/common/types/uri/uri.py b/gigl/common/types/uri/uri.py index 6e7657d80..12f699f5e 100644 --- a/gigl/common/types/uri/uri.py +++ b/gigl/common/types/uri/uri.py @@ -29,13 +29,49 @@ def _token_to_string(token: _URI_LIKE) -> str: return str(token) return "" - # TODO(kmonte): You should not be able to join a Uri with a Uri of a different type. - # *or* join HTTP on HTTP or GCS on GCS. - # This is not backwards compatible, so come around to this later. + @classmethod + def _is_absolute(cls, token: _URI_LIKE) -> bool: + token_str = cls._token_to_string(token) + + # Note: "://" is used to detect GcsUri and HttpUri prefixes, but will incorrectly + # classify relative LocalUri path's containing "://" as absolute + return Path(token_str).is_absolute() or "://" in token_str + + @classmethod + def _has_matching_uri_type( + cls, token: _URI_LIKE, allow_base_uri_join: bool = False + ) -> bool: + token_is_path_like = not isinstance(token, Uri) + if token_is_path_like: + # str and Path tokens do not have a Uri type to match. + return True + + token_matches_join_type = token.__class__ is cls + if token_matches_join_type: + # Same concrete Uri type, e.g. GcsUri.join(GcsUri(...)). + return True + + base_join_with_concrete_token = allow_base_uri_join and cls is Uri + if base_join_with_concrete_token: + # Uri.join(GcsUri(...), "suffix") stays supported for compatibility. + return True + + # Remaining Uri tokens are cross-type mixes, e.g. GcsUri with HttpUri. + return False + @classmethod def join(cls, token: _URI_LIKE, *tokens: _URI_LIKE) -> Self: """ Join multiple tokens to create a new Uri instance. + The first token may be an absolute or relative path. Every additional token must be relative. + + Note: + - Rejecting suffix strings containing "://" may break callers that + intentionally store URI-looking strings inside paths. + - Rejecting absolute LocalUri suffix tokens is stricter than os.path.join, + which implicitly discards earlier tokens on the join path. + - Concrete Uri joins cannot mix Uri token types; base Uri.join keeps + the first token generic for compatibility. Args: token: The first token to join. @@ -45,6 +81,23 @@ def join(cls, token: _URI_LIKE, *tokens: _URI_LIKE) -> Self: A new Uri instance representing the joined URI. """ + # Keep base Uri.join generic for existing callers that pass concrete Uri instances. + if not cls._has_matching_uri_type(token, allow_base_uri_join=True): + raise TypeError( + f"Cannot join {cls.__name__} with {token.__class__.__name__}" + ) + + for suffix in tokens: + if not cls._has_matching_uri_type(suffix): + raise TypeError( + f"Cannot join {cls.__name__} with {suffix.__class__.__name__}" + ) + + if cls._is_absolute(suffix): + raise TypeError( + f"URI join suffixes must be relative; got absolute path: {suffix}" + ) + token = cls._token_to_string(token) token_strs: list[str] = [cls._token_to_string(token) for token in tokens] joined_tmp_path = os.path.join(token, *token_strs) @@ -88,10 +141,6 @@ def __eq__(self, other: Any) -> bool: return False def __truediv__(self, other: _URI_LIKE) -> Self: - if isinstance(other, Uri) and not isinstance(other, type(self)): - raise TypeError( - f"Cannot use '/' operator to join {type(self).__name__} with {type(other).__name__}" - ) return self.join(self, other) diff --git a/tests/integration/common/file_loader_test.py b/tests/integration/common/file_loader_test.py index efa05aee0..211b03b37 100644 --- a/tests/integration/common/file_loader_test.py +++ b/tests/integration/common/file_loader_test.py @@ -218,7 +218,7 @@ def test_local_to_gcs_dir(self): local_files = ["a.txt", "b.txt", "c.txt", "d.txt"] local_src_dir: LocalUri = LocalUri.join(self.test_asset_directory, "src") gcs_dst_dir: GcsUri = GcsUri.join( - self.gcs_test_asset_directory, self.test_asset_directory, "dst" + self.gcs_test_asset_directory, self.test_asset_directory.uri, "dst" ) local_file_paths_src: list[LocalUri] = [ @@ -243,7 +243,7 @@ def test_local_to_gcs_dir(self): self.assertTrue(self.gcs_utils.does_gcs_file_exist(gcs_path=gcs_file)) self.gcs_utils.delete_files_in_bucket_dir( gcs_path=GcsUri.join( - self.gcs_test_asset_directory, self.test_asset_directory + self.gcs_test_asset_directory, self.test_asset_directory.uri ) ) @@ -251,7 +251,7 @@ def test_gcs_to_local_dir(self): local_files = ["a.txt", "b.txt", "c.txt", "d.txt"] local_src_dir: LocalUri = LocalUri.join(self.test_asset_directory, "src") gcs_src_dir: GcsUri = GcsUri.join( - self.gcs_test_asset_directory, self.test_asset_directory, "src" + self.gcs_test_asset_directory, self.test_asset_directory.uri, "src" ) local_dst_dir: LocalUri = LocalUri.join(self.test_asset_directory, "dst") @@ -294,16 +294,16 @@ def test_gcs_to_local_dir(self): self.assertTrue(local_fs.does_path_exist(file)) self.gcs_utils.delete_files_in_bucket_dir( gcs_path=GcsUri.join( - self.gcs_test_asset_directory, self.test_asset_directory + self.gcs_test_asset_directory, self.test_asset_directory.uri ) ) def test_gcs_to_gcs_dir(self): gcs_src_dir: GcsUri = GcsUri.join( - self.gcs_test_asset_directory, self.test_asset_directory, "src" + self.gcs_test_asset_directory, self.test_asset_directory.uri, "src" ) gcs_dst_dir: GcsUri = GcsUri.join( - self.gcs_test_asset_directory, self.test_asset_directory, "dst" + self.gcs_test_asset_directory, self.test_asset_directory.uri, "dst" ) dir_uri_map: dict[Uri, Uri] = {gcs_src_dir: gcs_dst_dir} diff --git a/tests/unit/common/types/uri_test.py b/tests/unit/common/types/uri_test.py index 7de450f28..13e492f39 100644 --- a/tests/unit/common/types/uri_test.py +++ b/tests/unit/common/types/uri_test.py @@ -38,6 +38,55 @@ def test_join(self): with self.subTest("LocalUri with Path"): joined = LocalUri.join("/foo/bar", Path("file.text")) self.assertEqual(joined, LocalUri("/foo/bar/file.text")) + with self.subTest("Uri with concrete first token"): + joined = Uri.join(GcsUri("gs://bucket"), "file.txt") + self.assertEqual(joined, Uri("gs://bucket/file.txt")) + self.assertIsInstance(joined, Uri) + with self.subTest("LocalUri suffix"): + relative_local_uri = LocalUri("file.txt") + joined = LocalUri.join("/foo/bar", relative_local_uri) + self.assertEqual(joined, LocalUri("/foo/bar/file.txt")) + self.assertIsInstance(joined, LocalUri) + + def test_join_invalid_suffix(self): + with self.subTest("relative LocalUri suffix with non-local join"): + relative_local_uri = LocalUri("file.txt") + with self.assertRaises(TypeError): + GcsUri.join("gs://bucket/path", relative_local_uri) + + with self.subTest("mixed Uri first token"): + with self.assertRaises(TypeError): + LocalUri.join(GcsUri("gs://bucket/path"), "file.txt") + + with self.subTest("absolute LocalUri suffix"): + absolute_local_uri = LocalUri("/other/file.txt") + with self.assertRaises(TypeError): + LocalUri.join("/foo/bar", absolute_local_uri) + + with self.subTest("absolute HttpUri suffix"): + http_uri = HttpUri("http://abc.com/file.txt") + with self.assertRaises(TypeError): + HttpUri.join("http://abc.com/xyz", http_uri) + + with self.subTest("absolute GcsUri suffix"): + gcs_uri = GcsUri("gs://bucket/file.txt") + with self.assertRaises(TypeError): + GcsUri.join("gs://bucket/path", gcs_uri) + + def test_join_rejects_relative_path_with_uri_separator(self): + with self.assertRaises(TypeError): + LocalUri.join("/foo/bar", "folder://file.txt") + + def test_base_uri_join_rejects_concrete_uri_suffix(self): + # Concrete Uri suffixes require a matching concrete join, not base Uri.join. + relative_local_uri = LocalUri("file.txt") + + with self.assertRaises(TypeError): + Uri.join("/foo/bar", relative_local_uri) + + def test_http_uri_constructor_rejects_invalid_remote_path(self): + with self.assertRaises(TypeError): + HttpUri("file.txt") def test_div_join(self): joined: Uri