DropAllTablesOperation.java
/*
* Copyright (C) 2022 B3Partners B.V.
*/
package nl.b3p.brmo.test.util.database.dbunit;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.dbunit.DatabaseUnitException;
import org.dbunit.database.DatabaseConfig;
import org.dbunit.database.DatabaseSequenceFilter;
import org.dbunit.database.IDatabaseConnection;
import org.dbunit.dataset.FilteredDataSet;
import org.dbunit.dataset.IDataSet;
import org.dbunit.dataset.filter.ITableFilter;
import org.dbunit.operation.DatabaseOperation;
public class DropAllTablesOperation extends DatabaseOperation {
private static final Log LOG = LogFactory.getLog(DropAllTablesOperation.class);
private final String DROP_PGSQL = "drop table %s cascade";
private final String DROP_ORACLE = "drop table %s cascade purge";
/**
* Executes this operation on the specified database using the specified dataset contents.
*
* @param connection the database connection.
* @param dataSet the dataset to be used by this operation.
*/
@Override
public void execute(IDatabaseConnection connection, IDataSet dataSet)
throws DatabaseUnitException, SQLException {
// get ordered list of tables
ITableFilter filter = new DatabaseSequenceFilter(connection);
final String dropStmt =
connection
.getConfig()
.getProperty(DatabaseConfig.PROPERTY_DATATYPE_FACTORY)
.getClass()
.getName()
.contains("Oracle")
? DROP_ORACLE
: DROP_PGSQL;
IDataSet dataset;
if (null == dataSet) {
dataset = new FilteredDataSet(filter, connection.createDataSet());
} else {
dataset = new FilteredDataSet(filter, dataSet);
}
LOG.debug("to be dropped table names: " + Arrays.toString(dataset.getTableNames()));
String[] tableNames = dataset.getTableNames();
List<String> reversedTableNames = Arrays.asList(tableNames);
Collections.reverse(reversedTableNames);
LOG.debug("to be dropped table names in order: " + Arrays.toString(dataset.getTableNames()));
try (Statement stmt = connection.getConnection().createStatement()) {
for (String table : reversedTableNames) {
stmt.addBatch(String.format(dropStmt, table));
}
stmt.executeBatch();
}
}
}