Skip to content

Commit

Permalink
Add other types to give a more robust grouping proc
Browse files Browse the repository at this point in the history
  • Loading branch information
gem-neo4j committed Jan 21, 2025
1 parent fb799e4 commit 45578eb
Show file tree
Hide file tree
Showing 2 changed files with 359 additions and 18 deletions.
207 changes: 191 additions & 16 deletions core/src/main/java/apoc/nodes/Grouping.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,32 @@
import static java.util.Collections.*;

import apoc.Pools;
import apoc.convert.ConvertUtils;
import apoc.result.VirtualNode;
import apoc.result.VirtualRelationship;
import apoc.util.Util;
import apoc.util.collection.Iterables;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.LocalTime;
import java.time.OffsetTime;
import java.time.ZonedDateTime;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.neo4j.exceptions.ArithmeticException;
import org.neo4j.graphdb.*;
import org.neo4j.logging.Log;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.NotThreadSafe;
import org.neo4j.procedure.Procedure;
import org.neo4j.values.storable.DurationValue;
import org.neo4j.values.storable.PointValue;

/**
* @author mh
Expand Down Expand Up @@ -368,6 +377,10 @@ private <C extends Collection<T>, T extends Entity> C fixAggregates(C pcs) {
double[] values = (double[]) v;
entry.setValue(values[1] == 0 ? 0 : values[0] / values[1]);
}
if (k.matches("^avg_.+") && v instanceof DurationValue) {
Long count = ((Number) pc.getProperty(k + "_count", 0)).longValue();
entry.setValue(divDurationValue((DurationValue) v, count));
}
if (k.matches("^collect_.+") && v instanceof Collection) {
entry.setValue(((Collection) v).toArray());
}
Expand All @@ -376,6 +389,21 @@ private <C extends Collection<T>, T extends Entity> C fixAggregates(C pcs) {
return pcs;
}

// This is copied from the monorepo (as there was no way to use it outside of Neo)
public DurationValue divDurationValue(DurationValue div, Long number) {
double divisor = number.doubleValue();

try {
return div.approximate(
(double) div.get("months").longValue() / divisor,
(double) div.get("days").longValue() / divisor,
(double) div.get("seconds").longValue() / divisor,
(double) div.get("nanoseconds").longValue() / divisor);
} catch (ArithmeticException | java.lang.ArithmeticException e) {
return div;
}
}

private void aggregate(Entity pc, Map<String, List<String>> aggregations, Map<String, Object> properties) {
aggregations.forEach((k2, aggNames) -> {
for (String aggName : aggNames) {
Expand All @@ -395,28 +423,36 @@ private void aggregate(Entity pc, Map<String, List<String>> aggregations, Map<St
pc.setProperty(key, ((Number) pc.getProperty(key, 0)).longValue() + 1);
break;
case "sum":
pc.setProperty(
key, ((Number) pc.getProperty(key, 0)).doubleValue() + Util.toDouble(value));
if (value instanceof DurationValue) {
DurationValue dv =
(DurationValue) pc.getProperty(key, DurationValue.duration(0, 0, 0, 0));
pc.setProperty(key, ((DurationValue) value).add(dv));
} else if (value instanceof Number) {
pc.setProperty(
key,
((Number) pc.getProperty(key, 0)).doubleValue() + Util.toDouble(value));
}
break;
case "min":
pc.setProperty(
key,
Math.min(
((Number) pc.getProperty(key, Double.MAX_VALUE)).doubleValue(),
Util.toDouble(value)));
pc.setProperty(key, getMin(key, pc, value));
break;
case "max":
pc.setProperty(
key,
Math.max(
((Number) pc.getProperty(key, Double.MIN_VALUE)).doubleValue(),
Util.toDouble(value)));
pc.setProperty(key, getMax(key, pc, value));
break;
case "avg": {
double[] avg = (double[]) pc.getProperty(key, new double[2]);
avg[0] += Util.toDouble(value);
avg[1] += 1;
pc.setProperty(key, avg);
if (value instanceof Number) {
double[] avg = (double[]) pc.getProperty(key, new double[2]);
avg[0] += Util.toDouble(value);
avg[1] += 1;
pc.setProperty(key, avg);
} else if (value instanceof DurationValue) {
DurationValue dv =
(DurationValue) pc.getProperty(key, DurationValue.duration(0, 0, 0, 0));
pc.setProperty(key, ((DurationValue) value).add(dv));
pc.setProperty(
key + "_count",
((Number) pc.getProperty(key + "_count", 0)).longValue() + 1);
}
break;
}
}
Expand All @@ -426,6 +462,145 @@ private void aggregate(Entity pc, Map<String, List<String>> aggregations, Map<St
});
}

private Object getMin(String key, Entity pc, Object value) {
Object prop = pc.getProperty(key);

if (prop == null) {
return value;
}

if (isComparableTypes(prop, value)) {
return compareValues(prop, value) ? prop : value;
}

return returnMinOfDifferentValues(prop, value);
}

private Object getMax(String key, Entity pc, Object value) {
Object prop = pc.getProperty(key);

if (prop == null) {
return value;
}

if (isComparableTypes(prop, value)) {
return compareValues(prop, value) ? value : prop;
}

return returnMaxOfDifferentValues(prop, value);
}

private boolean isComparableTypes(Object prop, Object value) {
return (prop instanceof ZonedDateTime && value instanceof ZonedDateTime)
|| (prop instanceof LocalDateTime && value instanceof LocalDateTime)
|| (prop instanceof LocalDate && value instanceof LocalDate)
|| (prop instanceof OffsetTime && value instanceof OffsetTime)
|| (prop instanceof LocalTime && value instanceof LocalTime)
|| (prop instanceof DurationValue && value instanceof DurationValue)
|| (prop instanceof String && value instanceof String)
|| (prop instanceof Boolean && value instanceof Boolean)
|| (prop instanceof Number && value instanceof Number)
|| ((prop instanceof Collection || prop.getClass().isArray())
&& (value instanceof Collection || value.getClass().isArray()))
|| (prop instanceof PointValue && value instanceof PointValue);
}

private boolean compareValues(Object prop, Object value) {
if (prop instanceof ZonedDateTime pZonedDateTime) {
return ((ZonedDateTime) value).isAfter(pZonedDateTime);
} else if (prop instanceof LocalDateTime pLocalDateTime) {
return ((LocalDateTime) value).isAfter(pLocalDateTime);
} else if (prop instanceof LocalDate pLocalDate) {
return ((LocalDate) value).isAfter(pLocalDate);
} else if (prop instanceof OffsetTime pOffsetTime) {
return ((OffsetTime) value).isAfter(pOffsetTime);
} else if (prop instanceof LocalTime pLocalTime) {
return ((LocalTime) value).isAfter(pLocalTime);
} else if (prop instanceof DurationValue pDurationValue) {
return pDurationValue.compareTo((DurationValue) value) < 0;
} else if (prop instanceof String pString) {
return pString.compareTo((String) value) < 0;
} else if (prop instanceof Boolean pBool) {
return !pBool; // Return `false` if `prop` is `false`
} else if (prop instanceof Number pNumber) {
return pNumber.doubleValue() < Util.toDouble(value);
} else if ((prop instanceof Collection || prop.getClass().isArray())
&& (value instanceof Collection || value.getClass().isArray())) {
return compareCollections(ConvertUtils.convertToList(prop), ConvertUtils.convertToList(value));
} else if (prop instanceof PointValue pPoint && value instanceof PointValue vPoint) {
return pPoint.compareTo(vPoint) < 0;
}
return false; // Default fallback (shouldn't reach here for comparable types)
}

private boolean compareCollections(Collection<?> col1, Collection<?> col2) {
Iterator<?> it1 = col1.iterator();
Iterator<?> it2 = col2.iterator();

while (it1.hasNext() && it2.hasNext()) {
Object elem1 = it1.next();
Object elem2 = it2.next();

// Compare elements recursively
int comparison = compareElements(elem1, elem2);
if (comparison != 0) {
return comparison < 0; // Return true if col1 < col2
}
}

// If one collection runs out of elements, it is smaller
return col1.size() < col2.size();
}

private int compareElements(Object elem1, Object elem2) {
if (elem1 == null && elem2 == null) {
return 0;
}
if (elem1 == null) {
return -1;
}
if (elem2 == null) {
return 1;
}

if (elem1 instanceof Comparable && elem2 instanceof Comparable) {
// Cast to Comparable and compare
return ((Comparable<Object>) elem1).compareTo(elem2);
}

// If elements are not directly comparable, use orderOfType
return Integer.compare(orderOfType(elem1), orderOfType(elem2));
}

private Object returnMinOfDifferentValues(Object prop, Object value) {
return orderOfType(prop) < orderOfType(value) ? prop : value;
}

private Object returnMaxOfDifferentValues(Object prop, Object value) {
return orderOfType(prop) < orderOfType(value) ? value : prop;
}

private int orderOfType(Object value) {
if (value != null && value.getClass().isArray()) {
return 0;
}
return switch (value) {
case null -> 11;
case Collection ignored -> 0;
case PointValue ignored -> 1;
case ZonedDateTime ignored -> 2;
case LocalDateTime ignored -> 3;
case LocalDate ignored -> 4;
case OffsetTime ignored -> 5;
case LocalTime ignored -> 6;
case DurationValue ignored -> 7;
case String ignored -> 8;
case Boolean ignored -> 9;
case Number ignored -> 10;
default -> 12;
};
}

/**
* Returns the properties for the given node according to the specified keys. If a node does not have a property
* assigned to given key, the value is set to {@code null}.
Expand Down
Loading

0 comments on commit 45578eb

Please sign in to comment.