PostGISCopyInsertBatch.java
/*
* Copyright (C) 2021 B3Partners B.V.
*
* SPDX-License-Identifier: MIT
*/
package nl.b3p.brmo.sql;
import java.nio.charset.StandardCharsets;
import java.sql.Connection;
import java.sql.SQLException;
import nl.b3p.brmo.sql.dialect.PostGISDialect;
import nl.b3p.brmo.sql.dialect.SQLDialect;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.commons.text.StringEscapeUtils;
import org.locationtech.jts.geom.Geometry;
import org.postgresql.PGConnection;
import org.postgresql.copy.CopyIn;
/**
* The PostgreSQL JDBC driver does not support parallel copy operations, even with multiple
* connections. So this class can cache the copy stream in memory so only one copy operation is
* active at a time, although this is significantly slower.
*/
public class PostGISCopyInsertBatch implements QueryBatch {
private static final Log LOG = LogFactory.getLog(PostGISCopyInsertBatch.class);
protected Connection connection;
protected String sql;
protected PostGISDialect dialect;
protected int batchSize;
protected boolean linearizeCurves;
protected int count = 0;
protected StringEscapeUtils.Builder copyData = PostgresCopyEscapeUtils.builder();
protected CopyIn copyIn = null;
protected CopyIn lastCopyIn = null;
protected boolean buffer;
public PostGISCopyInsertBatch(
Connection connection,
String sql,
int batchSize,
SQLDialect dialect,
boolean buffer,
boolean linearizeCurves) {
if (!(dialect instanceof PostGISDialect)) {
throw new IllegalArgumentException();
}
this.connection = connection;
this.sql = sql;
this.dialect = (PostGISDialect) dialect;
this.batchSize = batchSize;
this.buffer = buffer;
this.linearizeCurves = linearizeCurves;
}
protected void createCopyIn() throws SQLException {
PGConnection pgConnection = connection.unwrap(PGConnection.class);
this.copyIn = pgConnection.getCopyAPI().copyIn(sql);
}
protected void writeToCopy() throws SQLException {
if (copyIn == null) {
createCopyIn();
}
byte[] bytes = copyData.toString().getBytes(StandardCharsets.UTF_8);
copyIn.writeToCopy(bytes, 0, bytes.length);
copyData = PostgresCopyEscapeUtils.builder();
}
@Override
public boolean addBatch(Object[] params) throws Exception {
for (int i = 0; i < params.length; i++) {
if (i != 0) {
copyData.append("\t");
}
Object param = params[i];
if (param == null) {
copyData.append("\\N");
} else if (param instanceof Geometry) {
Geometry geometry = (Geometry) param;
copyData.append(dialect.getEWkt(geometry, linearizeCurves));
} else if (param instanceof Boolean) {
copyData.append((Boolean) param ? "t" : "f");
} else {
// TODO any more types need special conversion?
copyData.escape(param.toString());
}
}
copyData.append("\n");
if (!buffer) {
writeToCopy();
}
count++;
if (count == batchSize) {
this.executeBatch();
return true;
}
return false;
}
@Override
public void executeBatch() throws Exception {
if (count > 0) {
if (buffer) {
if (LOG.isDebugEnabled()) {
// To get the number of bytes this call is duplicated from writeToCopy()...
int bytes = copyData.toString().getBytes(StandardCharsets.UTF_8).length;
LOG.debug(
String.format(
"execute buffered copy batch, %d bytes, %d rows, sql: %s", bytes, count, sql));
}
writeToCopy();
} else {
LOG.debug(String.format("execute copy batch, %d rows, sql: %s", count, sql));
}
copyIn.endCopy();
lastCopyIn = copyIn;
// reset so writeToCopy will create a new one
copyIn = null;
count = 0;
}
}
@Override
public void close() {}
}