Skip to content

Commit

Permalink
Add custom and iceberg table aliases
Browse files Browse the repository at this point in the history
  • Loading branch information
arjunlol committed Dec 20, 2024
1 parent 166c08d commit 52d7bb8
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 31 deletions.
55 changes: 54 additions & 1 deletion src/query_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ func TestHandleQuery(t *testing.T) {
"description": {"bit_column"},
"values": {"1"},
},
"SELECT test_table.bit_column FROM public.test_table WHERE bit_column IS NOT NULL": {
"description": {"bit_column"},
"values": {"1"},
},
"SELECT bit_column FROM public.test_table WHERE bit_column IS NULL": {
"description": {"bit_column"},
"values": {""},
Expand Down Expand Up @@ -468,6 +472,10 @@ func TestHandleQuery(t *testing.T) {
"description": {"word", "catcode", "barelabel", "catdesc", "baredesc"},
"values": {"abort", "U", "t", "unreserved", "can be bare label"},
},
"SELECT pg_get_keywords.word FROM pg_catalog.pg_get_keywords() LIMIT 1": {
"description": {"word"},
"values": {"abort"},
},
"SELECT * FROM generate_series(1, 2) AS series(index) LIMIT 1": {
"description": {"index"},
"values": {"1"},
Expand Down Expand Up @@ -520,6 +528,51 @@ func TestHandleQuery(t *testing.T) {
"description": {"oid", "rolname"},
"values": {"10", "bemidb"},
},
// Table alias
"SELECT pg_shadow.usename FROM pg_shadow": {
"description": {"usename"},
"values": {"bemidb"},
},
"SELECT pg_roles.rolname FROM pg_roles": {
"description": {"rolname"},
"values": {"bemidb"},
},
"SELECT pg_extension.extname FROM pg_extension": {
"description": {"extname"},
"values": {"plpgsql"},
},
"SELECT pg_database.datname FROM pg_database": {
"description": {"datname"},
"values": {"bemidb"},
},
"SELECT pg_inherits.inhrelid FROM pg_inherits": {
"description": {"inhrelid"},
"values": {},
},
"SELECT pg_shdescription.objoid FROM pg_shdescription": {
"description": {"objoid"},
"values": {},
},
"SELECT pg_statio_user_tables.relid FROM pg_statio_user_tables": {
"description": {"relid"},
"values": {},
},
"SELECT pg_replication_slots.slot_name FROM pg_replication_slots": {
"description": {"slot_name"},
"values": {},
},
"SELECT pg_stat_gssapi.pid FROM pg_stat_gssapi": {
"description": {"pid"},
"values": {},
},
"SELECT pg_auth_members.oid FROM pg_auth_members": {
"description": {"oid"},
"values": {},
},
"SELECT tables.table_name FROM information_schema.tables": {
"description": {"table_name"},
"values": {"test_table"},
},
}

for query, responses := range responsesByQuery {
Expand Down Expand Up @@ -582,7 +635,7 @@ func TestHandleParseQuery(t *testing.T) {
&pgproto3.ParseComplete{},
})

remappedQuery := "SELECT usename, passwd FROM (VALUES ('bemidb', '10', 'FALSE', 'FALSE', 'TRUE', 'FALSE', 'bemidb-encrypted', 'NULL', 'NULL')) t(usename, usesysid, usecreatedb, usesuper, userepl, usebypassrls, passwd, valuntil, useconfig) WHERE usename = $1"
remappedQuery := "SELECT usename, passwd FROM (VALUES ('bemidb', '10', 'FALSE', 'FALSE', 'TRUE', 'FALSE', 'bemidb-encrypted', 'NULL', 'NULL')) pg_shadow(usename, usesysid, usecreatedb, usesuper, userepl, usebypassrls, passwd, valuntil, useconfig) WHERE usename = $1"
if preparedStatement.Query != remappedQuery {
t.Errorf("Expected the prepared statement query to be %v, got %v", remappedQuery, preparedStatement.Query)
}
Expand Down
23 changes: 10 additions & 13 deletions src/query_parser_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ func (parser *QueryParserTable) NodeToQuerySchemaTable(node *pgQuery.Node) Query
}
}

