Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ Notes:
- `measured_quantities` and `known_p` input may also be symbolic (e.g. measured_quantities = [rs.X])
"""
function Catalyst.make_si_ode(rs::ReactionSystem; measured_quantities = [], known_p = [],
ignore_no_measured_warn = false, remove_conserved = true)
ignore_no_measured_warn = false, remove_conserved = true, mtkcompile::Bool = false)
# Creates a MTK ODE System, and a list of measured quantities (there are equations).
# Gives these to SI to create an SI ode model of its preferred form.
osys, conseqs, _, _ = make_osys(rs; remove_conserved)
osys, conseqs, _, _ = make_osys(rs; remove_conserved, mtkcompile)
measured_quantities = make_measured_quantities(rs, measured_quantities, known_p,
conseqs; ignore_no_measured_warn)
return SI.mtk_to_si(osys, measured_quantities)[1]
Expand Down Expand Up @@ -65,9 +65,10 @@ Notes:
"""
function SI.assess_local_identifiability(rs::ReactionSystem, args...;
measured_quantities = [], known_p = [], funcs_to_check = Vector(),
remove_conserved = true, ignore_no_measured_warn = false, kwargs...)
remove_conserved = true, ignore_no_measured_warn = false,
mtkcompile::Bool = false, kwargs...)
# Creates an ODE System, list of measured quantities, and functions to check, of SI's preferred form.
osys, conseqs, consconsts, vars = make_osys(rs; remove_conserved)
osys, conseqs, consconsts, vars = make_osys(rs; remove_conserved, mtkcompile)
measured_quantities = make_measured_quantities(rs, measured_quantities, known_p,
conseqs; ignore_no_measured_warn)
funcs_to_check = make_ftc(funcs_to_check, conseqs, vars)
Expand Down Expand Up @@ -105,9 +106,10 @@ Notes:
"""
function SI.assess_identifiability(rs::ReactionSystem, args...;
measured_quantities = [], known_p = [], funcs_to_check = Vector(),
remove_conserved = true, ignore_no_measured_warn = false, kwargs...)
remove_conserved = true, ignore_no_measured_warn = false,
mtkcompile::Bool = false, kwargs...)
# Creates an ODE System, list of measured quantities, and functions to check, of SI's preferred form.
osys, conseqs, consconsts, vars = make_osys(rs; remove_conserved)
osys, conseqs, consconsts, vars = make_osys(rs; remove_conserved, mtkcompile)
measured_quantities = make_measured_quantities(rs, measured_quantities, known_p,
conseqs; ignore_no_measured_warn)
funcs_to_check = make_ftc(funcs_to_check, conseqs, vars)
Expand Down Expand Up @@ -147,9 +149,9 @@ Notes:
"""
function SI.find_identifiable_functions(rs::ReactionSystem, args...;
measured_quantities = [], known_p = [], remove_conserved = true,
ignore_no_measured_warn = false, kwargs...)
ignore_no_measured_warn = false, mtkcompile::Bool = false, kwargs...)
# Creates an ODE System, and list of measured quantities, of SI's preferred form.
osys, conseqs, consconsts, _ = make_osys(rs; remove_conserved)
osys, conseqs, consconsts, _ = make_osys(rs; remove_conserved, mtkcompile)
measured_quantities = make_measured_quantities(rs, measured_quantities, known_p,
conseqs; ignore_no_measured_warn)

Expand All @@ -162,26 +164,38 @@ end

# From a reaction system, creates the corresponding MTK-style ODE System for SI application
# Also compute the, later needed, conservation law equations and list of system symbols (unknowns and parameters).
function make_osys(rs::ReactionSystem; remove_conserved = true)
function make_osys(rs::ReactionSystem; remove_conserved = true, mtkcompile::Bool = false)
# Creates the ODE System corresponding to the ReactionSystem (expanding functions and flattening it).
# Creates a list of the systems all symbols (unknowns and parameters).
if !ModelingToolkitBase.iscomplete(rs)
error("Identifiability should only be computed for complete systems. A ReactionSystem can be marked as complete using the `complete` function.")
end
rs = complete(Catalyst.expand_registered_functions(flatten(rs)))
osys = complete(ode_model(rs; remove_conserved))
# SI treats Γ as a free parameter and works purely symbolically, so we skip
# the conservation law bindings (Γ => missing) that are only needed for
# numerical initialization.
osys = ode_model(rs; remove_conserved, add_cl_bindings = false)
if mtkcompile
osys = ModelingToolkitBase.mtkcompile(osys)
elseif ModelingToolkitBase.has_alg_equations(rs)
error("The input ReactionSystem has algebraic equations. This requires setting `mtkcompile = true`.")
else
osys = complete(osys)
end
vars = [unknowns(rs); parameters(rs)]

