package com.lingbobu.flashdb.transfer.impl;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import javax.sql.DataSource;
import org.springframework.jdbc.core.JdbcTemplate;
import com.lingbobu.flashdb.transfer.TransferInput;
import com.lingbobu.flashdb.transfer.TransferInput.PartableInput;
import com.lingbobu.flashdb.transfer.impl.InputDbQueryIN;
/**
* 多表分区的IN数据源(必须根据分区字段进行IN查询)
*/
public abstract class InputDbQueryInShards implements TransferInput, PartableInput {
public InputDbQueryInShards() { }
/**
* 计算字段值对应的分区表.
* @param values
* @return 返回以分区表id分组的字段值
*/
protected abstract Map<String, Object[]> splitShardsKeys(Object[] values);
private DataSource dataSource;
private String dataSourceSql;
private Object inValues;
public void setDataSource(DataSource dataSource) {
this.dataSource = dataSource;
}
public void init(String dataSourceSql, Object inValues) {
this.dataSourceSql = dataSourceSql;
this.inValues = inValues;
}
@Override
public RowIterator iterator() {
return new MultiPartsRowIterator(getPartInputs());
}
@Override
public TransferInput[] getPartInputs() {
Object[] values = InputDbQueryIN.convertInValues(inValues, new JdbcTemplate(dataSource));
Map<String, Object[]> splitShardsKeys = splitShardsKeys(values);
List<TransferInput> partInputs = new ArrayList<TransferInput>();
for (Map.Entry<String, Object[]> entry : splitShardsKeys.entrySet()) {
String tableId = entry.getKey();
String sql = dataSourceSql.replace("{TABLE_ID}", tableId);
TransferInput[] childParts = new InputDbQueryIN(dataSource, sql, entry.getValue()).getPartInputs();
for (TransferInput childPart : childParts) {
partInputs.add(childPart);
}
}
return partInputs.toArray(new TransferInput[partInputs.size()]);
}
}