package com.alibaba.druid.bvt.filter.wall;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.Statement;
import java.util.LinkedList;
import java.util.List;
import junit.framework.TestCase;
import org.junit.Assert;
import com.alibaba.druid.filter.Filter;
import com.alibaba.druid.pool.DruidDataSource;
import com.alibaba.druid.util.JdbcConstants;
import com.alibaba.druid.wall.WallConfig;
import com.alibaba.druid.wall.WallFilter;
public class WallFilterTest3 extends TestCase {
private DruidDataSource dataSource;
private WallFilter wallFilter;
protected void setUp() throws Exception {
dataSource = new DruidDataSource();
dataSource.setUrl("jdbc:h2:mem:wall_test;");
// dataSource.setFilters("wall");
dataSource.setDbType(JdbcConstants.MARIADB);
WallConfig config = new WallConfig();
config.setTenantCallBack(new TenantTestCallBack());
wallFilter = new WallFilter();
wallFilter.setConfig(config);
wallFilter.setDbType(JdbcConstants.MARIADB);
List<Filter> filters = new LinkedList<Filter>();
filters.add(wallFilter);
dataSource.setProxyFilters(filters);
dataSource.init();
}
protected void tearDown() throws Exception {
dataSource.close();
}
public void test_wallFilter() throws Exception {
Assert.assertEquals(JdbcConstants.MARIADB, wallFilter.getDbType());
Assert.assertFalse(wallFilter.isLogViolation());
wallFilter.setLogViolation(true);
Assert.assertTrue(wallFilter.isLogViolation());
wallFilter.setLogViolation(false);
Assert.assertFalse(wallFilter.isLogViolation());
Assert.assertTrue(wallFilter.isThrowException());
wallFilter.setThrowException(false);
Assert.assertFalse(wallFilter.isThrowException());
wallFilter.setThrowException(true);
Assert.assertTrue(wallFilter.isThrowException());
wallFilter.clearProviderCache();
wallFilter.getProviderWhiteList();
Assert.assertTrue(wallFilter.isInited());
{
Connection conn = dataSource.getConnection();
Statement stmt = conn.createStatement();
stmt.execute("CREATE TABLE t (FID INTEGER, FNAME VARCHAR(50), TENANT VARCHAR(32))");
stmt.close();
conn.close();
}
Assert.assertEquals(1, wallFilter.getProvider().getTableStat("t").getCreateCount());
{
Connection conn = dataSource.getConnection();
String sql = "INSERT INTO t (FID, FNAME) VALUES (?, ?)";
for (int i = 0; i < 10; ++i) {
PreparedStatement stmt = conn.prepareStatement(sql, Statement.NO_GENERATED_KEYS);
stmt.setInt(1, i + 10);
stmt.setString(2, "a" + (i + 10));
stmt.execute();
stmt.close();
}
conn.close();
}
Assert.assertEquals(10, wallFilter.getProvider().getTableStat("t").getInsertCount());
Assert.assertEquals(10, wallFilter.getProvider().getTableStat("t").getInsertDataCount());
{
Connection conn = dataSource.getConnection();
String sql = "INSERT INTO t (FID, FNAME) VALUES (?, ?)";
PreparedStatement stmt = conn.prepareStatement(sql, Statement.NO_GENERATED_KEYS);
for (int i = 0; i < 10; ++i) {
stmt.setInt(1, i + 20);
stmt.setString(2, "a" + (i + 20));
stmt.addBatch();
}
stmt.executeBatch();
stmt.close();
conn.close();
}
Assert.assertEquals(11, wallFilter.getProvider().getTableStat("t").getInsertCount());
Assert.assertEquals(20, wallFilter.getProvider().getTableStat("t").getInsertDataCount());
{
Connection conn = dataSource.getConnection();
Statement stmt = conn.createStatement();
for (int i = 0; i < 10; ++i) {
stmt.addBatch("INSERT INTO t (FID, FNAME) VALUES (" + i + ", 'a" + i + "')");
}
stmt.executeBatch();
stmt.close();
conn.close();
}
Assert.assertEquals(21, wallFilter.getProvider().getTableStat("t").getInsertCount());
Assert.assertEquals(0, wallFilter.getProvider().getTableStat("t").getUpdateDataCount());
{
String sql = "SELECT * FROM T";
Connection conn = dataSource.getConnection();
PreparedStatement stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY);
ResultSet rs = stmt.executeQuery();
while (rs.next()) {
}
rs.close();
stmt.close();
conn.close();
}
Assert.assertEquals(30, wallFilter.getProvider().getTableStat("t").getFetchRowCount());
{
String sql = "SELECT * FROM T";
Connection conn = dataSource.getConnection();
PreparedStatement stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY,
ResultSet.CONCUR_READ_ONLY,
ResultSet.HOLD_CURSORS_OVER_COMMIT);
ResultSet rs = stmt.executeQuery();
while (rs.next()) {
}
rs.close();
stmt.close();
conn.close();
}
Assert.assertEquals(60, wallFilter.getProvider().getTableStat("t").getFetchRowCount());
{
String sql = "SELECT * FROM T LIMIT 10";
Connection conn = dataSource.getConnection();
PreparedStatement stmt = conn.prepareStatement(sql, new int[0]);
ResultSet rs = stmt.executeQuery();
while (rs.next()) {
}
rs.close();
stmt.close();
conn.close();
}
Assert.assertEquals(70, wallFilter.getProvider().getTableStat("t").getFetchRowCount());
{
String sql = "SELECT * FROM T LIMIT 10";
Connection conn = dataSource.getConnection();
PreparedStatement stmt = conn.prepareStatement(sql, new String[0]);
ResultSet rs = stmt.executeQuery();
while (rs.next()) {
}
rs.close();
stmt.close();
conn.close();
}
Assert.assertEquals(80, wallFilter.getProvider().getTableStat("t").getFetchRowCount());
{
String sql = "SELECT * FROM T LIMIT 10";
Connection conn = dataSource.getConnection();
PreparedStatement stmt = conn.prepareCall(sql);
ResultSet rs = stmt.executeQuery();
while (rs.next()) {
}
rs.close();
stmt.close();
conn.close();
}
Assert.assertEquals(90, wallFilter.getProvider().getTableStat("t").getFetchRowCount());
{
String sql = "SELECT * FROM T LIMIT 10";
Connection conn = dataSource.getConnection();
PreparedStatement stmt = conn.prepareCall(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY);
ResultSet rs = stmt.executeQuery();
while (rs.next()) {
}
rs.close();
stmt.close();
conn.close();
}
Assert.assertEquals(100, wallFilter.getProvider().getTableStat("t").getFetchRowCount());
{
String sql = "SELECT * FROM T";
Connection conn = dataSource.getConnection();
PreparedStatement stmt = conn.prepareCall(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY,
ResultSet.HOLD_CURSORS_OVER_COMMIT);
ResultSet rs = stmt.executeQuery();
while (rs.next()) {
}
rs.close();
stmt.close();
conn.close();
}
Assert.assertEquals(130, wallFilter.getProvider().getTableStat("t").getFetchRowCount());
{
String sql = "SELECT * FROM T LIMIT 10";
Connection conn = dataSource.getConnection();
Statement stmt = conn.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY);
stmt.execute(sql, Statement.NO_GENERATED_KEYS);
ResultSet rs = stmt.getResultSet();
while (rs.next()) {
}
rs.close();
stmt.close();
conn.close();
}
Assert.assertEquals(140, wallFilter.getProvider().getTableStat("t").getFetchRowCount());
{
String sql = "SELECT * FROM T LIMIT 10";
Connection conn = dataSource.getConnection();
Statement stmt = conn.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY,
ResultSet.HOLD_CURSORS_OVER_COMMIT);
stmt.execute(sql, new int[0]);
ResultSet rs = stmt.getResultSet();
while (rs.next()) {
}
rs.close();
stmt.close();
conn.close();
}
Assert.assertEquals(150, wallFilter.getProvider().getTableStat("t").getFetchRowCount());
{
String sql = "SELECT * FROM T LIMIT 10";
Connection conn = dataSource.getConnection();
Statement stmt = conn.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY,
ResultSet.HOLD_CURSORS_OVER_COMMIT);
stmt.execute(sql, new String[0]);
ResultSet rs = stmt.getResultSet();
while (rs.next()) {
}
rs.close();
stmt.close();
conn.close();
}
Assert.assertEquals(160, wallFilter.getProvider().getTableStat("t").getFetchRowCount());
{
Connection conn = dataSource.getConnection();
Statement stmt = conn.createStatement();
stmt.executeUpdate("DELETE from t where FID = 0");
stmt.close();
conn.close();
}
Assert.assertEquals(1, wallFilter.getProvider().getTableStat("t").getDeleteDataCount());
Assert.assertEquals(0, wallFilter.getProvider().getTableStat("t").getUpdateDataCount());
{
Connection conn = dataSource.getConnection();
Statement stmt = conn.createStatement();
stmt.executeUpdate("DELETE from t where FID = 1 OR FID = 2", Statement.NO_GENERATED_KEYS);
stmt.close();
conn.close();
}
Assert.assertEquals(3, wallFilter.getProvider().getTableStat("t").getDeleteDataCount());
Assert.assertEquals(0, wallFilter.getProvider().getTableStat("t").getUpdateDataCount());
{
Connection conn = dataSource.getConnection();
Statement stmt = conn.createStatement();
stmt.executeUpdate("DELETE from t where FID = 3", new int[0]);
stmt.close();
conn.close();
}
Assert.assertEquals(4, wallFilter.getProvider().getTableStat("t").getDeleteDataCount());
Assert.assertEquals(0, wallFilter.getProvider().getTableStat("t").getUpdateDataCount());
{
Connection conn = dataSource.getConnection();
Statement stmt = conn.createStatement();
stmt.executeUpdate("DELETE from t where FID = 4", new String[0]);
stmt.close();
conn.close();
}
Assert.assertEquals(5, wallFilter.getProvider().getTableStat("t").getDeleteDataCount());
Assert.assertEquals(0, wallFilter.getProvider().getTableStat("t").getUpdateDataCount());
{
Connection conn = dataSource.getConnection();
PreparedStatement stmt = conn.prepareStatement("DELETE from t where FID = ?");
stmt.setInt(1, 5);
stmt.executeUpdate();
stmt.close();
conn.close();
}
Assert.assertEquals(6, wallFilter.getProvider().getTableStat("t").getDeleteDataCount());
Assert.assertEquals(0, wallFilter.getProvider().getTableStat("t").getUpdateDataCount());
{
Connection conn = dataSource.getConnection();
Statement stmt = conn.createStatement();
stmt.execute("update t SET fname = 'xx' where FID = 13 OR FID = 14");
stmt.close();
conn.close();
}
Assert.assertEquals(2, wallFilter.getProvider().getTableStat("t").getUpdateDataCount());
{
Connection conn = dataSource.getConnection();
PreparedStatement stmt = conn.prepareStatement("update t SET fname = 'xx' where FID = ? OR FID = ?");
stmt.setInt(1, 13);
stmt.setInt(2, 14);
stmt.execute();
stmt.close();
conn.close();
}
Assert.assertEquals(4, wallFilter.getProvider().getTableStat("t").getUpdateDataCount());
{
Connection conn = dataSource.getConnection();
PreparedStatement stmt = conn.prepareStatement("update t SET fname = 'xx' where FID = ? OR FID = ?");
stmt.setInt(1, 13);
stmt.setInt(2, 14);
stmt.execute();
stmt.close();
conn.close();
}
Assert.assertEquals(6, wallFilter.getProvider().getTableStat("t").getUpdateDataCount());
{
Connection conn = dataSource.getConnection();
PreparedStatement stmt = conn.prepareStatement("update t SET fname = 'xx' where FID = ?");
stmt.setInt(1, 13);
stmt.addBatch();
stmt.setInt(1, 14);
stmt.addBatch();
stmt.executeBatch();
stmt.close();
conn.close();
}
Assert.assertEquals(8, wallFilter.getProvider().getTableStat("t").getUpdateDataCount());
{
Connection conn = dataSource.getConnection();
Statement stmt = conn.createStatement();
stmt.execute("truncate table t");
stmt.close();
conn.close();
}
Assert.assertEquals(1, wallFilter.getProvider().getTableStat("t").getTruncateCount());
{
Connection conn = dataSource.getConnection();
Statement stmt = conn.createStatement();
stmt.execute("drop table t");
stmt.close();
conn.close();
}
Assert.assertEquals(1, wallFilter.getProvider().getTableStat("t").getDropCount());
Assert.assertEquals(0, wallFilter.getViolationCount());
wallFilter.resetViolationCount();
wallFilter.checkValid("select 1");
Assert.assertEquals(0, wallFilter.getViolationCount());
}
}