Skip to content

Commit

Permalink
Merge pull request #71 from trocco-io/feature/match_by_column_name
Browse files Browse the repository at this point in the history
support match_by_column_name
  • Loading branch information
d-hrs authored Mar 27, 2024
2 parents 9cfaa02 + 19216a7 commit 5fd4e49
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Snowflake output plugin for Embulk loads records to Snowflake.
- **merge_keys**: key column names for merging records in merge mode (string array, required in merge mode if table doesn't have primary key)
- **merge_rule**: list of column assignments for updating existing records used in merge mode, for example `"foo" = T."foo" + S."foo"` (`T` means target table and `S` means source table). (string array, default: always overwrites with new values)
- **batch_size**: size of a single batch insert (integer, default: 16777216)
- **match_by_column_name**: specify whether to load semi-structured data into columns in the target table that match corresponding columns represented in the data. ("case_sensitive", "case_insensitive", "none", default: "none")
- **default_timezone**: If input column type (embulk type) is timestamp, this plugin needs to format the timestamp into a SQL string. This default_timezone option is used to control the timezone. You can overwrite timezone for each columns using column_options option. (string, default: `UTC`)
- **column_options**: advanced: a key-value pairs where key is a column name and value is options for the column.
- **type**: type of a column when this plugin creates new tables (e.g. `VARCHAR(255)`, `INTEGER NOT NULL UNIQUE`). This used when this plugin creates intermediate tables (insert, truncate_insert and merge modes), when it creates the target table (insert_direct and replace modes), and when it creates nonexistent target table automatically. (string, default: depends on input column type. `BIGINT` if input column type is long, `BOOLEAN` if boolean, `DOUBLE PRECISION` if double, `CLOB` if string, `TIMESTAMP` if timestamp)
Expand Down
78 changes: 78 additions & 0 deletions src/main/java/org/embulk/output/SnowflakeOutputPlugin.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package org.embulk.output;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonValue;
import java.io.IOException;
import java.sql.SQLException;
import java.sql.Types;
import java.util.*;
import java.util.function.BiFunction;
import net.snowflake.client.jdbc.internal.org.bouncycastle.operator.OperatorCreationException;
import net.snowflake.client.jdbc.internal.org.bouncycastle.pkcs.PKCSException;
import org.embulk.config.ConfigDiff;
Expand Down Expand Up @@ -78,6 +81,47 @@ public interface SnowflakePluginTask extends PluginTask {
@Config("delete_stage_on_error")
@ConfigDefault("false")
public boolean getDeleteStageOnError();

@Config("match_by_column_name")
@ConfigDefault("\"none\"")
public MatchByColumnName getMatchByColumnName();

public void setCopyIntoTableColumnNames(String[] columnNames);

public String[] getCopyIntoTableColumnNames();

public void setCopyIntoCSVColumnNumbers(int[] columnNumbers);

public int[] getCopyIntoCSVColumnNumbers();

public enum MatchByColumnName {
CASE_SENSITIVE,
CASE_INSENSITIVE,
NONE;

@JsonValue
@Override
public String toString() {
return name().toLowerCase(Locale.ENGLISH);
}

@JsonCreator
public static MatchByColumnName fromString(String value) {
switch (value) {
case "case_sensitive":
return CASE_SENSITIVE;
case "case_insensitive":
return CASE_INSENSITIVE;
case "none":
return NONE;
default:
throw new ConfigException(
String.format(
"Unknown match_by_column_name '%s'. Supported values are case_sensitive, case_insensitive, none",
value));
}
}
}
}

@Override
Expand Down Expand Up @@ -187,6 +231,38 @@ protected void doBegin(
JdbcOutputConnection con, PluginTask task, final Schema schema, int taskCount)
throws SQLException {
super.doBegin(con, task, schema, taskCount);

SnowflakePluginTask pluginTask = (SnowflakePluginTask) task;
SnowflakePluginTask.MatchByColumnName matchByColumnName = pluginTask.getMatchByColumnName();
if (matchByColumnName == SnowflakePluginTask.MatchByColumnName.NONE) {
pluginTask.setCopyIntoCSVColumnNumbers(new int[0]);
pluginTask.setCopyIntoTableColumnNames(new String[0]);
return;
}

List<String> copyIntoTableColumnNames = new ArrayList<>();
List<Integer> copyIntoCSVColumnNumbers = new ArrayList<>();
JdbcSchema targetTableSchema = pluginTask.getTargetTableSchema();
BiFunction<String, String, Boolean> compare =
matchByColumnName == SnowflakePluginTask.MatchByColumnName.CASE_SENSITIVE
? String::equals
: String::equalsIgnoreCase;
int columnNumber = 1;
for (int i = 0; i < targetTableSchema.getCount(); i++) {
JdbcColumn targetColumn = targetTableSchema.getColumn(i);
if (targetColumn.isSkipColumn()) {
continue;
}
Column schemaColumn = schema.getColumn(i);
if (compare.apply(schemaColumn.getName(), targetColumn.getName())) {
copyIntoTableColumnNames.add(targetColumn.getName());
copyIntoCSVColumnNumbers.add(columnNumber);
}
columnNumber += 1;
}
pluginTask.setCopyIntoTableColumnNames(copyIntoTableColumnNames.toArray(new String[0]));
pluginTask.setCopyIntoCSVColumnNumbers(
copyIntoCSVColumnNumbers.stream().mapToInt(i -> i).toArray());
}

@Override
Expand All @@ -201,6 +277,8 @@ protected BatchInsert newBatchInsert(PluginTask task, Optional<MergeConfig> merg
return new SnowflakeCopyBatchInsert(
getConnector(task, true),
StageIdentifierHolder.getStageIdentifier(pluginTask),
pluginTask.getCopyIntoTableColumnNames(),
pluginTask.getCopyIntoCSVColumnNumbers(),
false,
pluginTask.getMaxUploadRetries(),
pluginTask.getEmtpyFieldAsNull());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,15 @@ public class SnowflakeCopyBatchInsert implements BatchInsert {
private List<Future<Void>> uploadAndCopyFutures;
private boolean emptyFieldAsNull;

private String[] copyIntoTableColumnNames;

private int[] copyIntoCSVColumnNumbers;

public SnowflakeCopyBatchInsert(
JdbcOutputConnector connector,
StageIdentifier stageIdentifier,
String[] copyIntoTableColumnNames,
int[] copyIntoCSVColumnNumbers,
boolean deleteStageFile,
int maxUploadRetries,
boolean emptyFieldAsNull)
Expand All @@ -51,6 +57,8 @@ public SnowflakeCopyBatchInsert(
openNewFile();
this.connector = connector;
this.stageIdentifier = stageIdentifier;
this.copyIntoTableColumnNames = copyIntoTableColumnNames;
this.copyIntoCSVColumnNumbers = copyIntoCSVColumnNumbers;
this.executorService = Executors.newCachedThreadPool();
this.deleteStageFile = deleteStageFile;
this.uploadAndCopyFutures = new ArrayList();
Expand Down Expand Up @@ -417,6 +425,8 @@ public Void call() throws SQLException, InterruptedException, ExecutionException
tableIdentifier,
stageIdentifier,
snowflakeStageFileName,
copyIntoTableColumnNames,
copyIntoCSVColumnNumbers,
delimiterString,
emptyFieldAsNull);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,24 @@ public void runCopy(
TableIdentifier tableIdentifier,
StageIdentifier stageIdentifier,
String filename,
String[] tableColumnNames,
int[] csvColumnNumbers,
String delimiterString,
boolean emptyFieldAsNull)
throws SQLException {
String sql =
buildCopySQL(tableIdentifier, stageIdentifier, filename, delimiterString, emptyFieldAsNull);
tableColumnNames != null && tableColumnNames.length > 0
? buildCopySQL(
tableIdentifier,
stageIdentifier,
filename,
tableColumnNames,
csvColumnNumbers,
delimiterString,
emptyFieldAsNull)
: buildCopySQL(
tableIdentifier, stageIdentifier, filename, delimiterString, emptyFieldAsNull);

runUpdate(sql);
}

Expand Down Expand Up @@ -196,6 +209,50 @@ protected String buildCopySQL(
return sb.toString();
}

protected String buildCopySQL(
TableIdentifier tableIdentifier,
StageIdentifier stageIdentifier,
String snowflakeStageFileName,
String[] tableColumnNames,
int[] csvColumnNumbers,
String delimiterString,
boolean emptyFieldAsNull) {
// Data load with transformation
// Correspondence between CSV column numbers and table column names can be specified.
// https://docs.snowflake.com/ja/sql-reference/sql/copy-into-table

StringBuilder sb = new StringBuilder();
sb.append("COPY INTO ");
quoteTableIdentifier(sb, tableIdentifier);
sb.append(" (");
for (int i = 0; i < tableColumnNames.length; i++) {
if (i != 0) {
sb.append(", ");
}
String column = quoteIdentifierString(tableColumnNames[i]);
sb.append(column);
}
sb.append(" ) FROM ( SELECT ");
for (int i = 0; i < csvColumnNumbers.length; i++) {
if (i != 0) {
sb.append(", ");
}
sb.append("t.$");
sb.append(csvColumnNumbers[i]);
}
sb.append(" from ");
quoteInternalStoragePath(sb, stageIdentifier, snowflakeStageFileName);
sb.append(" t ) ");
sb.append(" FILE_FORMAT = ( TYPE = CSV FIELD_DELIMITER = '");
sb.append(delimiterString);
sb.append("'");
if (!emptyFieldAsNull) {
sb.append(" EMPTY_FIELD_AS_NULL = FALSE");
}
sb.append(" );");
return sb.toString();
}

protected String buildDeleteStageFileSQL(
StageIdentifier stageIdentifier, String snowflakeStageFileName) {
StringBuilder sb = new StringBuilder();
Expand Down

0 comments on commit 5fd4e49

Please sign in to comment.