Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1320,7 +1320,36 @@ def make_obs_var(
TensorVariable
"""
name = rv_var.name
data = convert_observed_data(data).astype(rv_var.dtype)
data = convert_observed_data(data)

# Check if float observed data passed to an integer-dtype distribution
# would lose information when cast (e.g. Binomial, Poisson).
# Only applies to concrete numpy arrays with float dtype going to integer dtype.
# Float arrays containing exact integer values (e.g. [0.0, 1.0]) are allowed.
rv_dtype = np.dtype(rv_var.dtype)
data_dtype = getattr(data, "dtype", None)
if (
data_dtype is not None
and not isinstance(data, Variable)
and np.issubdtype(rv_dtype, np.integer)
and np.issubdtype(data_dtype, np.floating)
):
if isinstance(data, np.ma.MaskedArray):
check_data = data.compressed()
elif sparse.issparse(data):
check_data = data.data
else:
check_data = np.asarray(data)
cast_check = check_data.astype(rv_dtype)
if not np.array_equal(check_data, cast_check):
raise TypeError(
f"Observed data for '{name}' has dtype {data_dtype} with non-integer values "
f"that cannot be safely cast to the expected dtype {rv_dtype} of the "
f"{rv_var.owner.op} distribution. If the cast is intentional, convert the "
f"data explicitly before passing it as observed."
)

data = data.astype(rv_var.dtype)

if data.ndim != rv_var.ndim:
raise ShapeError(
Expand Down
39 changes: 39 additions & 0 deletions tests/model/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,45 @@ def test_make_obs_var():
del fake_model.named_vars[fake_distribution.name]


@pytest.mark.parametrize(
"dist_cls,dist_kwargs",
[
(pm.Binomial, {"n": 5, "p": 0.5}),
(pm.Poisson, {"mu": 3.0}),
],
)
def test_discrete_float_observed_raises(dist_cls, dist_kwargs):
"""Non-integer float observed data passed to discrete distributions should raise.

Regression test for https://github.com/pymc-devs/pymc/issues/8282.
"""
with pm.Model():
with pytest.raises(TypeError, match="non-integer values"):
dist_cls("bad", observed=np.array([0.11, 1.5, 2.89], dtype=np.float32), **dist_kwargs)


def test_discrete_float_observed_exact_ints_allowed():
"""Float arrays containing exact integer values should be accepted by discrete dists.

Regression test for https://github.com/pymc-devs/pymc/issues/8282.
"""
with pm.Model():
# Common pattern: floatX zeros/ones passed to Bernoulli
pm.Bernoulli("b", p=0.5, observed=np.array([0.0, 1.0, 0.0], dtype=np.float64))
# Integer-valued data in native int dtype
pm.Binomial("x", n=5, p=0.5, observed=np.array([0, 1, 2], dtype=np.int64))


def test_discrete_float_observed_scalar():
"""Scalar non-integer float to discrete should raise.

Regression test for https://github.com/pymc-devs/pymc/issues/8282.
"""
with pm.Model():
with pytest.raises(TypeError, match="non-integer values"):
pm.Poisson("bad", mu=3.0, observed=np.float64(1.5))


def test_initial_point():
with pm.Model() as model:
a = pm.Uniform("a")
Expand Down