diff --git a/fluss-common/src/main/java/org/apache/fluss/utils/InternalRowUtils.java b/fluss-common/src/main/java/org/apache/fluss/utils/InternalRowUtils.java index 438e17e9f6..c1ae060f90 100644 --- a/fluss-common/src/main/java/org/apache/fluss/utils/InternalRowUtils.java +++ b/fluss-common/src/main/java/org/apache/fluss/utils/InternalRowUtils.java @@ -24,13 +24,42 @@ import org.apache.fluss.row.BinaryString; import org.apache.fluss.row.Decimal; +import org.apache.fluss.row.InternalArray; +import org.apache.fluss.row.InternalMap; +import org.apache.fluss.row.InternalRow; import org.apache.fluss.row.TimestampLtz; import org.apache.fluss.row.TimestampNtz; +import org.apache.fluss.types.ArrayType; +import org.apache.fluss.types.DataType; import org.apache.fluss.types.DataTypeRoot; +import org.apache.fluss.types.MapType; +import org.apache.fluss.types.RowType; /** Utility class for {@link org.apache.fluss.row.InternalRow} related operations. */ public class InternalRowUtils { + /** + * Compares two objects based on their data type. + * + * @param x the first object + * @param y the second object + * @param type the data type + * @return a negative integer, zero, or a positive integer as x is less than, equal to, or + * greater than y + */ + public static int compare(Object x, Object y, DataType type) { + switch (type.getTypeRoot()) { + case ARRAY: + return compareArray((InternalArray) x, (InternalArray) y, (ArrayType) type); + case ROW: + return compareRow((InternalRow) x, (InternalRow) y, (RowType) type); + case MAP: + return compareMap((InternalMap) x, (InternalMap) y, (MapType) type); + default: + return compare(x, y, type.getTypeRoot()); + } + } + /** * Compares two objects based on their data type. * @@ -92,6 +121,64 @@ public static int compare(Object x, Object y, DataTypeRoot type) { return ret; } + private static int compareArray(InternalArray a1, InternalArray a2, ArrayType type) { + int size1 = a1.size(); + int size2 = a2.size(); + int size = Math.min(size1, size2); + InternalArray.ElementGetter getter = + InternalArray.createElementGetter(type.getElementType()); + + for (int i = 0; i < size; i++) { + Object o1 = getter.getElementOrNull(a1, i); + Object o2 = getter.getElementOrNull(a2, i); + + if (o1 == null && o2 == null) { + continue; + } + if (o1 == null) { + return -1; + } + if (o2 == null) { + return 1; + } + + int cmp = compare(o1, o2, type.getElementType()); + if (cmp != 0) { + return cmp; + } + } + return Integer.compare(size1, size2); + } + + private static int compareRow(InternalRow r1, InternalRow r2, RowType type) { + int count = type.getFieldCount(); + for (int i = 0; i < count; i++) { + InternalRow.FieldGetter getter = InternalRow.createFieldGetter(type.getTypeAt(i), i); + Object o1 = getter.getFieldOrNull(r1); + Object o2 = getter.getFieldOrNull(r2); + + if (o1 == null && o2 == null) { + continue; + } + if (o1 == null) { + return -1; + } + if (o2 == null) { + return 1; + } + + int cmp = compare(o1, o2, type.getTypeAt(i)); + if (cmp != 0) { + return cmp; + } + } + return 0; + } + + private static int compareMap(InternalMap m1, InternalMap m2, MapType type) { + throw new IllegalArgumentException("Map type is not comparable: " + type); + } + private static int byteArrayCompare(byte[] array1, byte[] array2) { for (int i = 0, j = 0; i < array1.length && j < array2.length; i++, j++) { int a = (array1[i] & 0xff); diff --git a/fluss-server/src/main/java/org/apache/fluss/server/kv/rowmerger/aggregate/functions/FieldMaxAgg.java b/fluss-server/src/main/java/org/apache/fluss/server/kv/rowmerger/aggregate/functions/FieldMaxAgg.java index 812dc84306..fbb4cffcbb 100644 --- a/fluss-server/src/main/java/org/apache/fluss/server/kv/rowmerger/aggregate/functions/FieldMaxAgg.java +++ b/fluss-server/src/main/java/org/apache/fluss/server/kv/rowmerger/aggregate/functions/FieldMaxAgg.java @@ -39,7 +39,7 @@ public Object agg(Object accumulator, Object inputField) { if (accumulator == null || inputField == null) { return accumulator == null ? inputField : accumulator; } - return InternalRowUtils.compare(accumulator, inputField, typeRoot) < 0 + return InternalRowUtils.compare(accumulator, inputField, fieldType) < 0 ? inputField : accumulator; } diff --git a/fluss-server/src/main/java/org/apache/fluss/server/kv/rowmerger/aggregate/functions/FieldMinAgg.java b/fluss-server/src/main/java/org/apache/fluss/server/kv/rowmerger/aggregate/functions/FieldMinAgg.java index eba6f49f5b..2b4da38d22 100644 --- a/fluss-server/src/main/java/org/apache/fluss/server/kv/rowmerger/aggregate/functions/FieldMinAgg.java +++ b/fluss-server/src/main/java/org/apache/fluss/server/kv/rowmerger/aggregate/functions/FieldMinAgg.java @@ -40,7 +40,7 @@ public Object agg(Object accumulator, Object inputField) { return accumulator == null ? inputField : accumulator; } - return InternalRowUtils.compare(accumulator, inputField, typeRoot) < 0 + return InternalRowUtils.compare(accumulator, inputField, fieldType) < 0 ? accumulator : inputField; } diff --git a/fluss-server/src/test/java/org/apache/fluss/server/kv/rowmerger/aggregate/ComplexTypeAggregationTest.java b/fluss-server/src/test/java/org/apache/fluss/server/kv/rowmerger/aggregate/ComplexTypeAggregationTest.java new file mode 100644 index 0000000000..ec98a4b4cb --- /dev/null +++ b/fluss-server/src/test/java/org/apache/fluss/server/kv/rowmerger/aggregate/ComplexTypeAggregationTest.java @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.fluss.server.kv.rowmerger.aggregate; + +import org.apache.fluss.config.Configuration; +import org.apache.fluss.config.TableConfig; +import org.apache.fluss.metadata.AggFunctions; +import org.apache.fluss.metadata.KvFormat; +import org.apache.fluss.metadata.Schema; +import org.apache.fluss.metadata.SchemaInfo; +import org.apache.fluss.record.BinaryValue; +import org.apache.fluss.record.TestingSchemaGetter; +import org.apache.fluss.row.BinaryArray; +import org.apache.fluss.row.BinaryRow; +import org.apache.fluss.row.BinaryString; +import org.apache.fluss.row.encode.RowEncoder; +import org.apache.fluss.server.kv.rowmerger.AggregateRowMerger; +import org.apache.fluss.types.DataTypes; +import org.apache.fluss.types.RowType; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +class ComplexTypeAggregationTest { + + private static final short SCHEMA_ID = (short) 1; + + @Test + void testArrayLastValue() { + Schema schema = + Schema.newBuilder() + .column("id", DataTypes.INT()) + .column("arr", DataTypes.ARRAY(DataTypes.INT()), AggFunctions.LAST_VALUE()) + .primaryKey("id") + .build(); + + TableConfig tableConfig = new TableConfig(new Configuration()); + AggregateRowMerger merger = createMerger(schema, tableConfig); + + BinaryArray arr1 = BinaryArray.fromPrimitiveArray(new int[] {1, 2, 3}); + BinaryArray arr2 = BinaryArray.fromPrimitiveArray(new int[] {4, 5, 6}); + + BinaryRow row1 = compactedRow(schema.getRowType(), new Object[] {1, arr1}); + BinaryRow row2 = compactedRow(schema.getRowType(), new Object[] {1, arr2}); + + BinaryValue merged = merger.merge(toBinaryValue(row1), toBinaryValue(row2)); + + BinaryArray resultArr = (BinaryArray) merged.row.getArray(1); + assertThat(resultArr).isEqualTo(arr2); + } + + @Test + void testArrayMinMax() { + Schema schema = + Schema.newBuilder() + .column("id", DataTypes.INT()) + .column("min_arr", DataTypes.ARRAY(DataTypes.INT()), AggFunctions.MIN()) + .column("max_arr", DataTypes.ARRAY(DataTypes.INT()), AggFunctions.MAX()) + .primaryKey("id") + .build(); + + TableConfig tableConfig = new TableConfig(new Configuration()); + AggregateRowMerger merger = createMerger(schema, tableConfig); + + // arr1 < arr2 + BinaryArray arr1 = BinaryArray.fromPrimitiveArray(new int[] {1, 2, 3}); + BinaryArray arr2 = BinaryArray.fromPrimitiveArray(new int[] {1, 2, 4}); + + BinaryRow row1 = compactedRow(schema.getRowType(), new Object[] {1, arr1, arr1}); + BinaryRow row2 = compactedRow(schema.getRowType(), new Object[] {1, arr2, arr2}); + + BinaryValue merged = merger.merge(toBinaryValue(row1), toBinaryValue(row2)); + + BinaryArray minResult = (BinaryArray) merged.row.getArray(1); + BinaryArray maxResult = (BinaryArray) merged.row.getArray(2); + + assertThat(minResult).isEqualTo(arr1); + assertThat(maxResult).isEqualTo(arr2); + + // Test with different sizes + // arr3 < arr1 (size 2 vs 3, prefix match) + BinaryArray arr3 = BinaryArray.fromPrimitiveArray(new int[] {1, 2}); + BinaryRow row3 = compactedRow(schema.getRowType(), new Object[] {1, arr3, arr3}); + + merged = merger.merge(toBinaryValue(row1), toBinaryValue(row3)); + minResult = (BinaryArray) merged.row.getArray(1); + maxResult = (BinaryArray) merged.row.getArray(2); + + assertThat(minResult).isEqualTo(arr3); + assertThat(maxResult).isEqualTo(arr1); + } + + @Test + void testRowMinMax() { + RowType nestedRowType = RowType.of(DataTypes.INT(), DataTypes.STRING()); + Schema schema = + Schema.newBuilder() + .column("id", DataTypes.INT()) + .column("min_row", nestedRowType, AggFunctions.MIN()) + .column("max_row", nestedRowType, AggFunctions.MAX()) + .primaryKey("id") + .build(); + + TableConfig tableConfig = new TableConfig(new Configuration()); + AggregateRowMerger merger = createMerger(schema, tableConfig); + + // row1 < row2 + BinaryRow nestedRow1 = + compactedRow(nestedRowType, new Object[] {1, BinaryString.fromString("a")}); + BinaryRow nestedRow2 = + compactedRow(nestedRowType, new Object[] {1, BinaryString.fromString("b")}); + + BinaryRow row1 = + compactedRow(schema.getRowType(), new Object[] {1, nestedRow1, nestedRow1}); + BinaryRow row2 = + compactedRow(schema.getRowType(), new Object[] {1, nestedRow2, nestedRow2}); + + BinaryValue merged = merger.merge(toBinaryValue(row1), toBinaryValue(row2)); + + BinaryRow minResult = (BinaryRow) merged.row.getRow(1, 2); + BinaryRow maxResult = (BinaryRow) merged.row.getRow(2, 2); + + assertThat(minResult).isEqualTo(nestedRow1); + assertThat(maxResult).isEqualTo(nestedRow2); + } + + private BinaryValue toBinaryValue(BinaryRow row) { + return new BinaryValue(SCHEMA_ID, row); + } + + private AggregateRowMerger createMerger(Schema schema, TableConfig tableConfig) { + TestingSchemaGetter schemaGetter = + new TestingSchemaGetter(new SchemaInfo(schema, (short) 1)); + AggregateRowMerger merger = + new AggregateRowMerger(tableConfig, tableConfig.getKvFormat(), schemaGetter); + merger.configureTargetColumns(null, (short) 1, schema); + return merger; + } + + private BinaryRow compactedRow(RowType rowType, Object[] values) { + RowEncoder encoder = RowEncoder.create(KvFormat.COMPACTED, rowType); + encoder.startNewRow(); + for (int i = 0; i < values.length; i++) { + encoder.encodeField(i, values[i]); + } + return encoder.finishRow(); + } +}