/**
* Copyright (C) 2011 Brian Ferris <bdferris@onebusaway.org>
* Copyright (C) 2011 Google, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.onebusaway.transit_data_federation.impl;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.annotation.PostConstruct;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.standard.StandardAnalyzer;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.queryParser.MultiFieldQueryParser;
import org.apache.lucene.queryParser.ParseException;
import org.apache.lucene.queryParser.QueryParser;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Searcher;
import org.apache.lucene.search.TopDocCollector;
import org.apache.lucene.search.TopDocs;
import org.onebusaway.container.refresh.Refreshable;
import org.onebusaway.gtfs.model.AgencyAndId;
import org.onebusaway.transit_data_federation.model.SearchResult;
import org.onebusaway.transit_data_federation.services.FederatedTransitDataBundle;
import org.onebusaway.transit_data_federation.services.RouteCollectionSearchIndexConstants;
import org.onebusaway.transit_data_federation.services.RouteCollectionSearchService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
@Component
public class RouteCollectionSearchServiceImpl implements
RouteCollectionSearchService {
private static Analyzer _analyzer = new StandardAnalyzer();
private static String[] NAME_FIELDS = {
RouteCollectionSearchIndexConstants.FIELD_ROUTE_SHORT_NAME,
RouteCollectionSearchIndexConstants.FIELD_ROUTE_LONG_NAME};
private FederatedTransitDataBundle _bundle;
private Searcher _searcher;
@Autowired
public void setBundle(FederatedTransitDataBundle bundle) {
_bundle = bundle;
}
@PostConstruct
@Refreshable(dependsOn = RefreshableResources.ROUTE_COLLECTION_SEARCH_DATA)
public void initialize() throws IOException {
File path = _bundle.getRouteSearchIndexPath();
if (path.exists()) {
IndexReader reader = IndexReader.open(path);
_searcher = new IndexSearcher(reader);
} else {
_searcher = null;
}
}
public SearchResult<AgencyAndId> searchForRoutesByName(String value,
int maxResultCount, double minScoreToKeep) throws IOException,
ParseException {
return search(new MultiFieldQueryParser(NAME_FIELDS, _analyzer), value,
maxResultCount, minScoreToKeep);
}
private SearchResult<AgencyAndId> search(QueryParser parser, String value,
int maxResultCount, double minScoreToKeep) throws IOException,
ParseException {
if (_searcher == null)
return new SearchResult<AgencyAndId>();
TopDocCollector collector = new TopDocCollector(maxResultCount);
Query query = parser.parse(value);
_searcher.search(query, collector);
TopDocs top = collector.topDocs();
Map<AgencyAndId, Float> topScores = new HashMap<AgencyAndId, Float>();
String lowerCaseQueryValue = value.toLowerCase();
for (ScoreDoc sd : top.scoreDocs) {
Document document = _searcher.doc(sd.doc);
String routeShortName = document.get(RouteCollectionSearchIndexConstants.FIELD_ROUTE_SHORT_NAME);
Set<String> tokens = new HashSet<String>();
if (routeShortName != null) {
for (String token : routeShortName.toLowerCase().split("\\b")) {
if (!token.isEmpty())
tokens.add(token);
}
}
// Result must have a minimum score to qualify
if (sd.score < minScoreToKeep && !tokens.contains(lowerCaseQueryValue))
continue;
// Keep the best score for a particular id
String agencyId = document.get(RouteCollectionSearchIndexConstants.FIELD_ROUTE_COLLECTION_AGENCY_ID);
String id = document.get(RouteCollectionSearchIndexConstants.FIELD_ROUTE_COLLECTION_ID);
AgencyAndId routeId = new AgencyAndId(agencyId, id);
Float score = topScores.get(routeId);
if (score == null || score < sd.score)
topScores.put(routeId, sd.score);
}
List<AgencyAndId> ids = new ArrayList<AgencyAndId>(topScores.size());
double[] scores = new double[topScores.size()];
int index = 0;
for (AgencyAndId id : topScores.keySet()) {
ids.add(id);
scores[index] = topScores.get(id);
index++;
}
return new SearchResult<AgencyAndId>(ids, scores);
}
}