diff --git a/doc/release_notes.rst b/doc/release_notes.rst index 7883db82..27f1751e 100644 --- a/doc/release_notes.rst +++ b/doc/release_notes.rst @@ -4,6 +4,8 @@ Release Notes Upcoming Version ---------------- +* Add documentation about `LinearExpression.where` with `drop=True`. Add `BaseExpression.variable_names` property. + **Features** *Inspect the solver after solving* diff --git a/examples/creating-expressions.ipynb b/examples/creating-expressions.ipynb index 370f3f74..cb41a2c6 100644 --- a/examples/creating-expressions.ipynb +++ b/examples/creating-expressions.ipynb @@ -321,10 +321,106 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "29", "metadata": {}, + "source": [ + "Sometimes `.where` may lead to a situation where some of the variables are completely masked" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30", + "metadata": {}, + "outputs": [], + "source": [ + "mask_a = xr.DataArray(False, coords=[time])\n", + "mask_b = xr.DataArray(time > 2, coords=[time])\n", + "\n", + "z = (x.where(mask_a) + y).where(mask_b)\n", + "z" + ] + }, + { + "cell_type": "markdown", + "id": "31", + "metadata": {}, + "source": [ + "In this example you can see that many of the elements of the LinearExpression are None. If you want to remove all the None terms, you can use `.where(.., drop=True)`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "32", + "metadata": {}, + "outputs": [], + "source": [ + "z = z.where(mask_b, drop=True)\n", + "z" + ] + }, + { + "cell_type": "markdown", + "id": "33", + "metadata": {}, + "source": [ + "That looks nicer!
" + ] + }, + { + "cell_type": "markdown", + "id": "34", + "metadata": {}, + "source": [ + "You may notice that the variable `x` is not used at all. The expression still contains two terms (one of them is unused) but it only has one variable `y`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35", + "metadata": {}, + "outputs": [], + "source": [ + "z.nterm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "36", + "metadata": {}, + "outputs": [], + "source": [ + "z.variable_names" + ] + }, + { + "cell_type": "markdown", + "id": "37", + "metadata": {}, + "source": [ + "You can get rid of the unused term with `.simplify()`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38", + "metadata": {}, + "outputs": [], + "source": [ + "z = z.simplify()\n", + "z.nterm" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "39", + "metadata": {}, "source": [ "## Using `.shift` to shift the Variable along one dimension\n", "\n", @@ -336,7 +432,7 @@ { "cell_type": "code", "execution_count": null, - "id": "30", + "id": "40", "metadata": {}, "outputs": [], "source": [ @@ -346,7 +442,7 @@ { "attachments": {}, "cell_type": "markdown", - "id": "31", + "id": "41", "metadata": {}, "source": [ "## Using `.groupby` to group by a key and apply operations on the groups\n", @@ -359,7 +455,7 @@ { "cell_type": "code", "execution_count": null, - "id": "32", + "id": "42", "metadata": {}, "outputs": [], "source": [ @@ -370,7 +466,7 @@ { "attachments": {}, "cell_type": "markdown", - "id": "33", + "id": "43", "metadata": {}, "source": [ "## Using `.rolling` to perform a rolling operation\n", @@ -383,7 +479,7 @@ { "cell_type": "code", "execution_count": null, - "id": "34", + "id": "44", "metadata": {}, "outputs": [], "source": [ diff --git a/linopy/expressions.py b/linopy/expressions.py index 70c45732..31868234 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -1331,6 +1331,26 @@ def nterm(self) -> int: """ return len(self.data._term) + @property + def variable_names(self) -> set[str]: + """ + Get the names of the unique variables present in the expression. + """ + if self.nterm == 0: + return set() + + # Collect all unique labels from the expression (excluding -1) + all_labels = self.vars.values.ravel() + unique_labels = np.unique(all_labels[all_labels != -1]) + + if len(unique_labels) == 0: + return set() + + # Batch lookup variable names for all labels + positions = self.model.variables.get_label_position(unique_labels) + + return {p[0] for p in positions if p[0] is not None} + @property def shape(self) -> tuple[int, ...]: """ diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index e9535ad6..79a1029b 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -1856,6 +1856,61 @@ def test_constant_only_expression_mul_linexpr_with_vars_and_const( assert (result_rev.const == expected_const).all() +def test_variable_names() -> None: + m = Model() + time = pd.Index(range(3), name="time") + + a = m.add_variables(name="a", coords=[time]) + b = m.add_variables(name="b", coords=[time]) + + expr = a + b + assert expr.nterm == 2 + assert expr.variable_names == {"a", "b"} + + mask = xr.DataArray(False, coords=[time]) + expr = a + (b * 1).where(mask) + assert expr.nterm == 2 + assert expr.variable_names == {"a"} + + expr = (b * 1).where(mask) + assert expr.nterm == 1 + assert expr.variable_names == set() + + expr = LinearExpression.from_constant(model=m, constant=5) + assert expr.nterm == 0 + assert expr.variable_names == set() + + # Single variable expression + expr = 1 * a + assert expr.variable_names == {"a"} + + # Repeated variable across terms (a + a) + expr = a + a + assert expr.variable_names == {"a"} + + +def test_nterm() -> None: + m = Model() + time = pd.Index(range(3), name="time") + all_false = xr.DataArray(False, coords=[time]) + not_0 = xr.DataArray([False, True, True], coords=[time]) + not_1 = xr.DataArray([True, False, True], coords=[time]) + not_2 = xr.DataArray([True, True, False], coords=[time]) + + a = m.add_variables(name="a", coords=[time]) + b = m.add_variables(name="b", coords=[time]) + c = m.add_variables(name="c", coords=[time]) + + expr = (a.where(not_0) + b.where(not_1) + c.where(not_2)).densify_terms() + assert expr.nterm == 3 + + expr = a + b.where(all_false) + assert expr.nterm == 2 + + expr = expr.simplify() + assert expr.nterm == 1 + + class TestJoinParameter: @pytest.fixture def m2(self) -> Model: diff --git a/test/test_quadratic_expression.py b/test/test_quadratic_expression.py index 3e21a60f..de6e28d7 100644 --- a/test/test_quadratic_expression.py +++ b/test/test_quadratic_expression.py @@ -360,3 +360,11 @@ def test_power_of_three(x: Variable) -> None: x**3 with pytest.raises(TypeError): (x * x) * (x * x) + + +def test_variable_names(x: Variable, y: Variable) -> None: + expr = 2 * (x * x) + 3 * y + 1 + assert expr.variable_names == {"x", "y"} + + expr = 2 * (x * x) + 1 + assert expr.variable_names == {"x"}