diff --git a/stone/backend.py b/stone/backend.py index 5d34160d..bf2fdf6d 100644 --- a/stone/backend.py +++ b/stone/backend.py @@ -178,6 +178,10 @@ def _record_output_path(self, output_path): self.output_manifest.add_output(self.target_folder_path, output_path) return True + def _validate_output_path(self, output_path): + # type: (str) -> None + _relative_output_path(self.target_folder_path, output_path) + @contextmanager def output_to_relative_path(self, relative_path, mode='wb'): # type: (typing.Text, typing.Text) -> typing.Iterator[None] @@ -188,6 +192,7 @@ def output_to_relative_path(self, relative_path, mode='wb'): Clears the output buffer on enter and exit. """ full_path = os.path.join(self.target_folder_path, relative_path) + self._validate_output_path(full_path) if self._record_output_path(full_path): self.clear_output_buffer() yield @@ -208,6 +213,7 @@ def output_to_relative_path(self, relative_path, mode='wb'): def copy_to_path(self, src, dst, *copy_args, **copy_kwargs): output_path = os.path.join(dst, os.path.basename(src)) if os.path.isdir(dst) else dst + self._validate_output_path(output_path) if self._record_output_path(output_path): return output_path return shutil.copy(src, dst, *copy_args, **copy_kwargs) diff --git a/stone/backends/swift.py b/stone/backends/swift.py index 4588ddba..79130e10 100644 --- a/stone/backends/swift.py +++ b/stone/backends/swift.py @@ -270,6 +270,7 @@ def _write_output_in_target_folder(self, output, file_name): if not os.path.exists(full_path): os.mkdir(full_path) full_path = os.path.join(full_path, file_name) + self._validate_output_path(full_path) if self._record_output_path(full_path): return with open(full_path, "w", encoding='utf-8') as fh: diff --git a/test/test_output_manifest.py b/test/test_output_manifest.py index 7b444588..54426e47 100644 --- a/test/test_output_manifest.py +++ b/test/test_output_manifest.py @@ -74,5 +74,53 @@ def generate(self, api): self.assertEqual(existing_files, []) +class TestOutputRootValidation(unittest.TestCase): + + def test_output_to_relative_path_rejects_parent_paths(self): + class ValidatingBackend(CodeBackend): + preserve_aliases = True + + def generate(self, api): + pass + + with tempfile.TemporaryDirectory() as output_root: + backend = ValidatingBackend(output_root, []) + + with self.assertRaises(AssertionError): + with backend.output_to_relative_path('../Generated.py'): + backend.emit('generated = True') + + def test_copy_to_path_rejects_parent_paths(self): + class ValidatingBackend(Backend): + preserve_aliases = True + + def generate(self, api): + pass + + with tempfile.NamedTemporaryFile() as source_file: + with tempfile.TemporaryDirectory() as output_root: + backend = ValidatingBackend(output_root, []) + + with self.assertRaises(AssertionError): + backend.copy_to_path( + source_file.name, + os.path.join(output_root, '..', 'Copied.py')) + + def test_swift_output_rejects_parent_paths(self): + class ValidatingBackend(SwiftBaseBackend): + preserve_aliases = True + + def generate(self, api): + pass + + with tempfile.TemporaryDirectory() as output_root: + backend = ValidatingBackend(output_root, []) + + with self.assertRaises(AssertionError): + backend._write_output_in_target_folder( + 'final class Generated {}', + '../Generated.swift') + + if __name__ == '__main__': unittest.main()