/*
* Licensed to CRATE Technology GmbH ("Crate") under one or more contributor
* license agreements. See the NOTICE file distributed with this work for
* additional information regarding copyright ownership. Crate licenses
* this file to you under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. You may
* obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* However, if you have executed another commercial license agreement
* with Crate these terms will supersede the license and you may use the
* software solely pursuant to the terms of the relevant commercial agreement.
*/
package io.crate.executor.transport.task.elasticsearch;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import io.crate.Constants;
import io.crate.PartitionName;
import io.crate.executor.QueryResult;
import io.crate.executor.Task;
import io.crate.executor.TaskResult;
import io.crate.metadata.Functions;
import io.crate.metadata.ReferenceInfo;
import io.crate.metadata.Scalar;
import io.crate.operation.Input;
import io.crate.operation.ProjectorUpstream;
import io.crate.operation.projectors.FlatProjectorChain;
import io.crate.operation.projectors.ProjectionToProjectorVisitor;
import io.crate.operation.projectors.Projector;
import io.crate.planner.node.dql.ESGetNode;
import io.crate.planner.projection.Projection;
import io.crate.planner.projection.TopNProjection;
import io.crate.planner.symbol.*;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.get.*;
import org.elasticsearch.action.support.TransportAction;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.search.fetch.source.FetchSourceContext;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.*;
public class ESGetTask implements Task<QueryResult> {
private final static Visitor VISITOR = new Visitor();
private final List<ListenableFuture<QueryResult>> results;
private final TransportAction transportAction;
private final ActionRequest request;
private final ActionListener listener;
public ESGetTask(Functions functions,
ProjectionToProjectorVisitor projectionToProjectorVisitor,
TransportMultiGetAction multiGetAction,
TransportGetAction getAction,
ESGetNode node) {
assert multiGetAction != null;
assert getAction != null;
assert node != null;
assert node.ids().size() > 0;
assert node.limit() == null || node.limit() != 0 : "shouldn't execute ESGetTask if limit is 0";
Map<String, Object> partitionValues = preparePartitionValues(node);
final Context ctx = new Context(functions, node.outputs().size(), partitionValues);
List<FieldExtractor> extractors = new ArrayList<>(node.outputs().size());
for (Symbol symbol : node.outputs()) {
extractors.add(VISITOR.process(symbol, ctx));
}
for (Symbol symbol : node.sortSymbols()) {
extractors.add(VISITOR.process(symbol, ctx));
}
final FetchSourceContext fsc = new FetchSourceContext(ctx.fields());
final SettableFuture<QueryResult> result = SettableFuture.create();
results = Arrays.<ListenableFuture<QueryResult>>asList(result);
if (node.ids().size() > 1) {
MultiGetRequest multiGetRequest = prepareMultiGetRequest(node, fsc);
transportAction = multiGetAction;
request = multiGetRequest;
FlatProjectorChain projectorChain = getFlatProjectorChain(projectionToProjectorVisitor, node);
listener = new MultiGetResponseListener(result, extractors, projectorChain);
} else {
GetRequest getRequest = prepareGetRequest(node, fsc);
transportAction = getAction;
request = getRequest;
listener = new GetResponseListener(result, extractors);
}
}
private Map<String, Object> preparePartitionValues(ESGetNode node) {
Map<String, Object> partitionValues;
if (node.partitionBy().isEmpty()) {
partitionValues = ImmutableMap.of();
} else {
PartitionName partitionName = PartitionName.fromStringSafe(node.index());
int numPartitionColumns = node.partitionBy().size();
partitionValues = new HashMap<>(numPartitionColumns);
for (int i = 0; i < node.partitionBy().size(); i++) {
ReferenceInfo info = node.partitionBy().get(i);
partitionValues.put(
info.ident().columnIdent().fqn(),
info.type().value(partitionName.values().get(i))
);
}
}
return partitionValues;
}
private GetRequest prepareGetRequest(ESGetNode node, FetchSourceContext fsc) {
GetRequest getRequest = new GetRequest(node.index(), Constants.DEFAULT_MAPPING_TYPE, node.ids().get(0));
getRequest.fetchSourceContext(fsc);
getRequest.realtime(true);
getRequest.routing(node.routingValues().get(0));
return getRequest;
}
private MultiGetRequest prepareMultiGetRequest(ESGetNode node, FetchSourceContext fsc) {
MultiGetRequest multiGetRequest = new MultiGetRequest();
for (int i = 0; i < node.ids().size(); i++) {
String id = node.ids().get(i);
MultiGetRequest.Item item = new MultiGetRequest.Item(node.index(), Constants.DEFAULT_MAPPING_TYPE, id);
item.fetchSourceContext(fsc);
item.routing(node.routingValues().get(i));
multiGetRequest.add(item);
}
multiGetRequest.realtime(true);
return multiGetRequest;
}
private FlatProjectorChain getFlatProjectorChain(ProjectionToProjectorVisitor projectionToProjectorVisitor,
ESGetNode node) {
FlatProjectorChain projectorChain = null;
if (node.limit() != null || node.offset() > 0 || !node.sortSymbols().isEmpty()) {
List<Symbol> orderBySymbols = new ArrayList<>(node.sortSymbols().size());
for (Symbol symbol : node.sortSymbols()) {
int i = node.outputs().indexOf(symbol);
if (i < 0 ) {
orderBySymbols.add(new InputColumn(node.outputs().size() + orderBySymbols.size()));
} else {
orderBySymbols.add(new InputColumn(i));
}
}
TopNProjection topNProjection = new TopNProjection(
com.google.common.base.Objects.firstNonNull(node.limit(), Constants.DEFAULT_SELECT_LIMIT),
node.offset(),
orderBySymbols,
node.reverseFlags(),
node.nullsFirst()
);
topNProjection.outputs(genInputColumns(node.outputs().size()));
projectorChain = new FlatProjectorChain(
Arrays.<Projection>asList(topNProjection),
projectionToProjectorVisitor
);
}
return projectorChain;
}
private static List<Symbol> genInputColumns(int size) {
List<Symbol> inputColumns = new ArrayList<>(size);
for (int i = 0; i < size; i++) {
inputColumns.add(new InputColumn(i));
}
return inputColumns;
}
static class MultiGetResponseListener implements ActionListener<MultiGetResponse>, ProjectorUpstream {
private final SettableFuture<QueryResult> result;
private final List<FieldExtractor> fieldExtractor;
@Nullable
private final FlatProjectorChain projectorChain;
private final Projector downstream;
public MultiGetResponseListener(final SettableFuture<QueryResult> result,
List<FieldExtractor> extractors,
@Nullable FlatProjectorChain projectorChain) {
this.result = result;
this.fieldExtractor = extractors;
this.projectorChain = projectorChain;
if (projectorChain == null) {
downstream = null;
} else {
downstream = projectorChain.firstProjector();
downstream.registerUpstream(this);
Futures.addCallback(projectorChain.result(), new FutureCallback<Object[][]>() {
@Override
public void onSuccess(@Nullable Object[][] rows) {
result.set(new QueryResult(rows));
}
@Override
public void onFailure(@Nonnull Throwable t) {
result.setException(t);
}
});
}
}
@Override
public void onResponse(MultiGetResponse responses) {
if (projectorChain == null) {
List<Object[]> rows = new ArrayList<>(responses.getResponses().length);
for (MultiGetItemResponse response : responses) {
if (response.isFailed() || !response.getResponse().isExists()) {
continue;
}
final Object[] row = new Object[fieldExtractor.size()];
int c = 0;
for (FieldExtractor extractor : fieldExtractor) {
row[c] = extractor.extract(response.getResponse());
c++;
}
rows.add(row);
}
result.set(new QueryResult(rows.toArray(new Object[rows.size()][])));
} else {
projectorChain.startProjections();
try {
for (MultiGetItemResponse response : responses) {
if (response.isFailed() || !response.getResponse().isExists()) {
continue;
}
final Object[] row = new Object[fieldExtractor.size()];
int c = 0;
for (FieldExtractor extractor : fieldExtractor) {
row[c] = extractor.extract(response.getResponse());
c++;
}
if (!downstream.setNextRow(row)) {
break;
}
}
downstream.upstreamFinished();
} catch (Exception e) {
downstream.upstreamFailed(e);
}
}
}
@Override
public void onFailure(Throwable e) {
result.setException(e);
}
@Override
public void downstream(Projector downstream) {
throw new UnsupportedOperationException("Setting downstream isn't supported on MultiGetResponseListener");
}
@Override
public Projector downstream() {
return downstream ;
}
}
static class GetResponseListener implements ActionListener<GetResponse> {
private final SettableFuture<QueryResult> result;
private final List<FieldExtractor> extractors;
public GetResponseListener(SettableFuture<QueryResult> result, List<FieldExtractor> extractors) {
this.result = result;
this.extractors = extractors;
}
@Override
public void onResponse(GetResponse response) {
if (!response.isExists()) {
result.set( TaskResult.EMPTY_RESULT);
return;
}
final Object[][] rows = new Object[1][extractors.size()];
int c = 0;
for (FieldExtractor extractor : extractors) {
/**
* NOTE: mapping isn't applied. So if an Insert was done using the ES Rest Endpoint
* the data might be returned in the wrong format (date as string instead of long)
*/
rows[0][c] = extractor.extract(response);
c++;
}
result.set(new QueryResult(rows));
}
@Override
public void onFailure(Throwable e) {
result.setException(e);
}
}
@Override
@SuppressWarnings("unchecked")
public void start() {
transportAction.execute(request, listener);
}
@Override
public List<ListenableFuture<QueryResult>> result() {
return results;
}
@Override
public void upstreamResult(List<ListenableFuture<TaskResult>> result) {
throw new UnsupportedOperationException(
String.format(Locale.ENGLISH, "upstreamResult not supported on %s",
getClass().getSimpleName()));
}
static FieldExtractor buildExtractor(final String field, final Context context) {
if (field.equals("_version")) {
return new FieldExtractor() {
@Override
public Object extract(GetResponse response) {
return response.getVersion();
}
};
} else if (field.equals("_id")) {
return new FieldExtractor() {
@Override
public Object extract(GetResponse response) {
return response.getId();
}
};
} else if (context.partitionValues.containsKey(field)) {
return new FieldExtractor() {
@Override
public Object extract(GetResponse response) {
return context.partitionValues.get(field);
}
};
} else {
return new FieldExtractor() {
@Override
public Object extract(GetResponse response) {
assert response.getSourceAsMap() != null;
return XContentMapValues.extractValue(field, response.getSourceAsMap());
}
};
}
}
static class Context {
private final List<String> fields;
private final Functions functions;
private final Map<String, Object> partitionValues;
private String[] fieldsArray;
Context(Functions functions, int size, Map<String, Object> partitionValues) {
this.functions = functions;
this.partitionValues = partitionValues;
fields = new ArrayList<>(size);
}
public void addField(String fieldName) {
fields.add(fieldName);
}
public String[] fields() {
if (fieldsArray == null) {
fieldsArray = fields.toArray(new String[fields.size()]);
}
return fieldsArray;
}
}
static class Visitor extends SymbolVisitor<Context, FieldExtractor> {
@Override
public FieldExtractor visitReference(Reference symbol, Context context) {
String fieldName = symbol.info().ident().columnIdent().fqn();
context.addField(fieldName);
return buildExtractor(fieldName, context);
}
@Override
public FieldExtractor visitDynamicReference(DynamicReference symbol, Context context) {
return visitReference(symbol, context);
}
@Override
public FieldExtractor visitFunction(Function symbol, Context context) {
List<FieldExtractor> subExtractors = new ArrayList<>(symbol.arguments().size());
for (Symbol argument : symbol.arguments()) {
subExtractors.add(process(argument, context));
}
return new FunctionExtractor((Scalar) context.functions.getSafe(symbol.info().ident()), subExtractors);
}
@Override
public FieldExtractor visitLiteral(Literal symbol, Context context) {
return new LiteralExtractor(symbol.value());
}
@Override
protected FieldExtractor visitSymbol(Symbol symbol, Context context) {
throw new UnsupportedOperationException(
SymbolFormatter.format("Get operation not supported with symbol %s in the result column list", symbol));
}
}
private interface FieldExtractor {
Object extract(GetResponse response);
}
private static class LiteralExtractor implements FieldExtractor {
private final Object literal;
private LiteralExtractor(Object literal) {
this.literal = literal;
}
@Override
public Object extract(GetResponse response) {
return literal;
}
}
private static class FunctionExtractor implements FieldExtractor {
private final Scalar scalar;
private final List<FieldExtractor> subExtractors;
public FunctionExtractor(Scalar scalar, List<FieldExtractor> subExtractors) {
this.scalar = scalar;
this.subExtractors = subExtractors;
}
@Override
public Object extract(final GetResponse response) {
Input[] inputs = new Input[subExtractors.size()];
int idx = 0;
for (final FieldExtractor subExtractor : subExtractors) {
inputs[idx] = new Input() {
@Override
public Object value() {
return subExtractor.extract(response);
}
};
idx++;
}
//noinspection unchecked
return scalar.evaluate(inputs);
}
}
}