# Computes equations for system conservation laws.
# If there are no conserved equations, the `conseqs` variable must still have the `Vector{Pair{Any, Any}}` type.
# The element-typed comprehensions guarantee `Vector{Pair{Any, Any}}` even
# when the iterable is empty (empty list comprehensions otherwise infer as
# `Vector{Any}`, and the previous fix-up `Vector{Pair{Any, Any}}[]` was a
# typo that produced `Vector{Vector{Pair{Any, Any}}}`).
if remove_conserved
conseqs = [ceq.lhs => ceq.rhs for ceq in conservedequations(rs)]
consconsts = [cconst.lhs => cconst.rhs for cconst in conservationlaw_constants(rs)]
isempty(conseqs) && (conseqs = Vector{Pair{Any, Any}}[])
isempty(consconsts) && (consconsts = Vector{Pair{Any, Any}}[])
conseqs = Pair{Any, Any}[ceq.lhs => ceq.rhs for ceq in conservedequations(rs)]
consconsts = Pair{Any, Any}[cconst.lhs => cconst.rhs
for cconst in conservationlaw_constants(rs)]
else
conseqs = Vector{Pair{Any, Any}}[]
consconsts = Vector{Pair{Any, Any}}[]
conseqs = Pair{Any, Any}[]
consconsts = Pair{Any, Any}[]
end

return osys, conseqs, consconsts, vars
Expand Down
2 changes: 1 addition & 1 deletion src/network_analysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -900,7 +900,7 @@ function cache_conservationlaw_eqs!(rn::ReactionSystem, N::AbstractMatrix, col_o
# Declare the conservation constant parameters
#`using guesses is for consistency and possibly faster initialisation
guesses = [Initial(depspecs[i] + rhs_terms[i]) for i in 1:nullity]
Γs = @parameters $(CONSERVED_CONSTANT_SYMBOL)[1:nullity] = missing [conserved = true, guess = guesses]
Γs = @parameters $(CONSERVED_CONSTANT_SYMBOL)[1:nullity] [conserved = true, guess = guesses]
constants = unwrap(only(Γs))

# Creates the conservation constant and conservation equation equations.
Expand Down
45 changes: 30 additions & 15 deletions src/reactionsystem_conversions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -510,14 +510,15 @@ end
# merge constraint components with the ReactionSystem components
# also handles removing BC and constant species
function addconstraints!(eqs, rs::ReactionSystem, ists, ispcs; remove_conserved = false,
compute_cl_initeqs = false, include_cl_as_eqs = false)
include_cl_as_eqs = false)
# if there are BC species, put them after the independent species
rssts = get_unknowns(rs)
sts = any(isbc, rssts) ? vcat(ists, filter(isbc, rssts)) : ists
ps = get_ps(rs)
initeqs = Equation[]
ics = MT.initial_conditions(rs)
obs = MT.observed(rs)
cl_bindings = Dict{Any, Any}()

