Skip to content

Commit

Permalink
Return proper messages when handling empty queries
Browse files Browse the repository at this point in the history
  • Loading branch information
exAspArk committed Jan 9, 2025
1 parent 6278ad8 commit a642cc0
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 20 deletions.
33 changes: 23 additions & 10 deletions src/query_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,10 @@ func (queryHandler *QueryHandler) HandleQuery(originalQuery string) ([]pgproto3.
return nil, err
}

if query == "" {
return []pgproto3.Message{&pgproto3.EmptyQueryResponse{}}, nil
}

rows, err := queryHandler.duckdb.QueryContext(context.Background(), query)
if err != nil {
errorMessage := err.Error()
Expand Down Expand Up @@ -217,23 +221,24 @@ func (queryHandler *QueryHandler) HandleParseQuery(message *pgproto3.Parse) ([]p
return nil, nil, err
}

statement, err := queryHandler.duckdb.PrepareContext(ctx, query)
if err != nil {
LogError(queryHandler.config, "Couldn't prepare query via DuckDB:", query+"\n"+err.Error())
return nil, nil, err
}

preparedStatement := &PreparedStatement{
Name: message.Name,
OriginalQuery: originalQuery,
Query: query,
Statement: statement,
ParameterOIDs: message.ParameterOIDs,
}
if query == "" {
return []pgproto3.Message{&pgproto3.EmptyQueryResponse{}}, preparedStatement, nil
}

messages := []pgproto3.Message{&pgproto3.ParseComplete{}}
statement, err := queryHandler.duckdb.PrepareContext(ctx, query)
preparedStatement.Statement = statement
if err != nil {
LogError(queryHandler.config, "Couldn't prepare query via DuckDB:", query+"\n"+err.Error())
return nil, nil, err
}

return messages, preparedStatement, nil
return []pgproto3.Message{&pgproto3.ParseComplete{}}, preparedStatement, nil
}

func (queryHandler *QueryHandler) HandleBindQuery(message *pgproto3.Bind, preparedStatement *PreparedStatement) ([]pgproto3.Message, *PreparedStatement, error) {
Expand Down Expand Up @@ -287,6 +292,10 @@ func (queryHandler *QueryHandler) HandleDescribeQuery(message *pgproto3.Describe
}
}

if preparedStatement.Query == "" {
return []pgproto3.Message{&pgproto3.NoData{}}, preparedStatement, nil
}

if len(preparedStatement.ParameterOIDs) != len(preparedStatement.Variables) { // Bind step didn't happen before
return []pgproto3.Message{&pgproto3.NoData{}}, preparedStatement, nil
}
Expand All @@ -311,6 +320,10 @@ func (queryHandler *QueryHandler) HandleExecuteQuery(message *pgproto3.Execute,
return nil, errors.New("portal mismatch")
}

if preparedStatement.Query == "" {
return []pgproto3.Message{&pgproto3.EmptyQueryResponse{}}, nil
}

if preparedStatement.Rows == nil { // If Describe step didn't have Bind step before
rows, err := preparedStatement.Statement.QueryContext(context.Background(), preparedStatement.Variables...)
if err != nil {
Expand Down Expand Up @@ -396,7 +409,7 @@ func (queryHandler *QueryHandler) remapQuery(query string) (string, error) {
}

if strings.HasSuffix(query, " --INSPECT") {
LogDebug(queryHandler.config, queryTree.Stmts[0].Stmt)
LogDebug(queryHandler.config, queryTree.Stmts)
}

queryTree.Stmts, err = queryHandler.queryRemapper.RemapStatements(queryTree.Stmts)
Expand Down
52 changes: 45 additions & 7 deletions src/query_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,13 +233,6 @@ func TestHandleQuery(t *testing.T) {
"values": {"memory", "public", "test_table"},
},

// Empty query
"-- ping": {
"description": {"1"},
"types": {Uint32ToString(pgtype.Int4OID)},
"values": {"1"},
},

// DISCARD
"DISCARD ALL": {
"description": {"1"},
Expand Down Expand Up @@ -921,6 +914,17 @@ func TestHandleQuery(t *testing.T) {
testDataRowValues(t, messages[1], []string{"UTC"})
testCommandCompleteTag(t, messages[2], "SHOW")
})

t.Run("Handles an empty query", func(t *testing.T) {
queryHandler := initQueryHandler()

messages, err := queryHandler.HandleQuery("-- ping")

testNoError(t, err)
testMessageTypes(t, messages, []pgproto3.Message{
&pgproto3.EmptyQueryResponse{},
})
})
}

func TestHandleParseQuery(t *testing.T) {
Expand Down Expand Up @@ -1024,6 +1028,22 @@ func TestHandleDescribeQuery(t *testing.T) {
}
})

t.Run("Handles DESCRIBE extended query step if query is empty", func(t *testing.T) {
queryHandler := initQueryHandler()
parseMessage := &pgproto3.Parse{Query: ""}
_, preparedStatement, _ := queryHandler.HandleParseQuery(parseMessage)
bindMessage := &pgproto3.Bind{}
_, preparedStatement, _ = queryHandler.HandleBindQuery(bindMessage, preparedStatement)
message := &pgproto3.Describe{ObjectType: 'P'}

messages, _, err := queryHandler.HandleDescribeQuery(message, preparedStatement)

testNoError(t, err)
testMessageTypes(t, messages, []pgproto3.Message{
&pgproto3.NoData{},
})
})

t.Run("Handles DESCRIBE (Statement) extended query step if there was no BIND step", func(t *testing.T) {
queryHandler := initQueryHandler()
query := "SELECT usename, passwd FROM pg_shadow WHERE usename=$1"
Expand Down Expand Up @@ -1061,6 +1081,24 @@ func TestHandleExecuteQuery(t *testing.T) {
})
testDataRowValues(t, messages[0], []string{"bemidb", "bemidb-encrypted"})
})

t.Run("Handles EXECUTE extended query step if query is empty", func(t *testing.T) {
queryHandler := initQueryHandler()
parseMessage := &pgproto3.Parse{Query: ""}
_, preparedStatement, _ := queryHandler.HandleParseQuery(parseMessage)
bindMessage := &pgproto3.Bind{}
_, preparedStatement, _ = queryHandler.HandleBindQuery(bindMessage, preparedStatement)
describeMessage := &pgproto3.Describe{ObjectType: 'P'}
_, preparedStatement, _ = queryHandler.HandleDescribeQuery(describeMessage, preparedStatement)
message := &pgproto3.Execute{}

messages, err := queryHandler.HandleExecuteQuery(message, preparedStatement)

testNoError(t, err)
testMessageTypes(t, messages, []pgproto3.Message{
&pgproto3.EmptyQueryResponse{},
})
})
}

func TestHandleMultipleQueries(t *testing.T) {
Expand Down
7 changes: 4 additions & 3 deletions src/query_remapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,18 @@ func NewQueryRemapper(config *Config, icebergReader *IcebergReader, duckdb *Duck
}

func (remapper *QueryRemapper) RemapStatements(statements []*pgQuery.RawStmt) ([]*pgQuery.RawStmt, error) {
// Empty query
if len(statements) == 0 {
return FALLBACK_QUERY_TREE.Stmts, nil
return statements, nil
}

for i, stmt := range statements {
node := stmt.Stmt

switch {
// Empty query
// Empty statement
case node == nil:
return nil, errors.New("empty query")
return nil, errors.New("empty statement")

// SELECT ...
case node.GetSelectStmt() != nil:
Expand Down

0 comments on commit a642cc0

Please sign in to comment.