Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 9 additions & 28 deletions array_api_strict/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from collections.abc import Iterator
from enum import IntEnum
from types import EllipsisType, ModuleType
from typing import Any, Final, Literal, SupportsIndex, Callable
from typing import Any, Literal, SupportsIndex, Callable

import numpy as np
import numpy.typing as npt
Expand All @@ -40,35 +40,11 @@
_real_to_complex_map,
_result_type,
)
from ._devices import CPU_DEVICE, Device, device_supports_dtype
from ._flags import get_array_api_strict_flags, set_array_api_strict_flags
from ._typing import PyCapsule


class Device:
_device: Final[str]
__slots__ = ("_device", "__weakref__")

def __init__(self, device: str = "CPU_DEVICE"):
if device not in ("CPU_DEVICE", "device1", "device2"):
raise ValueError(f"The device '{device}' is not a valid choice.")
self._device = device

def __repr__(self) -> str:
return f"array_api_strict.Device('{self._device}')"

def __eq__(self, other: object) -> bool:
if not isinstance(other, Device):
return False
return self._device == other._device

def __hash__(self) -> int:
return hash(("Device", self._device))


CPU_DEVICE = Device()
ALL_DEVICES = (CPU_DEVICE, Device("device1"), Device("device2"))


class Array:
"""
n-d array object for the array API namespace.
Expand Down Expand Up @@ -113,10 +89,15 @@ def _new(cls, x: npt.NDArray[Any] | np.generic, /, device: Device | None) -> Arr
raise TypeError(
f"The array_api_strict namespace does not support the dtype '{x.dtype}'"
)
obj._array = x
obj._dtype = _dtype

if device is None:
device = CPU_DEVICE
if not device_supports_dtype(device, _dtype):
raise ValueError(f"Device {device!r} does not support dtype={_dtype!r}.")

obj._array = x
obj._dtype = _dtype

obj._device = device
return obj

Expand Down
115 changes: 85 additions & 30 deletions array_api_strict/_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@

import numpy as np

from ._dtypes import DType, _all_dtypes, _np_dtype
from ._dtypes import DType, _all_dtypes, _np_dtype, bool as xp_bool
from ._devices import (
Device, device_supports_dtype, get_default_dtypes,
check_device as _check_device
)
from ._flags import get_array_api_strict_flags
from ._typing import NestedSequence, SupportsBufferProtocol, SupportsDLPack

Expand All @@ -14,7 +18,7 @@
from typing_extensions import TypeIs

# Circular import
from ._array_object import Array, Device
from ._array_object import Array


class Undef(Enum):
Expand All @@ -24,10 +28,15 @@ class Undef(Enum):
_undef = Undef.UNDEF


def _check_valid_dtype(dtype: DType | None) -> None:
def _check_valid_dtype(dtype: DType | None, device: Device | None = None) -> None:
# Note: Only spelling dtypes as the dtype objects is supported.
if dtype not in (None,) + _all_dtypes:
raise ValueError(f"dtype must be one of the supported dtypes, got {dtype!r}")
if dtype is not None:
if dtype not in _all_dtypes:
raise ValueError(f"dtype must be one of the supported dtypes, got {dtype!r}")

if device is not None:
if not device_supports_dtype(device, dtype):
raise ValueError(f"Device {device!r} does not support dtype={dtype!r}.")


def _supports_buffer_protocol(obj: object) -> TypeIs[SupportsBufferProtocol]:
Expand All @@ -38,18 +47,6 @@ def _supports_buffer_protocol(obj: object) -> TypeIs[SupportsBufferProtocol]:
return True


def _check_device(device: Device | None) -> None:
# _array_object imports in this file are inside the functions to avoid
# circular imports
from ._array_object import ALL_DEVICES, Device

if device is not None and not isinstance(device, Device):
raise ValueError(f"Unsupported device {device!r}")

if device is not None and device not in ALL_DEVICES:
raise ValueError(f"Unsupported device {device!r}")


def asarray(
obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol,
/,
Expand All @@ -65,11 +62,12 @@ def asarray(
"""
from ._array_object import Array

_check_valid_dtype(dtype)
_check_device(device)
_check_valid_dtype(dtype, device)
_np_dtype = None
if dtype is not None:
_np_dtype = dtype._np_dtype
_check_device(device)

if isinstance(obj, Array) and device is None:
device = obj.device