# make dependent species observables and add conservation constants as parameters
if remove_conserved && !isempty(conservedequations(rs))
Expand All @@ -527,6 +528,18 @@ function addconstraints!(eqs, rs::ReactionSystem, ists, ispcs; remove_conserved
ps = copy(ps)
push!(ps, nps.conservedconst)

# Bind Γ => missing and provide initialization equations so MTK solves for Γ
# during initialization (see MTK docs on parameter initialization).
# The binding is placed at the System level (not as a variable metadata default)
# so that SI.jl doesn't encounter Missing in the parameter symtype.
cl_bindings[nps.conservedconst] = missing
if !include_cl_as_eqs
initialmap = Dict(u => Initial(u) for u in species(rs))
for eq in nps.constantdefs
push!(initeqs, Symbolics.substitute(eq, initialmap))
end
end

# add the dependent species as observed. If `include_cl_as_eqs = true` add them as
# algebraic equations instead.
if !include_cl_as_eqs
Expand All @@ -535,13 +548,6 @@ function addconstraints!(eqs, rs::ReactionSystem, ists, ispcs; remove_conserved
else
append!(eqs, [0 ~ ceq.rhs - ceq.lhs for ceq in conservedequations(rs)])
end

# create initialization equations (only used for nonlinear systems)
if compute_cl_initeqs && !include_cl_as_eqs
initialmap = Dict(u => Initial(u) for u in species(rs))
conseqs = conservationlaw_constants(rs)
initeqs = [Symbolics.substitute(conseq, initialmap) for conseq in conseqs]
end
end

ceqs = Equation[eq for eq in get_eqs(rs) if eq isa Equation]
Expand All @@ -558,7 +564,7 @@ function addconstraints!(eqs, rs::ReactionSystem, ists, ispcs; remove_conserved
append!(eqs, ceqs)
end

eqs, sts, ps, obs, ics, initeqs
eqs, sts, ps, obs, ics, initeqs, cl_bindings
end

### Utility ###
Expand Down Expand Up @@ -626,6 +632,7 @@ function hybrid_model(rs::ReactionSystem;
combinatoric_ratelaws = get_combinatoric_ratelaws(rs),
include_zero_odes = true,
remove_conserved = false,
add_cl_bindings = true,
expand_catalyst_funs = true,
save_positions = (true, true),
checks = false,
Expand Down Expand Up @@ -715,8 +722,10 @@ function hybrid_model(rs::ReactionSystem;
jumps = vcat(rxn_jumps, user_jumps)

# --- Add constraints (BC species, constraint equations, conserved species) ---
initeqs = Equation[]
cl_bindings = Dict{Any, Any}()
if has_continuous
eqs, us, ps, obs, ics = addconstraints!(eqs, flatrs, ists, ispcs; remove_conserved)
eqs, us, ps, obs, ics, initeqs, cl_bindings = addconstraints!(eqs, flatrs, ists, ispcs; remove_conserved)
else
# Pure jump case.
any(isbc, get_unknowns(flatrs)) &&
Expand All @@ -729,12 +738,16 @@ function hybrid_model(rs::ReactionSystem;

# --- Construct unified System ---
# Note: brownians is a positional arg (5th) in the System constructor.
all_bindings = add_cl_bindings ? merge(MT.get_bindings(flatrs), cl_bindings) :
MT.get_bindings(flatrs)
all_initeqs = add_cl_bindings ? initeqs : Equation[]
MT.System(eqs, get_iv(flatrs), us, ps, brownian_vars;
poissonians = user_poissonians,
jumps,
observed = obs,
name,
bindings = MT.get_bindings(flatrs),
initialization_eqs = all_initeqs,
bindings = all_bindings,
initial_conditions = merge(initial_conditions, ics),
checks,
continuous_events = MT.get_continuous_events(flatrs),
Expand Down Expand Up @@ -910,8 +923,8 @@ function ss_ode_model(rs::ReactionSystem; name = nameof(rs),
ists, ispcs = get_indep_sts(fullrs, (remove_conserved && !include_cl_as_eqs))
eqs = assemble_drift(fullrs, ispcs; combinatoric_ratelaws, remove_conserved,
as_odes = false, include_zero_odes = false, expand_catalyst_funs)
eqs, us, ps, obs, ics, initeqs = addconstraints!(eqs, fullrs, ists, ispcs;
remove_conserved, compute_cl_initeqs = !include_cl_as_eqs, include_cl_as_eqs)
eqs, us, ps, obs, ics, initeqs, cl_bindings = addconstraints!(eqs, fullrs, ists, ispcs;
remove_conserved, include_cl_as_eqs)

# Comoutes the correct initial conditions and bindings.
initial_conditions, bindings = MT.convert_bindings_for_time_independent_system(rs)
Expand All @@ -924,7 +937,7 @@ function ss_ode_model(rs::ReactionSystem; name = nameof(rs),
System(eqs, us, ps;
name,
observed = obs, initialization_eqs = initeqs,
bindings,
bindings = merge(bindings, cl_bindings),
initial_conditions,
checks,
metadata = MT.get_metadata(rs),
Expand Down Expand Up @@ -1026,12 +1039,14 @@ function sde_model(rs::ReactionSystem;
remove_conserved, expand_catalyst_funs, use_jump_ratelaws)
noiseeqs = assemble_diffusion(flatrs, ists, ispcs; combinatoric_ratelaws,
remove_conserved, expand_catalyst_funs, use_jump_ratelaws)
eqs, us, ps, obs, ics = addconstraints!(eqs, flatrs, ists, ispcs; remove_conserved)
eqs, us, ps, obs, ics, initeqs, cl_bindings = addconstraints!(eqs, flatrs, ists, ispcs; remove_conserved)

return MT.System(eqs, get_iv(flatrs), us, ps;
noise_eqs = noiseeqs,
observed = obs,
name,
initialization_eqs = initeqs,
bindings = cl_bindings,
initial_conditions = merge(initial_conditions, ics),
checks,
continuous_events = MT.get_continuous_events(flatrs),
Expand Down
Loading
Loading