Package com.skyline.base.dao.impl

Source Code of com.skyline.base.dao.impl.BaseDaoImpl

package com.skyline.base.dao.impl;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.dao.DataAccessException;
import org.springframework.dao.DataRetrievalFailureException;
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
import org.springframework.jdbc.core.ColumnMapRowMapper;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.PreparedStatementCallback;
import org.springframework.jdbc.core.PreparedStatementCreator;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.RowMapperResultSetExtractor;
import org.springframework.jdbc.support.GeneratedKeyHolder;
import org.springframework.jdbc.support.JdbcUtils;
import org.springframework.jdbc.support.KeyHolder;
import org.springframework.stereotype.Repository;
import org.springframework.util.Assert;

import com.skyline.base.dao.BaseDao;
import com.skyline.common.bean.Page;

/**
* DAO基础类,封装了公共的DAO方法。如分页等的实现
*
* @author wuqh
* @version 1.00
* */
@Repository("baseDao")
public class BaseDaoImpl implements BaseDao {
  @Autowired
  protected JdbcTemplate jdbcTemplate;
  protected Map<String, String> formattedSqls;

  /**
   * 获取分页结果
   *
   * @param sql
   * @param page
   * @param rowMapper
   * @param args
   */
  protected <T> List<T> getPaginationResult(String sql, Page page, RowMapper<T> rowMapper, Object... args) {
    int total = getCountResult(sql, args);
    page.setTotal(total);
    if (total <= 0) {
      return new ArrayList<T>();
    }
    int startIndex = page.getStartIndex();
    String pageSql = sql + " limit " + startIndex + "," + page.getSize();
    return jdbcTemplate.query(pageSql, rowMapper, args);
  }

  /**
   * 统计结果总数
   *
   * @param sql
   * @param args
   */
  protected int getCountResult(String sql, Object... args) {
    String countSql = "select count(1) from (" + sql + ") cntTbl;";
    return jdbcTemplate.queryForInt(countSql, args);
  }

  /**
   * 插入数据并返回ID
   *
   * @param sql
   * @param args
   */
  protected long insertWithIdReturn(final String sql, final Object... args) {
    KeyHolder keyHolder = new GeneratedKeyHolder();
    jdbcTemplate.update(new PreparedStatementCreator() {
      @Override
      public PreparedStatement createPreparedStatement(Connection conn) throws SQLException {
        PreparedStatement ps = conn.prepareStatement(sql, PreparedStatement.RETURN_GENERATED_KEYS);
        for (int i = 1; i <= args.length; i++) {
          ps.setObject(i, args[i - 1]);
        }
        return ps;
      }

    }, keyHolder);
    return keyHolder.getKey().longValue();
  }

  /**
   * 批量更新
   *
   * @param sql
   * @param argsList
   *
   * */
  protected int[] batchUpdate(final String sql, final List<Object[]> argsList) {
    Assert.notEmpty(argsList, "args can not be empty while batch insert");
    return jdbcTemplate.batchUpdate(sql, new BatchPreparedStatementSetter() {
      @Override
      public void setValues(PreparedStatement ps, int i) throws SQLException {
        Object[] args = argsList.get(i);
        int length = args.length;
        for (int j = 1; j <= length; j++) {
          ps.setObject(j, args[j - 1]);
        }
      }

      @Override
      public int getBatchSize() {
        return argsList.size();
      }
    });
  }

