Skip to content

Commit 79634d8

Browse files
committed
fix(sql): preserve timezone casts for extraction ops
Apply timezone-aware cast handling in Trino and localize DuckDB timestamp extraction/time operations so timezone casts behave consistently for columns and scalars. Add regression tests for timezone cast hour/time extraction and Trino SQL timezone-function generation. Fixes #11965 Refs #11527 Refs #11211 Refs #11879 Signed-off-by: Mridankan Mandal <iib2024017@iiita.ac.in>
1 parent e040215 commit 79634d8

3 files changed

Lines changed: 74 additions & 1 deletion

File tree

ibis/backends/duckdb/tests/test_client.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,38 @@ def test_to_other_sql(con, snapshot):
201201
snapshot.assert_match(sql, "out.sql")
202202

203203

204+
def test_timezone_cast_extracts_and_time():
205+
con = ibis.duckdb.connect()
206+
t = ibis.memtable({"x": ["2023-01-02"]})
207+
expr = t.select(
208+
ams_hour=t.x.cast("timestamp('Europe/Amsterdam')").hour(),
209+
utc_hour=t.x.cast("timestamp('UTC')").hour(),
210+
ams_time=t.x.cast("timestamp('Europe/Amsterdam')").time(),
211+
utc_time=t.x.cast("timestamp('UTC')").time(),
212+
)
213+
214+
result = con.execute(expr)
215+
216+
assert result.ams_hour.iat[0] == 1
217+
assert result.utc_hour.iat[0] == 0
218+
assert str(result.ams_time.iat[0]) == "01:00:00"
219+
assert str(result.utc_time.iat[0]) == "00:00:00"
220+
221+
222+
def test_to_trino_sql_timezone_cast_uses_timezone_functions():
223+
t = ibis.memtable({"x": ["2023-01-02"]})
224+
expr = t.select(
225+
casted=t.x.cast("timestamp('Europe/Paris')"),
226+
hour=t.x.cast("timestamp('Europe/Paris')").hour(),
227+
time=t.x.cast("timestamp('Europe/Paris')").time(),
228+
)
229+
230+
sql = ibis.to_sql(expr, dialect="trino")
231+
232+
assert "AT_TIMEZONE(" in sql
233+
assert "WITH_TIMEZONE(" in sql
234+
235+
204236
def test_insert_preserves_column_case(con):
205237
name1 = ibis.util.guid()
206238
name2 = ibis.util.guid()

ibis/backends/sql/compilers/duckdb.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,12 +392,41 @@ def visit_CountDistinctStar(self, op, *, where, arg):
392392
)
393393
return self.agg.count(sge.Distinct(expressions=[row]), where=where)
394394

395+
def _localize_timestamp_for_extract(self, op, *, arg):
396+
if op.arg.dtype.is_timestamp() and (timezone := op.arg.dtype.timezone) is not None:
397+
return self.f.timezone(timezone, arg)
398+
return arg
399+
400+
def visit_Time(self, op, *, arg):
401+
arg = self._localize_timestamp_for_extract(op, arg=arg)
402+
return super().visit_Time(op, arg=arg)
403+
404+
def visit_ExtractEpochSeconds(self, op, *, arg):
405+
if op.arg.dtype.is_timestamp() and op.arg.dtype.timezone is not None:
406+
return self.f.epoch(arg)
407+
return super().visit_ExtractEpochSeconds(op, arg=arg)
408+
409+
def visit_ExtractHour(self, op, *, arg):
410+
return self.f.extract("hour", self._localize_timestamp_for_extract(op, arg=arg))
411+
412+
def visit_ExtractMinute(self, op, *, arg):
413+
return self.f.extract(
414+
"minute", self._localize_timestamp_for_extract(op, arg=arg)
415+
)
416+
417+
def visit_ExtractSecond(self, op, *, arg):
418+
return self.f.extract(
419+
"second", self._localize_timestamp_for_extract(op, arg=arg)
420+
)
421+
395422
def visit_ExtractMillisecond(self, op, *, arg):
423+
arg = self._localize_timestamp_for_extract(op, arg=arg)
396424
return self.f.mod(self.f.extract("ms", arg), 1_000)
397425

398426
# DuckDB extracts subminute microseconds and milliseconds
399427
# so we have to finesse it a little bit
400428
def visit_ExtractMicrosecond(self, op, *, arg):
429+
arg = self._localize_timestamp_for_extract(op, arg=arg)
401430
return self.f.mod(self.f.extract("us", arg), 1_000_000)
402431

403432
def visit_TimestampFromUNIX(self, op, *, arg, unit):

ibis/backends/sql/compilers/trino.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -585,9 +585,21 @@ def visit_Cast(self, op, *, arg, to):
585585
if from_.is_integer():
586586
return self.f.from_unixtime(arg, tz)
587587
else:
588-
return self.f.from_unixtime_nanos(
588+
out = self.f.from_unixtime_nanos(
589589
self.cast(arg, dt.Decimal(38, 9)) * 1_000_000_000
590590
)
591+
return self.f.at_timezone(out, tz)
592+
593+
if to.is_timestamp() and (timezone := to.timezone) is not None:
594+
if from_.is_string() or from_.is_date():
595+
arg = self.cast(arg, dt.Timestamp(scale=to.scale))
596+
from_ = dt.Timestamp(scale=to.scale)
597+
598+
if from_.is_timestamp():
599+
if from_.timezone is None:
600+
arg = self.f.with_timezone(arg, "UTC")
601+
return self.f.at_timezone(arg, timezone)
602+
591603
return super().visit_Cast(op, arg=arg, to=to)
592604

593605
def visit_CountDistinctStar(self, op, *, arg, where):

0 commit comments

Comments
 (0)