package com.googlecode.gaal.analysis.impl;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import com.googlecode.gaal.analysis.api.Context;
import com.googlecode.gaal.data.api.Corpus;
import com.googlecode.gaal.data.api.IntSequence;
import com.googlecode.gaal.data.api.IntervalSet;
import com.googlecode.gaal.data.api.Multiset;
import com.googlecode.gaal.data.impl.TreeMultiset;
import com.googlecode.gaal.suffix.api.IntervalTree.Interval;
import com.googlecode.gaal.suffix.api.LinearizedSuffixTree;
import com.googlecode.gaal.suffix.api.LinearizedSuffixTree.BinaryInterval;
import com.googlecode.gaal.suffix.impl.LinearizedSuffixTreeImpl;
public class NestedMaximalityContextExtractor<S> implements Iterable<Context> {
private final LinearizedSuffixTree lst;
private final LinearizedSuffixTree lpt;
private final IntervalSet<BinaryInterval> lstMaximalSet;
private final IntervalSet<BinaryInterval> lstBwtSet;
private final IntervalSet<BinaryInterval> lptMaximalSet;
private final Map<Context, Map<Interval, Integer>> envMap = new TreeMap<Context, Map<Interval, Integer>>();
private final boolean maximalOnly;
public NestedMaximalityContextExtractor(Corpus<S> corpus, boolean maximalOnly) {
this.lst = new LinearizedSuffixTreeImpl(corpus.sequence(), corpus.alphabetSize());
this.lpt = new LinearizedSuffixTreeImpl(corpus.sequence().reverse(), corpus.alphabetSize());
lstMaximalSet = new LocalMaximumSetBuilder().buildIntervalSet(lst);
lstBwtSet = new SingletonBwtSetBuilder().buildIntervalSet(lst);
lptMaximalSet = new LocalMaximumSetBuilder().buildIntervalSet(lpt);
this.maximalOnly = maximalOnly;
traverseLeft(lst.top(), 0, new HashSet<Interval>());
List<Context> oneFillTandems = new ArrayList<Context>();
for (Map.Entry<Context, Map<Interval, Integer>> entry : envMap.entrySet()) {
if (entry.getValue().size() == 1) {
oneFillTandems.add(entry.getKey());
}
}
for (Context tandem : oneFillTandems) {
envMap.remove(tandem);
}
}
public Map<Interval, Integer> getFill(NestedContext tandem) {
return envMap.get(tandem);
}
private void traverseLeft(BinaryInterval interval, int parentLcp, Set<Interval> fillSet) {
if (!interval.isTerminal()) {
int lcp = interval.lcp();
if (!lstMaximalSet.contains(interval)) {
if (lcp > parentLcp) {
if (maximalOnly)
fillSet.clear();
fillSet.add(interval);
}
} else if (!fillSet.isEmpty() && !lstBwtSet.contains(interval)) {
extendToRight(interval, fillSet);
return;
}
traverseLeft(interval.leftChild(), lcp, new HashSet<Interval>(fillSet));
traverseLeft(interval.rightChild(), lcp, new HashSet<Interval>(fillSet));
}
}
private void extendToRight(Interval interval, Set<Interval> fillSet) {
BinaryInterval lptInterval = lpt.search(interval.label().reverse());
assert (lptInterval != null);
assert (interval.label().size() == lptInterval.label().size());
if (!lptMaximalSet.contains(lptInterval)) {
collectLeftMaximal(lptInterval, lptInterval, interval, lptInterval.lcp(), fillSet);
}
}
private void collectLeftMaximal(BinaryInterval interval, Interval parent, Interval lstInterval, int parentLcp,
Set<Interval> fillSet) {
if (!interval.isTerminal()) {
int lcp = interval.lcp();
if (!lptMaximalSet.contains(interval)) {
if (!maximalOnly) {
if (lcp > parentLcp) {
addTandem(interval, parent, lstInterval, fillSet);
}
}
} else {
addTandem(interval, parent, lstInterval, fillSet);
return;
}
collectLeftMaximal(interval.leftChild(), parent, lstInterval, lcp, fillSet);
collectLeftMaximal(interval.rightChild(), parent, lstInterval, lcp, fillSet);
}
}
private void addTandem(Interval leftInterval, Interval leftParent, Interval rightInterval, Set<Interval> fillSet) {
for (Interval fill : fillSet) {
NestedContext env = new NestedContext(leftInterval.edgeLabel(leftParent).reverse(),
rightInterval.edgeLabel(fill));
int count = min(rightInterval.size(), leftInterval.size());
Map<Interval, Integer> fillMap = envMap.get(env);
if (fillMap == null) {
fillMap = new HashMap<Interval, Integer>();
envMap.put(env, fillMap);
fillMap.put(fill, count);
} else {
Integer currCount = fillMap.get(fill);
if (currCount == null) {
fillMap.put(fill, count);
} else {
fillMap.put(fill, currCount + count);
}
}
}
}
private int min(int i, int j) {
if (i < j)
return i;
else
return j;
}
public class NestedContext implements Context, Comparable<NestedContext> {
private final IntSequence left;
private final IntSequence right;
private Multiset<IntSequence> fillSet;
protected NestedContext(IntSequence left, IntSequence right) {
this.left = left;
this.right = right;
}
@Override
public IntSequence leftSequence() {
return left;
}
@Override
public IntSequence rightSequence() {
return right;
}
@Override
public Multiset<IntSequence> fillerSet() {
if (fillSet == null) {
fillSet = new TreeMultiset<IntSequence>();
Iterator<Interval> iterator = envMap.get(this).keySet().iterator();
while (iterator.hasNext()) {
fillSet.add(iterator.next().label());
}
}
return fillSet;
}
@Override
public int fillerSetSize() {
return fillerSet().size();
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + ((left == null) ? 0 : left.hashCode());
result = prime * result + ((right == null) ? 0 : right.hashCode());
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
Context other = (Context) obj;
if (left == null) {
if (other.leftSequence() != null)
return false;
} else if (!left.equals(other.leftSequence()))
return false;
if (right == null) {
if (other.rightSequence() != null)
return false;
} else if (!right.equals(other.rightSequence()))
return false;
return true;
}
@Override
public int compareTo(NestedContext other) {
int leftCompare = left.compareTo(other.left);
int rightCompare = right.compareTo(other.right);
if (leftCompare != 0)
return leftCompare;
return rightCompare;
}
}
@Override
public Iterator<Context> iterator() {
return envMap.keySet().iterator();
}
}