func (parser *QueryParserTable) MakeEmptyTableNode(columns []string, alias string) *pgQuery.Node {
return parser.utils.MakeSubselectWithoutRowsNode(columns, alias)
func (parser *QueryParserTable) MakeEmptyTableNode(tableName string, columns []string, alias string) *pgQuery.Node {
return parser.utils.MakeSubselectWithoutRowsNode(tableName, columns, alias)
}

// pg_catalog.pg_shadow -> VALUES(values...) t(columns...)
Expand All @@ -60,7 +60,7 @@ func (parser *QueryParserTable) MakePgShadowNode(user string, encryptedPassword
}
rowsValues = append(rowsValues, rowValues)

return parser.utils.MakeSubselectWithRowsNode(columns, rowsValues, alias)
return parser.utils.MakeSubselectWithRowsNode(PG_TABLE_PG_SHADOW, columns, rowsValues, alias)
}

// pg_catalog.pg_roles -> VALUES(values...) t(columns...)
Expand All @@ -79,15 +79,15 @@ func (parser *QueryParserTable) MakePgRolesNode(user string, alias string) *pgQu
}
rowsValues = append(rowsValues, rowValues)

return parser.utils.MakeSubselectWithRowsNode(columns, rowsValues, alias)
return parser.utils.MakeSubselectWithRowsNode(PG_TABLE_PG_ROLES, columns, rowsValues, alias)
}

// pg_catalog.pg_extension -> VALUES(values...) t(columns...)
func (parser *QueryParserTable) MakePgExtensionNode(alias string) *pgQuery.Node {
columns := PG_EXTENSION_VALUE_BY_COLUMN.Keys()
staticRowValues := PG_EXTENSION_VALUE_BY_COLUMN.Values()
rowsValues := [][]string{staticRowValues}
return parser.utils.MakeSubselectWithRowsNode(columns, rowsValues, alias)
return parser.utils.MakeSubselectWithRowsNode(PG_TABLE_PG_EXTENSION, columns, rowsValues, alias)
}

// pg_catalog.pg_database -> VALUES(values...) t(columns...)
Expand All @@ -105,7 +105,7 @@ func (parser *QueryParserTable) MakePgDatabaseNode(database string, alias string
}
rowsValues = append(rowsValues, rowValues)

return parser.utils.MakeSubselectWithRowsNode(columns, rowsValues, alias)
return parser.utils.MakeSubselectWithRowsNode(PG_TABLE_PG_DATABASE, columns, rowsValues, alias)
}

// System pg_* tables
Expand All @@ -120,7 +120,7 @@ func (parser *QueryParserTable) IsTableFromInformationSchema(qSchemaTable QueryS
}

// iceberg.table -> FROM iceberg_scan('path', skip_schema_inference = true)
func (parser *QueryParserTable) MakeIcebergTableNode(tablePath string, alias string) *pgQuery.Node {
func (parser *QueryParserTable) MakeIcebergTableNode(tablePath string, qSchemaTable QuerySchemaTable) *pgQuery.Node {
node := pgQuery.MakeSimpleRangeFunctionNode([]*pgQuery.Node{
pgQuery.MakeListNode([]*pgQuery.Node{
pgQuery.MakeFuncCallNode(
Expand All @@ -144,9 +144,6 @@ func (parser *QueryParserTable) MakeIcebergTableNode(tablePath string, alias str
),
}),
})
if alias == "" {
return node
}

// DuckDB doesn't support aliases on iceberg_scan() functions, so we need to wrap it in a nested select that can have an alias
selectStarNode := pgQuery.MakeResTargetNodeWithVal(
Expand All @@ -156,7 +153,7 @@ func (parser *QueryParserTable) MakeIcebergTableNode(tablePath string, alias str
),
0,
)
return parser.utils.MakeSubselectFromNode([]*pgQuery.Node{selectStarNode}, node, alias)
return parser.utils.MakeSubselectFromNode(qSchemaTable.Table, []*pgQuery.Node{selectStarNode}, node, qSchemaTable.Alias)
}

// pg_catalog.pg_get_keywords()
Expand Down Expand Up @@ -218,7 +215,7 @@ func (parser *QueryParserTable) MakePgGetKeywordsNode(node *pgQuery.Node) *pgQue
alias = node.GetAlias().Aliasname
}

return parser.utils.MakeSubselectWithRowsNode(columns, rows, alias)
return parser.utils.MakeSubselectWithRowsNode(PG_FUNCTION_PG_GET_KEYWORDS, columns, rows, alias)
}