  /**
   * 批量插入数据并返回ID(参考Spring批量插入,使用非InterruptibleBatchPreparedStatementSetter实现)
   *
   * @param sql
   * @param argsList
   */
  @SuppressWarnings("unchecked")
  protected <T extends Number> List<T> batchInsertWithIdReturn(final String sql, final List<Object[]> argsList) {
    Assert.notEmpty(argsList, "args can not be empty while batch insert");
    final KeyHolder keyHolder = new GeneratedKeyHolder();
    jdbcTemplate.execute(new PreparedStatementCreator() {
      @Override
      public PreparedStatement createPreparedStatement(Connection con) throws SQLException {
        PreparedStatement ps = con.prepareStatement(sql, PreparedStatement.RETURN_GENERATED_KEYS);
        return ps;
      }
    }, new PreparedStatementCallback<int[]>() {
      private void setValues(PreparedStatement ps, int i) throws SQLException {
        Object[] args = argsList.get(i);
        int length = args.length;
        for (int j = 1; j <= length; j++) {
          ps.setObject(j, args[j - 1]);
        }
      }

      private int[] processResult(PreparedStatement ps, int[] rows) throws SQLException {
        List<Map<String, Object>> generatedKeys = keyHolder.getKeyList();
        generatedKeys.clear();
        ResultSet keys = ps.getGeneratedKeys();
        if (keys != null) {
          try {
            RowMapperResultSetExtractor<Map<String, Object>> rse = new RowMapperResultSetExtractor<Map<String, Object>>(
                new ColumnMapRowMapper(), rows.length);
            generatedKeys.addAll(rse.extractData(keys));
          } finally {
            JdbcUtils.closeResultSet(keys);
          }
        }
        return rows;
      }

      @Override
      public int[] doInPreparedStatement(PreparedStatement ps) throws SQLException, DataAccessException {
        int batchSize = argsList.size();
        if (JdbcUtils.supportsBatchUpdates(ps.getConnection())) {
          for (int i = 0; i < batchSize; i++) {
            setValues(ps, i);
            ps.addBatch();
          }
          int[] rows = ps.executeBatch();
          return processResult(ps, rows);
        } else {
          List<Integer> rowsAffected = new ArrayList<Integer>();
          for (int i = 0; i < batchSize; i++) {
            setValues(ps, i);
            rowsAffected.add(ps.executeUpdate());
          }
          int[] rowsAffectedArray = new int[rowsAffected.size()];
          for (int i = 0; i < rowsAffectedArray.length; i++) {
            rowsAffectedArray[i] = rowsAffected.get(i);
          }
          return processResult(ps, rowsAffectedArray);
        }
      }
    });

    List<T> keys = new ArrayList<T>(argsList.size());
    List<Map<String, Object>> generatedKeys = keyHolder.getKeyList();
    int length = generatedKeys.size();
    for (int i = 0; i < length; i++) {
      Iterator<Object> keyIter = generatedKeys.get(i).values().iterator();
      if (keyIter.hasNext()) {
        Object key = keyIter.next();
        if (!(key instanceof Number)) {
          throw new DataRetrievalFailureException("The generated key is not of a supported numeric type. " + "Unable to cast ["
              + (key != null ? key.getClass().getName() : null) + "] to [" + Number.class.getName() + "]");
        }
        keys.add((T) key);
      }
    }
    return keys;
  }

  /**
   * 动态生成SQL中的表名
   *
   * @param sqlNeedFormat
   *            需要生成表名的原sql,格式如“select * from {0} where colum<>1”
   * @param tableName
   *            表名用于替换sqlNeedFormat中的{0},格式“tableName”
   * @return 最后生成的SQL,如参数中给的例子就是“select * from tableName where colum<>1”
   * */
  protected String genSql(String sqlNeedFormat, String tableName) {
    // 采用懒加载,会有多线程并发时有概率多次创建sqlFormatters,把一些之前用用过的MessageFormat冲掉的问题,不过数量应该不会太大,可以忽略
    if (formattedSqls == null) {
      formattedSqls = new HashMap<String, String>();
    }
    String key = tableName + sqlNeedFormat;
    String sql = formattedSqls.get(key);
    if (sql == null) {
      sql = MessageFormat.format(sqlNeedFormat, tableName);
      formattedSqls.put(key, sql);
    }
    return sql;

  }
}
TOP

Related Classes of com.skyline.base.dao.impl.BaseDaoImpl

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.