diff --git a/README.md b/README.md index 2d27ffc..17220a7 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ Snowflake output plugin for Embulk loads records to Snowflake. - **merge_keys**: key column names for merging records in merge mode (string array, required in merge mode if table doesn't have primary key) - **merge_rule**: list of column assignments for updating existing records used in merge mode, for example `"foo" = T."foo" + S."foo"` (`T` means target table and `S` means source table). (string array, default: always overwrites with new values) - **batch_size**: size of a single batch insert (integer, default: 16777216) +- **match_by_column_name**: specify whether to load semi-structured data into columns in the target table that match corresponding columns represented in the data. ("case_sensitive", "case_insensitive", "none", default: "none") - **default_timezone**: If input column type (embulk type) is timestamp, this plugin needs to format the timestamp into a SQL string. This default_timezone option is used to control the timezone. You can overwrite timezone for each columns using column_options option. (string, default: `UTC`) - **column_options**: advanced: a key-value pairs where key is a column name and value is options for the column. - **type**: type of a column when this plugin creates new tables (e.g. `VARCHAR(255)`, `INTEGER NOT NULL UNIQUE`). This used when this plugin creates intermediate tables (insert, truncate_insert and merge modes), when it creates the target table (insert_direct and replace modes), and when it creates nonexistent target table automatically. (string, default: depends on input column type. `BIGINT` if input column type is long, `BOOLEAN` if boolean, `DOUBLE PRECISION` if double, `CLOB` if string, `TIMESTAMP` if timestamp) diff --git a/src/main/java/org/embulk/output/SnowflakeOutputPlugin.java b/src/main/java/org/embulk/output/SnowflakeOutputPlugin.java index 9cafe3c..7f34019 100644 --- a/src/main/java/org/embulk/output/SnowflakeOutputPlugin.java +++ b/src/main/java/org/embulk/output/SnowflakeOutputPlugin.java @@ -1,9 +1,12 @@ package org.embulk.output; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonValue; import java.io.IOException; import java.sql.SQLException; import java.sql.Types; import java.util.*; +import java.util.function.BiFunction; import net.snowflake.client.jdbc.internal.org.bouncycastle.operator.OperatorCreationException; import net.snowflake.client.jdbc.internal.org.bouncycastle.pkcs.PKCSException; import org.embulk.config.ConfigDiff; @@ -78,6 +81,47 @@ public interface SnowflakePluginTask extends PluginTask { @Config("delete_stage_on_error") @ConfigDefault("false") public boolean getDeleteStageOnError(); + + @Config("match_by_column_name") + @ConfigDefault("\"none\"") + public MatchByColumnName getMatchByColumnName(); + + public void setCopyIntoTableColumnNames(String[] columnNames); + + public String[] getCopyIntoTableColumnNames(); + + public void setCopyIntoCSVColumnNumbers(int[] columnNumbers); + + public int[] getCopyIntoCSVColumnNumbers(); + + public enum MatchByColumnName { + CASE_SENSITIVE, + CASE_INSENSITIVE, + NONE; + + @JsonValue + @Override + public String toString() { + return name().toLowerCase(Locale.ENGLISH); + } + + @JsonCreator + public static MatchByColumnName fromString(String value) { + switch (value) { + case "case_sensitive": + return CASE_SENSITIVE; + case "case_insensitive": + return CASE_INSENSITIVE; + case "none": + return NONE; + default: + throw new ConfigException( + String.format( + "Unknown match_by_column_name '%s'. Supported values are case_sensitive, case_insensitive, none", + value)); + } + } + } } @Override @@ -187,6 +231,38 @@ protected void doBegin( JdbcOutputConnection con, PluginTask task, final Schema schema, int taskCount) throws SQLException { super.doBegin(con, task, schema, taskCount); + + SnowflakePluginTask pluginTask = (SnowflakePluginTask) task; + SnowflakePluginTask.MatchByColumnName matchByColumnName = pluginTask.getMatchByColumnName(); + if (matchByColumnName == SnowflakePluginTask.MatchByColumnName.NONE) { + pluginTask.setCopyIntoCSVColumnNumbers(new int[0]); + pluginTask.setCopyIntoTableColumnNames(new String[0]); + return; + } + + List copyIntoTableColumnNames = new ArrayList<>(); + List copyIntoCSVColumnNumbers = new ArrayList<>(); + JdbcSchema targetTableSchema = pluginTask.getTargetTableSchema(); + BiFunction compare = + matchByColumnName == SnowflakePluginTask.MatchByColumnName.CASE_SENSITIVE + ? String::equals + : String::equalsIgnoreCase; + int columnNumber = 1; + for (int i = 0; i < targetTableSchema.getCount(); i++) { + JdbcColumn targetColumn = targetTableSchema.getColumn(i); + if (targetColumn.isSkipColumn()) { + continue; + } + Column schemaColumn = schema.getColumn(i); + if (compare.apply(schemaColumn.getName(), targetColumn.getName())) { + copyIntoTableColumnNames.add(targetColumn.getName()); + copyIntoCSVColumnNumbers.add(columnNumber); + } + columnNumber += 1; + } + pluginTask.setCopyIntoTableColumnNames(copyIntoTableColumnNames.toArray(new String[0])); + pluginTask.setCopyIntoCSVColumnNumbers( + copyIntoCSVColumnNumbers.stream().mapToInt(i -> i).toArray()); } @Override @@ -201,6 +277,8 @@ protected BatchInsert newBatchInsert(PluginTask task, Optional merg return new SnowflakeCopyBatchInsert( getConnector(task, true), StageIdentifierHolder.getStageIdentifier(pluginTask), + pluginTask.getCopyIntoTableColumnNames(), + pluginTask.getCopyIntoCSVColumnNumbers(), false, pluginTask.getMaxUploadRetries(), pluginTask.getEmtpyFieldAsNull()); diff --git a/src/main/java/org/embulk/output/snowflake/SnowflakeCopyBatchInsert.java b/src/main/java/org/embulk/output/snowflake/SnowflakeCopyBatchInsert.java index a1ebd8c..724cee5 100644 --- a/src/main/java/org/embulk/output/snowflake/SnowflakeCopyBatchInsert.java +++ b/src/main/java/org/embulk/output/snowflake/SnowflakeCopyBatchInsert.java @@ -40,9 +40,15 @@ public class SnowflakeCopyBatchInsert implements BatchInsert { private List> uploadAndCopyFutures; private boolean emptyFieldAsNull; + private String[] copyIntoTableColumnNames; + + private int[] copyIntoCSVColumnNumbers; + public SnowflakeCopyBatchInsert( JdbcOutputConnector connector, StageIdentifier stageIdentifier, + String[] copyIntoTableColumnNames, + int[] copyIntoCSVColumnNumbers, boolean deleteStageFile, int maxUploadRetries, boolean emptyFieldAsNull) @@ -51,6 +57,8 @@ public SnowflakeCopyBatchInsert( openNewFile(); this.connector = connector; this.stageIdentifier = stageIdentifier; + this.copyIntoTableColumnNames = copyIntoTableColumnNames; + this.copyIntoCSVColumnNumbers = copyIntoCSVColumnNumbers; this.executorService = Executors.newCachedThreadPool(); this.deleteStageFile = deleteStageFile; this.uploadAndCopyFutures = new ArrayList(); @@ -417,6 +425,8 @@ public Void call() throws SQLException, InterruptedException, ExecutionException tableIdentifier, stageIdentifier, snowflakeStageFileName, + copyIntoTableColumnNames, + copyIntoCSVColumnNumbers, delimiterString, emptyFieldAsNull); diff --git a/src/main/java/org/embulk/output/snowflake/SnowflakeOutputConnection.java b/src/main/java/org/embulk/output/snowflake/SnowflakeOutputConnection.java index 1168629..a46bc2b 100644 --- a/src/main/java/org/embulk/output/snowflake/SnowflakeOutputConnection.java +++ b/src/main/java/org/embulk/output/snowflake/SnowflakeOutputConnection.java @@ -25,11 +25,24 @@ public void runCopy( TableIdentifier tableIdentifier, StageIdentifier stageIdentifier, String filename, + String[] tableColumnNames, + int[] csvColumnNumbers, String delimiterString, boolean emptyFieldAsNull) throws SQLException { String sql = - buildCopySQL(tableIdentifier, stageIdentifier, filename, delimiterString, emptyFieldAsNull); + tableColumnNames != null && tableColumnNames.length > 0 + ? buildCopySQL( + tableIdentifier, + stageIdentifier, + filename, + tableColumnNames, + csvColumnNumbers, + delimiterString, + emptyFieldAsNull) + : buildCopySQL( + tableIdentifier, stageIdentifier, filename, delimiterString, emptyFieldAsNull); + runUpdate(sql); } @@ -196,6 +209,50 @@ protected String buildCopySQL( return sb.toString(); } + protected String buildCopySQL( + TableIdentifier tableIdentifier, + StageIdentifier stageIdentifier, + String snowflakeStageFileName, + String[] tableColumnNames, + int[] csvColumnNumbers, + String delimiterString, + boolean emptyFieldAsNull) { + // Data load with transformation + // Correspondence between CSV column numbers and table column names can be specified. + // https://docs.snowflake.com/ja/sql-reference/sql/copy-into-table + + StringBuilder sb = new StringBuilder(); + sb.append("COPY INTO "); + quoteTableIdentifier(sb, tableIdentifier); + sb.append(" ("); + for (int i = 0; i < tableColumnNames.length; i++) { + if (i != 0) { + sb.append(", "); + } + String column = quoteIdentifierString(tableColumnNames[i]); + sb.append(column); + } + sb.append(" ) FROM ( SELECT "); + for (int i = 0; i < csvColumnNumbers.length; i++) { + if (i != 0) { + sb.append(", "); + } + sb.append("t.$"); + sb.append(csvColumnNumbers[i]); + } + sb.append(" from "); + quoteInternalStoragePath(sb, stageIdentifier, snowflakeStageFileName); + sb.append(" t ) "); + sb.append(" FILE_FORMAT = ( TYPE = CSV FIELD_DELIMITER = '"); + sb.append(delimiterString); + sb.append("'"); + if (!emptyFieldAsNull) { + sb.append(" EMPTY_FIELD_AS_NULL = FALSE"); + } + sb.append(" );"); + return sb.toString(); + } + protected String buildDeleteStageFileSQL( StageIdentifier stageIdentifier, String snowflakeStageFileName) { StringBuilder sb = new StringBuilder();