Skip to content

Commit

Permalink
Issue apache#32280/ bug fix: WHERE segment with ON CONFLICT segment i…
Browse files Browse the repository at this point in the history
…n INSERT throws exception (#2)

* add support for WHERE segment with ON CONFLICT segment in INSERT statement of postgres

* updated RELEASE-NOTES.md

* remove redundant commented code
  • Loading branch information
omkar-shitole authored Jan 2, 2025
1 parent 27b4042 commit ed4a82e
Show file tree
Hide file tree
Showing 13 changed files with 471 additions and 25 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
1. Encrypt: Use sql bind info in EncryptInsertPredicateColumnTokenGenerator to avoid wrong column table mapping - [#34110](https://github.com/apache/shardingsphere/pull/34110)
1. Mode: Fixes `JDBCRepository` improper handling of H2-database in memory mode - [#33281](https://github.com/apache/shardingsphere/issues/33281)
1. Mode: Fixes duplicate column names added when index changed in DDL - [#33982](https://github.com/apache/shardingsphere/issues/33281)
1. SQL Binder: Fixes bug: throwing exception while using WHERE statement in ON CONFLICT with INSERT INTO in Postgres [#32280](https://github.com/apache/shardingsphere/issues/32280)

### Change Logs

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.shardingsphere.infra.binder.context.segment.insert.values;

import com.google.common.base.Preconditions;
import lombok.Getter;
import lombok.ToString;
import org.apache.shardingsphere.infra.binder.context.type.WhereAvailable;
import org.apache.shardingsphere.sql.parser.statement.core.extractor.ColumnExtractor;
import org.apache.shardingsphere.sql.parser.statement.core.extractor.ExpressionExtractor;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.assignment.ColumnAssignmentSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.BinaryOperationExpression;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.FunctionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.simple.LiteralExpressionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.WhereSegment;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;

@Getter
@ToString
public final class OnConflictUpdateContext implements WhereAvailable {

private final int parameterCount;

private final List<ExpressionSegment> valueExpressions;

private final Collection<WhereSegment> whereSegments = new LinkedList<>();

private final Collection<ColumnSegment> columnSegments;

private final Collection<BinaryOperationExpression> joinConditions = new LinkedList<>();

private final List<ParameterMarkerExpressionSegment> parameterMarkerExpressions;

private final List<Object> parameters;

public OnConflictUpdateContext(final Collection<ColumnAssignmentSegment> assignments, final List<Object> params, final int parametersOffset, Optional<WhereSegment> segment) {
List<ExpressionSegment> expressionSegments = assignments.stream().map(ColumnAssignmentSegment::getValue).collect(Collectors.toList());
segment.ifPresent(whereSegments::add);
for (WhereSegment whereSegment : whereSegments) {
expressionSegments.add(whereSegment.getExpr());
}
columnSegments = assignments.stream().map(each -> each.getColumns().get(0)).collect(Collectors.toList());
ColumnExtractor.extractColumnSegments(columnSegments, whereSegments);
ExpressionExtractor.extractJoinConditions(joinConditions, whereSegments);
valueExpressions = getValueExpressions(expressionSegments);
parameterMarkerExpressions = ExpressionExtractor.getParameterMarkerExpressions(expressionSegments);
parameterCount = parameterMarkerExpressions.size();
parameters = getParameters(params, parametersOffset);
}

private List<ExpressionSegment> getValueExpressions(final Collection<ExpressionSegment> assignments) {
List<ExpressionSegment> result = new ArrayList<>(assignments.size());
result.addAll(assignments);
return result;
}

private List<Object> getParameters(final List<Object> params, final int paramsOffset) {
if (params.isEmpty() || 0 == parameterCount) {
return Collections.emptyList();
}
List<Object> result = new ArrayList<>(parameterCount);
result.addAll(params.subList(paramsOffset, paramsOffset + parameterCount));
return result;
}

/**
* Get value.
*
* @param index index
* @return value
*/
public Object getValue(final int index) {
ExpressionSegment valueExpression = valueExpressions.get(index);
if (valueExpression instanceof ParameterMarkerExpressionSegment) {
return parameters.get(getParameterIndex((ParameterMarkerExpressionSegment) valueExpression));
}
if (valueExpression instanceof FunctionSegment) {
return valueExpression;
}
return ((LiteralExpressionSegment) valueExpression).getLiterals();
}

private int getParameterIndex(final ParameterMarkerExpressionSegment paramMarkerExpression) {
int result = parameterMarkerExpressions.indexOf(paramMarkerExpression);
Preconditions.checkArgument(result >= 0, "Can not get parameter index.");
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.shardingsphere.infra.binder.context.segment.insert.values.InsertSelectContext;
import org.apache.shardingsphere.infra.binder.context.segment.insert.values.InsertValueContext;
import org.apache.shardingsphere.infra.binder.context.segment.insert.values.OnDuplicateUpdateContext;
import org.apache.shardingsphere.infra.binder.context.segment.insert.values.OnConflictUpdateContext;
import org.apache.shardingsphere.infra.binder.context.segment.table.TablesContext;
import org.apache.shardingsphere.infra.binder.context.statement.CommonSQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.type.TableAvailable;
Expand All @@ -43,6 +44,7 @@
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.assignment.InsertValuesSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.assignment.SetAssignmentSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.OnConflictKeyColumnsSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.OnDuplicateKeyColumnsSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.combine.CombineSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.BinaryOperationExpression;
Expand All @@ -52,6 +54,7 @@
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.WithSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.SimpleTableSegment;
import org.apache.shardingsphere.sql.parser.statement.core.statement.dml.InsertStatement;
import org.apache.shardingsphere.sql.parser.statement.core.extractor.TableExtractor;

import java.util.ArrayList;
import java.util.Collection;
Expand Down Expand Up @@ -92,6 +95,9 @@ public final class InsertStatementContext extends CommonSQLStatementContext impl
@Getter
private OnDuplicateUpdateContext onDuplicateKeyUpdateValueContext;

@Getter
private OnConflictUpdateContext onConflictKeyUpdateValueContext;

private GeneratedKeyContext generatedKeyContext;

public InsertStatementContext(final ShardingSphereMetaData metaData, final List<Object> params, final InsertStatement sqlStatement, final String currentDatabaseName) {
Expand Down Expand Up @@ -179,6 +185,18 @@ private Optional<OnDuplicateUpdateContext> getOnDuplicateKeyUpdateValueContext(f
return Optional.of(onDuplicateUpdateContext);
}

private Optional<OnConflictUpdateContext> getOnConflictKeyUpdateValueContext(final List<Object> params, final AtomicInteger parametersOffset) {
Optional<OnConflictKeyColumnsSegment> onConflictKeyColumnsSegment = getSqlStatement().getOnConflictKeyColumns();
if (!onConflictKeyColumnsSegment.isPresent()) {
return Optional.empty();
}
Collection<ColumnAssignmentSegment> onConflictKeyColumns = onConflictKeyColumnsSegment.get().getColumns();
Optional<WhereSegment> whereSegment = getSqlStatement().getOnConflictKeyColumns().flatMap(OnConflictKeyColumnsSegment::getWhere);
OnConflictUpdateContext onConflictUpdateContext = new OnConflictUpdateContext(onConflictKeyColumns, params, parametersOffset.get(), whereSegment);
parametersOffset.addAndGet(onConflictUpdateContext.getParameterCount());
return Optional.of(onConflictUpdateContext);
}

private Collection<SimpleTableSegment> getAllSimpleTableSegments() {
TableExtractor tableExtractor = new TableExtractor();
tableExtractor.extractTablesFromInsert(getSqlStatement());
Expand Down Expand Up @@ -228,6 +246,15 @@ public List<Object> getOnDuplicateKeyUpdateParameters() {
return null == onDuplicateKeyUpdateValueContext ? new ArrayList<>() : onDuplicateKeyUpdateValueContext.getParameters();
}

/**
* Get on duplicate key update parameters.
*
* @return on duplicate key update parameters
*/
public List<Object> getOnConflictKeyUpdateParameters() {
return null == onConflictKeyUpdateValueContext ? new ArrayList<>() : onConflictKeyUpdateValueContext.getParameters();
}

/**
* Get generated key context.
*
Expand Down Expand Up @@ -315,6 +342,7 @@ public void setUpParameters(final List<Object> params) {
insertValueContexts = getInsertValueContexts(params, parametersOffset, valueExpressions);
insertSelectContext = getInsertSelectContext(metaData, params, parametersOffset, currentDatabaseName).orElse(null);
onDuplicateKeyUpdateValueContext = getOnDuplicateKeyUpdateValueContext(params, parametersOffset).orElse(null);
onConflictKeyUpdateValueContext = getOnConflictKeyUpdateValueContext(params, parametersOffset).orElse(null);
ShardingSphereSchema schema = getSchema(metaData, currentDatabaseName);
generatedKeyContext = new GeneratedKeyContextEngine(getSqlStatement(), schema).createGenerateKeyContext(insertColumnNamesAndIndexes, insertValueContexts, params).orElse(null);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ColumnProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.statement.dml.InsertStatement;
import org.apache.shardingsphere.sql.parser.statement.postgresql.dml.PostgreSQLInsertStatement;

import java.util.Collection;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -63,6 +64,12 @@ private InsertStatement copy(final InsertStatement sqlStatement) {
result.getValues().addAll(sqlStatement.getValues());
sqlStatement.getSetAssignment().ifPresent(result::setSetAssignment);
sqlStatement.getOnDuplicateKeyColumns().ifPresent(result::setOnDuplicateKeyColumns);
sqlStatement.getOnConflictKeyColumns().ifPresent(segment -> {
if (result instanceof PostgreSQLInsertStatement) {
((PostgreSQLInsertStatement) result).setOnConflictKeyColumnsSegment(segment);
}
});
sqlStatement.getWithSegment().ifPresent(result::setWithSegment);
sqlStatement.getOutputSegment().ifPresent(result::setOutputSegment);
sqlStatement.getMultiTableInsertType().ifPresent(result::setMultiTableInsertType);
sqlStatement.getMultiTableInsertIntoSegment().ifPresent(result::setMultiTableInsertIntoSegment);
Expand Down
Loading

0 comments on commit ed4a82e

Please sign in to comment.