package org.jboss.fresh.persist;
import org.apache.log4j.Logger;
import org.jboss.fresh.io.IOUtils;
import org.jboss.fresh.registry.RegistryContext;
import javax.sql.DataSource;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Map;
/**
* Default mode of operation is to read/increment/update on every pk generation. This way multiple machines can use
* multiple instances of the same generator and this way we can scale. Of course it's not as efficient as HiLo generator
* where contention is only on every 100th or 1000th generation when hi number is increased, while with this one
* contention happens every time.
*/
public class JDBCIncrementalPKGenerator implements PKGenerator {
private static final int INTEGER = 0;
private static final int LONG = 1;
private static final int SHARED = 0;
private static final int EXCLUSIVE = 1;
private static final Logger log = Logger.getLogger(JDBCIncrementalPKGenerator.class);
private int type = INTEGER;
private int mode = SHARED;
private String pkname = null;
private String table;
private String namecol;
private String valcol;
private String jdbcdriver;
private String jdbcurl;
private String ds;
private String user, pass;
private DataSource dts;
private String cval;
public JDBCIncrementalPKGenerator(Object val, String fname, String key) throws Exception {
this.pkname = key;
init(val, fname);
}
public JDBCIncrementalPKGenerator(Object val, String fname) throws Exception {
init(val, fname);
}
protected void init(Object val, String fname) throws Exception {
// the second parameter is path to the config file - config file is in properties format
// we first try using the context classloader
// if it fails we try using the filesystem
// if reading conf fails we throw exception
if (fname == null) {
throw new RuntimeException("Configuration file not specified!");
}
InputStream ins = Thread.currentThread().getContextClassLoader().getResourceAsStream(fname);
if (ins == null) {
try {
ins = new FileInputStream(fname);
} catch (IOException ex) {
throw new RuntimeException("Couldn't find configuration: " + fname, ex);
}
}
final Reader rd = new InputStreamReader(ins);
final Map conf = IOUtils.readPropsAsMap(rd);
initialize(conf);
// see if you can get current value - if you can't
// call setKeyAsString with the passed val
// if exclusive mode save the current value
String strval = String.valueOf(val);
if (setNewKeyAsString(strval) && mode == EXCLUSIVE) {
cval = strval;
}
}
private void initialize(Map conf) {
// read from configuration if and only if it was nos specified by the constructor.
if (pkname == null) {
pkname = (String) conf.get("pkname");
}
// the following must be present
table = (String) conf.get("table.name");
namecol = (String) conf.get("column.keyname");
valcol = (String) conf.get("column.genvalue");
if (pkname == null || table == null || namecol == null || valcol == null) {
throw new RuntimeException(
"The following properties must be present in the configuration: pkname, table.name, column.keyname, column.genvalue");
}
// either
jdbcdriver = (String) conf.get("jdbc.driver");
jdbcurl = (String) conf.get("jdbc.url");
// or
ds = (String) conf.get("jndi.datasource");
if (ds == null && (jdbcdriver == null || jdbcurl == null)) {
throw new RuntimeException("Configuration must contain either jndi.datasource or jdbc.driver and jdbc.url");
}
if (jdbcdriver != null && jdbcurl == null) {
throw new RuntimeException(
"If jdbc.driver is specified in the configuration then there must be jdbc.url specified as well");
}
if (jdbcdriver != null) {
user = (String) conf.get("jdbc.user");
pass = (String) conf.get("jdbc.pass");
// register driver
try {
//Thread.currentThread().getContextClassLoader().loadClass(jdbcdriver); // Doesn't work. You must load it with the system classloader
Class.forName(jdbcdriver);
} catch (Exception ex) {
throw new RuntimeException("Initialization of JDBCIncrementalPKGenerator failed", ex);
}
}
String tmp = (String) conf.get("type");
if (tmp != null) {
if ("integer".equalsIgnoreCase(tmp)) {
type = INTEGER;
} else if ("long".equalsIgnoreCase(tmp)) {
type = LONG;
} else {
throw new RuntimeException("Unknown tpye specified: " + tmp);
}
}
tmp = (String) conf.get("mode");
if (tmp != null) {
if ("shared".equalsIgnoreCase(tmp)) {
mode = SHARED;
} else if ("".equalsIgnoreCase(tmp)) {
mode = SHARED;
} else {
throw new RuntimeException("Unknown mode specified: " + tmp);
}
}
}
private Connection getConnection() throws Exception {
Connection con;
if (jdbcdriver != null) {
if (user == null) {
try {
con = DriverManager.getConnection(jdbcurl);
} catch (SQLException ex) {
throw new RuntimeException("failed to get connection for jdbcurl: " + jdbcurl, ex);
}
con.setAutoCommit(false);
} else {
try {
con = DriverManager.getConnection(jdbcurl, user, pass);
} catch (SQLException ex) {
throw new RuntimeException("failed to get connection for jdbcurl: " + jdbcurl, ex);
}
con.setAutoCommit(false);
}
} else {
if (dts == null) {
RegistryContext ctx = new RegistryContext();
dts = (DataSource) ctx.lookup(ds);
}
con = dts.getConnection();
log.info("Got connection from pool: " + con + " (autocommit:" + con.getAutoCommit() + ')');
con.setAutoCommit(false);
con.setTransactionIsolation(Connection.TRANSACTION_SERIALIZABLE);
}
return con;
// use driver manager and get connection - you can have a connection pooling Driver
// use data-source
// lookup datasource through jndi
// lookup datasource through ServletContext
}
public synchronized Object newKey() throws Exception {
//log.debug("newKey");
String retVal;
try {
Connection con = getConnection();
try {
startTx(con);
// getConnection
// start transaction
// (vice-versa with DataSource)
//
// read value
String val = null;
if (mode == SHARED) {
txlock(con);
val = readValue(con);
} else {
val = cval;
}
// increment
retVal = increment(val);
cval = retVal;
// update
updateValue(con, retVal);
log.info("new val: " + cval);
//log.info("commiting");
commitTx(con);
} catch (Throwable ex) {
try {
rollbackTx(con);
} catch (SQLException e) {
throw new RuntimeException(e);
}
throw ex;
} finally {
try {
returnConnection(con);
} catch (Exception ex) {
log.warn("Exception while returning connection to the pool:", ex);
}
}
} catch (RuntimeException ex) {
throw ex;
} catch (Exception ex) {
throw ex;
} catch (Throwable ex) {
throw new RuntimeException(ex);
}
// return 'incremented' value
if (type == INTEGER) {
return Integer.valueOf(retVal);
} else if (type == LONG) {
return Long.valueOf(retVal);
}
return retVal;
}
public synchronized boolean returnKey(Object o) throws Exception {
//log.debug("returnKey");
try {
Connection con = getConnection();
try {
startTx(con);
// getConnection
// start transaction
// (vice-versa with DataSource)
//
// read value
// compare with the passed value
// if they match call decrement else return false
String val = null;
if (mode == SHARED) {
txlock(con);
val = readValue(con);
} else {
val = cval;
}
if (val != null && val.equals(String.valueOf(o))) {
// we can return
// increment
cval = decrement(val);
// update
updateValue(con, cval);
//log.debug("commiting");
commitTx(con);
return true;
}
//log.debug("commiting");
// commit
commitTx(con);
} catch (Throwable ex) {
try {
rollbackTx(con);
} catch (SQLException e) {
throw new RuntimeException(e);
}
throw ex;
} finally {
try {
returnConnection(con);
} catch (Exception ex) {
log.warn("Exception while returning connection to the pool:", ex);
}
}
} catch (RuntimeException ex) {
throw ex;
} catch (Exception ex) {
throw ex;
} catch (Throwable ex) {
throw new RuntimeException(ex);
}
return false;
}
public synchronized Object getCurrentKey() {
//log.debug("getCurrentKey");
// getConnection
String val = null;
try {
if (mode == SHARED) {
Connection con = getConnection();
try {
txlock(con);
val = readValue(con);
} finally {
try {
returnConnection(con);
} catch (Exception ex) {
log.warn("Exception while returning connection to the pool:", ex);
}
}
} else {
val = cval;
}
} catch (RuntimeException ex) {
throw ex;
} catch (Exception ex) {
throw new RuntimeException(ex);
}
if (type == INTEGER) {
return Integer.valueOf(val);
} else if (type == LONG) {
return Long.valueOf(val);
}
return val;
}
// set key in any event
// if update doesnt work use insert
//
public synchronized void setKeyAsString(String val) throws Exception {
//log.debug("setKeyAsString");
// getConnection
try {
Connection con = getConnection();
try {
cval = val;
updateValue(con, cval);
commitTx(con);
} catch (Throwable ex) {
try {
rollbackTx(con);
} catch (SQLException e) {
throw new RuntimeException(e);
}
throw ex;
} finally {
try {
returnConnection(con);
} catch (Exception ex) {
log.warn("Exception while returning connection to the pool:", ex);
}
}
} catch (RuntimeException ex) {
throw ex;
} catch (Exception ex) {
throw ex;
} catch (Throwable ex) {
throw new RuntimeException(ex);
}
}
// only update key if no current value exists yet
public synchronized boolean setNewKeyAsString(String nval) throws Exception {
//log.debug("setNewKeyAsString");
// getConnection
// start transaction
// (vice-versa with DataSource)
//
// update
// if update fails do insert
// commit
// return true
try {
Connection con = getConnection();
try {
startTx(con);
// getConnection
// start transaction
// (vice-versa with DataSource)
//
// read value
String val = null;
if (mode == SHARED) {
txlock(con);
val = readValue(con);
} else {
val = cval;
}
if (val == null) {
// we can return
// increment
// update
updateValue(con, nval);
cval = nval;
commitTx(con);
return true;
}
commitTx(con);
} catch (Throwable ex) {
try {
rollbackTx(con);
} catch (SQLException e) {
throw new RuntimeException(e);
}
throw ex;
} finally {
try {
returnConnection(con);
} catch (Exception ex) {
log.warn("Exception while returning connection to the pool:", ex);
}
}
} catch (RuntimeException ex) {
throw ex;
} catch (Exception ex) {
throw ex;
} catch (Throwable ex) {
throw new RuntimeException(ex);
}
return false;
}
private void txlock(Connection con) throws SQLException {
final String sql = "UPDATE " + table + " SET " + valcol + "=? WHERE " + namecol + "=?";
final PreparedStatement pst = con.prepareStatement(sql);
pst.setString(1, cval == null ? "0" : cval);
pst.setString(2, pkname + ".txlock");
try {
log.debug("Executing update: " + pst + " (" + pkname + ".txlock, " + cval + ')');
int u = pst.executeUpdate();
if (u == 0) {
throw new SQLException("txlock not there yet.");
}
} catch (SQLException ex) {
log.error("Could not execute update: " + sql, ex);
final String insert = "INSERT INTO " + table + " (" + namecol + ", " + valcol + ") VALUES (?,?)";
final PreparedStatement cst = con.prepareStatement(insert);
try {
cst.setString(1, pkname + ".txlock");
cst.setString(2, cval == null ? "0" : cval);
log.debug("Executing insert: " + cst + " (" + pkname + ".txlock, " + cval + ')');
cst.execute();
} catch (SQLException e) {
log.error("Could not execute insert: " + insert, e);
pst.setString(1, cval == null ? "0" : cval);
pst.setString(2, pkname + ".txlock");
log.debug("Re-Executing update: " + pst + " (" + pkname + ".txlock, " + cval + ')');
pst.executeUpdate();
log.info("update done.");
} finally {
try {
cst.close();
} catch (Exception e) {
log.warn("Exception while closing prepared statement:", ex);
}
}
} finally {
try {
pst.close();
} catch (Exception ex) {
log.warn("Exception while closing prepared statement:", ex);
}
}
}
private String readValue(Connection con) throws SQLException {
//log.error("readValue");
PreparedStatement pst = con.prepareStatement("SELECT " + valcol + " FROM " + table + " WHERE " + namecol + "=?");
pst.setString(1, pkname);
String value = null;
ResultSet rs = pst.executeQuery();
try {
if (rs.next()) {
value = rs.getString(1);
}
if (rs.next()) {
throw new RuntimeException("More than one row with key " + pkname + " found in the table: " + table);
}
} finally {
try {
rs.close();
} catch (Exception ex) {
log.warn("Exception while closing resultset:", ex);
}
try {
pst.close();
} catch (Exception ex) {
log.warn("Exception while closing prepared statement:", ex);
}
}
return value;
}
private void updateValue(Connection con, String val) throws SQLException {
//log.error("updateValue " + pkname + ": " + val);
int updCount = 0;
PreparedStatement pst = con.prepareStatement("UPDATE " + table + " SET " + valcol + "=? WHERE " + namecol + "=?");
try {
pst.setString(1, val);
pst.setString(2, pkname);
try {
updCount = pst.executeUpdate();
//log.error("done execute: " + updCount);
} catch (Exception ex) {
log.warn("Failed to update value for pk " + pkname, ex);
insertValue(con, val);
return;
}
} finally {
try {
pst.close();
} catch (Exception ex) {
log.warn("Exception while closing prepared statement:", ex);
}
}
if (updCount == 0) {
insertValue(con, val);
}
}
private void insertValue(Connection con, String val) throws SQLException {
//log.error("insertValue");
int updCount = 0;
PreparedStatement pst = con.prepareStatement(
"INSERT INTO " + table + " ( " + namecol + ", " + valcol + ") VALUES (?, ?)");
try {
pst.setString(1, pkname);
pst.setString(2, val);
updCount = pst.executeUpdate();
} finally {
try {
pst.close();
} catch (Exception ex) {
log.warn("Exception while closing prepared statement:", ex);
}
}
if (updCount == 0) {
throw new RuntimeException("Failed to insert - updateCount==0 but no exception. Possibly a driver problem.");
}
}
protected String increment(String str) {
// parse value to long or to int - based on conf
// increment value
if (type == INTEGER) {
try {
int val = Integer.parseInt(str);
return String.valueOf(++val);
} catch (NumberFormatException ex) {
throw new RuntimeException("type mismatch (tried to parse " + str + " to INTEGER)", ex);
}
} else if (type == LONG) {
try {
long val = Long.parseLong(str);
return String.valueOf(++val);
} catch (NumberFormatException ex) {
throw new RuntimeException("type mismatch (tried to parse " + str + " to LONG)", ex);
}
}
throw new RuntimeException("Internal application error. Failed to break at init time on bad type: " + type);
}
private void startTx(Connection con) throws SQLException {
//con.begin();
}
private static void commitTx(Connection con) throws SQLException {
log.debug("Commiting connection: " + con);
con.commit();
}
private static void rollbackTx(Connection con) throws SQLException {
log.debug("Rolling back connection: " + con);
con.rollback();
}
private static void returnConnection(Connection con) throws SQLException {
log.debug("Returning connection: " + con);
con.close();
}
protected String decrement(String str) {
// parse value to long or to int - based on conf
// decrement value
if (type == INTEGER) {
try {
int val = Integer.parseInt(str);
return String.valueOf(--val);
} catch (NumberFormatException ex) {
throw new RuntimeException("type mismatch (tried to parse " + str + " to INTEGER)", ex);
}
} else if (type == LONG) {
try {
long val = Long.parseLong(str);
return String.valueOf(--val);
} catch (NumberFormatException ex) {
throw new RuntimeException("type mismatch (tried to parse " + str + " to LONG)", ex);
}
}
throw new RuntimeException("Internal application error. Failed to break at init time on bad type: " + type);
}
}