Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 48 additions & 6 deletions src/microplex_us/targets/arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,20 @@
ARCH_STATE_TO_NATIONAL_ROLLUP_VARIABLES = frozenset(
{
"aca_aptc_amount",
"actc_amount",
"actc_claims",
"charitable_amount",
"charitable_claims",
"ctc_amount",
"ctc_claims",
"medical_amount",
"medical_claims",
"mortgage_interest_amount",
"mortgage_interest_claims",
"qbi_amount",
"qbi_claims",
"salt_amount",
"salt_claims",
}
)

Expand Down Expand Up @@ -4582,7 +4594,7 @@ def arch_target_record_to_canonical_spec(
)
)
elif positive_measure is not None:
model_variable = positive_measure
model_variable = "tax_unit_count"
entity = EntityType.TAX_UNIT
filters.append(
TargetFilter(feature=positive_measure, operator=">", value=0)
Expand Down Expand Up @@ -5429,12 +5441,17 @@ def _ssi_category_filters_for_arch_constraint(
value: str,
) -> tuple[TargetFilter, ...]:
category = str(value).strip().lower()
if operator == "==" and category in {"aged", "blind", "disabled"}:
category_feature = {
"aged": "is_ssi_aged",
"blind": "is_blind",
"disabled": "is_ssi_disabled",
}.get(category)
if operator == "==" and category_feature is not None:
return (
TargetFilter(
feature="ssi_category",
operator=operator,
value=category.upper(),
feature=category_feature,
operator=">",
value=0,
),
)
return (TargetFilter(feature="ssi_category", operator=operator, value=value),)
Expand Down Expand Up @@ -5720,12 +5737,25 @@ def _arch_target_query_variables(
record: ArchTargetRecord,
target: CanonicalTargetSpec,
) -> set[str]:
metadata_variable = str(target.metadata.get("variable") or "")
domain_variables = _arch_target_domain_variables(target)
variables = {
record.variable,
str(target.metadata.get("variable")),
}
if not (
target.aggregation is TargetAggregation.COUNT
and metadata_variable in {
"household_count",
"person_count",
"spm_unit_count",
"tax_unit_count",
}
and domain_variables
):
variables.add(metadata_variable)
if target.measure is not None:
variables.add(str(target.measure))
variables.update(domain_variables)
if target.aggregation is TargetAggregation.SUM:
variables.update(_arch_target_cell_variables(target))
return {variable for variable in variables if variable}
Expand Down Expand Up @@ -5891,6 +5921,18 @@ def _target_domain_variables_match(
):
return True

count_positive_measure = _positive_measure_for_count_record(
str(target.metadata.get("arch_variable") or "")
)
if (
target.aggregation is TargetAggregation.COUNT
and count_positive_measure
and count_positive_measure in effective_target_domain_variables
and cell_domain_variables
== effective_target_domain_variables - {count_positive_measure}
):
return True

return False


Expand Down
90 changes: 78 additions & 12 deletions tests/targets/test_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,30 @@ def _insert_complete_state_rollup_targets(path: Path) -> None:
)
for index, (stratum_id, *_rest) in enumerate(ctc_strata)
]
deduction_targets = [
(
30_000 + index * 4 + offset,
stratum_id,
variable,
2024,
value + index,
target_type,
None,
"IRS_SOI",
"SOI Individual Returns - Itemized Deductions",
None,
None,
)
for index, (stratum_id, *_rest) in enumerate(ctc_strata)
for offset, (variable, value, target_type) in enumerate(
(
("qbi_amount", 2_000.0, "AMOUNT"),
("qbi_claims", 200.0, "COUNT"),
("medical_amount", 300.0, "AMOUNT"),
("medical_claims", 30.0, "COUNT"),
)
)
]
aca_targets = [
(
20_000 + index,
Expand Down Expand Up @@ -540,7 +564,7 @@ def _insert_complete_state_rollup_targets(path: Path) -> None:
notes
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
[*ctc_targets, *ctc_count_targets, *aca_targets],
[*ctc_targets, *ctc_count_targets, *deduction_targets, *aca_targets],
)
conn.commit()
conn.close()
Expand Down Expand Up @@ -583,6 +607,7 @@ def test_arch_provider_ages_soi_and_maps_return_counts_to_positive_amounts(tmp_p
assert count_target.metadata["display_label"] == count_target.description
assert count_target.metadata["target_semantic"] == "count"
assert count_target.metadata["model_variable_role"] == "preserved_input"
assert count_target.metadata["variable"] == "tax_unit_count"
assert count_target.measure is None
assert count_target.value == pytest.approx(11.0)
assert {
Expand Down Expand Up @@ -1057,8 +1082,8 @@ def test_arch_provider_matches_current_profile_aliases(tmp_path):
10: "self_employment_income",
11: "person_count",
12: "employment_income",
13: "employment_income",
14: "self_employment_income",
13: "tax_unit_count",
14: "tax_unit_count",
}


Expand Down Expand Up @@ -2258,7 +2283,8 @@ def test_arch_provider_maps_real_estate_tax_targets(tmp_path):

assert {target.metadata["target_id"] for target in target_set.targets} == {8, 9}
assert {target.metadata["variable"] for target in target_set.targets} == {
"real_estate_taxes"
"real_estate_taxes",
"tax_unit_count",
}


Expand Down Expand Up @@ -3400,6 +3426,26 @@ def test_arch_target_profile_coverage_rolls_complete_state_targets_to_national(
geo_level="national",
domain_variable="adjusted_gross_income,non_refundable_ctc",
),
PolicyEngineUSTargetCell(
"qualified_business_income_deduction",
geo_level="national",
domain_variable="qualified_business_income_deduction",
),
PolicyEngineUSTargetCell(
"tax_unit_count",
geo_level="national",
domain_variable="qualified_business_income_deduction",
),
PolicyEngineUSTargetCell(
"medical_expense_deduction",
geo_level="national",
domain_variable="medical_expense_deduction,tax_unit_itemizes",
),
PolicyEngineUSTargetCell(
"tax_unit_count",
geo_level="national",
domain_variable="medical_expense_deduction,tax_unit_itemizes",
),
PolicyEngineUSTargetCell(
"aca_ptc",
geo_level="national",
Expand All @@ -3408,8 +3454,8 @@ def test_arch_target_profile_coverage_rolls_complete_state_targets_to_national(
),
)

assert report.target_cell_count == 5
assert report.covered_cell_count == 5
assert report.target_cell_count == 9
assert report.covered_cell_count == 9
target_set = provider.load_target_set(
TargetQuery(
period=2024,
Expand All @@ -3419,20 +3465,40 @@ def test_arch_target_profile_coverage_rolls_complete_state_targets_to_national(
)
)
rollup_targets = {
(target.measure or target.metadata["variable"], target.aggregation): target
(
target.measure or target.metadata["variable"],
target.aggregation,
target.metadata["arch_variable"],
): target
for target in target_set.targets
if target.metadata["geo_level"] == "national"
and str(target.metadata["target_id"]).startswith("-")
}
assert rollup_targets[
("non_refundable_ctc", TargetAggregation.SUM)
("non_refundable_ctc", TargetAggregation.SUM, "ctc_amount")
].value == pytest.approx(sum(1_000.0 + index for index in range(51)))
assert rollup_targets[
("non_refundable_ctc", TargetAggregation.COUNT)
("tax_unit_count", TargetAggregation.COUNT, "ctc_claims")
].value == pytest.approx(sum(100.0 + index for index in range(51)))
assert rollup_targets[("aca_ptc", TargetAggregation.SUM)].value == pytest.approx(
sum(10_000.0 + index for index in range(51))
)
assert rollup_targets[
(
"qualified_business_income_deduction",
TargetAggregation.SUM,
"qbi_amount",
)
].value == pytest.approx(sum(2_000.0 + index for index in range(51)))
assert rollup_targets[
("tax_unit_count", TargetAggregation.COUNT, "qbi_claims")
].value == pytest.approx(sum(200.0 + index for index in range(51)))
assert rollup_targets[
("medical_expense_deduction", TargetAggregation.SUM, "medical_amount")
].value == pytest.approx(sum(300.0 + index for index in range(51)))
assert rollup_targets[
("tax_unit_count", TargetAggregation.COUNT, "medical_claims")
].value == pytest.approx(sum(30.0 + index for index in range(51)))
assert rollup_targets[
("aca_ptc", TargetAggregation.SUM, "aca_aptc_amount")
].value == pytest.approx(sum(10_000.0 + index for index in range(51)))


def test_arch_target_profile_coverage_reports_current_pe_profile(tmp_path):
Expand Down
36 changes: 16 additions & 20 deletions tests/targets/test_arch_facts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1315,7 +1315,7 @@ def test_arch_consumer_fact_jsonl_provider_maps_tax_exempt_interest(
qualified_returns = targets_by_arch_variable["qualified_dividends_returns"]
qualified_amount = targets_by_arch_variable["qualified_dividends_amount"]

assert returns.metadata["variable"] == "tax_exempt_interest_income"
assert returns.metadata["variable"] == "tax_unit_count"
assert returns.aggregation.value == "count"
assert {
(
Expand All @@ -1330,7 +1330,7 @@ def test_arch_consumer_fact_jsonl_provider_maps_tax_exempt_interest(
}
assert amount.metadata["variable"] == "tax_exempt_interest_income"
assert amount.measure == "tax_exempt_interest_income"
assert qualified_returns.metadata["variable"] == "qualified_dividend_income"
assert qualified_returns.metadata["variable"] == "tax_unit_count"
assert qualified_returns.aggregation.value == "count"
assert {
(
Expand Down Expand Up @@ -1408,7 +1408,7 @@ def test_arch_consumer_fact_jsonl_provider_maps_schedule_c_self_employment(
partnership_returns = targets_by_arch_variable["partnership_scorp_income_returns"]
partnership_amount = targets_by_arch_variable["partnership_scorp_income_amount"]

assert returns.metadata["variable"] == "self_employment_income"
assert returns.metadata["variable"] == "tax_unit_count"
assert {
(
target_filter.feature,
Expand All @@ -1423,8 +1423,7 @@ def test_arch_consumer_fact_jsonl_provider_maps_schedule_c_self_employment(
assert amount.metadata["variable"] == "self_employment_income"
assert amount.measure == "self_employment_income"
assert (
partnership_returns.metadata["variable"]
== "tax_unit_partnership_s_corp_income"
partnership_returns.metadata["variable"] == "tax_unit_count"
)
assert {
(
Expand Down Expand Up @@ -1656,7 +1655,7 @@ def test_arch_consumer_fact_jsonl_provider_maps_table_2_1_itemized_details(
assert ("itemized_deductions", ">", "0") in _target_filter_tuples(charitable)

charitable_count = targets_by_arch_variable["charitable_returns"]
assert charitable_count.metadata["variable"] == "charitable_deduction"
assert charitable_count.metadata["variable"] == "tax_unit_count"
assert charitable_count.aggregation.value == "count"
assert ("charitable_deduction", ">", "0") in _target_filter_tuples(
charitable_count
Expand Down Expand Up @@ -2217,7 +2216,7 @@ def test_arch_consumer_fact_jsonl_provider_maps_state_broad_soi_concepts(
}

schedule_c_returns = targets_by_arch_variable["schedule_c_income_returns"]
assert schedule_c_returns.metadata["variable"] == "self_employment_income"
assert schedule_c_returns.metadata["variable"] == "tax_unit_count"
assert schedule_c_returns.aggregation.value == "count"
assert ("self_employment_income", ">", "0") in _target_filter_tuples(
schedule_c_returns
Expand All @@ -2238,10 +2237,7 @@ def test_arch_consumer_fact_jsonl_provider_maps_state_broad_soi_concepts(
assert qbi.measure == "qualified_business_income_deduction"

qbi_claims = targets_by_arch_variable["qbi_claims"]
assert (
qbi_claims.metadata["variable"]
== "qualified_business_income_deduction"
)
assert qbi_claims.metadata["variable"] == "tax_unit_count"
assert qbi_claims.aggregation.value == "count"
assert (
"qualified_business_income_deduction",
Expand All @@ -2254,7 +2250,7 @@ def test_arch_consumer_fact_jsonl_provider_maps_state_broad_soi_concepts(
assert rental.measure == "rental_income"

rental_returns = targets_by_arch_variable["rental_royalty_income_returns"]
assert rental_returns.metadata["variable"] == "rental_income"
assert rental_returns.metadata["variable"] == "tax_unit_count"
assert rental_returns.aggregation.value == "count"
assert ("rental_income", ">", "0") in _target_filter_tuples(
rental_returns
Expand All @@ -2265,7 +2261,7 @@ def test_arch_consumer_fact_jsonl_provider_maps_state_broad_soi_concepts(
assert ctc.measure == "non_refundable_ctc"

ctc_claims = targets_by_arch_variable["ctc_claims"]
assert ctc_claims.metadata["variable"] == "non_refundable_ctc"
assert ctc_claims.metadata["variable"] == "tax_unit_count"
assert ctc_claims.aggregation.value == "count"
assert ("non_refundable_ctc", ">", "0") in _target_filter_tuples(
ctc_claims
Expand All @@ -2276,7 +2272,7 @@ def test_arch_consumer_fact_jsonl_provider_maps_state_broad_soi_concepts(
assert actc.measure == "refundable_ctc"

actc_claims = targets_by_arch_variable["actc_claims"]
assert actc_claims.metadata["variable"] == "refundable_ctc"
assert actc_claims.metadata["variable"] == "tax_unit_count"
assert actc_claims.aggregation.value == "count"
assert ("refundable_ctc", ">", "0") in _target_filter_tuples(
actc_claims
Expand Down Expand Up @@ -2403,7 +2399,7 @@ def test_arch_consumer_fact_jsonl_provider_maps_eitc_by_agi_and_children(
target = target_set.targets[0]

assert target.metadata["arch_variable"] == "eitc_claims"
assert target.metadata["variable"] == "eitc"
assert target.metadata["variable"] == "tax_unit_count"
assert target.aggregation.value == "count"
assert _target_filter_tuples(target) == {
("eitc", ">", "0"),
Expand Down Expand Up @@ -2912,7 +2908,7 @@ def find_target(

aged_count = find_target(
"ssi_recipients",
{("ssi_category", "==", "AGED"), ("ssi", ">", 0)},
{("is_ssi_aged", ">", 0), ("ssi", ">", 0)},
)
assert aged_count.measure is None
assert aged_count.entity.value == "person"
Expand All @@ -2921,7 +2917,7 @@ def find_target(
assert {
(target_filter.feature, target_filter.operator.value, target_filter.value)
for target_filter in aged_count.filters
} == {("ssi_category", "==", "AGED"), ("ssi", ">", 0)}
} == {("is_ssi_aged", ">", 0), ("ssi", ">", 0)}

ca_payments = find_target(
"ssi_total_payments",
Expand All @@ -2939,7 +2935,7 @@ def find_target(
ca_disabled_count = find_target(
"ssi_recipients",
{
("ssi_category", "==", "DISABLED"),
("is_ssi_disabled", ">", 0),
("ssi", ">", 0),
("state_fips", "==", "06"),
},
Expand All @@ -2950,7 +2946,7 @@ def find_target(
(target_filter.feature, target_filter.operator.value, target_filter.value)
for target_filter in ca_disabled_count.filters
} == {
("ssi_category", "==", "DISABLED"),
("is_ssi_disabled", ">", 0),
("ssi", ">", 0),
("state_fips", "==", "06"),
}
Expand Down Expand Up @@ -4144,7 +4140,7 @@ def test_arch_fact_provider_maps_soi_table_1_4_income_source_facts(
== "count"
)
assert getattr(wages_returns.entity, "value", wages_returns.entity) == "tax_unit"
assert wages_returns.metadata["variable"] == "employment_income"
assert wages_returns.metadata["variable"] == "tax_unit_count"
assert (
"employment_income",
">",
Expand Down
Loading