Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions src/nlp_expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ function variable_ref_type(::Type{GenericNonlinearExpr}, x::AbstractJuMPScalar)
return variable_ref_type(x)
end

function variable_ref_type(
::Type{GenericNonlinearExpr},
::AbstractArray{T},
) where {T<:AbstractJuMPScalar}
return variable_ref_type(T)
end

value_type(::Type{GenericNonlinearExpr{V}}) where {V} = value_type(V)

function _has_variable_ref_type(a)
Expand Down Expand Up @@ -330,14 +337,16 @@ Base.isreal(::GenericNonlinearExpr) = true

# Univariate operators

_is_real(::Any) = false
_is_real(::Real) = true
_is_real(::AbstractVariableRef) = true
_is_real(::GenericAffExpr{<:Real}) = true
_is_real(::GenericQuadExpr{<:Real}) = true
_is_real(::GenericNonlinearExpr) = true
_is_real(::NonlinearExpression) = true
_is_real(::NonlinearParameter) = true
_is_real(::Type) = false
_is_real(::Type{<:Real}) = true
_is_real(::Type{<:AbstractVariableRef}) = true
_is_real(::Type{<:GenericAffExpr{<:Real}}) = true
_is_real(::Type{<:GenericQuadExpr{<:Real}}) = true
_is_real(::Type{<:GenericNonlinearExpr}) = true
_is_real(::Type{<:NonlinearExpression}) = true
_is_real(::Type{<:NonlinearParameter}) = true
_is_real(::Type{<:AbstractArray{T}}) where {T} = _is_real(T)
_is_real(x) = _is_real(typeof(x))

function _throw_if_not_real(x)
if !_is_real(x)
Expand Down Expand Up @@ -569,6 +578,10 @@ end

moi_function(x::Number) = x

# `moi_function(::Array)` would be ambiguous with
# `moi_function(AbstractArray{<:AbstractVariableRef})`
moi_function(x::AbstractArray) = moi_function.(x)

function moi_function(f::GenericNonlinearExpr{V}) where {V}
ret = MOI.ScalarNonlinearFunction(f.head, similar(f.args))
stack = Tuple{MOI.ScalarNonlinearFunction,Int,GenericNonlinearExpr{V}}[]
Expand Down
60 changes: 60 additions & 0 deletions test/test_nlp_expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
module TestNLPExpr

using JuMP
using LinearAlgebra
using Test

import LinearAlgebra
Expand Down Expand Up @@ -1232,4 +1233,63 @@ function test_extension_euler_to_exp(
return
end

function test_array()
model = Model()
@variable(model, x)
vov = MOI.VectorOfVariables([index(x)])
op_det = NonlinearOperator(det, :det)
@objective(model, Min, op_det([x]))
f = MOI.get(model, MOI.ObjectiveFunction{MOI.ScalarNonlinearFunction}())
@test f.head == :det
@test f.args == [vov]

op_dot = NonlinearOperator(dot, :dot)
a = [2.0]
@objective(model, Min, op_dot([x], a))
f = MOI.get(model, MOI.ObjectiveFunction{MOI.ScalarNonlinearFunction}())
@test f.head == :dot
@test length(f.args) == 2
@test f.args[1] == vov
@test f.args[2] == a
end

# Inspired from contiguous arrays in ArrayDiff and GenOpt
struct ContiguousVectorOfVariableRefs <: AbstractVector{JuMP.VariableRef}
offset::Int
length::Int
end

struct Contiguous end

function JuMP.Containers.container(
_,
axe::JuMP.Containers.VectorizedProductIterator{Tuple{Base.OneTo{Int}}},
::Contiguous,
)
# Correctness don't matter, it's not for the sake of this test
return ContiguousVectorOfVariableRefs(0, length(axe))
end

struct ContiguousVectorOfVariableIndices <: MOI.AbstractVectorFunction
offset::Int
length::Int
end

Base.copy(x::ContiguousVectorOfVariableIndices) = x

function JuMP.moi_function(x::ContiguousVectorOfVariableRefs)
return ContiguousVectorOfVariableIndices(x.offset, x.length)
end

function test_custom_array()
model = Model()
@variable(model, x[1:2], container = Contiguous())
@test x === ContiguousVectorOfVariableRefs(0, 2)
op_norm = NonlinearOperator(LinearAlgebra.norm, :norm)
@objective(model, Min, op_norm(x))
f = MOI.get(model, MOI.ObjectiveFunction{MOI.ScalarNonlinearFunction}())
@test f.head == :norm
@test f.args[] == ContiguousVectorOfVariableIndices(0, 2)
end

end # module
Loading