Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/routers/openml/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
CasualString128,
Identifier,
IntegerRange,
SystemString64,
TagString,
integer_range_regex,
)
from schemas.datasets.openml import DatasetMetadata, DatasetStatus, Feature, FeatureType
Expand All @@ -63,7 +63,7 @@
)
async def tag_dataset(
data_id: Annotated[Identifier, Body()],
tag: Annotated[SystemString64, Body()],
tag: Annotated[TagString, Body()],
user: Annotated[User, Depends(fetch_user_or_raise)],
expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)],
) -> dict[str, dict[str, Any]]:
Expand All @@ -87,13 +87,13 @@ async def tag_dataset(

class TagInfo(TypedDict):
id: str
tag: NotRequired[SystemString64 | list[SystemString64]]
tag: NotRequired[TagString | list[TagString]]


@router.post(path="/untag", deprecated=True)
async def untag_dataset_like_php(
data_id: Annotated[Identifier, Body()],
tag: Annotated[SystemString64, Body()],
tag: Annotated[TagString, Body()],
user: Annotated[User, Depends(fetch_user_or_raise)],
expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)],
) -> dict[Literal["data_untag"], TagInfo]:
Expand All @@ -110,7 +110,7 @@ async def untag_dataset_like_php(
@router.delete(path="/{identifier}/tag", status_code=HTTPStatus.NO_CONTENT)
async def untag_dataset(
identifier: Identifier,
tag: Annotated[SystemString64, Query()],
tag: Annotated[TagString, Query()],
user: Annotated[User, Depends(fetch_user_or_raise)],
expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)],
) -> None:
Expand Down Expand Up @@ -158,7 +158,7 @@ def _quality_clause(quality: str, range_: str | None) -> str:
async def list_datasets( # noqa: PLR0913, C901
pagination: Annotated[Pagination, Body(default_factory=Pagination)],
data_name: Annotated[CasualString128 | None, Body()] = None,
tag: Annotated[SystemString64 | None, Body()] = None,
tag: Annotated[TagString | None, Body()] = None,
data_version: Annotated[
Identifier | None,
Body(description="The dataset version to include in the search."),
Expand Down
6 changes: 3 additions & 3 deletions src/routers/openml/setups.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from database.users import User
from routers.dependencies import expdb_connection, fetch_user_or_raise
from routers.types import Identifier, SystemString64
from routers.types import Identifier, TagString
from schemas.setups import SetupParameters, SetupResponse

if TYPE_CHECKING:
Expand Down Expand Up @@ -49,7 +49,7 @@ async def get_setup(
@router.post(path="/tag")
async def tag_setup(
setup_id: Annotated[Identifier, Body()],
tag: Annotated[SystemString64, Body()],
tag: Annotated[TagString, Body()],
user: Annotated[User, Depends(fetch_user_or_raise)],
expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)],
) -> dict[str, dict[str, str | list[str]]]:
Expand All @@ -76,7 +76,7 @@ async def tag_setup(
@router.post(path="/untag")
async def untag_setup(
setup_id: Annotated[Identifier, Body()],
tag: Annotated[SystemString64, Body()],
tag: Annotated[TagString, Body()],
user: Annotated[User, Depends(fetch_user_or_raise)],
expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)],
) -> dict[str, dict[str, str | list[str]]]:
Expand Down
6 changes: 3 additions & 3 deletions src/routers/openml/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
CasualString128,
Identifier,
IntegerRange,
SystemString64,
TagString,
integer_range_regex,
)
from schemas.datasets.openml import Task
Expand Down Expand Up @@ -231,8 +231,8 @@ def _quality_clause(quality: str, range_: str | None) -> str:
async def list_tasks( # noqa: PLR0913, PLR0912, C901, PLR0915
pagination: Annotated[Pagination, Body(default_factory=Pagination)],
task_type_id: Annotated[Identifier | None, Body(description="Filter by task type id.")] = None,
tag: Annotated[SystemString64 | None, Body()] = None,
data_tag: Annotated[SystemString64 | None, Body()] = None,
tag: Annotated[TagString | None, Body()] = None,
data_tag: Annotated[TagString | None, Body()] = None,
status: Annotated[TaskStatusFilter, Body()] = TaskStatusFilter.ACTIVE,
task_id: Annotated[
list[Identifier] | None,
Expand Down
4 changes: 3 additions & 1 deletion src/routers/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from pydantic import Field

SystemString64 = Annotated[str, Field(pattern=r"^[\w\-\.]+$", min_length=1, max_length=64)]
# Known as SystemString64 in the XSD
TagString = Annotated[str, Field(pattern=r"^[\w\-\.]+$", min_length=1, max_length=64)]

CasualString128 = Annotated[str, Field(pattern=r"^[\w\-\.\(\),]+$", min_length=1, max_length=128)]
Identifier = Annotated[int, Field(gt=0)]

Expand Down
19 changes: 0 additions & 19 deletions tests/routers/openml/dataset_tag_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,6 @@ async def test_dataset_tag_rejects_unauthorized(key: ApiKey, py_api: httpx.Async
assert response.status_code == HTTPStatus.UNAUTHORIZED


@pytest.mark.parametrize(
"tag",
["", "h@", " a", "a" * 65],
ids=["too short", "@", "space", "too long"],
)
async def test_dataset_tag_invalid_tag_is_rejected(
# Constraints for the tag are handled by FastAPI
tag: str,
py_api: httpx.AsyncClient,
) -> None:
response = await py_api.post(
f"/datasets/tag?api_key={ApiKey.ADMIN}",
json={"data_id": 1, "tag": tag},
)

assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY
assert response.json()["errors"][0]["loc"] == ["body", "tag"]


# ── Direct call tests: tag_dataset ──


Expand Down
50 changes: 50 additions & 0 deletions tests/types_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import string

import pytest
from pydantic import TypeAdapter, ValidationError

from routers.types import Identifier, TagString

_identifier = TypeAdapter(Identifier)


def test_identifier_accepts_positive_integer() -> None:
assert _identifier.validate_strings("1") == 1


def test_identifier_rejects_non_integer() -> None:
with pytest.raises(ValidationError):
_identifier.validate_strings("foo")

with pytest.raises(ValidationError):
_identifier.validate_strings("1.2")


def test_identifier_rejects_negative() -> None:
with pytest.raises(ValidationError):
_identifier.validate_strings("0")
Comment thread
PGijsbers marked this conversation as resolved.
Outdated


def test_identifier_rejects_zero() -> None:
with pytest.raises(ValidationError):
_identifier.validate_strings("0")

Comment thread
PGijsbers marked this conversation as resolved.

_tag_string = TypeAdapter(TagString)
_valid_punctuation_tag = {"-", ".", "_"}
_invalid_punctuation_tag = set(string.punctuation) - _valid_punctuation_tag


def test_tag_string_pattern() -> None:
assert _tag_string.json_schema()["pattern"] == r"^[\w\-\.]+$"


@pytest.mark.parametrize("tag", ["a", "c" * 64, "version2.0", "study-14", "study_15"])
def test_tag_string_accepts_valid(tag: str) -> None:
assert _tag_string.validate_strings(tag) == tag


@pytest.mark.parametrize("tag", ["", "c" * 65, *_invalid_punctuation_tag])
def test_tag_string_rejects_invalid(tag: str) -> None:
Comment thread
PGijsbers marked this conversation as resolved.
Outdated
with pytest.raises(ValidationError):
_tag_string.validate_strings(tag)
Loading