Skip to content

Commit

Permalink
Merge remote-tracking branch 'remotes/origin/main' into feature/match…
Browse files Browse the repository at this point in the history
…_by_column_name
  • Loading branch information
d-hrs committed Mar 26, 2024
2 parents 2d4b9bf + 9cfaa02 commit 19216a7
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 29 deletions.
64 changes: 36 additions & 28 deletions src/main/java/org/embulk/output/SnowflakeOutputPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import net.snowflake.client.jdbc.internal.org.bouncycastle.pkcs.PKCSException;
import org.embulk.config.ConfigDiff;
import org.embulk.config.ConfigException;
import org.embulk.config.ConfigSource;
import org.embulk.config.TaskSource;
import org.embulk.output.jdbc.*;
import org.embulk.output.snowflake.PrivateKeyReader;
Expand All @@ -27,8 +28,6 @@
import org.embulk.util.config.ConfigDefault;

public class SnowflakeOutputPlugin extends AbstractJdbcOutputPlugin {
private StageIdentifier stageIdentifier;

public interface SnowflakePluginTask extends PluginTask {
@Config("driver_path")
@ConfigDefault("null")
Expand Down Expand Up @@ -79,6 +78,10 @@ public interface SnowflakePluginTask extends PluginTask {
@ConfigDefault("true")
public boolean getEmtpyFieldAsNull();

@Config("delete_stage_on_error")
@ConfigDefault("false")
public boolean getDeleteStageOnError();

@Config("match_by_column_name")
@ConfigDefault("\"none\"")
public MatchByColumnName getMatchByColumnName();
Expand Down Expand Up @@ -188,25 +191,39 @@ protected JdbcOutputConnector getConnector(PluginTask task, boolean retryableMet
}

@Override
public ConfigDiff resume(
TaskSource taskSource, Schema schema, int taskCount, OutputPlugin.Control control) {
throw new UnsupportedOperationException("snowflake output plugin does not support resuming");
}

@Override
protected void doCommit(JdbcOutputConnection con, PluginTask task, int taskCount)
throws SQLException {
super.doCommit(con, task, taskCount);
SnowflakeOutputConnection snowflakeCon = (SnowflakeOutputConnection) con;

public ConfigDiff transaction(
ConfigSource config, Schema schema, int taskCount, OutputPlugin.Control control) {
PluginTask task = CONFIG_MAPPER.map(config, this.getTaskClass());
SnowflakePluginTask t = (SnowflakePluginTask) task;
if (this.stageIdentifier == null) {
this.stageIdentifier = StageIdentifierHolder.getStageIdentifier(t);
StageIdentifier stageIdentifier = StageIdentifierHolder.getStageIdentifier(t);
ConfigDiff configDiff;
SnowflakeOutputConnection snowflakeCon = null;

try {
snowflakeCon = (SnowflakeOutputConnection) getConnector(task, true).connect(true);
snowflakeCon.runCreateStage(stageIdentifier);
configDiff = super.transaction(config, schema, taskCount, control);
if (t.getDeleteStage()) {
snowflakeCon.runDropStage(stageIdentifier);
}
} catch (Exception e) {
if (t.getDeleteStage() && t.getDeleteStageOnError()) {
try {
snowflakeCon.runDropStage(stageIdentifier);
} catch (SQLException ex) {
throw new RuntimeException(ex);
}
}
throw new RuntimeException(e);
}

if (t.getDeleteStage()) {
snowflakeCon.runDropStage(this.stageIdentifier);
}
return configDiff;
}

@Override
public ConfigDiff resume(
TaskSource taskSource, Schema schema, int taskCount, OutputPlugin.Control control) {
throw new UnsupportedOperationException("snowflake output plugin does not support resuming");
}

@Override
Expand Down Expand Up @@ -255,20 +272,11 @@ protected BatchInsert newBatchInsert(PluginTask task, Optional<MergeConfig> merg
throw new UnsupportedOperationException(
"Snowflake output plugin doesn't support 'merge_direct' mode.");
}

SnowflakePluginTask t = (SnowflakePluginTask) task;
// TODO: put some where executes once
if (this.stageIdentifier == null) {
SnowflakeOutputConnection snowflakeCon =
(SnowflakeOutputConnection) getConnector(task, true).connect(true);
this.stageIdentifier = StageIdentifierHolder.getStageIdentifier(t);
snowflakeCon.runCreateStage(this.stageIdentifier);
}
SnowflakePluginTask pluginTask = (SnowflakePluginTask) task;

return new SnowflakeCopyBatchInsert(
getConnector(task, true),
this.stageIdentifier,
StageIdentifierHolder.getStageIdentifier(pluginTask),
pluginTask.getCopyIntoTableColumnNames(),
pluginTask.getCopyIntoCSVColumnNumbers(),
false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ public SnowflakeCopyBatchInsert(
@Override
public void prepare(TableIdentifier loadTable, JdbcSchema insertSchema) throws SQLException {
this.connection = (SnowflakeOutputConnection) connector.connect(true);
this.connection.runCreateStage(stageIdentifier);
this.tableIdentifier = loadTable;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@
import org.embulk.output.jdbc.JdbcSchema;
import org.embulk.output.jdbc.MergeConfig;
import org.embulk.output.jdbc.TableIdentifier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SnowflakeOutputConnection extends JdbcOutputConnection {
private final Logger logger = LoggerFactory.getLogger(SnowflakeOutputConnection.class);

public SnowflakeOutputConnection(Connection connection) throws SQLException {
super(connection, null);
}
Expand Down Expand Up @@ -45,11 +49,13 @@ public void runCopy(
public void runCreateStage(StageIdentifier stageIdentifier) throws SQLException {
String sql = buildCreateStageSQL(stageIdentifier);
runUpdate(sql);
logger.info("SQL: {}", sql);
}

public void runDropStage(StageIdentifier stageIdentifier) throws SQLException {
String sql = buildDropStageSQL(stageIdentifier);
runUpdate(sql);
logger.info("SQL: {}", sql);
}

public void runUploadFile(
Expand Down

0 comments on commit 19216a7

Please sign in to comment.