// array_upper(array, 1)
Expand Down Expand Up @@ -380,7 +377,7 @@ func (parser *QueryParserTable) MakePgShowAllSettingsNode(node *pgQuery.Node) *p
alias = node.GetAlias().Aliasname
}

return parser.utils.MakeSubselectFromNode(targetList, fromNode, alias)
return parser.utils.MakeSubselectFromNode(PG_FUNCTION_PG_SHOW_ALL_SETTINGS, targetList, fromNode, alias)
}

func (parser *QueryParserTable) isPgCatalogSchema(qSchemaTable QuerySchemaTable) bool {
Expand Down
16 changes: 6 additions & 10 deletions src/query_parser_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@ import (
pgQuery "github.com/pganalyze/pg_query_go/v5"
)

const (
DEFAULT_ALIAS = "t"
)

type QueryParserUtils struct {
config *Config
}
Expand All @@ -16,7 +12,7 @@ func NewQueryParserUtils(config *Config) *QueryParserUtils {
return &QueryParserUtils{config: config}
}

func (utils *QueryParserUtils) MakeSubselectWithRowsNode(columns []string, rowsValues [][]string, alias string) *pgQuery.Node {
func (utils *QueryParserUtils) MakeSubselectWithRowsNode(tableName string, columns []string, rowsValues [][]string, alias string) *pgQuery.Node {
columnNodes := make([]*pgQuery.Node, len(columns))
for i, column := range columns {
columnNodes[i] = pgQuery.MakeStrNode(column)
Expand All @@ -32,7 +28,7 @@ func (utils *QueryParserUtils) MakeSubselectWithRowsNode(columns []string, rowsV
}

if alias == "" {
alias = DEFAULT_ALIAS
alias = tableName
}

return &pgQuery.Node{
Expand All @@ -54,7 +50,7 @@ func (utils *QueryParserUtils) MakeSubselectWithRowsNode(columns []string, rowsV
}
}

func (utils *QueryParserUtils) MakeSubselectWithoutRowsNode(columns []string, alias string) *pgQuery.Node {
func (utils *QueryParserUtils) MakeSubselectWithoutRowsNode(tableName string, columns []string, alias string) *pgQuery.Node {
columnNodes := make([]*pgQuery.Node, len(columns))
for i, column := range columns {
columnNodes[i] = pgQuery.MakeStrNode(column)
Expand All @@ -69,7 +65,7 @@ func (utils *QueryParserUtils) MakeSubselectWithoutRowsNode(columns []string, al
}

if alias == "" {
alias = DEFAULT_ALIAS
alias = tableName
}

return &pgQuery.Node{
Expand All @@ -92,9 +88,9 @@ func (utils *QueryParserUtils) MakeSubselectWithoutRowsNode(columns []string, al
}
}

func (utils *QueryParserUtils) MakeSubselectFromNode(targetList []*pgQuery.Node, fromNode *pgQuery.Node, alias string) *pgQuery.Node {
func (utils *QueryParserUtils) MakeSubselectFromNode(tableName string, targetList []*pgQuery.Node, fromNode *pgQuery.Node, alias string) *pgQuery.Node {
if alias == "" {
alias = DEFAULT_ALIAS
alias = tableName
}

return &pgQuery.Node{
Expand Down
14 changes: 7 additions & 7 deletions src/select_remapper_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,35 +66,35 @@ func (remapper *SelectRemapperTable) RemapTable(node *pgQuery.Node) *pgQuery.Nod
return node
case PG_TABLE_PG_INHERITS:
// pg_catalog.pg_inherits -> return nothing
tableNode := parser.MakeEmptyTableNode(PG_INHERITS_COLUMNS, qSchemaTable.Alias)
tableNode := parser.MakeEmptyTableNode(PG_TABLE_PG_INHERITS, PG_INHERITS_COLUMNS, qSchemaTable.Alias)
return remapper.overrideTable(node, tableNode)
case PG_TABLE_PG_SHDESCRIPTION:
// pg_catalog.pg_shdescription -> return nothing
tableNode := parser.MakeEmptyTableNode(PG_SHDESCRIPTION_COLUMNS, qSchemaTable.Alias)
tableNode := parser.MakeEmptyTableNode(PG_TABLE_PG_SHDESCRIPTION, PG_SHDESCRIPTION_COLUMNS, qSchemaTable.Alias)
return remapper.overrideTable(node, tableNode)
case PG_TABLE_PG_STATIO_USER_TABLES:
// pg_catalog.pg_statio_user_tables -> return nothing
tableNode := parser.MakeEmptyTableNode(PG_STATIO_USER_TABLES_COLUMNS, qSchemaTable.Alias)
tableNode := parser.MakeEmptyTableNode(PG_TABLE_PG_STATIO_USER_TABLES, PG_STATIO_USER_TABLES_COLUMNS, qSchemaTable.Alias)
return remapper.overrideTable(node, tableNode)
case PG_TABLE_PG_EXTENSION:
// pg_catalog.pg_extension -> return hard-coded extension info
tableNode := parser.MakePgExtensionNode(qSchemaTable.Alias)
return remapper.overrideTable(node, tableNode)
case PG_TABLE_PG_REPLICATION_SLOTS:
// pg_replication_slots -> return nothing
tableNode := parser.MakeEmptyTableNode(PG_REPLICATION_SLOTS_COLUMNS, qSchemaTable.Alias)
tableNode := parser.MakeEmptyTableNode(PG_TABLE_PG_REPLICATION_SLOTS, PG_REPLICATION_SLOTS_COLUMNS, qSchemaTable.Alias)
return remapper.overrideTable(node, tableNode)
case PG_TABLE_PG_DATABASE:
// pg_catalog.pg_database -> return hard-coded database info
tableNode := parser.MakePgDatabaseNode(remapper.config.Database, qSchemaTable.Alias)
return remapper.overrideTable(node, tableNode)
case PG_TABLE_PG_STAT_GSSAPI:
// pg_catalog.pg_stat_gssapi -> return nothing
tableNode := parser.MakeEmptyTableNode(PG_STAT_GSSAPI_COLUMNS, qSchemaTable.Alias)
tableNode := parser.MakeEmptyTableNode(PG_TABLE_PG_STAT_GSSAPI, PG_STAT_GSSAPI_COLUMNS, qSchemaTable.Alias)
return remapper.overrideTable(node, tableNode)
case PG_TABLE_PG_AUTH_MEMBERS:
// pg_catalog.pg_auth_members -> return empty table
tableNode := parser.MakeEmptyTableNode(PG_AUTH_MEMBERS_COLUMNS, qSchemaTable.Alias)
tableNode := parser.MakeEmptyTableNode(PG_TABLE_PG_AUTH_MEMBERS, PG_AUTH_MEMBERS_COLUMNS, qSchemaTable.Alias)
return remapper.overrideTable(node, tableNode)
default:
// pg_catalog.pg_* other system tables -> return as is
Expand Down Expand Up @@ -127,7 +127,7 @@ func (remapper *SelectRemapperTable) RemapTable(node *pgQuery.Node) *pgQuery.Nod
}
}
icebergPath := remapper.icebergReader.MetadataFilePath(schemaTable)
tableNode := parser.MakeIcebergTableNode(icebergPath, qSchemaTable.Alias)
tableNode := parser.MakeIcebergTableNode(icebergPath, qSchemaTable)
return remapper.overrideTable(node, tableNode)
}

Expand Down

0 comments on commit 52d7bb8

Please sign in to comment.