Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -335,65 +335,66 @@ def polars_impl(run_config: RunConfig) -> QueryResult:
item = get_data(run_config.dataset_path, "item", run_config.suffix)
date_dim = get_data(run_config.dataset_path, "date_dim", run_config.suffix)

cross_items = build_cross_items(
store_sales, catalog_sales, web_sales, item, date_dim, year=year
)
average_sales = build_average_sales(
store_sales, catalog_sales, web_sales, date_dim, year=year
)

# week_dates is ≤7 rows (one calendar week), computed once as a 1-partition frame.
# Push the week filter into each channel before the UNION via a semi-join so that
# ~99% of rows (everything outside the target week) are dropped before the
# expensive cross_items join and groupby.
target_week = (
date_dim.filter(
(pl.col("d_year") == year + 1)
& (pl.col("d_moy") == 12)
& (pl.col("d_dom") == day)
)
.select("d_week_seq")
.unique()
)
week_dates = date_dim.join(target_week, on="d_week_seq").select("d_date_sk")

all_sales = pl.concat(
[
store_sales.select(
store_sales.join(
week_dates, left_on="ss_sold_date_sk", right_on="d_date_sk", how="semi"
).select(
[
pl.lit("store").alias("channel"),
pl.col("ss_item_sk").alias("item_sk"),
pl.col("ss_quantity").alias("quantity"),
pl.col("ss_list_price").alias("list_price"),
pl.col("ss_sold_date_sk").alias("date_sk"),
]
),
catalog_sales.select(
catalog_sales.join(
week_dates, left_on="cs_sold_date_sk", right_on="d_date_sk", how="semi"
).select(
[
pl.lit("catalog").alias("channel"),
pl.col("cs_item_sk").alias("item_sk"),
pl.col("cs_quantity").alias("quantity"),
pl.col("cs_list_price").alias("list_price"),
pl.col("cs_sold_date_sk").alias("date_sk"),
]
),
web_sales.select(
web_sales.join(
week_dates, left_on="ws_sold_date_sk", right_on="d_date_sk", how="semi"
).select(
[
pl.lit("web").alias("channel"),
pl.col("ws_item_sk").alias("item_sk"),
pl.col("ws_quantity").alias("quantity"),
pl.col("ws_list_price").alias("list_price"),
pl.col("ws_sold_date_sk").alias("date_sk"),
]
),
]
)

cross_items = build_cross_items(
store_sales, catalog_sales, web_sales, item, date_dim, year=year
)
average_sales = build_average_sales(
store_sales, catalog_sales, web_sales, date_dim, year=year
)

# d_week_seq target is the same for all 3 channels; compute it once.
target_week = (
date_dim.filter(
(pl.col("d_year") == year + 1)
& (pl.col("d_moy") == 12)
& (pl.col("d_dom") == day)
)
.select("d_week_seq")
.unique()
)
week_dates = date_dim.join(target_week, on="d_week_seq").select("d_date_sk")

# Build y: all 3 channels in a single pipeline.
# cross_items and average_sales each appear once — no CSE needed.
# After group_by the frame is tiny, so the cross join with the 1-row
# average_sales frame is negligible even if Polars fuses it into an IEJoin.
y = (
all_sales.join(cross_items, left_on="item_sk", right_on="ss_item_sk")
.join(item, left_on="item_sk", right_on="i_item_sk")
.join(week_dates, left_on="date_sk", right_on="d_date_sk")
.group_by(["channel", "i_brand_id", "i_class_id", "i_category_id"])
.agg(
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,38 +110,103 @@ def polars_impl(run_config: RunConfig) -> QueryResult:
sort_by = {"i_item_id": False, "i_item_desc": False, "s_state": False}
limit = 100

store_sales_base = (
q1 = f"{year}Q1"
q1_q3 = [f"{year}Q1", f"{year}Q2", f"{year}Q3"]

# Pre-filter date_dim to only qualifying d_date_sk values.
d1_dates = date_dim.filter(pl.col("d_quarter_name") == q1).select("d_date_sk")
d_q3_dates = date_dim.filter(pl.col("d_quarter_name").is_in(q1_q3)).select(
"d_date_sk"
)

# store_returns has [6] partitions — at the broadcast limit. Filter it to Q1-Q3 dates
# first, then use the (customer, item) pairs it contains to pre-filter both store_sales
# and catalog_sales before those larger tables enter the expensive shuffle joins.
store_returns_filtered = store_returns.join(
d_q3_dates, left_on="sr_returned_date_sk", right_on="d_date_sk", how="semi"
).select(["sr_customer_sk", "sr_item_sk", "sr_ticket_number", "sr_return_quantity"])

# (customer, item) pairs present in any qualifying store return; stays at [6] partitions
# so broadcast is free. Polars will CACHE this shared subplan.
sr_customer_item = store_returns_filtered.select(["sr_customer_sk", "sr_item_sk"])

store_sales_filtered = (
store_sales.join(
date_dim, left_on="ss_sold_date_sk", right_on="d_date_sk", suffix="_d1"
d1_dates, left_on="ss_sold_date_sk", right_on="d_date_sk", how="semi"
)
.join(
sr_customer_item,
left_on=["ss_customer_sk", "ss_item_sk"],
right_on=["sr_customer_sk", "sr_item_sk"],
how="semi",
)
.select(
[
"ss_customer_sk",
"ss_item_sk",
"ss_store_sk",
"ss_ticket_number",
"ss_quantity",
]
)
.join(
item.select(["i_item_sk", "i_item_id", "i_item_desc"]),
left_on="ss_item_sk",
right_on="i_item_sk",
)
.join(
store.select(["s_store_sk", "s_state"]),
left_on="ss_store_sk",
right_on="s_store_sk",
)
.select(
[
"ss_customer_sk",
"ss_item_sk",
"ss_ticket_number",
"ss_quantity",
"i_item_id",
"i_item_desc",
"s_state",
]
)
.join(item, left_on="ss_item_sk", right_on="i_item_sk")
.join(store, left_on="ss_store_sk", right_on="s_store_sk")
.filter(pl.col("d_quarter_name") == f"{year}Q1")
)

store_returns_base = store_returns.join(
date_dim, left_on="sr_returned_date_sk", right_on="d_date_sk", suffix="_d2"
).filter(pl.col("d_quarter_name").is_in([f"{year}Q1", f"{year}Q2", f"{year}Q3"]))

catalog_sales_base = catalog_sales.join(
date_dim, left_on="cs_sold_date_sk", right_on="d_date_sk", suffix="_d3"
).filter(pl.col("d_quarter_name").is_in([f"{year}Q1", f"{year}Q2", f"{year}Q3"]))
catalog_sales_filtered = (
catalog_sales.join(
d_q3_dates, left_on="cs_sold_date_sk", right_on="d_date_sk", how="semi"
)
.join(
sr_customer_item,
left_on=["cs_bill_customer_sk", "cs_item_sk"],
right_on=["sr_customer_sk", "sr_item_sk"],
how="semi",
)
.select(["cs_bill_customer_sk", "cs_item_sk", "cs_quantity"])
)

return QueryResult(
frame=(
store_sales_base.join(
store_returns_base,
store_sales_filtered.join(
store_returns_filtered,
left_on=["ss_customer_sk", "ss_item_sk", "ss_ticket_number"],
right_on=["sr_customer_sk", "sr_item_sk", "sr_ticket_number"],
how="inner",
suffix="_sr",
)
.select(
[
"ss_customer_sk",
"ss_item_sk",
"ss_quantity",
"sr_return_quantity",
"i_item_id",
"i_item_desc",
"s_state",
]
)
.join(
catalog_sales_base,
catalog_sales_filtered,
left_on=["ss_customer_sk", "ss_item_sk"],
right_on=["cs_bill_customer_sk", "cs_item_sk"],
how="inner",
suffix="_cs",
)
.group_by(["i_item_id", "i_item_desc", "s_state"])
.agg(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,7 @@ def polars_impl(run_config: RunConfig) -> QueryResult:
catalog_sales = get_data(
run_config.dataset_path, "catalog_sales", run_config.suffix
)
customer_demographics_1 = get_data(
run_config.dataset_path, "customer_demographics", run_config.suffix
)
customer_demographics_2 = get_data(
customer_demographics = get_data(
run_config.dataset_path, "customer_demographics", run_config.suffix
)
customer = get_data(run_config.dataset_path, "customer", run_config.suffix)
Expand All @@ -134,30 +131,52 @@ def polars_impl(run_config: RunConfig) -> QueryResult:
date_dim = get_data(run_config.dataset_path, "date_dim", run_config.suffix)
item = get_data(run_config.dataset_path, "item", run_config.suffix)

# Pre-filter each dimension table before joining against catalog_sales [45 partitions].
# d_year not in GROUP BY — semi-join keeps only the date key in the pipeline.
filtered_dates = date_dim.filter(pl.col("d_year") == year).select("d_date_sk")
filtered_cd1 = customer_demographics.filter(
(pl.col("cd_gender") == gen) & (pl.col("cd_education_status") == es)
).select(["cd_demo_sk", "cd_dep_count"])
filtered_customer = customer.filter(pl.col("c_birth_month").is_in(month)).select(
["c_customer_sk", "c_current_cdemo_sk", "c_current_addr_sk", "c_birth_year"]
)
filtered_addr = customer_address.filter(pl.col("ca_state").is_in(state)).select(
["ca_address_sk", "ca_county", "ca_state", "ca_country"]
)

base_query = (
catalog_sales.join(date_dim, left_on="cs_sold_date_sk", right_on="d_date_sk")
.join(item, left_on="cs_item_sk", right_on="i_item_sk")
catalog_sales.select(
[
"cs_sold_date_sk",
"cs_item_sk",
"cs_bill_cdemo_sk",
"cs_bill_customer_sk",
"cs_quantity",
"cs_list_price",
"cs_coupon_amt",
"cs_sales_price",
"cs_net_profit",
]
)
.join(
customer_demographics_1,
left_on="cs_bill_cdemo_sk",
right_on="cd_demo_sk",
suffix="_cd1",
filtered_dates, left_on="cs_sold_date_sk", right_on="d_date_sk", how="semi"
)
.join(
item.select(["i_item_sk", "i_item_id"]),
left_on="cs_item_sk",
right_on="i_item_sk",
)
.join(customer, left_on="cs_bill_customer_sk", right_on="c_customer_sk")
.join(filtered_cd1, left_on="cs_bill_cdemo_sk", right_on="cd_demo_sk")
.join(
customer_demographics_2,
filtered_customer, left_on="cs_bill_customer_sk", right_on="c_customer_sk"
)
.join(
customer_demographics.select("cd_demo_sk"),
left_on="c_current_cdemo_sk",
right_on="cd_demo_sk",
suffix="_cd2",
)
.join(customer_address, left_on="c_current_addr_sk", right_on="ca_address_sk")
.filter(
(pl.col("cd_gender") == gen)
& (pl.col("cd_education_status") == es)
& pl.col("c_birth_month").is_in(month)
& (pl.col("d_year") == year)
& pl.col("ca_state").is_in(state)
how="semi",
)
.join(filtered_addr, left_on="c_current_addr_sk", right_on="ca_address_sk")
)

agg_exprs = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,6 @@ def polars_impl(run_config: RunConfig) -> QueryResult:
),
]
)
# Step 2: Create wswscs CTE equivalent (aggregate by week and day of week)
# First join with date_dim to get day names
wscs_with_dates = wscs.join(date_dim, left_on="sold_date_sk", right_on="d_date_sk")
# Create separate aggregations for each day to better control null handling
days = (
"Sunday",
"Monday",
Expand All @@ -158,35 +154,26 @@ def polars_impl(run_config: RunConfig) -> QueryResult:
"fri_sales",
"sat_sales",
)
# Start with all week sequences
all_weeks = wscs_with_dates.select("d_week_seq").unique()
wswscs = all_weeks

# Pre-filter date_dim to 4 years ([year-1, year, year+1, year+2]) to capture
# boundary weeks that span year transitions (e.g. from Dec 28 to Jan 3). Filtering to
# only [year, year+1] incorrectly excludes Dec days whose d_week_seq also
# appears in year's date_dim, producing null sales for those boundary weeks.
date_dim_prefilter = date_dim.filter(
pl.col("d_year").is_in([year - 1, year, year + 1, year + 2])
).select(["d_date_sk", "d_week_seq", "d_day_name"])
wswscs = (
wscs_with_dates.with_columns(
wscs.join(date_dim_prefilter, left_on="sold_date_sk", right_on="d_date_sk")
.group_by("d_week_seq")
.agg(
[
pl.when(pl.col("d_day_name") == day)
.then(pl.col("sales_price"))
.otherwise(None)
.sum()
.alias(name)
for day, name in zip(days, day_cols, strict=True)
]
)
.group_by("d_week_seq")
.agg(
*(pl.col(name).sum().alias(name) for name in day_cols),
*(pl.col(name).count().alias(f"{name}_count") for name in day_cols),
)
.with_columns(
[
pl.when(pl.col(f"{name}_count") > 0)
.then(pl.col(name))
.otherwise(None)
.alias(name)
for name in day_cols
]
)
.select(["d_week_seq", *day_cols])
)

# Step 3: Create year data (y subquery equivalent)
Expand Down
Loading
Loading