diff --git a/scripts/sqlpp23-ddl2cpp b/scripts/sqlpp23-ddl2cpp index 355154f8..822b0c3a 100755 --- a/scripts/sqlpp23-ddl2cpp +++ b/scripts/sqlpp23-ddl2cpp @@ -178,7 +178,7 @@ class DdlParser: # Data type parsers def get_type_parser(base_type, data_type): - type_names = getattr(cls, f"ddl_{base_type}_types") + type_names = getattr(cls, f"ddl_{base_type}_types").copy() if custom_types and (base_type in custom_types): type_names.extend(custom_types[base_type]) return pp.Or( @@ -194,18 +194,31 @@ class DdlParser: ddl_date = get_type_parser("date", "date") ddl_timestamp = get_type_parser("timestamp", "timestamp") ddl_time = get_type_parser("time", "time") + + custom_parsers = [] + if custom_types: + for base_type, sql_types in custom_types.items(): + if not cls.is_base_type(base_type): + custom_parsers.append( + pp.Or(map(pp.CaselessKeyword, sorted(sql_types, reverse=True))) + .set_parse_action(pp.replace_with(base_type)) + ) + ddl_unknown = pp.Word(pp.alphanums).set_parse_action(pp.replace_with("UNKNOWN")) - cls.ddl_type = ( - ddl_boolean - | ddl_integral - | ddl_serial - | ddl_floating_point - | ddl_text - | ddl_blob - | ddl_timestamp - | ddl_date - | ddl_time - | ddl_unknown + cls.ddl_type = pp.MatchFirst( + custom_parsers + + [ + ddl_boolean, + ddl_integral, + ddl_serial, + ddl_floating_point, + ddl_text, + ddl_blob, + ddl_timestamp, + ddl_date, + ddl_time, + ddl_unknown, + ] ) # Constraints parser @@ -697,7 +710,8 @@ class ModelWriter: + cls._escape_if_reserved(column.name) + ", " + column_member + ");" , file=header) const_prefix = "const " if column.is_const else "" - type_str = column.cpp_type if column.cpp_type else "::sqlpp::" + column.data_type + raw_type = column.cpp_type if column.cpp_type else column.data_type + type_str = raw_type if "::" in raw_type else "::sqlpp::" + raw_type if column.is_nullable: print(" using data_type = " + const_prefix + "std::optional<" + type_str + ">;", file=header) else: @@ -845,8 +859,7 @@ def get_custom_types(filename): name_from_camel = re.sub(r"[A-Z]", lambda m : ("_" if m.start() else "") + m[0].lower(), name_ident) if DdlParser.is_base_type(name_from_camel): return name_from_camel - logging.error(f"Custom types file uses an unknown base type {name_ident}") - sys.exit(ExitCode.BAD_CUSTOM_TYPES) + return name_ident for row in reader: values = [cleaned for value in row["custom_types"] if (cleaned := clean_custom_type(value)) != ""] if values: diff --git a/scripts/sqlpp23-ddl2cpp-unit-tests b/scripts/sqlpp23-ddl2cpp-unit-tests index 08efef52..ad244bb6 100755 --- a/scripts/sqlpp23-ddl2cpp-unit-tests +++ b/scripts/sqlpp23-ddl2cpp-unit-tests @@ -109,6 +109,13 @@ class SelfTest(unittest.TestCase): result = ddl2cpp.DdlParser.ddl_type.parse_string(t, parse_all=True) self.assertEqual(result[0], "UNKNOWN") + def test_custom_type(self): + DdlParser.initialize({"my_custom_base": ["MY_SQL_TYPE"]}) + result = ddl2cpp.DdlParser.ddl_type.parse_string("MY_SQL_TYPE", parse_all=True) + self.assertEqual(result[0], "my_custom_base") + # Reset parser for other tests + DdlParser.initialize() + def test_column(self): test_data = [ { @@ -605,6 +612,53 @@ class SelfTest(unittest.TestCase): tables = ddl2cpp.DdlExecutor.execute([parsed], make_args(), {}) self.assertEqual(tables["t"].columns["id"].cpp_type, "MyUuidType") + # Annotation with qualified vs unqualified type + with self.subTest("Annotation with qualified vs unqualified type"): + import io + from contextlib import redirect_stdout + + parsed = ddl2cpp.DdlParser.ddl.parse_string(""" + CREATE TABLE t ( + -- cpp_type:my_ns::Type + a bigint NOT NULL, + -- cpp_type:unqualified_type + b bigint NOT NULL + ) + """, parse_all=True) + tables = ddl2cpp.DdlExecutor.execute([parsed], make_args(), {}) + + f = io.StringIO() + with redirect_stdout(f): + ddl2cpp.ModelWriter._write_table(tables["t"], f, make_args(namespace="ns", path_to_module=None, generate_table_creation_helper=False, naming_style="identity")) + output = f.getvalue() + self.assertIn("using data_type = my_ns::Type;", output) + self.assertIn("using data_type = ::sqlpp::unqualified_type;", output) + + # Mapping with qualified vs unqualified type + with self.subTest("Mapping with qualified vs unqualified type"): + import io + from contextlib import redirect_stdout + + # Re-initialize with custom mappings + DdlParser.initialize({"my_ns::Type": ["MY_TYPE_A"], "unqualified_type": ["MY_TYPE_B"]}) + parsed = ddl2cpp.DdlParser.ddl.parse_string(""" + CREATE TABLE t ( + a MY_TYPE_A NOT NULL, + b MY_TYPE_B NOT NULL + ) + """, parse_all=True) + tables = ddl2cpp.DdlExecutor.execute([parsed], make_args(), {}) + + f = io.StringIO() + with redirect_stdout(f): + ddl2cpp.ModelWriter._write_table(tables["t"], f, make_args(namespace="ns", path_to_module=None, generate_table_creation_helper=False, naming_style="identity")) + output = f.getvalue() + self.assertIn("using data_type = my_ns::Type;", output) + self.assertIn("using data_type = ::sqlpp::unqualified_type;", output) + + # Restore parser + DdlParser.initialize() + # Schema-qualified ALTER TABLE resolved with --postgresql-schema with self.subTest("Schema-qualified ALTER TABLE resolved with --postgresql-schema"): parsed = ddl2cpp.DdlParser.ddl.parse_string(""" diff --git a/tests/scripts/ddl2cpp_sample_good_custom_type.cpp b/tests/scripts/ddl2cpp_sample_good_custom_type.cpp index 25aff756..f4436159 100644 --- a/tests/scripts/ddl2cpp_sample_good_custom_type.cpp +++ b/tests/scripts/ddl2cpp_sample_good_custom_type.cpp @@ -1,6 +1,13 @@ +#include + +namespace my_ns { +struct uuid { + bool operator==(const uuid&) const = default; +}; +} // namespace my_ns + #include #include -#include template void test_db_model() { @@ -29,6 +36,7 @@ void test_db_model() { tab_foo.builtinDate = std::chrono::sys_days{}; tab_foo.builtinDateTime = std::chrono::system_clock::now(); tab_foo.builtinTime = std::chrono::seconds{10}; + tab_foo.myUuid = my_ns::uuid{}; } int main() { diff --git a/tests/scripts/ddl2cpp_sample_good_custom_type.sql b/tests/scripts/ddl2cpp_sample_good_custom_type.sql index 6eec0c2f..20197ffc 100644 --- a/tests/scripts/ddl2cpp_sample_good_custom_type.sql +++ b/tests/scripts/ddl2cpp_sample_good_custom_type.sql @@ -49,6 +49,7 @@ CREATE TABLE tab_foo builtinBlob BINARY, builtinDate DATE, builtinDateTime TIMESTAMPTZ, - builtinTime TIME WITH TIME ZONE + builtinTime TIME WITH TIME ZONE, + myUuid UUID ) WITH SYSTEM VERSIONING; -- enable System-Versioning for this table diff --git a/tests/scripts/ddl2cpp_sample_good_custom_type_new.csv b/tests/scripts/ddl2cpp_sample_good_custom_type_new.csv index 7e32e920..e6679db8 100644 --- a/tests/scripts/ddl2cpp_sample_good_custom_type_new.csv +++ b/tests/scripts/ddl2cpp_sample_good_custom_type_new.csv @@ -7,3 +7,4 @@ serial, CustomSerialType text, CustomTextType, another_text_type time, CustomTimeType timestamp, CustomTimestampType +my_ns::uuid, UUID diff --git a/tests/scripts/ddl2cpp_sample_good_custom_type_old.csv b/tests/scripts/ddl2cpp_sample_good_custom_type_old.csv index b5f96da1..60b98b42 100644 --- a/tests/scripts/ddl2cpp_sample_good_custom_type_old.csv +++ b/tests/scripts/ddl2cpp_sample_good_custom_type_old.csv @@ -7,3 +7,4 @@ Serial, CustomSerialType Text, CustomTextType, another_text_type Time, CustomTimeType Timestamp, CustomTimestampType +my_ns::uuid, UUID