Expand Down Expand Up @@ -108,6 +106,27 @@ def asarray(
raise OverflowError("Integer out of bounds for array dtypes")

res = np.array(obj, dtype=_np_dtype, copy=copy)

# numpy default dtype may differ; if so, adjust the dtype
if dtype is None and device is not None:
res_dtype = DType(res.dtype)
if not device_supports_dtype(device, res_dtype):
# find out the default dtype for the device
from ._data_type_functions import isdtype
if isdtype(res_dtype, "bool"):
targ_dtype = DType("bool")
elif isdtype(res_dtype, "integral"):
targ_dtype = get_default_dtypes(device)["integral"]
elif isdtype(res_dtype, "real floating"):
targ_dtype = get_default_dtypes(device)["real floating"]
elif isdtype(res_dtype, "complex floating"):
targ_dtype = get_default_dtypes(device)["complex floating"]
else:
raise ValueError(f"{res_dtype = } not understood.")
del isdtype

res = res.astype(targ_dtype._np_dtype)

return Array._new(res, device=device)


Expand All @@ -127,8 +146,13 @@ def arange(
"""
from ._array_object import Array

_check_valid_dtype(dtype)
_check_device(device)
_check_valid_dtype(dtype, device)
if dtype is None:
if any(isinstance(x, float) for x in (start, stop, step)):
dtype = get_default_dtypes(device)["real floating"]
else:
dtype = get_default_dtypes(device)["integral"]

return Array._new(
np.arange(start, stop, step, dtype=_np_dtype(dtype)),
Expand All @@ -149,8 +173,10 @@ def empty(
"""
from ._array_object import Array

_check_valid_dtype(dtype)
_check_device(device)
_check_valid_dtype(dtype, device)
if dtype is None:
dtype = get_default_dtypes(device)["real floating"]

return Array._new(np.empty(shape, dtype=_np_dtype(dtype)), device=device)

Expand All @@ -165,10 +191,12 @@ def empty_like(
"""
from ._array_object import Array

_check_valid_dtype(dtype)
_check_device(device)
if device is None:
device = x.device
if dtype is None:
dtype = x.dtype
_check_valid_dtype(dtype, device)

return Array._new(np.empty_like(x._array, dtype=_np_dtype(dtype)), device=device)

Expand All @@ -189,8 +217,10 @@ def eye(
"""
from ._array_object import Array

_check_valid_dtype(dtype)
_check_device(device)
_check_valid_dtype(dtype, device)
if dtype is None:
dtype = get_default_dtypes(device)["real floating"]

return Array._new(
np.eye(n_rows, M=n_cols, k=k, dtype=_np_dtype(dtype)), device=device
Expand Down Expand Up @@ -237,12 +267,22 @@ def full(
"""
from ._array_object import Array

_check_valid_dtype(dtype)
_check_device(device)
_check_valid_dtype(dtype, device)

if not isinstance(fill_value, bool | int | float | complex):
msg = f"Expected Python scalar fill_value, got type {type(fill_value)}"
raise TypeError(msg)

if dtype is None:
if type(fill_value) == bool:
dtype = xp_bool
else:
kind = {
int: "integral", float: "real floating", complex: "complex floating"
}[type(fill_value)]
dtype = get_default_dtypes(device)[kind]

res = np.full(shape, fill_value, dtype=_np_dtype(dtype))
if DType(res.dtype) not in _all_dtypes:
# This will happen if the fill value is not something that NumPy
Expand All @@ -266,10 +306,12 @@ def full_like(
"""
from ._array_object import Array

_check_valid_dtype(dtype)
_check_device(device)
if device is None:
device = x.device
if dtype is None:
dtype = x.dtype
_check_valid_dtype(dtype, device)

if not isinstance(fill_value, bool | int | float | complex):
msg = f"Expected Python scalar fill_value, got type {type(fill_value)}"
Expand Down Expand Up @@ -300,8 +342,13 @@ def linspace(
"""
from ._array_object import Array

_check_valid_dtype(dtype)
_check_device(device)
_check_valid_dtype(dtype, device)
if dtype is None:
if isinstance(start, complex) or isinstance(stop, complex):
dtype = get_default_dtypes(device)["complex floating"]
else:
dtype = get_default_dtypes(device)["real floating"]

return Array._new(
np.linspace(start, stop, num, dtype=_np_dtype(dtype), endpoint=endpoint),
Expand Down Expand Up @@ -353,8 +400,10 @@ def ones(
"""
from ._array_object import Array

_check_valid_dtype(dtype)
_check_device(device)
_check_valid_dtype(dtype, device)
if dtype is None:
dtype = get_default_dtypes(device)["real floating"]

return Array._new(np.ones(shape, dtype=_np_dtype(dtype)), device=device)

Expand All @@ -369,10 +418,12 @@ def ones_like(
"""
from ._array_object import Array

_check_valid_dtype(dtype)
_check_device(device)
if device is None:
device = x.device
if dtype is None:
dtype = x.dtype
_check_valid_dtype(dtype, device)

return Array._new(np.ones_like(x._array, dtype=_np_dtype(dtype)), device=device)

Expand Down Expand Up @@ -418,8 +469,10 @@ def zeros(
"""
from ._array_object import Array

_check_valid_dtype(dtype)
_check_device(device)
_check_valid_dtype(dtype, device)
if dtype is None:
dtype = get_default_dtypes(device)["real floating"]

return Array._new(np.zeros(shape, dtype=_np_dtype(dtype)), device=device)

Expand All @@ -434,9 +487,11 @@ def zeros_like(
"""
from ._array_object import Array

_check_valid_dtype(dtype)
_check_device(device)
if device is None:
device = x.device
if dtype is None:
dtype = x.dtype
_check_valid_dtype(dtype, device)

return Array._new(np.zeros_like(x._array, dtype=_np_dtype(dtype)), device=device)
Loading
Loading