Skip to content

Commit

Permalink
Feature/optimization (#28)
Browse files Browse the repository at this point in the history
* search optimizations
  • Loading branch information
mounaTay authored Aug 28, 2024
1 parent 652b8ee commit 3b0e76e
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 42 deletions.
1 change: 1 addition & 0 deletions .gitpod.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ tasks:
command: |
. setup.sh
mvn spring-boot:run
ports:
- port: 8081
onOpen: open-browser
Expand Down
7 changes: 6 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,12 @@
<artifactId>jts-core</artifactId>
<version>1.19.0</version>
</dependency>
<dependency>
<dependency>
<groupId>org.locationtech.jts.io</groupId>
<artifactId>jts-io-common</artifactId>
<version>1.19.0</version>
</dependency>
<dependency>
<groupId>org.springdoc</groupId>
<artifactId>springdoc-openapi-starter-webmvc-ui</artifactId>
<version>2.5.0</version>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.datastax.oss.cass_stac.entity.ItemCollection;
import com.datastax.oss.cass_stac.model.ItemSearchRequest;
import com.datastax.oss.cass_stac.service.ItemService;
import com.datastax.oss.cass_stac.util.GeoJsonParser;
import com.datastax.oss.cass_stac.util.SortUtils;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
Expand All @@ -13,9 +14,10 @@
import io.swagger.v3.oas.annotations.responses.ApiResponses;
import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.RequiredArgsConstructor;
import org.locationtech.jts.geom.Geometry;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.cassandra.core.query.CassandraPageRequest;
import org.springframework.data.domain.Pageable;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.web.ErrorResponse;
Expand Down Expand Up @@ -83,11 +85,15 @@ The bounding box is provided as four or six numbers, depending on whether the co
[
144.838158,
-37.927719
],
[
144.81543,
-37.927299
]
]
]
}
""", description = "Search items that intersect this polygon, coordinates should be of length 4")}) @RequestParam(required = false) Geometry intersects,
""", description = "Search items that intersect this polygon, coordinates should be of length 4")}) @RequestParam(required = false) String intersects,
@Parameter(description = "Either a date-time or an interval, open or closed. Date and time expressions adhere to RFC 3339. Open intervals are expressed using double-dots.",
examples = {
@ExampleObject(name = "A closed interval", value = "2023-01-30T00:00:00Z/2018-03-18T12:31:12Z"),
Expand All @@ -109,7 +115,7 @@ The bounding box is provided as four or six numbers, depending on whether the co
try {
ItemCollection response = itemService.search(
bbox,
intersects,
GeoJsonParser.parseGeometry(intersects),
datetime,
limit,
ids,
Expand Down
7 changes: 7 additions & 0 deletions src/main/java/com/datastax/oss/cass_stac/dao/ItemDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,22 @@

import com.datastax.oss.cass_stac.entity.Item;
import com.datastax.oss.cass_stac.entity.ItemPrimaryKey;
import org.jetbrains.annotations.NotNull;
import org.springframework.data.cassandra.repository.CassandraRepository;
import org.springframework.data.cassandra.repository.Query;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Slice;
import org.springframework.data.repository.query.Param;
import org.springframework.stereotype.Repository;

import java.util.List;

@Repository
public interface ItemDao extends CassandraRepository<Item, ItemPrimaryKey> {
@NotNull
Slice<Item> findAll(@NotNull Pageable pageable);

@Query(value = "SELECT * FROM item where partition_id = :partition_id AND id = :id")
List<Item> findItemByPartitionIdAndId(@Param("partition_id") final String partition_id, @Param("id") final String id);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,11 @@
import com.fasterxml.jackson.databind.node.ObjectNode;
import org.locationtech.jts.geom.Geometry;
import org.n52.jackson.datatype.jts.JtsModule;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Map;
import java.util.Objects;

public class GeoJsonItemRequest extends PropertyObject {
private static final Logger logger = LoggerFactory.getLogger(GeoJsonItemRequest.class);

@JsonProperty("type")
private static final String TYPE = "Item";
Expand Down
103 changes: 71 additions & 32 deletions src/main/java/com/datastax/oss/cass_stac/service/ItemService.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,21 @@
import org.locationtech.jts.geom.Polygon;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.cassandra.core.CassandraTemplate;
import org.springframework.data.cassandra.core.query.CassandraPageRequest;
import org.springframework.data.cassandra.core.query.Criteria;
import org.springframework.data.cassandra.core.query.Query;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Slice;
import org.springframework.stereotype.Service;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.time.Instant;
import java.time.OffsetDateTime;
import java.time.ZoneOffset;
import java.time.format.DateTimeParseException;
import java.util.*;

Expand All @@ -45,6 +54,7 @@ public class ItemService {
private static final ObjectMapper objectMapper = new ObjectMapper().registerModule(new JavaTimeModule());

private static final Set<String> datetimeFields = new HashSet<>(Arrays.asList("datetime", "start_datetime", "end_datetime", "created", "updated"));
private final CassandraTemplate cassandraTemplate;

public ItemModelResponse getItemById(final String id) {
final ItemId itemId = itemIdDao.findById(id)
Expand Down Expand Up @@ -211,7 +221,7 @@ private static List<Float> convertJsonNodeToFloatArray(JsonNode jsonNode) {
private ItemDto convertItemToDto(final Item item) throws IOException {
JsonNode bbox = objectMapper.readValue(item.getAdditional_attributes(), JsonNode.class).get("bbox");
List<Float> floatArray = null;
if (bbox.isArray()) {
if (bbox != null && bbox.isArray()) {
floatArray = convertJsonNodeToFloatArray(bbox);
}
return ItemDto.builder()
Expand Down Expand Up @@ -257,42 +267,69 @@ public ItemCollection search(List<Float> bbox,
Boolean includeIds,
Boolean includeObjects) {

List<Item> allItems = (ids != null) ? itemDao.findItemByIds(ids) : itemDao.findAll();
List<Item> allItems = new ArrayList<>();

if (intersects != null) {
allItems = allItems.stream().filter(_item -> GeometryUtil.fromGeometryByteBuffer(_item.getGeometry())
.intersects(intersects)).toList();
Instant minDate = Instant.EPOCH;
Instant maxDate = Instant.now().plusSeconds(3155695200L);

if (datetime != null && datetime.contains("/")) {
String[] parts = datetime.split("/");
minDate = parts[0].equals("..") ? Instant.EPOCH : Instant.parse(parts[0]);
maxDate = parts[1].equals("..") ? Instant.now().plusSeconds(3155695200L) : Instant.parse(parts[1]);
} else if (datetime != null) {
minDate = Instant.parse(datetime);
maxDate = Instant.parse(datetime);
}

Query dbQuery = Query.empty();

if (collectionsArray != null) {
allItems = allItems.stream().filter(_item -> collectionsArray.contains(_item.getCollection())).toList();
dbQuery = dbQuery.and(Criteria.where("collection").in(collectionsArray)).withAllowFiltering();
}

if (datetime != null && datetime.contains("/")) {
String[] parts = datetime.split("/");
Instant start = parts[0].equals("..") ? Instant.EPOCH : Instant.parse(parts[0]);
Instant end = parts[1].equals("..") ? Instant.now().plusSeconds(3155695200L) : Instant.parse(parts[1]);
allItems = allItems.stream().filter(item -> item.getDatetime().compareTo(start) >= 0 && item.getDatetime().compareTo(end) <= 0).toList();
} else if (datetime != null) {
Instant instantDateTime = Instant.parse(datetime);
allItems = allItems.stream().filter(item -> item.getDatetime().equals(instantDateTime)).toList();
if (datetime != null) {
dbQuery = dbQuery.and(Criteria.where("datetime").lte(maxDate))
.and(Criteria.where("datetime").gte(minDate)).withAllowFiltering();
}

if (bbox != null) {
allItems = allItems.stream().filter(_item -> {
limit = limit == null ? 10 : limit;
Pageable pageable = PageRequest.of(0, 1500);
Slice<Item> itemPage;

do {
// Fetch a page of items
itemPage = cassandraTemplate.slice(dbQuery.pageRequest(pageable), Item.class);

// Add the current page content to the list
allItems.addAll(itemPage.getContent());

// Move to the next page
pageable = itemPage.hasNext() ? itemPage.nextPageable() : null;

} while (pageable != null);

allItems = allItems.stream().filter(_item -> {
boolean valid = true;
if (ids != null) {
valid = ids.contains(_item.getId().getId());
}

if (intersects != null)
valid = GeometryUtil.fromGeometryByteBuffer(_item.getGeometry())
.intersects(intersects);

if (bbox != null) {
ItemDto itemDto;
try {
itemDto = convertItemToDto(_item);
} catch (IOException e) {
throw new RuntimeException(e);
}
return BboxIntersects(itemDto.getBbox(), bbox);
}).toList();
}
valid = BboxIntersects(itemDto.getBbox(), bbox);
}

if (query != null) {
QueryEvaluator evaluator = new QueryEvaluator();
allItems = allItems.stream().filter(_item -> {
if (query != null) {
QueryEvaluator evaluator = new QueryEvaluator();
Map<String, Object> additionalAttributes;
JsonNode attributes;
try {
Expand All @@ -302,9 +339,11 @@ public ItemCollection search(List<Float> bbox,
}
additionalAttributes = objectMapper.convertValue(attributes, new TypeReference<>() {
});
return evaluator.evaluate(query, additionalAttributes);
}).toList();
}
valid = evaluator.evaluate(query, additionalAttributes);
}
return valid;
}).toList();


if (sort != null) {
allItems = SortUtils.sortItems(allItems, sort);
Expand Down Expand Up @@ -355,9 +394,9 @@ public ImageResponse getPartitions(

List<String> partitions = switch (request.getGeometry().getGeometryType()) {
case "Point" ->
getPointPartitions(request, minDate, maxDate, objectTypeFilter, whereClause, bindVars, useCentroid, filterObjectsByPolygon);
getPointPartitions(request.getGeometry(), minDate, maxDate, objectTypeFilter, whereClause, bindVars, useCentroid, filterObjectsByPolygon);
case "Polygon" ->
getPolygonPartitions(request, minDate, maxDate, objectTypeFilter, whereClause, bindVars, useCentroid, filterObjectsByPolygon);
getPolygonPartitions(request.getGeometry(), minDate, maxDate, objectTypeFilter, whereClause, bindVars, useCentroid, filterObjectsByPolygon);
default -> throw new IllegalStateException("Unexpected value: " + request.getGeometry().getGeometryType());
};
Optional<List<Item>> items = includeObjects
Expand All @@ -372,7 +411,7 @@ public ImageResponse getPartitions(
}

private List<String> getPointPartitions(
ItemModelRequest request,
Geometry geometry,
OffsetDateTime minDate,
OffsetDateTime maxDate,
List<String> objectTypeFilter,
Expand All @@ -383,15 +422,14 @@ private List<String> getPointPartitions(
final int geoResolution = 6;
final GeoTimePartition.TimeResolution timeResolution = GeoTimePartition.TimeResolution.valueOf("MONTH");

Geometry geometry = request.getGeometry();
Point point = geometry.getFactory().createPoint(geometry.getCoordinate());
return Collections.singletonList(minDate != null
? new GeoTimePartition(geoResolution, timeResolution).getGeoTimePartitionForPoint(point, minDate)
: new GeoPartition(geoResolution).getGeoPartitionForPoint(point));
}

private List<String> getPolygonPartitions(
ItemModelRequest request,
Geometry geometry,
OffsetDateTime minDate,
OffsetDateTime maxDate,
List<String> objectTypeFilter,
Expand All @@ -402,11 +440,9 @@ private List<String> getPolygonPartitions(
final int geoResolution = 6;
final GeoTimePartition.TimeResolution timeResolution = GeoTimePartition.TimeResolution.valueOf("MONTH");

Geometry geometry = request.getGeometry();
Polygon polygon = geometry.getFactory().createPolygon(geometry.getCoordinates());
return (maxDate != null && minDate != null) ? new GeoTimePartition(geoResolution, timeResolution)
.getGeoTimePartitions(polygon, minDate, maxDate) : new GeoPartition(geoResolution).getGeoPartitions(polygon);

}

public AggregationCollection agg(
Expand All @@ -419,6 +455,9 @@ public AggregationCollection agg(
List<String> aggregations,
List<AggregateRequest.Range> ranges
) {
if (datetime.isEmpty())
throw new RuntimeException("datetime is required to filter out data");

ItemCollection itemCollection = search(bbox, intersects, datetime, MAX_VALUE, ids, collections, query, null, false, false, true);

List<Aggregation> aggegationList = aggregations.stream().map(aggregationName -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
import com.datastax.oss.cass_stac.model.ItemModelRequest;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.locationtech.jts.geom.Geometry;
import org.locationtech.jts.io.ParseException;
import org.locationtech.jts.io.geojson.GeoJsonReader;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -18,4 +22,9 @@ public static ItemModelRequest parseGeoJson(String geoJson) throws IOException {

return objectMapper.treeToValue(rootNode, ItemModelRequest.class);
}

public static Geometry parseGeometry(String geoJson) throws ParseException {
GeoJsonReader reader = new GeoJsonReader();
return reader.read(geoJson);
}
}
7 changes: 4 additions & 3 deletions src/main/java/com/datastax/oss/cass_stac/util/SortUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,18 @@ public static List<Item> sortItems(List<Item> items, List<SortBy> sortBy) {
.sorted(comparator)
.collect(Collectors.toList());
}

private static Comparator<Item> getComparatorForField(String field) {
if (field.contains(".")) {
String[] fields = field.split("\\.");
return switch (fields[0]) {
case "properties" -> Comparator.comparing(Item::getProperties);
case "properties" -> Comparator.comparing(Item::getProperties);
case "additional_attributes" -> Comparator.comparing(Item::getAdditional_attributes);
default -> throw new IllegalArgumentException("Invalid sort field: " + field);
};
}else{
} else {
return switch (field) {
// TODO compare Ids
case "id" -> Comparator.comparing(item -> item.getId().getId());
case "collection" -> Comparator.comparing(Item::getCollection);
case "datetime" -> Comparator.comparing(Item::getDatetime);
case "properties" -> Comparator.comparing(Item::getProperties);
Expand Down

0 comments on commit 3b0e76e

Please sign in to comment.