diff --git a/src/main/java/org/apache/sysds/hops/NaryOp.java b/src/main/java/org/apache/sysds/hops/NaryOp.java index 6962beadcbc..d752316a526 100644 --- a/src/main/java/org/apache/sysds/hops/NaryOp.java +++ b/src/main/java/org/apache/sysds/hops/NaryOp.java @@ -165,7 +165,7 @@ else if ( areDimsBelowThreshold() ) setRequiresRecompileIfNecessary(); //ensure cp exec type for single-node operations - if ( _op == OpOpN.PRINTF || _op == OpOpN.EVAL || _op == OpOpN.LIST + if ( _op == OpOpN.PRINTF || _op == OpOpN.EVAL || _op == OpOpN.LIST || _op == OpOpN.EINSUM //TODO: cbind/rbind of lists only support in CP right now || (_op == OpOpN.CBIND && getInput().get(0).getDataType().isList()) || (_op == OpOpN.RBIND && getInput().get(0).getDataType().isList()) diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimization.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimization.java index fdd2f8343fe..960560c254e 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimization.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimization.java @@ -210,7 +210,7 @@ protected void optimizeMMChain(Hop hop, List mmChain, List mmOperators * Thomas H. Cormen, Charles E. Leiserson, Ronald L. Rivest, Clifford Stein * Introduction to Algorithms, Third Edition, MIT Press, page 395. */ - private static int[][] mmChainDP(double[] dimArray, int size) + public static int[][] mmChainDP(double[] dimArray, int size) { double[][] dpMatrix = new double[size][size]; //min cost table int[][] split = new int[size][size]; //min cost index table diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java new file mode 100644 index 00000000000..29c3187c3e1 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.einsum; + +import org.apache.commons.logging.Log; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import scala.Int; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public abstract class EOpNode { + public Character c1; + public Character c2; + public Integer dim1; + public Integer dim2; + public EOpNode(Character c1, Character c2, Integer dim1, Integer dim2) { + this.c1 = c1; + this.c2 = c2; + this.dim1 = dim1; + this.dim2 = dim2; + } + + public String getOutputString() { + if(c1 == null) return "''"; + if(c2 == null) return c1.toString(); + return c1.toString() + c2.toString(); + } + public abstract List getChildren(); + + public String[] recursivePrintString(){ + ArrayList inpStrings = new ArrayList<>(); + for (EOpNode node : getChildren()) { + inpStrings.add(node.recursivePrintString()); + } + String[] inpRes = inpStrings.stream().flatMap(Arrays::stream).toArray(String[]::new); + String[] res = new String[1 + inpRes.length]; + + res[0] = this.toString(); + + for (int i=0; i inputs, int numOfThreads, Log LOG); + + public abstract EOpNode reorderChildrenAndOptimize(EOpNode parent, Character outChar1, Character outChar2); +} + diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java new file mode 100644 index 00000000000..d2917121690 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java @@ -0,0 +1,450 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.einsum; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.commons.lang3.tuple.Triple; +import org.apache.commons.logging.Log; +import org.apache.sysds.runtime.codegen.LibSpoofPrimitives; +import org.apache.sysds.runtime.functionobjects.Multiply; +import org.apache.sysds.runtime.functionobjects.Plus; +import org.apache.sysds.runtime.functionobjects.SwapIndex; +import org.apache.sysds.runtime.instructions.cp.DoubleObject; +import org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction; +import org.apache.sysds.runtime.instructions.cp.ScalarObject; +import org.apache.sysds.runtime.matrix.data.LibMatrixMult; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.AggregateOperator; +import org.apache.sysds.runtime.matrix.operators.BinaryOperator; +import org.apache.sysds.runtime.matrix.operators.ReorgOperator; +import org.apache.sysds.runtime.matrix.operators.SimpleOperator; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.function.Predicate; + +import static org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction.ensureMatrixBlockColumnVector; +import static org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction.ensureMatrixBlockRowVector; + +public class EOpNodeBinary extends EOpNode { + + public enum EBinaryOperand { // upper case: char remains, lower case: summed (reduced) dimension + ////// mm: ////// + Ba_aC, // -> BC + aB_Ca, // -> CB + Ba_Ca, // -> BC + aB_aC, // -> BC + + ////// element-wise multiplications and sums ////// + aB_aB,// elemwise and colsum -> B + Ab_Ab, // elemwise and rowsum ->A + Ab_bA, // elemwise, either colsum or rowsum -> A + aB_Ba, + ab_ab,//M-M sum all + ab_ba, //M-M.T sum all + aB_a,// -> B + Ab_b, // -> A + + ////// elementwise, no summations: ////// + A_A,// v-elemwise -> A + AB_AB,// M-M elemwise -> AB + AB_BA, // M-M.T elemwise -> AB + AB_A, // M-v colwise -> BA!? + AB_B, // M-v rowwise -> AB + + ////// other ////// + a_a,// dot -> + A_B, // outer mult -> AB + A_scalar, // v-scalar + AB_scalar, // m-scalar + scalar_scalar + } + public EOpNode left; + public EOpNode right; + public EBinaryOperand operand; + private boolean transposeResult; + public EOpNodeBinary(EOpNode left, EOpNode right, EBinaryOperand operand){ + super(null,null,null, null); + Character c1, c2; + Integer dim1, dim2; + switch(operand){ + case Ba_aC -> { + c1=left.c1; + c2=right.c2; + dim1=left.dim1; + dim2=right.dim2; + } + case aB_Ca -> { + c1=left.c2; + c2=right.c1; + dim1=left.dim2; + dim2=right.dim1; + } + case Ba_Ca -> { + c1=left.c1; + c2=right.c1; + dim1=left.dim1; + dim2=right.dim1; + } + case aB_aC -> { + c1=left.c2; + c2=right.c2; + dim1=left.dim2; + dim2=right.dim2; + } + case aB_aB, aB_Ba, aB_a -> { + c1=left.c2; + c2=null; + dim1=left.dim2; + dim2=null; + } + case Ab_Ab, Ab_bA, Ab_b, A_A, A_scalar -> { + c1=left.c1; + c2=null; + dim1=left.dim1; + dim2=null; + } + case ab_ab, ab_ba, a_a, scalar_scalar -> { + c1=null; + c2=null; + dim1=null; + dim2=null; + } + case AB_AB, AB_BA, AB_A, AB_B, AB_scalar ->{ + c1=left.c1; + c2=left.c2; + dim1=left.dim1; + dim2=left.dim2; + } + case A_B -> { + c1=left.c1; + c2=right.c1; + dim1=left.dim1; + dim2=right.dim1; + } + default -> throw new IllegalStateException("EOpNodeBinary Unexpected type: " + operand); + } + // super(c1, c2, dim1, dim2); // unavailable in JDK < 22 + this.c1 = c1; + this.c2 = c2; + this.dim1 = dim1; + this.dim2 = dim2; + this.left = left; + this.right = right; + this.operand = operand; + } + + public void setTransposeResult(boolean transposeResult){ + this.transposeResult = transposeResult; + } + + public static EOpNodeBinary combineMatrixMultiply(EOpNode left, EOpNode right) { + if (left.c2 == right.c1) { return new EOpNodeBinary(left, right, EBinaryOperand.Ba_aC); } + if (left.c2 == right.c2) { return new EOpNodeBinary(left, right, EBinaryOperand.Ba_Ca); } + if (left.c1 == right.c1) { return new EOpNodeBinary(left, right, EBinaryOperand.aB_aC); } + if (left.c1 == right.c2) { + var res = new EOpNodeBinary(left, right, EBinaryOperand.aB_Ca); + res.setTransposeResult(true); + return res; + } + throw new RuntimeException("EOpNodeBinary::combineMatrixMultiply: invalid matrix operation"); + } + + @Override + public List getChildren() { + return List.of(this.left, this.right); + } + @Override + public String toString() { + return this.getClass().getSimpleName()+" ("+ operand.toString()+") "+getOutputString(); + } + + @Override + public MatrixBlock computeEOpNode(ArrayList inputs, int numThreads, Log LOG) { + EOpNodeBinary bin = this; + MatrixBlock left = this.left.computeEOpNode(inputs, numThreads, LOG); + MatrixBlock right = this.right.computeEOpNode(inputs, numThreads, LOG); + + AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); + + MatrixBlock res; + + switch (bin.operand){ + case AB_AB -> { + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + } + case A_A -> { + ensureMatrixBlockColumnVector(left); + ensureMatrixBlockColumnVector(right); + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + } + case a_a -> { + ensureMatrixBlockColumnVector(left); + ensureMatrixBlockColumnVector(right); + res = new MatrixBlock(0.0); + res.allocateDenseBlock(); + res.getDenseBlockValues()[0] = LibMatrixMult.dotProduct(left.getDenseBlockValues(), right.getDenseBlockValues(), 0,0 , left.getNumRows()); + } + case Ab_Ab -> { + res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A, List.of(left, right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), + null, numThreads); + } + case aB_aB -> { + res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_A__B, List.of(left, right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), + null, numThreads); + } + case ab_ab -> { + res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__, List.of(left, right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), + null, numThreads); + } + case ab_ba -> { + res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__, List.of(left), List.of(right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), + null, numThreads); + } + case Ab_bA -> { + res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A, List.of(left), List.of(right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), + null, numThreads); + } + case aB_Ba -> { + res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_A__B, List.of(left), List.of(right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), + null, numThreads); + } + case AB_BA -> { + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads); + right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + } + case Ba_aC -> { + res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), numThreads); + } + case aB_Ca -> { + res = LibMatrixMult.matrixMult(right,left, new MatrixBlock(), numThreads); + } + case Ba_Ca -> { + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads); + right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); + res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), numThreads); + } + case aB_aC -> { + if(false && LibMatrixMult.isSkinnyRightHandSide(left.getNumRows(), left.getNumColumns(), right.getNumRows(), right.getNumColumns(), false)){ + res = new MatrixBlock(left.getNumColumns(), right.getNumColumns(),false); + res.allocateDenseBlock(); + double[] m1 = left.getDenseBlock().values(0); + double[] m2 = right.getDenseBlock().values(0); + double[] c = res.getDenseBlock().values(0); + int alen = left.getNumColumns(); + int blen = right.getNumColumns(); + for(int i =0;i { + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left},new ScalarObject[]{new DoubleObject(right.get(0,0))}, new MatrixBlock()); + } + case AB_B -> { + ensureMatrixBlockRowVector(right); + res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); + } + case Ab_b -> { + res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A, List.of(left), new ArrayList<>(), List.of(right), new ArrayList<>(), new ArrayList<>(), + null, numThreads); + } + case AB_A -> { + ensureMatrixBlockColumnVector(right); + res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); + } + case aB_a -> { + res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_A__B, List.of(left), new ArrayList<>(), new ArrayList<>(), List.of(right), new ArrayList<>(), + null, numThreads); + } + case A_B -> { + ensureMatrixBlockColumnVector(left); + ensureMatrixBlockRowVector(right); + res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); + } + case scalar_scalar -> { + return new MatrixBlock(left.get(0,0)*right.get(0,0)); + } + default -> { + throw new IllegalArgumentException("Unexpected value: " + bin.operand.toString()); + } + + } + if(transposeResult){ + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads); + res = res.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); + } + if(c2 == null) ensureMatrixBlockColumnVector(res); + return res; + } + + @Override + public EOpNode reorderChildrenAndOptimize(EOpNode parent, Character outChar1, Character outChar2) { + if (this.operand ==EBinaryOperand.aB_aC){ + if(this.right.c2 == outChar1) { // result is CB so Swap aB and aC + var tmpLeft = left; left = right; right = tmpLeft; + var tmpC1 = c1; c1 = c2; c2 = tmpC1; + var tmpDim1 = dim1; dim1 = dim2; dim2 = tmpDim1; + } + if(EinsumCPInstruction.FUSE_OUTER_MULTIPLY && left instanceof EOpNodeFuse fuse && fuse.einsumRewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__AB + && (!EinsumCPInstruction.FUSE_OUTER_MULTIPLY_EXCEEDS_L2_CACHE_CHECK || ((fuse.dim1 * fuse.dim2 *(fuse.ABs.size()+fuse.BAs.size())) + (right.dim1*right.dim2)) * 8 > 6 * 1024 * 1024) + && LibMatrixMult.isSkinnyRightHandSide(left.dim1, left.dim2, right.dim1, right.dim2, false)) { + fuse.AZs.add(right); + fuse.einsumRewriteType = EOpNodeFuse.EinsumRewriteType.AB_BA_A_AZ__BZ; + fuse.c1 = fuse.c2; + fuse.c2 = right.c2; + return fuse; + } + + left = left.reorderChildrenAndOptimize(this, left.c2, left.c1); // maybe can be reordered + if(left.c2 == right.c1) { // check if change happened: + this.operand = EBinaryOperand.Ba_aC; + } + right = right.reorderChildrenAndOptimize(this, right.c1, right.c2); + }else if (this.operand ==EBinaryOperand.Ba_Ca){ + if(this.right.c1 == outChar1) { // result is CB so Swap Ba and Ca + var tmpLeft = left; left = right; right = tmpLeft; + var tmpC1 = c1; c1 = c2; c2 = tmpC1; + var tmpDim1 = dim1; dim1 = dim2; dim2 = tmpDim1; + } + + right = right.reorderChildrenAndOptimize(this, right.c2, right.c1); // maybe can be reordered + if(left.c2 == right.c1) { // check if change happened: + this.operand = EBinaryOperand.Ba_aC; + } + left = left.reorderChildrenAndOptimize(this, left.c1, left.c2); + }else { + left = left.reorderChildrenAndOptimize(this, left.c1, left.c2); // just recurse + right = right.reorderChildrenAndOptimize(this, right.c1, right.c2); + } + return this; + } + + // used in the old approach + public static Triple> TryCombineAndCost(EOpNode n1 , EOpNode n2, HashMap charToSizeMap, HashMap charToOccurences, Character outChar1, Character outChar2){ + Predicate cannotBeSummed = (c) -> + c == outChar1 || c == outChar2 || charToOccurences.get(c) > 2; + + if(n1.c1 == null) { + // n2.c1 also has to be null + return Triple.of(1, EBinaryOperand.scalar_scalar, Pair.of(null, null)); + } + + if(n2.c1 == null) { + if(n1.c2 == null) + return Triple.of(charToSizeMap.get(n1.c1), EBinaryOperand.A_scalar, Pair.of(n1.c1, null)); + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.AB_scalar, Pair.of(n1.c1, n1.c2)); + } + + if(n1.c1 == n2.c1){ + if(n1.c2 != null){ + if ( n1.c2 == n2.c2){ + if( cannotBeSummed.test(n1.c1)){ + if(cannotBeSummed.test(n1.c2)){ + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.AB_AB, Pair.of(n1.c1, n1.c2)); + } + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.Ab_Ab, Pair.of(n1.c1, null)); + } + + if(cannotBeSummed.test(n1.c2)){ + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.aB_aB, Pair.of(n1.c2, null)); + } + + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.ab_ab, Pair.of(null, null)); + + } + + else if(n2.c2 == null){ + if(cannotBeSummed.test(n1.c1)){ + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*2, EBinaryOperand.AB_A, Pair.of(n1.c1, n1.c2)); + } + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*2, EBinaryOperand.aB_a, Pair.of(n1.c2, null)); // in theory (null, n1.c2) + } + else if(n1.c1 ==outChar1 || n1.c1==outChar2|| charToOccurences.get(n1.c1) > 2){ + return null;// AB,AC + } + else { + return Triple.of((charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2))+(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c2)), EBinaryOperand.aB_aC, Pair.of(n1.c2, n2.c2)); // or n2.c2, n1.c2 + } + }else{ // n1.c2 = null -> c2.c2 = null + if(n1.c1 ==outChar1 || n1.c1==outChar2 || charToOccurences.get(n1.c1) > 2){ + return Triple.of(charToSizeMap.get(n1.c1), EBinaryOperand.A_A, Pair.of(n1.c1, null)); + } + return Triple.of(charToSizeMap.get(n1.c1), EBinaryOperand.a_a, Pair.of(null, null)); + } + + + }else{ // n1.c1 != n2.c1 + if(n1.c2 == null) { + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1), EBinaryOperand.A_B, Pair.of(n1.c1, n2.c1)); + } + else if(n2.c2 == null) { // ab,c + if (n1.c2 == n2.c1) { + if(cannotBeSummed.test(n1.c2)){ + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1), EBinaryOperand.AB_B, Pair.of(n1.c1, n1.c2)); + } + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1), EBinaryOperand.Ab_b, Pair.of(n1.c1, null)); + } + return null; // AB,C + } + else if (n1.c2 == n2.c1) { + if(n1.c1 == n2.c2){ // ab,ba + if(cannotBeSummed.test(n1.c1)){ + if(cannotBeSummed.test(n1.c2)){ + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.AB_BA, Pair.of(n1.c1, n1.c2)); + } + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.Ab_bA, Pair.of(n1.c1, null)); + } + if(cannotBeSummed.test(n1.c2)){ + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.aB_Ba, Pair.of(n1.c2, null)); + } + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.ab_ba, Pair.of(null, null)); + } + if(cannotBeSummed.test(n1.c2)){ + return null; // AB_B + }else{ + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c2), EBinaryOperand.Ba_aC, Pair.of(n1.c1, n2.c2)); + } + } + if(n1.c1 == n2.c2) { + if(cannotBeSummed.test(n1.c1)){ + return null; // AB_B + } + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c1), EBinaryOperand.aB_Ca, Pair.of(n2.c1, n1.c2)); // * its just reorder of mmult + } + else if (n1.c2 == n2.c2) { + if(n1.c2 ==outChar1 || n1.c2==outChar2|| charToOccurences.get(n1.c2) > 2){ + return null; // BA_CA + }else{ + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2) +(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c1)), EBinaryOperand.Ba_Ca, Pair.of(n1.c1, n2.c1)); // or n2.c1, n1.c1 + } + } + else { // something like AB,CD + return null; + } + } + } +} diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java new file mode 100644 index 00000000000..fd710d19d1b --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.einsum; + +import org.apache.commons.logging.Log; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; + +import java.util.ArrayList; +import java.util.List; + +public class EOpNodeData extends EOpNode { + public int matrixIdx; + public EOpNodeData(Character c1, Character c2, Integer dim1, Integer dim2, int matrixIdx){ + super(c1,c2,dim1,dim2); + this.matrixIdx = matrixIdx; + } + + @Override + public List getChildren() { + return List.of(); + } + @Override + public String toString() { + return this.getClass().getSimpleName()+" ("+matrixIdx+") "+getOutputString(); + } + @Override + public MatrixBlock computeEOpNode(ArrayList inputs, int numOfThreads, Log LOG) { + return inputs.get(matrixIdx); + } + + @Override + public EOpNode reorderChildrenAndOptimize(EOpNode parent, Character outChar1, Character outChar2) { + return this; + } +} diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java new file mode 100644 index 00000000000..1107462caf9 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java @@ -0,0 +1,462 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.einsum; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.commons.logging.Log; +import org.apache.sysds.runtime.codegen.SpoofRowwise; +import org.apache.sysds.runtime.functionobjects.Multiply; +import org.apache.sysds.runtime.functionobjects.SwapIndex; +import org.apache.sysds.runtime.instructions.cp.DoubleObject; +import org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction; +import org.apache.sysds.runtime.instructions.cp.ScalarObject; +import org.apache.sysds.runtime.matrix.data.LibMatrixMult; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.BinaryOperator; +import org.apache.sysds.runtime.matrix.operators.ReorgOperator; +import org.jetbrains.annotations.NotNull; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.function.Function; + +import static org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction.ensureMatrixBlockColumnVector; +import static org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction.ensureMatrixBlockRowVector; + +public class EOpNodeFuse extends EOpNode { + + private EOpNode scalar = null; + + public enum EinsumRewriteType{ + // B -> row*vec, A -> row*scalar + AB_BA_B_A__AB, + AB_BA_A__B, + AB_BA_B_A__A, + AB_BA_B_A__, + + // scalar from row(AB).dot(B) multiplied by row(AZ) + AB_BA_B_A_AZ__Z, + + // AZ: last step is outer matrix multiplication using vector Z + AB_BA_A_AZ__BZ, AB_BA_A_AZ__ZB, + } + + public EinsumRewriteType einsumRewriteType; + public List ABs; + public List BAs; + public List Bs; + public List As; + public List AZs; + @Override + public List getChildren(){ + List all = new ArrayList<>(); + all.addAll(ABs); + all.addAll(BAs); + all.addAll(Bs); + all.addAll(As); + all.addAll(AZs); + if (scalar != null) all.add(scalar); + return all; + }; + private EOpNodeFuse(Character c1, Character c2, Integer dim1, Integer dim2, EinsumRewriteType einsumRewriteType, List ABs, List BAs, List Bs, List As, List AZs) { + super(c1,c2, dim1, dim2); + this.einsumRewriteType = einsumRewriteType; + this.ABs = ABs; + this.BAs = BAs; + this.Bs = Bs; + this.As = As; + this.AZs = AZs; + } + public EOpNodeFuse(EinsumRewriteType einsumRewriteType, List ABs, List BAs, List Bs, List As, List AZs, List, List>> AXsAndXs) { + super(null,null,null, null); + switch(einsumRewriteType) { + case AB_BA_B_A__A->{ + c1 = ABs.get(0).c1; + dim1 = ABs.get(0).dim1; + }case AB_BA_A__B -> { + c1 = ABs.get(0).c2; + dim1 = ABs.get(0).dim2; + }case AB_BA_B_A__ -> { + }case AB_BA_B_A__AB -> { + c1 = ABs.get(0).c1; + dim1 = ABs.get(0).dim1; + c2 = ABs.get(0).c2; + dim2 = ABs.get(0).dim2; + }case AB_BA_B_A_AZ__Z -> { + c1 = AZs.get(0).c1; + dim1 = AZs.get(0).dim2; + }case AB_BA_A_AZ__BZ ->{ + c1 = ABs.get(0).c2; + dim1 = ABs.get(0).dim2; + c2 = AZs.get(0).c2; + dim2 = AZs.get(0).dim2; + }case AB_BA_A_AZ__ZB ->{ + c2 = ABs.get(0).c2; + dim2 = ABs.get(0).dim2; + c1 = AZs.get(0).c2; + dim1 = AZs.get(0).dim2; + } + } + this.einsumRewriteType = einsumRewriteType; + this.ABs = ABs; + this.BAs = BAs; + this.Bs = Bs; + this.As = As; + this.AZs = AZs; + } + + @Override + public String toString() { + return this.getClass().getSimpleName()+" ("+einsumRewriteType.toString()+") "+this.getOutputString(); + } + + public void addScalarAsIntermediate(EOpNode scalar) { + if(einsumRewriteType == EinsumRewriteType.AB_BA_B_A__A || einsumRewriteType == EinsumRewriteType.AB_BA_B_A_AZ__Z) + this.scalar = scalar; + else + throw new RuntimeException("EOpNodeFuse.addScalarAsIntermediate: scalar is undefined for type "+einsumRewriteType.toString()); + } + + public static List findFuseOps(ArrayList operands, Character outChar1, Character outChar2, + HashMap charToSize, HashMap charToOccurences, ArrayList ret) { + ArrayList result = new ArrayList<>(); + HashSet matricesChars = new HashSet<>(); + HashMap> matricesCharsStartingWithChar = new HashMap<>(); + HashMap> charsToMatrices = new HashMap<>(); + + for(EOpNode operand1 : operands) { + String k; + + if(operand1.c2 != null) { + k = operand1.c1.toString() + operand1.c2; + matricesChars.add(k); + if(matricesCharsStartingWithChar.containsKey(operand1.c1)) { + matricesCharsStartingWithChar.get(operand1.c1).add(k); + } + else { + HashSet set = new HashSet<>(); + set.add(k); + matricesCharsStartingWithChar.put(operand1.c1, set); + } + } + else { + k = operand1.c1.toString(); + } + + if(charsToMatrices.containsKey(k)) { + charsToMatrices.get(k).add(operand1); + } + else { + ArrayList matrices = new ArrayList<>(); + matrices.add(operand1); + charsToMatrices.put(k, matrices); + } + } + ArrayList> matricesCharsSorted = new ArrayList<>(matricesChars.stream() + .map(x -> Pair.of(charsToMatrices.get(x).get(0).dim1 * charsToMatrices.get(x).get(0).dim2, x)).toList()); + matricesCharsSorted.sort(Comparator.comparing(Pair::getLeft)); + ArrayList AZs = new ArrayList<>(); + + HashSet usedMatricesChars = new HashSet<>(); + HashSet usedOperands = new HashSet<>(); + + for(String ABCandidate : matricesCharsSorted.stream().map(Pair::getRight).toList()) { + if(usedMatricesChars.contains(ABCandidate)) continue; + + char a = ABCandidate.charAt(0); + char b = ABCandidate.charAt(1); + String AB = ABCandidate; + String BA = "" + b + a; + + int BAsCount = (charsToMatrices.containsKey(BA) ? charsToMatrices.get(BA).size() : 0); + int ABsCount = charsToMatrices.get(AB).size(); + + if(BAsCount > ABsCount + 1) { + BA = "" + a + b; + AB = "" + b + a; + char tmp = a; + a = b; + b = tmp; + int tmp2 = ABsCount; + ABsCount = BAsCount; + BAsCount = tmp2; + } + String A = "" + a; + String B = "" + b; + int AsCount = (charsToMatrices.containsKey(A) && !usedMatricesChars.contains(A) ? charsToMatrices.get(A).size() : 0); + int BsCount = (charsToMatrices.containsKey(B) && !usedMatricesChars.contains(B) ? charsToMatrices.get(B).size() : 0); + + if(AsCount == 0 && BsCount == 0 && (ABsCount + BAsCount) < 2) { // no elementwise multiplication possible + continue; + } + + int usedBsCount = BsCount + ABsCount + BAsCount; + + boolean doSumA = false; + boolean doSumB = charToOccurences.get(b) == usedBsCount && (outChar1 == null || b != outChar1) && (outChar2 == null || b != outChar2); + HashSet AZCandidates = matricesCharsStartingWithChar.get(a); + + String AZ = null; + Character z = null; + boolean includeAZ = AZCandidates.size() == 2; + + if(includeAZ) { + for(var AZCandidate : AZCandidates) { + if(AB.equals(AZCandidate)) {continue;} + AZs = charsToMatrices.get(AZCandidate); + z = AZCandidate.charAt(1); + String Z = "" + z; + AZ = "" + a + z; + int AZsCount= AZs.size(); + int ZsCount= charsToMatrices.containsKey(Z) ? charsToMatrices.get(Z).size() : 0; + doSumA = AZsCount + ABsCount + BAsCount + AsCount == charToOccurences.get(a) && (outChar1 == null || a != outChar1) && (outChar2 == null || a != outChar2); + boolean doSumZ = AZsCount + ZsCount == charToOccurences.get(z) && (outChar1 == null || z != outChar1) && (outChar2 == null || z != outChar2); + if(!doSumA){ + includeAZ = false; + } else if(!doSumB && doSumZ){ // swap the order, to have only one fusion AB,...,AZ->Z + b = z; + z = AB.charAt(1); + AB = "" + a + b; + BA = "" + b + a; + A = "" + a; + B = "" + b; + AZ = "" + a + z; + AZs = charsToMatrices.get(AZ); + doSumB = true; + } else if(!doSumB && !doSumZ){ // outer between B and Z + if(!EinsumCPInstruction.FUSE_OUTER_MULTIPLY + || (EinsumCPInstruction.FUSE_OUTER_MULTIPLY_EXCEEDS_L2_CACHE_CHECK && ((charToSize.get(a) * charToSize.get(b) *(ABsCount + BAsCount)) + (charToSize.get(a)*charToSize.get(z)*(AZsCount))) * 8 < 6 * 1024 * 1024) + || !LibMatrixMult.isSkinnyRightHandSide(charToSize.get(AB.charAt(0)), charToSize.get(AB.charAt(1)), charToSize.get(AB.charAt(0)), charToSize.get(AZCandidates.iterator().next().charAt(1)),false)) { + includeAZ = false; + } + } else if(doSumB && doSumZ){ + // it will be two separate templates and then mutliply a vectors + } else if (doSumB && !doSumZ) { + // ->Z template OK + } + break; + } + } + + if(!includeAZ) { + doSumA = charToOccurences.get(a) == AsCount + ABsCount + BAsCount && (outChar1 == null || a != outChar1) && (outChar2 == null || a != outChar2); + } + + ArrayList ABs = charsToMatrices.containsKey(AB) ? charsToMatrices.get(AB) : new ArrayList<>(); + ArrayList BAs = charsToMatrices.containsKey(BA) ? charsToMatrices.get(BA) : new ArrayList<>(); + ArrayList As = charsToMatrices.containsKey(A) && !usedMatricesChars.contains(A) ? charsToMatrices.get(A) : new ArrayList<>(); + ArrayList Bs = charsToMatrices.containsKey(B) && !usedMatricesChars.contains(B) ? charsToMatrices.get(B) : new ArrayList<>(); + Character c1 = null, c2 = null; + Integer dim1 = null, dim2 = null; + EinsumRewriteType type = null; + + if(includeAZ) { + if(doSumB) { + type = EinsumRewriteType.AB_BA_B_A_AZ__Z; + c1 = z; + } + else if((outChar1 != null && outChar2 != null) && outChar1 == z && outChar2 == b) { + type = EinsumRewriteType.AB_BA_A_AZ__ZB; + c1 = z; c2 = b; + } + else if((outChar1 != null && outChar2 != null) && outChar1 == b && outChar2 == z) { + type = EinsumRewriteType.AB_BA_A_AZ__BZ; + c1 = b; c2 = z; + } + else { + type = EinsumRewriteType.AB_BA_A_AZ__ZB; + c1 = z; c2 = b; + } + } + else { + AZs= new ArrayList<>(); + if(doSumA) { + if(doSumB) { + type = EinsumRewriteType.AB_BA_B_A__; + } + else { + type = EinsumRewriteType.AB_BA_A__B; + c1 = AB.charAt(1); + } + } + else if(doSumB) { + type = EinsumRewriteType.AB_BA_B_A__A; + c1 = AB.charAt(0); + } + else { + type = EinsumRewriteType.AB_BA_B_A__AB; + c1 = AB.charAt(0); c2 = AB.charAt(1); + } + } + + if(c1 != null) { + charToOccurences.put(c1, charToOccurences.get(c1) + 1); + dim1 = charToSize.get(c1); + } + if(c2 != null) { + charToOccurences.put(c2, charToOccurences.get(c2) + 1); + dim2 = charToSize.get(c2); + } + boolean includeB = type != EinsumRewriteType.AB_BA_A__B && type != EinsumRewriteType.AB_BA_A_AZ__BZ && type != EinsumRewriteType.AB_BA_A_AZ__ZB; + + usedOperands.addAll(ABs); + usedOperands.addAll(BAs); + usedOperands.addAll(As); + if (includeB) usedOperands.addAll(Bs); + if (includeAZ) usedOperands.addAll(AZs); + + usedMatricesChars.add(AB); + usedMatricesChars.add(BA); + usedMatricesChars.add(A); + if (includeB) usedMatricesChars.add(B); + if (includeAZ) usedMatricesChars.add(AZ); + + var e = new EOpNodeFuse(c1, c2, dim1, dim2, type, ABs, BAs, includeB ? Bs : new ArrayList<>(), As, AZs); + + result.add(e); + } + + for(EOpNode n : operands) { + if(!usedOperands.contains(n)){ + ret.add(n); + } else { + charToOccurences.put(n.c1, charToOccurences.get(n.c1) - 1); + if(charToOccurences.get(n.c2)!= null) + charToOccurences.put(n.c2, charToOccurences.get(n.c2)-1); + } + } + + return result; + } + public static MatrixBlock compute(EinsumRewriteType rewriteType, List ABsInput, List mbBAs, List mbBs, List mbAs, List mbAZs, + Double scalar, int numThreads){ + boolean isResultAB =rewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__AB; + boolean isResultA = rewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A; + boolean isResultB = rewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_A__B; + boolean isResult_ = rewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__; + boolean isResultZ = rewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A_AZ__Z; + boolean isResultBZ =rewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_A_AZ__BZ; + boolean isResultZB =rewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_A_AZ__ZB; + + ArrayList mbABs = new ArrayList<>(ABsInput); + int bSize = mbABs.get(0).getNumColumns(); + int aSize = mbABs.get(0).getNumRows(); + if (!mbBAs.isEmpty()) { + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads); + for(MatrixBlock mb : mbBAs) //BA->AB + mbABs.add(mb.reorgOperations(transpose, null, 0, 0, 0)); + } + + if(mbAs.size() > 1) mbAs = multiplyVectorsIntoOne(mbAs, aSize); + if(mbBs.size() > 1) mbBs = multiplyVectorsIntoOne(mbBs, bSize); + + int constDim2 = -1; + int zSize = 0; + int azCount = 0; + switch(rewriteType){ + case AB_BA_B_A_AZ__Z, AB_BA_A_AZ__BZ, AB_BA_A_AZ__ZB -> { + constDim2 = mbAZs.get(0).getNumColumns(); + zSize = mbAZs.get(0).getNumColumns(); + azCount = mbAZs.size(); + } + } + + SpoofRowwise.RowType rowType = switch(rewriteType){ + case AB_BA_B_A__AB -> SpoofRowwise.RowType.NO_AGG; + case AB_BA_A__B -> SpoofRowwise.RowType.COL_AGG_T; + case AB_BA_B_A__A -> SpoofRowwise.RowType.ROW_AGG; + case AB_BA_B_A__ -> SpoofRowwise.RowType.FULL_AGG; + case AB_BA_B_A_AZ__Z -> SpoofRowwise.RowType.COL_AGG_CONST; + case AB_BA_A_AZ__BZ -> SpoofRowwise.RowType.COL_AGG_B1_T; + case AB_BA_A_AZ__ZB -> SpoofRowwise.RowType.COL_AGG_B1; + }; + EinsumSpoofRowwise r = new EinsumSpoofRowwise(rewriteType, rowType, constDim2, + mbABs.size()-1, !mbBs.isEmpty() && (!isResultBZ && !isResultZB && !isResultB), mbAs.size(), azCount, zSize); + + ArrayList fuseInputs = new ArrayList<>(); + fuseInputs.addAll(mbABs); + if(!isResultBZ && !isResultZB && !isResultB) + fuseInputs.addAll(mbBs); + fuseInputs.addAll(mbAs); + if (isResultZ || isResultBZ || isResultZB) + fuseInputs.addAll(mbAZs); + + ArrayList scalarObjects = new ArrayList<>(); + if(scalar != null){ + scalarObjects.add(new DoubleObject(scalar)); + } + MatrixBlock out = r.execute(fuseInputs, scalarObjects, new MatrixBlock(), numThreads); + + if(isResultB && !mbBs.isEmpty()){ + LibMatrixMult.vectMultiply(mbBs.get(0).getDenseBlockValues(), out.getDenseBlockValues(), 0,0, mbABs.get(0).getNumColumns()); + } + if(isResultBZ && !mbBs.isEmpty()){ + ensureMatrixBlockColumnVector(mbBs.get(0)); + out = out.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), mbBs.get(0)); + } + if(isResultZB && !mbBs.isEmpty()){ + ensureMatrixBlockRowVector(mbBs.get(0)); + out = out.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), mbBs.get(0)); + } + + if( isResultA || isResultB || isResultZ) + ensureMatrixBlockColumnVector(out); + + return out; + } + @Override + public MatrixBlock computeEOpNode(ArrayList inputs, int numThreads, Log LOG) { + final Function eOpNodeToMatrixBlock = n -> n.computeEOpNode(inputs, numThreads, LOG); + ArrayList mbABs = new ArrayList<>(ABs.stream().map(eOpNodeToMatrixBlock).toList()); + List mbBAs = BAs.stream().map(eOpNodeToMatrixBlock).toList(); + List mbBs = Bs.stream().map(eOpNodeToMatrixBlock).toList(); + List mbAs = As.stream().map(eOpNodeToMatrixBlock).toList(); + List mbAZs = AZs.stream().map(eOpNodeToMatrixBlock).toList(); + Double scalar = this.scalar == null ? null : this.scalar.computeEOpNode(inputs, numThreads, LOG).get(0,0); + return EOpNodeFuse.compute(this.einsumRewriteType, mbABs, mbBAs, mbBs, mbAs, mbAZs , scalar, numThreads); + } + + @Override + public EOpNode reorderChildrenAndOptimize(EOpNode parent, Character outChar1, Character outChar2) { + ABs.replaceAll(n -> n.reorderChildrenAndOptimize(this, n.c1, n.c2)); + BAs.replaceAll(n -> n.reorderChildrenAndOptimize(this, n.c1, n.c2)); + As.replaceAll(n -> n.reorderChildrenAndOptimize(this, n.c1, n.c2)); + Bs.replaceAll(n -> n.reorderChildrenAndOptimize(this, n.c1, n.c2)); + AZs.replaceAll(n -> n.reorderChildrenAndOptimize(this, n.c1, n.c2)); + return this; + } + + private static @NotNull List multiplyVectorsIntoOne(List mbs, int size) { + MatrixBlock mb = new MatrixBlock(mbs.get(0).getNumRows(), mbs.get(0).getNumColumns(), false); + mb.allocateDenseBlock(); + for(int i = 1; i< mbs.size(); i++) { // multiply Bs + if(i==1) + LibMatrixMult.vectMultiplyWrite(mbs.get(0).getDenseBlock().values(0), mbs.get(1).getDenseBlock().values(0), mb.getDenseBlock().values(0),0,0,0, size); + else + LibMatrixMult.vectMultiply(mbs.get(i).getDenseBlock().values(0),mb.getDenseBlock().values(0),0,0, size); + } + return List.of(mb); + } +} + diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeUnary.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeUnary.java new file mode 100644 index 00000000000..a94bbc95656 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeUnary.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.einsum; + +import org.apache.commons.logging.Log; +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.common.Types; +import org.apache.sysds.hops.AggUnaryOp; +import org.apache.sysds.hops.rewrite.HopRewriteUtils; +import org.apache.sysds.runtime.functionobjects.DiagIndex; +import org.apache.sysds.runtime.functionobjects.KahanPlus; +import org.apache.sysds.runtime.functionobjects.Plus; +import org.apache.sysds.runtime.functionobjects.ReduceAll; +import org.apache.sysds.runtime.functionobjects.ReduceCol; +import org.apache.sysds.runtime.functionobjects.ReduceDiag; +import org.apache.sysds.runtime.functionobjects.ReduceRow; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.AggregateOperator; +import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; +import org.apache.sysds.runtime.matrix.operators.ReorgOperator; + +import java.util.ArrayList; +import java.util.List; + +public class EOpNodeUnary extends EOpNode { + private final EUnaryOperand eUnaryOperand; + public EOpNode child; + + public enum EUnaryOperand { + DIAG, TRACE, SUM, SUM_COLS, SUM_ROWS + } + public EOpNodeUnary(Character c1, Character c2, Integer dim1, Integer dim2, EOpNode child, EUnaryOperand eUnaryOperand) { + super(c1, c2, dim1, dim2); + this.child = child; + this.eUnaryOperand = eUnaryOperand; + } + + @Override + public List getChildren() { + return List.of(child); + } + @Override + public String toString() { + return this.getClass().getSimpleName()+" ("+eUnaryOperand.toString()+") "+this.getOutputString(); + } + + @Override + public MatrixBlock computeEOpNode(ArrayList inputs, int numOfThreads, Log LOG) { + MatrixBlock mb = child.computeEOpNode(inputs, numOfThreads, LOG); + return switch(eUnaryOperand) { + case DIAG->{ + ReorgOperator op = new ReorgOperator(DiagIndex.getDiagIndexFnObject()); + yield mb.reorgOperations(op, new MatrixBlock(),0,0,0); + } + case TRACE -> { + AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), Types.CorrectionLocationType.LASTCOLUMN); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceDiag.getReduceDiagFnObject(), numOfThreads); + MatrixBlock res = new MatrixBlock(10, 10, false); + mb.aggregateUnaryOperations(aggun, res,0,null); + yield res; + } + case SUM->{ + AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), numOfThreads); + MatrixBlock res = new MatrixBlock(1, 1, false); + mb.aggregateUnaryOperations(aggun, res, 0, null); + yield res; + } + case SUM_COLS ->{ + AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), numOfThreads); + MatrixBlock res = new MatrixBlock(mb.getNumColumns(), 1, false); + mb.aggregateUnaryOperations(aggun, res, 0, null); + yield res; + } + case SUM_ROWS ->{ + AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), numOfThreads); + MatrixBlock res = new MatrixBlock(mb.getNumRows(), 1, false); + mb.aggregateUnaryOperations(aggun, res, 0, null); + yield res; + } + }; + } + + @Override + public EOpNode reorderChildrenAndOptimize(EOpNode parent, Character outChar1, Character outChar2) { + return this; + } +} diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EinsumContext.java b/src/main/java/org/apache/sysds/runtime/einsum/EinsumContext.java index 6da39e59873..55692d0109e 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EinsumContext.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EinsumContext.java @@ -23,155 +23,98 @@ import java.util.ArrayList; import java.util.HashMap; -import java.util.HashSet; import java.util.Iterator; - public class EinsumContext { - public enum ContractDimensions { - CONTRACT_LEFT, - CONTRACT_RIGHT, - CONTRACT_BOTH, - } public Integer outRows; public Integer outCols; public Character outChar1; public Character outChar2; public HashMap charToDimensionSize; public String equationString; - public boolean[] diagonalInputs; - public HashSet summingChars; - public HashSet contractDimsSet; - public ContractDimensions[] contractDims; public ArrayList newEquationStringInputsSplit; - public HashMap> characterAppearanceIndexes; // for each character, this tells in which inputs it appears + public HashMap characterAppearanceCount; private EinsumContext(){}; public static EinsumContext getEinsumContext(String eqStr, ArrayList inputs){ EinsumContext res = new EinsumContext(); res.equationString = eqStr; - res.charToDimensionSize = new HashMap(); - HashSet summingChars = new HashSet<>(); - ContractDimensions[] contractDims = new ContractDimensions[inputs.size()]; - boolean[] diagonalInputs = new boolean[inputs.size()]; // all false by default - HashSet contractDimsSet = new HashSet<>(); - HashMap> partsCharactersToIndices = new HashMap<>(); - ArrayList newEquationStringSplit = new ArrayList<>(); + HashMap charToDimensionSize = new HashMap<>(); + HashMap characterAppearanceCount = new HashMap<>(); + ArrayList newEquationStringSplit = new ArrayList<>(); + Character outChar1 = null; + Character outChar2 = null; Iterator it = inputs.iterator(); MatrixBlock curArr = it.next(); - int arrSizeIterator = 0; - int arrayIterator = 0; - int i; - // first iteration through string: collect information on character-size and what characters are summing characters - for (i = 0; true; i++) { - char c = eqStr.charAt(i); - if(c == '-'){ - i+=2; - break; - } - if(c == ','){ - arrayIterator++; - curArr = it.next(); - arrSizeIterator = 0; - } - else{ - if (res.charToDimensionSize.containsKey(c)) { // sanity check if dims match, this is already checked at validation - if(arrSizeIterator == 0 && res.charToDimensionSize.get(c) != curArr.getNumRows()) - throw new RuntimeException("Einsum: character "+c+" has multiple conflicting sizes"); - else if(arrSizeIterator == 1 && res.charToDimensionSize.get(c) != curArr.getNumColumns()) - throw new RuntimeException("Einsum: character "+c+" has multiple conflicting sizes"); - summingChars.add(c); - } else { - if(arrSizeIterator == 0) - res.charToDimensionSize.put(c, curArr.getNumRows()); - else if(arrSizeIterator == 1) - res.charToDimensionSize.put(c, curArr.getNumColumns()); - } - - arrSizeIterator++; - } - } - - int numOfRemainingChars = eqStr.length() - i; - - if (numOfRemainingChars > 2) - throw new RuntimeException("Einsum: dim > 2 not supported"); - - arrSizeIterator = 0; - - Character outChar1 = numOfRemainingChars > 0 ? eqStr.charAt(i) : null; - Character outChar2 = numOfRemainingChars > 1 ? eqStr.charAt(i+1) : null; - res.outRows=(numOfRemainingChars > 0 ? res.charToDimensionSize.get(outChar1) : 1); - res.outCols=(numOfRemainingChars > 1 ? res.charToDimensionSize.get(outChar2) : 1); - - arrayIterator=0; - // second iteration through string: collect remaining information - for (i = 0; true; i++) { - char c = eqStr.charAt(i); - if (c == '-') { - break; - } - if (c == ',') { - arrayIterator++; - arrSizeIterator = 0; - continue; - } - String s = ""; - - if(summingChars.contains(c)) { - s+=c; - if(!partsCharactersToIndices.containsKey(c)) - partsCharactersToIndices.put(c, new ArrayList<>()); - partsCharactersToIndices.get(c).add(arrayIterator); - } - else if((outChar1 != null && c == outChar1) || (outChar2 != null && c == outChar2)) { - s+=c; - } - else { - contractDimsSet.add(c); - contractDims[arrayIterator] = ContractDimensions.CONTRACT_LEFT; - } - - if(i + 1 < eqStr.length()) { // process next character together - char c2 = eqStr.charAt(i + 1); - i++; - if (c2 == '-') { newEquationStringSplit.add(s); break;} - if (c2 == ',') { arrayIterator++; newEquationStringSplit.add(s); continue; } - - if (c2 == c){ - diagonalInputs[arrayIterator] = true; - if (contractDims[arrayIterator] == ContractDimensions.CONTRACT_LEFT) contractDims[arrayIterator] = ContractDimensions.CONTRACT_BOTH; - } - else{ - if(summingChars.contains(c2)) { - s+=c2; - if(!partsCharactersToIndices.containsKey(c2)) - partsCharactersToIndices.put(c2, new ArrayList<>()); - partsCharactersToIndices.get(c2).add(arrayIterator); - } - else if((outChar1 != null && c2 == outChar1) || (outChar2 != null && c2 == outChar2)) { - s+=c2; - } - else { - contractDimsSet.add(c2); - contractDims[arrayIterator] = contractDims[arrayIterator] == ContractDimensions.CONTRACT_LEFT ? ContractDimensions.CONTRACT_BOTH : ContractDimensions.CONTRACT_RIGHT; - } - } - } - newEquationStringSplit.add(s); - arrSizeIterator++; - } - - res.contractDims = contractDims; - res.contractDimsSet = contractDimsSet; - res.diagonalInputs = diagonalInputs; - res.summingChars = summingChars; + int i = 0; + + char c = eqStr.charAt(i); + for(i = 0; i < eqStr.length(); i++) { + StringBuilder sb = new StringBuilder(2); + for(;i < eqStr.length(); i++){ + c = eqStr.charAt(i); + if (c == ' ') continue; + if (c == ',' || c == '-' ) break; + if (!Character.isAlphabetic(c)) { + throw new RuntimeException("Einsum: only alphabetic characters are supported for dimensions: "+c); + } + sb.append(c); + if (characterAppearanceCount.containsKey(c)) characterAppearanceCount.put(c, characterAppearanceCount.get(c) + 1) ; + else characterAppearanceCount.put(c, 1); + } + String s = sb.toString(); + newEquationStringSplit.add(s); + + if(s.length() > 0){ + if (charToDimensionSize.containsKey(s.charAt(0))) + if (charToDimensionSize.get(s.charAt(0)) != curArr.getNumRows()) + throw new RuntimeException("Einsum: character "+c+" has multiple conflicting sizes"); + charToDimensionSize.put(s.charAt(0), curArr.getNumRows()); + } + if(s.length() > 1){ + if (charToDimensionSize.containsKey(s.charAt(1))) + if (charToDimensionSize.get(s.charAt(1)) != curArr.getNumColumns()) + throw new RuntimeException("Einsum: character "+c+" has multiple conflicting sizes"); + charToDimensionSize.put(s.charAt(1), curArr.getNumColumns()); + } + if(s.length() > 2) throw new RuntimeException("Einsum: only up-to 2D inputs strings allowed "); + + if( c==','){ + curArr = it.next(); + } + else if (c=='-') break; + + if (i == eqStr.length() - 1) {throw new RuntimeException("Einsum: missing '->' substring "+c);} + } + + if (i == eqStr.length() - 1 || eqStr.charAt(i+1) != '>') throw new RuntimeException("Einsum: missing '->' substring "+c); + i+=2; + + StringBuilder sb = new StringBuilder(2); + + for(;i < eqStr.length(); i++){ + c = eqStr.charAt(i); + if (c == ' ') continue; + if (!Character.isAlphabetic(c)) { + throw new RuntimeException("Einsum: only alphabetic characters are supported for dimensions: "+c); + } + sb.append(c); + } + String s = sb.toString(); + if(s.length() > 0) outChar1 = s.charAt(0); + if(s.length() > 1) outChar2 = s.charAt(1); + if(s.length() > 2) throw new RuntimeException("Einsum: only up-to 2D output allowed "); + + res.outRows=(outChar1 == null ? 1 : charToDimensionSize.get(outChar1)); + res.outCols=(outChar2 == null ? 1 : charToDimensionSize.get(outChar2)); + res.outChar1 = outChar1; res.outChar2 = outChar2; res.newEquationStringInputsSplit = newEquationStringSplit; - res.characterAppearanceIndexes = partsCharactersToIndices; + res.characterAppearanceCount = characterAppearanceCount; + res.charToDimensionSize = charToDimensionSize; return res; } } diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java b/src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java index 5643159ef9a..7fdce50d3ba 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java @@ -79,9 +79,13 @@ public static Triple { + genexec_AB(a, ai, b, null, c, ci, len, grix, rix); + if(scalars.length != 0) { LibMatrixMult.vectMultiplyWrite(scalars[0], c, c, ci, ci, len); } + } + case AB_BA_A__B -> { + genexec_B(a, ai, b, null, c, ci, len, grix, rix); + } + case AB_BA_B_A__A -> { + // HARDCODEDgenexec_A_or_(a,ai,b,scalars,c,ci,len,grix,rix); + genexec_A_or_(a, ai, b, null, c, ci, len, grix, rix); + if(scalars.length != 0) { c[rix] *= scalars[0]; } + } + case AB_BA_B_A__ -> { + genexec_A_or_(a, ai, b, null, c, ci, len, grix, rix); + if(scalars.length != 0) { c[0] *= scalars[0]; } + } + case AB_BA_B_A_AZ__Z -> { + double[] temp = {0}; + genexec_A_or_(a, ai, b, null, temp, 0, len, grix, rix); + if(scalars.length != 0) { temp[0] *= scalars[0]; } + if(_AZCount > 1) { + double[] temp2 = new double[_ZSize]; + int bi = _AZStartIndex; + LibMatrixMult.vectMultiplyWrite(b[bi++].values(0), b[bi++].values(0), temp2, _ZSize * rix, + _ZSize * rix, 0, _ZSize); + while(bi < _AZStartIndex + _AZCount) { + LibMatrixMult.vectMultiplyWrite(temp2, b[bi++].values(0), temp2, 0, _ZSize * rix, 0, _ZSize); + } + LibMatrixMult.vectMultiplyAdd(temp[0], temp2, c, 0, 0, _ZSize); + } + else + LibMatrixMult.vectMultiplyAdd(temp[0], b[_AZStartIndex].values(rix), c, _ZSize * rix, 0, _ZSize); + } + case AB_BA_A_AZ__BZ -> { + double[] temp = new double[len]; + genexec_B(a, ai, b, null, temp, 0, len, grix, rix); + if(scalars.length != 0) { + LibMatrixMult.vectMultiplyWrite(scalars[0], temp, temp, 0, 0, len); + } + if(_AZCount > 1) { + double[] temp2 = new double[_ZSize]; + int bi = _AZStartIndex; + LibMatrixMult.vectMultiplyWrite(b[bi++].values(0), b[bi++].values(0), temp2, _ZSize * rix, + _ZSize * rix, 0, _ZSize); + while(bi < _AZStartIndex + _AZCount) { + LibMatrixMult.vectMultiplyWrite(temp2, b[bi++].values(0), temp2, 0, _ZSize * rix, 0, _ZSize); + } + LibSpoofPrimitives.vectOuterMultAdd(temp, temp2, c, 0, 0, 0, len, _ZSize); + } + else + LibSpoofPrimitives.vectOuterMultAdd(temp, b[_AZStartIndex].values(rix), c, 0, _ZSize * rix, 0, len, _ZSize); + } + case AB_BA_A_AZ__ZB -> { + double[] temp = new double[len]; + genexec_B(a, ai, b, null, temp, 0, len, grix, rix); + if(scalars.length != 0) { + LibMatrixMult.vectMultiplyWrite(scalars[0], temp, temp, 0, 0, len); + } + if(_AZCount > 1) { + double[] temp2 = new double[_ZSize]; + int bi = _AZStartIndex; + LibMatrixMult.vectMultiplyWrite(b[bi++].values(0), b[bi++].values(0), temp2, _ZSize * rix, + _ZSize * rix, 0, _ZSize); + while(bi < _AZStartIndex + _AZCount) { + LibMatrixMult.vectMultiplyWrite(temp2, b[bi++].values(0), temp2, 0, _ZSize * rix, 0, _ZSize); + } + LibSpoofPrimitives.vectOuterMultAdd(temp2, temp, c, 0, 0, 0, _ZSize, len); + } + else + LibSpoofPrimitives.vectOuterMultAdd(b[_AZStartIndex].values(rix), temp, c, _ZSize * rix, 0, 0, _ZSize, len); + } + default -> throw new NotImplementedException(); + } + } + + private void genexec_AB(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, long grix, + int rix) { + int bi = 0; + double[] TMP1 = null; + if(_ABCount != 0) { + if(_ABCount == 1 & _ACount == 0 && !_Bsupplied) { + LibMatrixMult.vectMultiplyWrite(a, b[0].values(rix), c, ai, ai, ci, len); + return; + } + TMP1 = LibSpoofPrimitives.vectMultWrite(a, b[bi++].values(rix), ai, ai, len); + while(bi < _ABCount) { + if(_ACount == 0 && !_Bsupplied && bi == _ABCount - 1) { + LibMatrixMult.vectMultiplyWrite(TMP1, b[bi++].values(rix), c, 0, ai, ci, len); + } + else { + LibMatrixMult.vectMultiplyWrite(TMP1, b[bi++].values(rix), TMP1, 0, ai, 0, len); + } + } + } + + if(_Bsupplied) { + if(_ACount == 1) + if(TMP1 == null) + vectMultiplyWrite(b[bi + 1].values(0)[rix], a, b[bi].values(0), c, ai, 0, ci, len); + else + vectMultiplyWrite(b[bi + 1].values(0)[rix], TMP1, b[bi].values(0), c, 0, 0, ci, len); + else if(TMP1 == null) + LibMatrixMult.vectMultiplyWrite(a, b[bi].values(0), c, ai, 0, ci, len); + else + LibMatrixMult.vectMultiplyWrite(TMP1, b[bi].values(0), c, 0, 0, ci, len); + } + else if(_ACount == 1) { + if(TMP1 == null) + LibMatrixMult.vectMultiplyWrite(b[bi].values(0)[rix], a, c, ai, ci, len); + else + LibMatrixMult.vectMultiplyWrite(b[bi].values(0)[rix], TMP1, c, 0, ci, len); + } + } + + private void genexec_B(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, long grix, + int rix) { + int bi = 0; + double[] TMP1 = null; + if(_ABCount == 1 && _ACount == 0) + LibMatrixMult.vectMultiplyAdd(a, b[bi++].values(rix), c, ai, ai, 0, len); + else if(_ABCount != 0) { + TMP1 = LibSpoofPrimitives.vectMultWrite(a, b[bi++].values(rix), ai, ai, len); + while(bi < _ABCount) { + if(_ACount == 0 && bi == _ABCount - 1) { + LibMatrixMult.vectMultiplyAdd(TMP1, b[bi++].values(rix), c, 0, ai, 0, len); + } + else { + LibMatrixMult.vectMultiplyWrite(TMP1, b[bi++].values(rix), TMP1, 0, ai, 0, len); + } + } + } + + if(_ACount == 1) { + if(TMP1 == null) + LibMatrixMult.vectMultiplyAdd(b[bi].values(0)[rix], a, c, ai, 0, len); + else + LibMatrixMult.vectMultiplyAdd(b[bi].values(0)[rix], TMP1, c, 0, 0, len); + } + } + + private void genexec_A_or_(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, long grix, int rix) { + int bi = 0; + double[] TMP1 = null; + double TMP2 = 0; + if(_ABCount == 1 && !_Bsupplied) + TMP2 = LibSpoofPrimitives.dotProduct(a, b[bi++].values(rix), ai, ai, len); + else if(_ABCount != 0) { + TMP1 = LibSpoofPrimitives.vectMultWrite(a, b[bi++].values(rix), ai, ai, len); + while(bi < _ABCount) { + if(!_Bsupplied && bi == _ABCount - 1) + TMP2 = LibSpoofPrimitives.dotProduct(TMP1, b[bi++].values(rix), 0, ai, len); + else + LibMatrixMult.vectMultiplyWrite(TMP1, b[bi++].values(rix), TMP1, 0, ai, 0, len); + } + } + + if(_Bsupplied) + if(_ABCount != 0) TMP2 = LibSpoofPrimitives.dotProduct(TMP1, b[bi++].values(0), 0, 0, len); + else TMP2 = LibSpoofPrimitives.dotProduct(a, b[bi++].values(0), ai, 0, len); + else if(_ABCount == 0) TMP2 = LibSpoofPrimitives.vectSum(a, ai, len); + + if(_ACount == 1) TMP2 *= b[bi].values(0)[rix]; + + if(_EinsumRewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A) c[ci] = TMP2; + else c[0] += TMP2; + } + + private void HARDCODEDgenexec_A_or_(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, + int len, long grix, int rix) { + double[] TMP1 = LibSpoofPrimitives.vectMultWrite(a, b[0].values(rix), ai, ai, len); + double TMP2 = LibSpoofPrimitives.dotProduct(TMP1, b[1].values(0), 0, ai, len); + TMP2 *= b[2].values(0)[rix]; + c[rix] = TMP2; + } + + protected void genexec(double[] avals, int[] aix, int ai, SideInput[] b, double[] scalars, double[] c, int ci, + int alen, int len, long grix, int rix) { + throw new RuntimeException("Sparse fused einsum not implemented"); + } + + // I am not sure if it is worth copying to LibMatrixMult so for now added it here + private static final VectorSpecies SPECIES = DoubleVector.SPECIES_PREFERRED; + private static final int vLen = SPECIES.length(); + + public static void vectMultiplyWrite(final double aval, double[] a, double[] b, double[] c, int ai, int bi, int ci, + final int len) { + final int bn = len % vLen; + + //rest, not aligned to vLen-blocks + for(int j = 0; j < bn; j++, ai++, bi++, ci++) + c[ci] = aval * b[bi] * a[ai]; + + //unrolled vLen-block (for better instruction-level parallelism) + DoubleVector avalVec = DoubleVector.broadcast(SPECIES, aval); + for(int j = bn; j < len; j += vLen, ai += vLen, bi += vLen, ci += vLen) { + DoubleVector aVec = DoubleVector.fromArray(SPECIES, a, ai); + DoubleVector bVec = DoubleVector.fromArray(SPECIES, b, bi); + avalVec.mul(bVec).mul(aVec).intoArray(c, ci); + } + } + + public static void vectMultiplyAdd(final double aval, double[] a, double[] b, double[] c, int ai, int bi, int ci, + final int len) { + final int bn = len % vLen; + + //rest, not aligned to vLen-blocks + for(int j = 0; j < bn; j++, ai++, bi++, ci++) + c[ci] += aval * b[bi] * a[ai]; + + //unrolled vLen-block (for better instruction-level parallelism) + DoubleVector avalVec = DoubleVector.broadcast(SPECIES, aval); + for(int j = bn; j < len; j += vLen, ai += vLen, bi += vLen, ci += vLen) { + DoubleVector aVec = DoubleVector.fromArray(SPECIES, a, ai); + DoubleVector bVec = DoubleVector.fromArray(SPECIES, b, bi); + DoubleVector cVec = DoubleVector.fromArray(SPECIES, c, ci); + DoubleVector tmp = aVec.mul(bVec); + tmp.fma(avalVec, cVec).intoArray(c, ci); + } + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java index c67dd290799..b8af61d35b4 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java @@ -19,7 +19,6 @@ package org.apache.sysds.runtime.instructions.cp; -import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Triple; import org.apache.commons.logging.Log; @@ -29,31 +28,33 @@ import org.apache.sysds.common.Types.DataType; import org.apache.sysds.hops.LiteralOp; import org.apache.sysds.hops.OptimizerUtils; -import org.apache.sysds.hops.codegen.SpoofCompiler; import org.apache.sysds.hops.codegen.cplan.CNode; -import org.apache.sysds.hops.codegen.cplan.CNodeBinary; import org.apache.sysds.hops.codegen.cplan.CNodeCell; import org.apache.sysds.hops.codegen.cplan.CNodeData; -import org.apache.sysds.hops.codegen.cplan.CNodeRow; import org.apache.sysds.runtime.codegen.*; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysds.runtime.einsum.EinsumContext; +import org.apache.sysds.runtime.einsum.*; +import org.apache.sysds.runtime.einsum.EOpNodeBinary.EBinaryOperand; import org.apache.sysds.runtime.functionobjects.*; -import org.apache.sysds.runtime.matrix.data.LibMatrixMult; import org.apache.sysds.runtime.matrix.data.MatrixBlock; -import org.apache.sysds.runtime.matrix.operators.AggregateOperator; -import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; -import org.apache.sysds.runtime.matrix.operators.BinaryOperator; import org.apache.sysds.runtime.matrix.operators.Operator; import org.apache.sysds.runtime.matrix.operators.ReorgOperator; -import org.apache.sysds.runtime.matrix.operators.SimpleOperator; +import org.apache.sysds.utils.Explain; import java.util.*; -import java.util.function.Predicate; + +import static org.apache.sysds.api.DMLScript.EXPLAIN; +import static org.apache.sysds.hops.rewrite.RewriteMatrixMultChainOptimization.mmChainDP; public class EinsumCPInstruction extends BuiltinNaryCPInstruction { - public static boolean FORCE_CELL_TPL = false; + public static final boolean FORCE_CELL_TPL = false; // naive looped solution + + public static final boolean FUSE_OUTER_MULTIPLY = true; + public static final boolean FUSE_OUTER_MULTIPLY_EXCEEDS_L2_CACHE_CHECK = true; + + public static final boolean PRINT_TRACE = true; + protected static final Log LOG = LogFactory.getLog(EinsumCPInstruction.class.getName()); public String eqStr; private final int _numThreads; @@ -62,15 +63,13 @@ public class EinsumCPInstruction extends BuiltinNaryCPInstruction { public EinsumCPInstruction(Operator op, String opcode, String istr, CPOperand out, CPOperand... inputs) { super(op, opcode, istr, out, inputs); - _numThreads = OptimizerUtils.getConstrainedNumThreads(-1); + _numThreads = OptimizerUtils.getConstrainedNumThreads(-1)/2; _in = inputs; this.eqStr = inputs[0].getName(); - Logger.getLogger(EinsumCPInstruction.class).setLevel(Level.TRACE); + if (PRINT_TRACE) + Logger.getLogger(EinsumCPInstruction.class).setLevel(Level.TRACE); } - @SuppressWarnings("unused") - private EinsumContext einc = null; - @Override public void processInstruction(ExecutionContext ec) { //get input matrices and scalars, incl pinning of matrices @@ -81,119 +80,180 @@ public void processInstruction(ExecutionContext ec) { if(mb instanceof CompressedMatrixBlock){ mb = ((CompressedMatrixBlock) mb).getUncompressed("Spoof instruction"); } + if(mb.getNumRows() == 1){ + ensureMatrixBlockColumnVector(mb); + } inputs.add(mb); } } EinsumContext einc = EinsumContext.getEinsumContext(eqStr, inputs); - this.einc = einc; String resultString = einc.outChar2 != null ? String.valueOf(einc.outChar1) + einc.outChar2 : einc.outChar1 != null ? String.valueOf(einc.outChar1) : ""; - if( LOG.isDebugEnabled() ) LOG.trace("outrows:"+einc.outRows+", outcols:"+einc.outCols); + if( LOG.isTraceEnabled() ) LOG.trace("output: "+resultString +" "+einc.outRows+"x"+einc.outCols); ArrayList inputsChars = einc.newEquationStringInputsSplit; if(LOG.isTraceEnabled()) LOG.trace(String.join(",",einc.newEquationStringInputsSplit)); - - contractDimensionsAndComputeDiagonals(einc, inputs); + ArrayList eOpNodes = new ArrayList<>(inputsChars.size()); + ArrayList eOpNodesScalars = new ArrayList<>(inputsChars.size()); // computed separately and not included into plan until it is already created //make all vetors col vectors for(int i = 0; i < inputs.size(); i++){ - if(inputs.get(i) != null && inputsChars.get(i).length() == 1) EnsureMatrixBlockColumnVector(inputs.get(i)); + if(inputsChars.get(i).length() == 1) ensureMatrixBlockColumnVector(inputs.get(i)); } - if(LOG.isTraceEnabled()) for(Character c : einc.characterAppearanceIndexes.keySet()){ - ArrayList a = einc.characterAppearanceIndexes.get(c); - LOG.trace(c+" count= "+a.size()); + addSumDimensionsDiagonalsAndScalars(einc, inputsChars, eOpNodes, eOpNodesScalars, einc.charToDimensionSize); + + HashMap characterToOccurences = einc.characterAppearanceCount; + + for (int i = 0; i < inputsChars.size(); i++) { + if (inputsChars.get(i) == null) continue; + Character c1 = inputsChars.get(i).isEmpty() ? null : inputsChars.get(i).charAt(0); + Character c2 = inputsChars.get(i).length() > 1 ? inputsChars.get(i).charAt(1) : null; + Integer dim1 = c1 == null ? null : einc.charToDimensionSize.get(c1); + Integer dim2 = c1 == null ? null : einc.charToDimensionSize.get(c2); + EOpNodeData n = new EOpNodeData(c1,c2,dim1,dim2, i); + eOpNodes.add(n); } - // compute scalar by suming-all matrices: - Double scalar = null; - for(int i=0;i< inputs.size(); i++){ - String s = inputsChars.get(i); - if(s.equals("")){ - MatrixBlock mb = inputs.get(i); - if (scalar == null) scalar = mb.get(0,0); - else scalar*= mb.get(0,0); - inputs.set(i,null); - inputsChars.set(i,null); + ArrayList ret = new ArrayList<>(); + addVectorMultiplies(eOpNodes, eOpNodesScalars,characterToOccurences, einc.outChar1, einc.outChar2, ret); + eOpNodes = ret; + + List plan; + ArrayList remainingMatrices; + + if(!FORCE_CELL_TPL) { + if(true){ + plan = generateGreedyPlan(eOpNodes, eOpNodesScalars, einc.charToDimensionSize, characterToOccurences, einc.outChar1, einc.outChar2); + }else { // old way: try to do fusion first and then rest in binary fashion cost based + List fuseOps; + do { + ret = new ArrayList<>(); + fuseOps = EOpNodeFuse.findFuseOps(eOpNodes, einc.outChar1, einc.outChar2, einc.charToDimensionSize, characterToOccurences, ret); + + if(!fuseOps.isEmpty()) { + for (EOpNodeFuse fuseOp : fuseOps) { + if (fuseOp.c1 == null) { + eOpNodesScalars.add(fuseOp); + continue; + } + ret.add(fuseOp); + } + eOpNodes = ret; + } + } while(eOpNodes.size() > 1 && !fuseOps.isEmpty()); + Pair> costAndPlan = generateBinaryPlanCostBased(0, eOpNodes, einc.charToDimensionSize, characterToOccurences, + einc.outChar1, einc.outChar2); + plan = costAndPlan.getRight(); + } + if(!eOpNodesScalars.isEmpty()){ + EOpNode l = eOpNodesScalars.get(0); + for(int i = 1; i < eOpNodesScalars.size(); i++){ + l = new EOpNodeBinary(l, eOpNodesScalars.get(i), EBinaryOperand.scalar_scalar); + } + + if(plan.isEmpty()) plan.add(l); + else { + int minCost = Integer.MAX_VALUE; + EOpNode addToNode = null; + int minIdx = -1; + for(int i = 0; i < plan.size(); i++) { + EOpNode n = plan.get(i); + Pair costAndNode = addScalarToPlanFindMinCost(n, einc.charToDimensionSize); + if(costAndNode.getLeft() < minCost) { + minCost = costAndNode.getLeft(); + addToNode = costAndNode.getRight(); + minIdx = i; + } + } + plan.set(minIdx, mergeEOpNodeWithScalar(addToNode, l)); + } + } - } - if (scalar != null) { - inputsChars.add(""); - inputs.add(new MatrixBlock(scalar)); - } + if(plan.size() == 2 && plan.get(0).c2 == null && plan.get(1).c2 == null){ + if (plan.get(0).c1 == einc.outChar1 && plan.get(1).c1 == einc.outChar2) + plan.set(0, new EOpNodeBinary(plan.get(0), plan.get(1), EBinaryOperand.A_B)); + if (plan.get(0).c1 == einc.outChar2 && plan.get(1).c1 == einc.outChar1) + plan.set(0, new EOpNodeBinary(plan.get(1), plan.get(0), EBinaryOperand.A_B)); + plan.remove(1); + } - HashMap characterToOccurences = new HashMap<>(); - for (Character key :einc.characterAppearanceIndexes.keySet()) { - characterToOccurences.put(key, einc.characterAppearanceIndexes.get(key).size()); - } - for (Character key :einc.charToDimensionSize.keySet()) { - if(!characterToOccurences.containsKey(key)) - characterToOccurences.put(key, 1); - } + if(plan.size() == 1) + plan.set(0,plan.get(0).reorderChildrenAndOptimize(null, einc.outChar1, einc.outChar2)); - ArrayList eOpNodes = new ArrayList<>(inputsChars.size()); - for (int i = 0; i < inputsChars.size(); i++) { - if (inputsChars.get(i) == null) continue; - EOpNodeData n = new EOpNodeData(inputsChars.get(i).length() > 0 ? inputsChars.get(i).charAt(0) : null, inputsChars.get(i).length() > 1 ? inputsChars.get(i).charAt(1) : null, i); - eOpNodes.add(n); + if (EXPLAIN != Explain.ExplainType.NONE ) { + System.out.println("Einsum plan:"); + for(int i = 0; i < plan.size(); i++) { + System.out.println((i + 1) + "."); + System.out.println("- " + String.join("\n- ", plan.get(i).recursivePrintString())); + } + } + + remainingMatrices = executePlan(plan, inputs); + }else{ + plan = eOpNodes; + remainingMatrices = inputs; } - Pair > plan = FORCE_CELL_TPL ? null : generatePlan(0, eOpNodes, einc.charToDimensionSize, characterToOccurences, einc.outChar1, einc.outChar2); - ArrayList resMatrices = FORCE_CELL_TPL ? null : executePlan(plan.getRight(), inputs); -// ArrayList resMatrices = executePlan(plan.getRight(), inputs, true); - if(!FORCE_CELL_TPL && resMatrices.size() == 1){ - EOpNode resNode = plan.getRight().get(0); + if(!FORCE_CELL_TPL && remainingMatrices.size() == 1){ + EOpNode resNode = plan.get(0); if (einc.outChar1 != null && einc.outChar2 != null){ if(resNode.c1 == einc.outChar1 && resNode.c2 == einc.outChar2){ - ec.setMatrixOutput(output.getName(), resMatrices.get(0)); + ec.setMatrixOutput(output.getName(), remainingMatrices.get(0)); } else if(resNode.c1 == einc.outChar2 && resNode.c2 == einc.outChar1){ + if( LOG.isTraceEnabled()) LOG.trace("Transposing the final result"); + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); - MatrixBlock resM = resMatrices.get(0).reorgOperations(transpose, new MatrixBlock(),0,0,0); + MatrixBlock resM = remainingMatrices.get(0).reorgOperations(transpose, new MatrixBlock(),0,0,0); ec.setMatrixOutput(output.getName(), resM); }else{ - if(LOG.isTraceEnabled()) LOG.trace("Einsum expected: "+resultString + ", got: "+resNode.c1+resNode.c2); - throw new RuntimeException("Einsum plan produced different result"); + if(LOG.isTraceEnabled()) LOG.trace("Einsum error, expected: "+resultString + ", got: "+resNode.c1+resNode.c2); + throw new RuntimeException("Einsum plan produced different result, expected: "+resultString + ", got: "+resNode.c1+resNode.c2); } }else if (einc.outChar1 != null){ if(resNode.c1 == einc.outChar1 && resNode.c2 == null){ - ec.setMatrixOutput(output.getName(), resMatrices.get(0)); + ensureMatrixBlockColumnVector(remainingMatrices.get(0)); + ec.setMatrixOutput(output.getName(), remainingMatrices.get(0)); }else{ if(LOG.isTraceEnabled()) LOG.trace("Einsum expected: "+resultString + ", got: "+resNode.c1+resNode.c2); throw new RuntimeException("Einsum plan produced different result"); } }else{ if(resNode.c1 == null && resNode.c2 == null){ - ec.setScalarOutput(output.getName(), new DoubleObject(resMatrices.get(0).get(0, 0)));; + ec.setScalarOutput(output.getName(), new DoubleObject(remainingMatrices.get(0).get(0, 0)));; } } }else{ // use cell template with loops for remaining - ArrayList mbs = resMatrices; + ArrayList mbs = remainingMatrices; ArrayList chars = new ArrayList<>(); - for (int i = 0; i < plan.getRight().size(); i++) { + for (int i = 0; i < plan.size(); i++) { String s; - if(plan.getRight().get(i).c1 == null) s = ""; - else if(plan.getRight().get(i).c2 == null) s = plan.getRight().get(i).c1.toString(); - else s = plan.getRight().get(i).c1.toString() + plan.getRight().get(i).c2; + if(plan.get(i).c1 == null) s = ""; + else if(plan.get(i).c2 == null) s = plan.get(i).c1.toString(); + else s = plan.get(i).c1.toString() + plan.get(i).c2; chars.add(s); } ArrayList summingChars = new ArrayList<>(); - for (Character c : einc.characterAppearanceIndexes.keySet()) { + for (Character c : characterToOccurences.keySet()) { if (c != einc.outChar1 && c != einc.outChar2) summingChars.add(c); } if(LOG.isTraceEnabled()) LOG.trace("finishing with cell tpl: "+String.join(",", chars)); MatrixBlock res = computeCellSummation(mbs, chars, resultString, einc.charToDimensionSize, summingChars, einc.outRows, einc.outCols); + if (einc.outChar2 == null) + ensureMatrixBlockColumnVector(res); + if (einc.outRows == 1 && einc.outCols == 1) ec.setScalarOutput(output.getName(), new DoubleObject(res.get(0, 0))); else ec.setMatrixOutput(output.getName(), res); @@ -204,102 +264,367 @@ else if(resNode.c1 == einc.outChar2 && resNode.c2 == einc.outChar1){ } - private void contractDimensionsAndComputeDiagonals(EinsumContext einc, ArrayList inputs) { - for(int i = 0; i< einc.contractDims.length; i++){ - //AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(),Types.CorrectionLocationType.LASTCOLUMN); - AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); - - if(einc.diagonalInputs[i]){ - ReorgOperator op = new ReorgOperator(DiagIndex.getDiagIndexFnObject()); - inputs.set(i, inputs.get(i).reorgOperations(op, new MatrixBlock(),0,0,0)); - } - if (einc.contractDims[i] == null) continue; - switch (einc.contractDims[i]){ - case CONTRACT_BOTH: { - AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), _numThreads); - MatrixBlock res = new MatrixBlock(1, 1, false); - inputs.get(i).aggregateUnaryOperations(aggun, res, 0, null); - inputs.set(i, res); - break; + private EOpNode mergeEOpNodeWithScalar(EOpNode addToNode, EOpNode scalar) { + if(addToNode instanceof EOpNodeFuse fuse) { + switch (fuse.einsumRewriteType) { + case AB_BA_B_A__A, AB_BA_B_A_AZ__Z -> { + fuse.addScalarAsIntermediate(scalar); + return fuse; } - case CONTRACT_RIGHT: { - AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads); - MatrixBlock res = new MatrixBlock(inputs.get(i).getNumRows(), 1, false); - inputs.get(i).aggregateUnaryOperations(aggun, res, 0, null); - inputs.set(i, res); - break; + }; + return new EOpNodeBinary(addToNode,scalar,EBinaryOperand.AB_scalar); + } + if(addToNode.c1 == null) + return new EOpNodeBinary(addToNode,scalar,EBinaryOperand.scalar_scalar); + if(addToNode.c2 == null) + return new EOpNodeBinary(addToNode,scalar,EBinaryOperand.A_scalar); + return new EOpNodeBinary(addToNode,scalar,EBinaryOperand.AB_scalar); + } + + private static Pair addScalarToPlanFindMinCost(EOpNode plan, HashMap charToSizeMap) { + int thisSize = 0; + if(plan.c1 != null) thisSize += charToSizeMap.get(plan.c1); + if(plan.c2 != null) thisSize += charToSizeMap.get(plan.c2); + int cost = thisSize; + + if (plan instanceof EOpNodeData || plan instanceof EOpNodeUnary) return Pair.of(thisSize, plan); + + List inputs = List.of(); + + if (plan instanceof EOpNodeBinary bin) inputs = List.of(bin.left, bin.right); + else if(plan instanceof EOpNodeFuse fuse){ + cost = switch (fuse.einsumRewriteType) { + case AB_BA_B_A__ -> 1; + case AB_BA_B_A__AB -> thisSize; + case AB_BA_A__B -> thisSize; + case AB_BA_B_A__A -> 2; // intermediate is scalar, 2 because if there is some real scalar + case AB_BA_B_A_AZ__Z -> 2; // intermediate is scalar + case AB_BA_A_AZ__BZ -> thisSize; + case AB_BA_A_AZ__ZB -> thisSize; + }; + inputs = fuse.getChildren(); + } + + for(EOpNode inp : inputs){ + Pair min = addScalarToPlanFindMinCost(inp, charToSizeMap); + if(min.getLeft() < cost){ + cost = min.getLeft(); + plan = min.getRight(); + } + } + return Pair.of(cost, plan); + } + + private static void addVectorMultiplies(ArrayList eOpNodes, ArrayList eOpNodesScalars,HashMap charToOccurences, Character outChar1, Character outChar2,ArrayList ret) { + HashMap> vectorCharacterToIndices = new HashMap<>(); + for (int i = 0; i < eOpNodes.size(); i++) { + if (eOpNodes.get(i).c2 == null) { + if (vectorCharacterToIndices.containsKey(eOpNodes.get(i).c1)) + vectorCharacterToIndices.get(eOpNodes.get(i).c1).add(eOpNodes.get(i)); + else + vectorCharacterToIndices.put(eOpNodes.get(i).c1, new ArrayList<>(Collections.singletonList(eOpNodes.get(i)))); + } + } + HashSet usedNodes = new HashSet<>(); + for(Character c : vectorCharacterToIndices.keySet()){ + ArrayList nodes = vectorCharacterToIndices.get(c); + + if(nodes.size()==1) continue; + EOpNode left = nodes.get(0); + usedNodes.add(left); + boolean canBeSummed = c != outChar1 && c != outChar2 && charToOccurences.get(c) == nodes.size(); + + for(int i = 1; i < nodes.size(); i++){ + EOpNode right = nodes.get(i); + + if(canBeSummed && i == nodes.size()-1){ + left = new EOpNodeBinary(left,right, EBinaryOperand.a_a); + }else { + left = new EOpNodeBinary(left,right, EBinaryOperand.A_A); } - case CONTRACT_LEFT: { - AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), _numThreads); - MatrixBlock res = new MatrixBlock(inputs.get(i).getNumColumns(), 1, false); - inputs.get(i).aggregateUnaryOperations(aggun, res, 0, null); - inputs.set(i, res); - break; + usedNodes.add(right); + } + if(canBeSummed) { + eOpNodesScalars.add(left); + charToOccurences.put(c, 0); + } + else { + ret.add(left); + charToOccurences.put(c, charToOccurences.get(c) - nodes.size() + 1); + } + } + for(EOpNode inp : eOpNodes){ + if(!usedNodes.contains(inp)) ret.add(inp); + } + } + + private void addSumDimensionsDiagonalsAndScalars(EinsumContext einc, ArrayList inputStrings, + ArrayList eOpNodes, ArrayList eOpNodesScalars, + HashMap charToDimensionSize) { + for(int i = 0; i< inputStrings.size(); i++){ + String s = inputStrings.get(i); + if (s.isEmpty()){ + eOpNodesScalars.add(new EOpNodeData(null, null, null, null,i)); + inputStrings.set(i, null); + continue; + }else if (s.length() == 1){ + char c1 = s.charAt(0); + if((einc.outChar1 == null || c1 != einc.outChar1) && (einc.outChar2 == null || c1 != einc.outChar2) && einc.characterAppearanceCount.get(c1) == 1){ + EOpNode e0 = new EOpNodeData(c1, null, charToDimensionSize.get(c1), null, i); + eOpNodesScalars.add(new EOpNodeUnary(null, null, null, null, e0, EOpNodeUnary.EUnaryOperand.SUM)); + inputStrings.set(i, null); } - default: - break; + continue; } + + char c1 = s.charAt(0); + char c2 = s.charAt(1); + Character newC1 = null; + EOpNodeUnary.EUnaryOperand op = null; + + if(c1 == c2){ + if((einc.outChar1 == null || c1 != einc.outChar1) && (einc.outChar2 == null || c1 != einc.outChar2) && einc.characterAppearanceCount.get(c1) == 2){ + op = EOpNodeUnary.EUnaryOperand.TRACE; + }else { + einc.characterAppearanceCount.put(c1, einc.characterAppearanceCount.get(c1) - 1); + op = EOpNodeUnary.EUnaryOperand.DIAG; + newC1 = c1; + } + }else if((einc.outChar1 == null || c1 != einc.outChar1) && (einc.outChar2 == null || c1 != einc.outChar2) && einc.characterAppearanceCount.get(c1) == 1){ + if ((einc.outChar1 == null || c2 != einc.outChar1) && (einc.outChar2 == null || c2 != einc.outChar2) && einc.characterAppearanceCount.get(c2) == 1){ + op = EOpNodeUnary.EUnaryOperand.SUM; + }else{ + newC1 = c2; + op = EOpNodeUnary.EUnaryOperand.SUM_COLS; + } + }else if((einc.outChar1 == null || c2 != einc.outChar1) && (einc.outChar2 == null || c2 != einc.outChar2) && einc.characterAppearanceCount.get(c2) == 1){ + newC1 = c1; + op = EOpNodeUnary.EUnaryOperand.SUM_ROWS; + } + + if(op == null) continue; + EOpNodeData e0 = new EOpNodeData(c1, c2, charToDimensionSize.get(c1), charToDimensionSize.get(c2), i); + Integer dim1 = newC1 == null ? null : charToDimensionSize.get(newC1); + EOpNodeUnary res = new EOpNodeUnary(newC1, null, dim1, null, e0, op); + + if(op == EOpNodeUnary.EUnaryOperand.SUM) eOpNodesScalars.add(res); + else eOpNodes.add(res); + + inputStrings.set(i, null); } } - private enum EBinaryOperand { // upper case: char has to remain, lower case: to be summed - ////// summations: ////// - aB_a,// -> B - Ba_a, // -> B - Ba_aC, // mmult -> BC - aB_Ca, - Ba_Ca, // -> BC - aB_aC, // outer mult, possibly with transposing first -> BC - a_a,// dot -> - - ////// elementwisemult and sums, something like ij,ij->i ////// - aB_aB,// elemwise and colsum -> B - Ba_Ba, // elemwise and rowsum ->B - Ba_aB, // elemwise, either colsum or rowsum -> B -// aB_Ba, - - ////// elementwise, no summations: ////// - A_A,// v-elemwise -> A - AB_AB,// M-M elemwise -> AB - AB_BA, // M-M.T elemwise -> AB - AB_A, // M-v colwise -> BA!? - BA_A, // M-v rowwise -> BA - ab_ab,//M-M sum all - ab_ba, //M-M.T sum all - ////// other ////// - A_B, // outer mult -> AB - A_scalar, // v-scalar - AB_scalar, // m-scalar - scalar_scalar + private static List generateGreedyPlan(ArrayList eOpNodes, + ArrayList eOpNodesScalars, HashMap charToSizeMap, HashMap charToOccurences, Character outChar1, Character outChar2) { + ArrayList ret; + int lastNumOfOperands = -1; + while(lastNumOfOperands != eOpNodes.size() && eOpNodes.size() > 1){ + lastNumOfOperands = eOpNodes.size(); + + List fuseOps; + do { + ret = new ArrayList<>(); + fuseOps = EOpNodeFuse.findFuseOps(eOpNodes, outChar1, outChar2, charToSizeMap, charToOccurences, ret); + + if(!fuseOps.isEmpty()) { + for (EOpNodeFuse fuseOp : fuseOps) { + if (fuseOp.c1 == null) { + eOpNodesScalars.add(fuseOp); + continue; + } + ret.add(fuseOp); + } + eOpNodes = ret; + } + } while(eOpNodes.size() > 1 && !fuseOps.isEmpty()); + + ret = new ArrayList<>(); + addVectorMultiplies(eOpNodes, eOpNodesScalars,charToOccurences, outChar1, outChar2, ret); + eOpNodes = ret; + + ret = new ArrayList<>(); + ArrayList> matrixMultiplies = findMatrixMultiplicationChains(eOpNodes, outChar1, outChar2, charToOccurences, + ret); + + for(List list : matrixMultiplies) { + EOpNodeBinary bin = optimizeMMChain(list, charToSizeMap); + ret.add(bin); + } + eOpNodes = ret; + } + + return eOpNodes; } - private abstract class EOpNode { - public Character c1; - public Character c2; // nullable - public EOpNode(Character c1, Character c2){ - this.c1 = c1; - this.c2 = c2; + + private static void reverseMMChainIfBeneficial(ArrayList mmChain){ // possibly check the cost instead of number of transposes + char c1 = mmChain.get(0).c1; + char c2 = mmChain.get(0).c2; + int noTransposes = 0; + for (int i=1; i (mmChain.size() / 2 )+1) { + Collections.reverse(mmChain); } } - private class EOpNodeBinary extends EOpNode { - public EOpNode left; - public EOpNode right; - public EBinaryOperand operand; - public EOpNodeBinary(Character c1, Character c2, EOpNode left, EOpNode right, EBinaryOperand operand){ - super(c1,c2); - this.left = left; - this.right = right; - this.operand = operand; + private static EOpNodeBinary optimizeMMChain(List mmChainL, HashMap charToSizeMap) { + ArrayList mmChain = new ArrayList<>(mmChainL); + reverseMMChainIfBeneficial(mmChain); + ArrayList> dimensions = new ArrayList<>(); + + for(int i = 0; i < mmChain.size()-1; i++){ + EOpNode n1 = mmChain.get(i); + EOpNode n2 = mmChain.get(i+1); + if(n1.c2 == n2.c1 || n1.c2 == n2.c2) dimensions.add(Pair.of(charToSizeMap.get(n1.c1), charToSizeMap.get(n1.c2))); + else dimensions.add(Pair.of(charToSizeMap.get(n1.c2), charToSizeMap.get(n1.c1))); // transpose this one } + EOpNode prelast = mmChain.get(mmChain.size()-2); + EOpNode last = mmChain.get(mmChain.size()-1); + if (last.c1 == prelast.c2 || last.c1 == prelast.c1) dimensions.add(Pair.of(charToSizeMap.get(last.c1), charToSizeMap.get(last.c2))); + else dimensions.add(Pair.of(charToSizeMap.get(last.c2), charToSizeMap.get(last.c1))); + + + double[] dimsArray = new double[mmChain.size() + 1]; + getDimsArray( dimensions, dimsArray ); + + int size = mmChain.size(); + int[][] splitMatrix = mmChainDP(dimsArray, mmChain.size()); + + return (EOpNodeBinary) getBinaryFromSplit(splitMatrix,0,size-1, mmChain); } - private class EOpNodeData extends EOpNode { - public int matrixIdx; - public EOpNodeData(Character c1, Character c2, int matrixIdx){ - super(c1,c2); - this.matrixIdx = matrixIdx; + + private static EOpNode getBinaryFromSplit(int[][] splitMatrix, int i, int j, List mmChain) { + if (i==j) return mmChain.get(i); + int split = splitMatrix[i][j]; + + EOpNode left = getBinaryFromSplit(splitMatrix,i,split,mmChain); + EOpNode right = getBinaryFromSplit(splitMatrix,split+1,j,mmChain); + return EOpNodeBinary.combineMatrixMultiply(left, right); + } + + private static void getDimsArray( ArrayList> chain, double[] dimsArray ) + { + for( int i = 0; i < chain.size(); i++ ) { + if (i == 0) { + dimsArray[i] = chain.get(i).getLeft(); + if (dimsArray[i] <= 0) { + throw new RuntimeException( + "EinsumCPInstruction::optimizeMMChain() : Invalid Matrix Dimension: "+ dimsArray[i]); + } + } + else if (chain.get(i - 1).getRight() != chain.get(i).getLeft()) { + throw new RuntimeException( + "EinsumCPInstruction::optimizeMMChain() : Matrix Dimension Mismatch: " + + chain.get(i - 1).getRight()+" != "+chain.get(i).getLeft()); + } + + dimsArray[i + 1] = chain.get(i).getRight(); + if( dimsArray[i + 1] <= 0 ) { + throw new RuntimeException( + "EinsumCPInstruction::optimizeMMChain() : Invalid Matrix Dimension: " + dimsArray[i + 1]); + } } } + private static ArrayList> findMatrixMultiplicationChains(ArrayList inpOperands, Character outChar1, Character outChar2, HashMap charToOccurences, + ArrayList ret) { + HashSet charactersThatCanBeContracted = new HashSet<>(); + HashMap> characterToNodes = new HashMap<>(); + ArrayList operandsTodo = new ArrayList<>(); + for(EOpNode op : inpOperands) { + if(op.c2 == null || op.c1 == null) continue; + + if (characterToNodes.containsKey(op.c1)) characterToNodes.get(op.c1).add(op); + else characterToNodes.put(op.c1, new ArrayList<>(Collections.singletonList(op))); + if (characterToNodes.containsKey(op.c2)) characterToNodes.get(op.c2).add(op); + else characterToNodes.put(op.c2, new ArrayList<>(Collections.singletonList(op))); + + boolean todo = false; + if (charToOccurences.get(op.c1) == 2 && op.c1 != outChar1 && op.c1 != outChar2) { + charactersThatCanBeContracted.add(op.c1); + todo = true; + } + if (charToOccurences.get(op.c2) == 2 && op.c2 != outChar1 && op.c2 != outChar2) { + charactersThatCanBeContracted.add(op.c2); + todo = true; + } + if (todo) operandsTodo.add(op); + } + ArrayList> res = new ArrayList<>(); + + HashSet doneNodes = new HashSet<>(); + + for(int i = 0; i < operandsTodo.size(); i++){ + EOpNode iterateNode = operandsTodo.get(i); + + if (doneNodes.contains(iterateNode)) continue; // was added previously + doneNodes.add(iterateNode); + + LinkedList multiplies = new LinkedList<>(); + multiplies.add(iterateNode); + + EOpNode nextNode = iterateNode; + Character nextC = iterateNode.c2; + // add to right using c2 + while(charactersThatCanBeContracted.contains(nextC)) { + EOpNode one = characterToNodes.get(nextC).get(0); + EOpNode two = characterToNodes.get(nextC).get(1); + if (nextNode == one){ + multiplies.addLast(two); + nextNode = two; + }else{ + multiplies.addLast(one); + nextNode = one; + } + if(nextNode.c1 == nextC) nextC = nextNode.c2; + else nextC = nextNode.c1; + doneNodes.add(nextNode); + } + + // add to left using c1 + nextNode = iterateNode; + nextC = iterateNode.c1; + while(charactersThatCanBeContracted.contains(nextC)) { + EOpNode one = characterToNodes.get(nextC).get(0); + EOpNode two = characterToNodes.get(nextC).get(1); + if (nextNode == one){ + multiplies.addFirst(two); + nextNode = two; + }else{ + multiplies.addFirst(one); + nextNode = one; + } + if(nextNode.c1 == nextC) nextC = nextNode.c2; + else nextC = nextNode.c1; + doneNodes.add(nextNode); + } + + res.add(multiplies); + } - private Pair /* ideally with one element */> generatePlan(int cost, ArrayList operands, HashMap charToSizeMap, HashMap charToOccurences, Character outChar1, Character outChar2) { + for(EOpNode op : inpOperands) { + if (doneNodes.contains(op)) continue; + ret.add(op); + } + + return res; + } + + // old way, DFS finds all paths + private Pair> generateBinaryPlanCostBased(int cost, ArrayList operands, HashMap charToSizeMap, HashMap charToOccurences, Character outChar1, Character outChar2) { Integer minCost = cost; List minNodes = operands; @@ -307,16 +632,15 @@ public EOpNodeData(Character c1, Character c2, int matrixIdx){ boolean swap = (operands.get(0).c2 == null && operands.get(1).c2 != null) || operands.get(0).c1 == null; EOpNode n1 = operands.get(!swap ? 0 : 1); EOpNode n2 = operands.get(!swap ? 1 : 0); - Triple> t = TryCombineAndCost(n1, n2, charToSizeMap, charToOccurences, outChar1, outChar2); + Triple> t = EOpNodeBinary.TryCombineAndCost(n1, n2, charToSizeMap, charToOccurences, outChar1, outChar2); if (t != null) { - EOpNodeBinary newNode = new EOpNodeBinary(t.getRight().getLeft(), t.getRight().getRight(), n1, n2, t.getMiddle()); + EOpNodeBinary newNode = new EOpNodeBinary(n1, n2, t.getMiddle()); int thisCost = cost + t.getLeft(); return Pair.of(thisCost, Arrays.asList(newNode)); } return Pair.of(cost, operands); } else if (operands.size() == 1){ - // check for transpose return Pair.of(cost, operands); } @@ -326,17 +650,15 @@ else if (operands.size() == 1){ EOpNode n1 = operands.get(!swap ? i : j); EOpNode n2 = operands.get(!swap ? j : i); - - Triple> t = TryCombineAndCost(n1, n2, charToSizeMap, charToOccurences, outChar1, outChar2); + Triple> t = EOpNodeBinary.TryCombineAndCost(n1, n2, charToSizeMap, charToOccurences, outChar1, outChar2); if (t != null){ - EOpNodeBinary newNode = new EOpNodeBinary(t.getRight().getLeft(), t.getRight().getRight(), n1, n2, t.getMiddle()); + EOpNodeBinary newNode = new EOpNodeBinary(n1, n2, t.getMiddle()); int thisCost = cost + t.getLeft(); if(n1.c1 != null) charToOccurences.put(n1.c1, charToOccurences.get(n1.c1)-1); if(n1.c2 != null) charToOccurences.put(n1.c2, charToOccurences.get(n1.c2)-1); if(n2.c1 != null) charToOccurences.put(n2.c1, charToOccurences.get(n2.c1)-1); if(n2.c2 != null) charToOccurences.put(n2.c2, charToOccurences.get(n2.c2)-1); - if(newNode.c1 != null) charToOccurences.put(newNode.c1, charToOccurences.get(newNode.c1)+1); if(newNode.c2 != null) charToOccurences.put(newNode.c2, charToOccurences.get(newNode.c2)+1); @@ -346,8 +668,8 @@ else if (operands.size() == 1){ } newOperands.add(newNode); - Pair> furtherPlan = generatePlan(thisCost, newOperands,charToSizeMap, charToOccurences, outChar1, outChar2); - if(furtherPlan.getRight().size() < (minNodes.size()) || furtherPlan.getLeft() < minCost){ + Pair> furtherPlan = generateBinaryPlanCostBased(thisCost, newOperands, charToSizeMap, charToOccurences, outChar1, outChar2); + if(furtherPlan.getRight().size() < minNodes.size() || furtherPlan.getLeft() < minCost){ minCost = furtherPlan.getLeft(); minNodes = furtherPlan.getRight(); } @@ -365,325 +687,28 @@ else if (operands.size() == 1){ return Pair.of(minCost, minNodes); } - private static Triple> TryCombineAndCost(EOpNode n1 , EOpNode n2, HashMap charToSizeMap, HashMap charToOccurences, Character outChar1, Character outChar2){ - Predicate cannotBeSummed = (c) -> - c == outChar1 || c == outChar2 || charToOccurences.get(c) > 2; - - if(n1.c1 == null) { - // n2.c1 also has to be null - return Triple.of(1, EBinaryOperand.scalar_scalar, Pair.of(null, null)); - } - - if(n2.c1 == null) { - if(n1.c2 == null) - return Triple.of(charToSizeMap.get(n1.c1), EBinaryOperand.A_scalar, Pair.of(n1.c1, null)); - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.AB_scalar, Pair.of(n1.c1, n1.c2)); - } - - if(n1.c1 == n2.c1){ - if(n1.c2 != null){ - if ( n1.c2 == n2.c2){ - if( cannotBeSummed.test(n1.c1)){ - if(cannotBeSummed.test(n1.c2)){ - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.AB_AB, Pair.of(n1.c1, n1.c2)); - } - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.Ba_Ba, Pair.of(n1.c1, null)); - } - - if(cannotBeSummed.test(n1.c2)){ - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.aB_aB, Pair.of(n1.c2, null)); - } - - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.ab_ab, Pair.of(null, null)); - - } - - else if(n2.c2 == null){ - if(cannotBeSummed.test(n1.c1)){ - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*2, EBinaryOperand.AB_A, Pair.of(n1.c1, n1.c2)); - } - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*2, EBinaryOperand.aB_a, Pair.of(n1.c2, null)); // in theory (null, n1.c2) - } - else if(n1.c1 ==outChar1 || n1.c1==outChar2|| charToOccurences.get(n1.c1) > 2){ - return null;// AB,AC - } - else { - return Triple.of((charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2))+(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c2)), EBinaryOperand.aB_aC, Pair.of(n1.c2, n2.c2)); // or n2.c2, n1.c2 - } - }else{ // n1.c2 = null -> c2.c2 = null - if(n1.c1 ==outChar1 || n1.c1==outChar2 || charToOccurences.get(n1.c1) > 2){ - return Triple.of(charToSizeMap.get(n1.c1), EBinaryOperand.A_A, Pair.of(n1.c1, null)); - } - return Triple.of(charToSizeMap.get(n1.c1), EBinaryOperand.a_a, Pair.of(null, null)); - } - - - }else{ // n1.c1 != n2.c1 - if(n1.c2 == null) { - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1), EBinaryOperand.A_B, Pair.of(n1.c1, n2.c1)); - } - else if(n2.c2 == null) { // ab,c - if (n1.c2 == n2.c1) { - if(cannotBeSummed.test(n1.c2)){ - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1), EBinaryOperand.BA_A, Pair.of(n1.c1, n1.c2)); - } - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1), EBinaryOperand.Ba_a, Pair.of(n1.c1, null)); - } - return null; // AB,C - } - else if (n1.c2 == n2.c1) { - if(n1.c1 == n2.c2){ // ab,ba - if(cannotBeSummed.test(n1.c1)){ - if(cannotBeSummed.test(n1.c2)){ - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.AB_BA, Pair.of(n1.c1, n1.c2)); - } - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.Ba_aB, Pair.of(n1.c1, null)); - } - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.ab_ba, Pair.of(null, null)); - } - if(cannotBeSummed.test(n1.c2)){ - return null; // AB_B - }else{ - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c2), EBinaryOperand.Ba_aC, Pair.of(n1.c1, n2.c2)); -// if(n1.c1 ==outChar1 || n1.c1==outChar2|| charToOccurences.get(n1.c1) > 2){ -// return null; // AB_B -// } -// return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.Ba_a, Pair.of(n1.c1, null)); - } - } - if(n1.c1 == n2.c2) { - if(cannotBeSummed.test(n1.c1)){ - return null; // AB_B - } - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c1), EBinaryOperand.aB_Ca, Pair.of(n2.c1, n1.c2)); // * its just reorder of mmult - } - else if (n1.c2 == n2.c2) { - if(n1.c2 ==outChar1 || n1.c2==outChar2|| charToOccurences.get(n1.c2) > 2){ - return null; // BA_CA - }else{ - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2) +(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c1)), EBinaryOperand.Ba_Ca, Pair.of(n1.c1, n2.c1)); // or n2.c1, n1.c1 - } - } - else { // we have something like ab,cd - return null; - } - } - } - - private ArrayList executePlan(List plan, ArrayList inputs){ - return executePlan(plan, inputs, false); - } - private ArrayList executePlan(List plan, ArrayList inputs, boolean codegen) { + private ArrayList executePlan(List plan, ArrayList inputs) { ArrayList res = new ArrayList<>(plan.size()); for(EOpNode p : plan){ - if(codegen) res.add(ComputeEOpNodeCodegen(p, inputs)); - else res.add(ComputeEOpNode(p, inputs)); - } - return res; - } - - private MatrixBlock ComputeEOpNode(EOpNode eOpNode, ArrayList inputs){ - if(eOpNode instanceof EOpNodeData eOpNodeData){ - return inputs.get(eOpNodeData.matrixIdx); - } - EOpNodeBinary bin = (EOpNodeBinary) eOpNode; - MatrixBlock left = ComputeEOpNode(bin.left, inputs); - MatrixBlock right = ComputeEOpNode(bin.right, inputs); - - AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); - - MatrixBlock res; - switch (bin.operand){ - case AB_AB -> { - res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); - } - case A_A -> { - EnsureMatrixBlockColumnVector(left); - EnsureMatrixBlockColumnVector(right); - res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); - } - case a_a -> { - EnsureMatrixBlockColumnVector(left); - EnsureMatrixBlockColumnVector(right); - res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); - AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), _numThreads); - res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); - } - //////////// - case Ba_Ba -> { - res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); - AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads); - res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); - } - case aB_aB -> { - res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); - AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), _numThreads); - res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); - EnsureMatrixBlockColumnVector(res); - } - case ab_ab -> { - res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); - AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), _numThreads); - res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); - } - case ab_ba -> { - ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); - right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); - res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); - AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), _numThreads); - res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); - } - case Ba_aB -> { - ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); - right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); - res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); - AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads); - res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); - } - - ///////// - case AB_BA -> { - ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); - right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); - res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); - } - case Ba_aC -> { - res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), _numThreads); - } - case aB_Ca -> { - res = LibMatrixMult.matrixMult(right,left, new MatrixBlock(), _numThreads); - } - case Ba_Ca -> { - ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); - right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); - res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), _numThreads); - } - case aB_aC -> { - ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); - left = left.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); - res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), _numThreads); - } - case A_scalar, AB_scalar -> { - res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left},new ScalarObject[]{new DoubleObject(right.get(0,0))}, new MatrixBlock()); - } - case BA_A -> { - EnsureMatrixBlockRowVector(right); - res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); - } - case Ba_a -> { - EnsureMatrixBlockRowVector(right); - res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); - AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads); - res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); - } - - case AB_A -> { - EnsureMatrixBlockColumnVector(right); - res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); - } - case aB_a -> { - EnsureMatrixBlockColumnVector(right); - res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); - AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), _numThreads); - res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); - EnsureMatrixBlockColumnVector(res); - } - - case A_B -> { - EnsureMatrixBlockColumnVector(left); - EnsureMatrixBlockRowVector(right); - res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); - } - case scalar_scalar -> { - return new MatrixBlock(left.get(0,0)*right.get(0,0)); - } - default -> { - throw new IllegalArgumentException("Unexpected value: " + bin.operand.toString()); - } - + res.add(p.computeEOpNode(inputs, _numThreads, LOG)); } return res; } - private static MatrixBlock ComputeEOpNodeCodegen(EOpNode eOpNode, ArrayList inputs){ - return rComputeEOpNodeCodegen(eOpNode, inputs); -// throw new NotImplementedException(); - } - private static CNodeData MatrixBlockToCNodeData(MatrixBlock mb, int id){ - return new CNodeData("ce"+id, id, mb.getNumRows(), mb.getNumColumns(), DataType.MATRIX); - } - private static MatrixBlock rComputeEOpNodeCodegen(EOpNode eOpNode, ArrayList inputs) { - if (eOpNode instanceof EOpNodeData eOpNodeData){ - return inputs.get(eOpNodeData.matrixIdx); -// return new CNodeData("ce"+eOpNodeData.matrixIdx, eOpNodeData.matrixIdx, inputs.get(eOpNodeData.matrixIdx).getNumRows(), inputs.get(eOpNodeData.matrixIdx).getNumColumns(), DataType.MATRIX); - } - - EOpNodeBinary bin = (EOpNodeBinary) eOpNode; -// CNodeData dataLeft = null; -// if (bin.left instanceof EOpNodeData eOpNodeData) dataLeft = new CNodeData("ce"+eOpNodeData.matrixIdx, eOpNodeData.matrixIdx, inputs.get(eOpNodeData.matrixIdx).getNumRows(), inputs.get(eOpNodeData.matrixIdx).getNumColumns(), DataType.MATRIX); -// CNodeData dataRight = null; -// if (bin.right instanceof EOpNodeData eOpNodeData) dataRight = new CNodeData("ce"+eOpNodeData.matrixIdx, eOpNodeData.matrixIdx, inputs.get(eOpNodeData.matrixIdx).getNumRows(), inputs.get(eOpNodeData.matrixIdx).getNumColumns(), DataType.MATRIX); - - if(bin.operand == EBinaryOperand.AB_AB){ - if (bin.right instanceof EOpNodeBinary rBinary && rBinary.operand == EBinaryOperand.AB_AB){ - MatrixBlock left = rComputeEOpNodeCodegen(bin.left, inputs); - - MatrixBlock right1 = rComputeEOpNodeCodegen(((EOpNodeBinary) bin.right).left, inputs); - MatrixBlock right2 = rComputeEOpNodeCodegen(((EOpNodeBinary) bin.right).right, inputs); - - CNodeData d0 = MatrixBlockToCNodeData(left, 0); - CNodeData d1 = MatrixBlockToCNodeData(right1, 1); - CNodeData d2 = MatrixBlockToCNodeData(right2, 2); -// CNodeNary nary = new CNodeNary(cnodeIn, CNodeNary.NaryType.) - CNodeBinary rightBinary = new CNodeBinary(d1, d2, CNodeBinary.BinType.VECT_MULT); - CNodeBinary cNodeBinary = new CNodeBinary(d0, rightBinary, CNodeBinary.BinType.VECT_MULT); - ArrayList cnodeIn = new ArrayList<>(); - cnodeIn.add(d0); - cnodeIn.add(d1); - cnodeIn.add(d2); - - CNodeRow cnode = new CNodeRow(cnodeIn, cNodeBinary); - - cnode.setRowType(SpoofRowwise.RowType.NO_AGG); - cnode.renameInputs(); - - - String src = cnode.codegen(false, SpoofCompiler.GeneratorAPI.JAVA); - if( LOG.isTraceEnabled()) LOG.trace(CodegenUtils.printWithLineNumber(src)); - Class cla = CodegenUtils.compileClass("codegen." + cnode.getClassname(), src); - - SpoofOperator op = CodegenUtils.createInstance(cla); - MatrixBlock mb = new MatrixBlock(); - - ArrayList scalars = new ArrayList<>(); - ArrayList mbs = new ArrayList<>(3); - mbs.add(left); - mbs.add(right1); - mbs.add(right2); - MatrixBlock out = op.execute(mbs, scalars, mb, 6); - - return out; - } - } - - throw new NotImplementedException(); - } - - private void releaseMatrixInputs(ExecutionContext ec){ for (CPOperand input : _in) if(input.getDataType()==DataType.MATRIX) ec.releaseMatrixInput(input.getName()); //todo release other } - private static void EnsureMatrixBlockColumnVector(MatrixBlock mb){ + public static void ensureMatrixBlockColumnVector(MatrixBlock mb){ if(mb.getNumColumns() > 1){ mb.setNumRows(mb.getNumColumns()); mb.setNumColumns(1); mb.getDenseBlock().resetNoFill(mb.getNumRows(),1); } } - private static void EnsureMatrixBlockRowVector(MatrixBlock mb){ + public static void ensureMatrixBlockRowVector(MatrixBlock mb){ if(mb.getNumRows() > 1){ mb.setNumColumns(mb.getNumRows()); mb.setNumRows(1); @@ -699,9 +724,9 @@ private static void indent(StringBuilder sb, int level) { private MatrixBlock computeCellSummation(ArrayList inputs, List inputsChars, String resultString, HashMap charToDimensionSizeInt, List summingChars, int outRows, int outCols){ - ArrayList cnodeIn = new ArrayList<>(); - cnodeIn.add(new CNodeData(new LiteralOp(3), 0, 0, DataType.SCALAR)); - CNodeCell cnode = new CNodeCell(cnodeIn, null); + ArrayList dummyIn = new ArrayList<>(); + dummyIn.add(new CNodeData(new LiteralOp(0), 0, 0, DataType.SCALAR)); + CNodeCell cnode = new CNodeCell(dummyIn, null); StringBuilder sb = new StringBuilder(); int indent = 2; @@ -784,7 +809,7 @@ else if (summingChars.contains(inputsChars.get(i).charAt(1))) { sb.append(itVar0); sb.append(inputsChars.get(i).charAt(1)); sb.append(")"); - } else if (resultString.length() >= 1 &&inputsChars.get(i).charAt(1) == resultString.charAt(0)) { + } else if (resultString.length() >= 1 && inputsChars.get(i).charAt(1) == resultString.charAt(0)) { sb.append("rix)"); } else if (resultString.length() == 2 && inputsChars.get(i).charAt(1) == resultString.charAt(1)) { sb.append("cix)"); @@ -806,7 +831,7 @@ else if (summingChars.contains(inputsChars.get(i).charAt(1))) { indent--; sb.append("}\n"); } - String src = CNodeCell.JAVA_TEMPLATE;// + String src = CNodeCell.JAVA_TEMPLATE; src = src.replace("%TMP%", cnode.createVarname()); src = src.replace("%TYPE%", "NO_AGG"); src = src.replace("%SPARSE_SAFE%", "false"); diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java index 5753fbbadbe..52d03725fc8 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java @@ -3982,6 +3982,25 @@ public static void vectMultiplyWrite( double[] a, double[] b, double[] c, int ai aVec.mul(bVec).intoArray(c, ci); } } + + //note: public for use by codegen for consistency + public static void vectMultiplyAdd( double[] a, double[] b, double[] c, int ai, int bi, int ci, final int len ){ + final int bn = len%vLen; + + //rest, not aligned to vLen-blocks + for( int j = 0; j < bn; j++, ai++, bi++, ci++) + c[ ci ] += a[ ai ] * b[ bi ]; + + //unrolled vLen-block (for better instruction-level parallelism) + for( int j = bn; j < len; j+=vLen, ai+=vLen, bi+=vLen, ci+=vLen) + { + DoubleVector aVec = DoubleVector.fromArray(SPECIES, a, ai); + DoubleVector bVec = DoubleVector.fromArray(SPECIES, b, bi); + DoubleVector cVec = DoubleVector.fromArray(SPECIES, c, ci); + cVec = aVec.fma(bVec, cVec); + cVec.intoArray(c, ci); + } + } public static void vectMultiplyWrite( final double[] a, double[] b, double[] c, int[] bix, final int ai, final int bi, final int ci, final int len ) { final int bn = len%8; diff --git a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java index dbf9047968f..97e1eb83bf9 100644 --- a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java +++ b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java @@ -37,65 +37,109 @@ import java.io.IOException; import java.nio.file.Files; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.List; +import java.util.Map; @RunWith(Parameterized.class) public class EinsumTest extends AutomatedTestBase { - final private static List TEST_CONFIGS = List.of( - new Config("ij,jk->ik", List.of(shape(50, 600), shape(600, 10))), // mm - new Config("ji,jk->ik", List.of(shape(600, 5), shape(600, 10))), - new Config("ji,kj->ik", List.of(shape(600, 5), shape(10, 600))), - new Config("ij,kj->ik", List.of(shape(5, 600), shape(10, 600))), - - new Config("ji,jk->i", List.of(shape(600, 5), shape(600, 10))), - new Config("ij,jk->i", List.of(shape(5, 600), shape(600, 10))), - - new Config("ji,jk->k", List.of(shape(600, 5), shape(600, 10))), - new Config("ij,jk->k", List.of(shape(5, 600), shape(600, 10))), - - new Config("ji,jk->j", List.of(shape(600, 5), shape(600, 10))), - - new Config("ji,ji->ji", List.of(shape(600, 5), shape(600, 5))), // elemwise mult - new Config("ji,ji,ji->ji", List.of(shape(600, 5),shape(600, 5), shape(600, 5)), - List.of(0.0001, 0.0005, 0.001)), - new Config("ji,ij->ji", List.of(shape(600, 5), shape(5, 600))), // elemwise mult - - - new Config("ij,i->ij", List.of(shape(100, 50), shape(100))), // col mult - new Config("ji,i->ij", List.of(shape(50, 100), shape(100))), // row mult - new Config("ij,i->i", List.of(shape(100, 50), shape(100))), - new Config("ij,i->j", List.of(shape(100, 50), shape(100))), - - new Config("i,i->", List.of(shape(50), shape(50))), - new Config("i,j->", List.of(shape(50), shape(80))), - new Config("i,j->ij", List.of(shape(50), shape(80))), // outer vect mult - new Config("i,j->ji", List.of(shape(50), shape(80))), // outer vect mult - - new Config("ij->", List.of(shape(100, 50))), // sum - new Config("ij->i", List.of(shape(100, 50))), // sum(1) - new Config("ij->j", List.of(shape(100, 50))), // sum(0) - new Config("ij->ji", List.of(shape(100, 50))), // T - - new Config("ab,cd->ba", List.of(shape( 600, 10), shape(6, 5))), - new Config("ab,cd,g->ba", List.of(shape( 600, 10), shape(6, 5), shape(3))), - - new Config("ab,bc,cd,de->ae", List.of(shape(5, 600), shape(600, 10), shape(10, 5), shape(5, 4))), // chain of mm - - new Config("ji,jz,zx->ix", List.of(shape(600, 5), shape( 600, 10), shape(10, 2))), - new Config("fx,fg,fz,xg->z", List.of(shape(600, 5), shape( 600, 10), shape(600, 6), shape(5, 10))), - new Config("fx,fg,fz,xg,zx,zg->g", // each idx 3 times (cell tpl) - List.of(shape(5, 60), shape(5, 30), shape(5, 10), shape(60, 30), shape(10, 60), shape(10, 30))), - - new Config("i->", List.of(shape(100))), - new Config("i->i", List.of(shape(100))) - ); - + final private static List TEST_CONFIGS = List.of( + new Config("ij,jk->ik", List.of(shape(5, 6), shape(6, 5))), // mm + new Config("ji,jk->ik", List.of(shape(6, 5), shape(6, 10))), + new Config("ji,kj->ik", List.of(shape(6, 5), shape(10, 6))), + new Config("ij,kj->ik", List.of(shape(5, 6), shape(10, 6))), + new Config("ij,jk->ki", List.of(shape(5, 6), shape(6, 5))), // mm t + new Config("ji,jk->ki", List.of(shape(6, 5), shape(6, 10))), + new Config("ji,kj->ki", List.of(shape(6, 5), shape(10, 6))), + new Config("ij,kj->ki", List.of(shape(5, 6), shape(10, 6))), + new Config("ij,kp,pj->ki", List.of(shape(5,6), shape(5,4), shape(4, 6))), // reordering + new Config("ab,bc,cd,de->ae", List.of(shape(5, 6), shape(6, 5),shape(5, 6), shape(6, 5))), // mm chain + new Config("de,cd,bc,ab->ae", List.of(shape(6, 5), shape(5, 6),shape(6, 5), shape(5, 6))), // mm chain + new Config("ab,cb,de,cd->ae", List.of(shape(5, 6), shape(5,6), shape(6, 5),shape(5, 6))), // mm chain + + new Config("ji,jk->i", List.of(shape(6, 5), shape(6, 4))), + new Config("ij,jk->i", List.of(shape(5, 6), shape(6, 4))), + new Config("ji,jk->k", List.of(shape(6, 5), shape(6, 4))), + new Config("ij,jk->k", List.of(shape(5, 6), shape(6, 4))), + new Config("ji,jk->j", List.of(shape(6, 5), shape(6, 4))), + + new Config("ji,ji->ji", List.of(shape(60, 5), shape(60, 5))), // elemwise mult + new Config("ji,ji->j", List.of(shape(60, 5), shape(60, 5))), + new Config("ji,ji->i", List.of(shape(60, 5), shape(60, 5))), + new Config("ji,ij->ji", List.of(shape(60, 5), shape(5, 60))), // elemwise mult + new Config("ji,ij->i", List.of(shape(60, 5), shape(5, 60))), + new Config("ji,ij->j", List.of(shape(60, 5), shape(5, 60))), + + new Config("ij,i->ij", List.of(shape(10, 5), shape(10))), // col mult + new Config("ji,i->ij", List.of(shape(5, 10), shape(10))), // row mult + new Config("ij,i->i", List.of(shape(10, 5), shape(10))), + new Config("ij,i->j", List.of(shape(10, 5), shape(10))), + + new Config("i,i->", List.of(shape(5), shape(5))), // dot + new Config("i,j->", List.of(shape(5), shape(80))), // sum + new Config("i,j->ij", List.of(shape(5), shape(80))), // outer vect mult + new Config("i,j->ji", List.of(shape(5), shape(80))), // outer vect mult + + new Config("ij->", List.of(shape(10, 5))), // sum + new Config("i->", List.of(shape(10))), // sum + new Config("ij->i", List.of(shape(10, 5))), // sum(1) + new Config("ij->j", List.of(shape(10, 5))), // sum(0) + new Config("ij->ji", List.of(shape(10, 5))), // T + new Config("ij->ij", List.of(shape(10, 5))), + new Config("i->i", List.of(shape(10))), + new Config("ii->i", List.of(shape(10, 10))), // Diag + new Config("ii->", List.of(shape(10, 10))), // Trace + new Config("ii,i->i", List.of(shape(10, 10),shape(10))), // Diag*vec + + new Config("ab,cd->ba", List.of(shape( 6, 10), shape(6, 5))), // sum cd to scalar and multiply ab + + new Config("fx,fg,fz,xg,zx,zg->g", // each idx 3 times (the cell tpl fallback) + List.of(shape(5, 6), shape(5, 3), shape(5, 10), shape(6, 3), shape(10, 6), shape(10, 3))), + + // test fused: + new Config("ij,ij,ji->ij", List.of(shape(10, 5), shape(10, 5), shape(5, 10))), + new Config("ij,ij,ji,i,j->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))), + new Config("ij,ij,ji,i->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10))), + new Config("ij,ij,ji,j->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(5))), + + new Config("ij,ij,ji->i", List.of(shape(10, 5), shape(10, 5), shape(5, 10))), + new Config("ij,ij,ji,i,j->i", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))), + new Config("ij,ij,ji,i->i", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10))), + new Config("ij,ij,ji,j->i", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(5))), + + new Config("ij,ij,ji->j", List.of(shape(10, 5), shape(10, 5), shape(5, 10))), + new Config("ij,ij,ji,i,j->j", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))), + new Config("ij,ij,ji,i->j", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10))), + new Config("ij,ij,ji,j->j", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(5))), + + new Config("ij,ij,ji->", List.of(shape(10, 5), shape(10, 5), shape(5, 10))), + new Config("ij,ij,ji,i,j->", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))), + new Config("ij,ij,ji,i->", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10))), + new Config("ij,ij,ji,j->", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(5))), + + new Config("ij,ij,ij,i,j,iz->z", List.of(shape(5, 5), shape(5, 5), shape(5, 5),shape(5),shape(5),shape(5, 6))), + new Config("ij,ij,ij,i,j,iz,z->z", List.of(shape(5, 5), shape(5, 5), shape(5, 5),shape(5),shape(5),shape(5, 6),shape(6))), + + new Config("ij,i,j,iz->z", List.of(shape(10, 5),shape(10),shape(5),shape(10, 51))), + new Config("ij,i,j,iz,iz->z", List.of(shape(10, 5),shape(10),shape(5),shape(10, 51),shape(10, 51))), + new Config("ij,i,j,iz,iz->z", List.of(shape(10, 5),shape(10),shape(5),shape(10, 4),shape(10, 4))), // order swapped because sizeof(iz) < sizeof(ij), but should still produce the same tmpl + new Config("ij,i,j,iz->z", List.of(shape(20, 10),shape(20),shape(10),shape(20, 10))), + + new Config("ij,ij,ji,j,i, ab,ba,ab,a,b->jb", Map.of('i',10, 'j',5, 'a', 11, 'b', 6)), + new Config("ij,ij,ji,j,i,iz->jz", Map.of('i',10, 'j',10,'z', 10)), // // no skinny right + // takes 2 mins each due to R package being super slow so commented out: +// new Config("ij,ij,i,j,iz->jz", Map.of('i',100000, 'j',50,'z', 40), List.of(0.0000001,0.000002, 0.000003,0.000004,0.000004)), // skinny right with outer mm +// new Config("ij,ij,i,iz->jz", Map.of('i',100000, 'j',50,'z', 40), List.of(0.0000001,0.000002, 0.000003,0.000004)), // skinny right with outer mm +// new Config("ij,j,i,iz->zj", Map.of('i',100000, 'j',50,'z', 40), List.of(0.0000001,0.000002, 0.000003,0.000004)), // skinny right with outer mm + new Config("ab,ab,a,ag,gz->bz", List.of(shape(10, 5), shape(10, 5),shape(10),shape(10,200),shape(200,7))) + ,new Config("ab,ab,a,ag,gz->bz", List.of(shape(10, 5), shape(10, 5),shape(10),shape(10,20),shape(20,7))) + ,new Config("ab,ab,bc,bc->bc", List.of(shape(10, 5), shape(10, 5),shape(5,20),shape(5,20))) + ); private final int id; private final String einsumStr; - //private final List shapes; private final File dmlFile; private final File rFile; private final boolean outputScalar; @@ -103,7 +147,6 @@ public class EinsumTest extends AutomatedTestBase public EinsumTest(String einsumStr, List shapes, File dmlFile, File rFile, boolean outputScalar, int id){ this.id = id; this.einsumStr = einsumStr; - //this.shapes = shapes; this.dmlFile = dmlFile; this.rFile = rFile; this.outputScalar = outputScalar; @@ -116,7 +159,6 @@ public static Collection data() throws IOException { int counter = 1; for (Config config : TEST_CONFIGS) { - //List files = new ArrayList<>(); String fullDMLScriptName = "SystemDS_einsum_test" + counter; File dmlFile = File.createTempFile(fullDMLScriptName, ".dml"); @@ -153,12 +195,12 @@ private static StringBuilder createDmlFile(Config config, boolean outputScalar) sb.append("A"); sb.append(i); - if (dims.length == 1) { // A1 = seq(1,1000) * 0.0001 + if (dims.length == 1) { // e.g. A1 = seq(1,100) * 0.0001 sb.append(" = seq(1,"); sb.append(dims[0]); sb.append(") * "); sb.append(factor); - } else { // A0 = matrix(seq(1,50000), 1000, 50) * 0.0001 + } else { // e.g. A0 = matrix(seq(1,5000), 100, 5) * 0.0001 sb.append(" = matrix(seq(1, "); sb.append(dims[0]*dims[1]); sb.append("), "); @@ -172,7 +214,6 @@ private static StringBuilder createDmlFile(Config config, boolean outputScalar) sb.append("\n"); } sb.append("\n"); - sb.append("R = einsum(\""); sb.append(config.einsumStr); sb.append("\", "); @@ -202,17 +243,17 @@ private static StringBuilder createRFile(Config config, boolean outputScalar) { for (int i = 0; i < config.shapes.size(); i++) { int[] dims = config.shapes.get(i); - + double factor = config.factors != null ? config.factors.get(i) : 0.0001; sb.append("A"); sb.append(i); - if (dims.length == 1) { // A1 = seq(1,1000) * 0.0001 + if (dims.length == 1) { // e.g. A1 = seq(1,100) * 0.0001 sb.append(" = seq(1,"); sb.append(dims[0]); sb.append(") * "); sb.append(factor); - } else { // A0 = matrix(seq(1,50000), 1000, 50, byrow=TRUE) * 0.0001 + } else { // e.g. A0 = matrix(seq(1,5000), 100, 5, byrow=TRUE) * 0.0001 sb.append(" = matrix(seq(1, "); sb.append(dims[0]*dims[1]); sb.append("), "); @@ -252,7 +293,7 @@ private static StringBuilder createRFile(Config config, boolean outputScalar) { @Test public void testEinsumWithFiles() { System.out.println("Testing einsum: " + this.einsumStr); - testCodegenIntegration(TEST_NAME_EINSUM+this.id); + test(TEST_NAME_EINSUM+this.id); } @After public void cleanUp() { @@ -271,19 +312,36 @@ public void cleanUp() { } private static class Config { - public List factors; + public List factors; String einsumStr; List shapes; - Config(String einsum, List shapes) { - this.einsumStr = einsum; - this.shapes = shapes; - this.factors = null; - } + Config(String einsum, List shapes) { + this(einsum,shapes,null); + } + Config(String einsum, Map charToSize){ + this(einsum, charToSize, null); + } + + Config(String einsum, Map charToSize, List factors) { + this.einsumStr = einsum; + String leftPart = einsum.split("->")[0]; + List shapes = new ArrayList<>(); + for(String op : Arrays.stream(leftPart.split(",")).map(x->x.trim()).toList()){ + if (op.length() == 1) { + shapes.add(new int[]{charToSize.get(op.charAt(0))}); + }else{ + shapes.add(new int[]{charToSize.get(op.charAt(0)),charToSize.get(op.charAt(1))}); + } + + } + this.shapes = shapes; + this.factors = factors; + } Config(String einsum, List shapes, List factors) { this.einsumStr = einsum; this.shapes = shapes; - this.factors = factors; + this.factors = factors; } } @@ -295,7 +353,7 @@ private static int[] shape(int... dims) { private static final String TEST_NAME_EINSUM = "einsum"; private static final String TEST_DIR = "functions/einsum/"; private static final String TEST_CLASS_DIR = TEST_DIR + EinsumTest.class.getSimpleName() + "/"; - private final static String TEST_CONF = "SystemDS-config-codegen.xml"; + private final static String TEST_CONF = "SystemDS-config-einsum.xml"; private final static File TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, TEST_CONF); private static double eps = Math.pow(10, -10); @@ -307,7 +365,7 @@ public void setUp() { addTestConfiguration( TEST_NAME_EINSUM+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_EINSUM+i, new String[] { String.valueOf(i) }) ); } - private void testCodegenIntegration( String testname) + private void test(String testname) { boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; ExecMode platformOld = setExecMode(ExecType.CP); @@ -329,16 +387,17 @@ private void testCodegenIntegration( String testname) runTest(true, false, null, -1); runRScript(true); + HashMap dmlfile; + HashMap rfile; if(outputScalar){ - HashMap dmlfile = readDMLScalarFromOutputDir("S"); - HashMap rfile = readRScalarFromExpectedDir("S"); - TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + dmlfile = readDMLScalarFromOutputDir("S"); + rfile = readRScalarFromExpectedDir("S"); }else { //compare matrices - HashMap dmlfile = readDMLMatrixFromOutputDir("S"); - HashMap rfile = readRMatrixFromExpectedDir("S"); - TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + dmlfile = readDMLMatrixFromOutputDir("S"); + rfile = readRMatrixFromExpectedDir("S"); } + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); } finally { resetExecMode(platformOld); diff --git a/src/test/scripts/functions/einsum/SystemDS-config-codegen.xml b/src/test/scripts/functions/einsum/SystemDS-config-einsum.xml similarity index 83% rename from src/test/scripts/functions/einsum/SystemDS-config-codegen.xml rename to src/test/scripts/functions/einsum/SystemDS-config-einsum.xml index 626b31ebd76..f6640593c42 100644 --- a/src/test/scripts/functions/einsum/SystemDS-config-codegen.xml +++ b/src/test/scripts/functions/einsum/SystemDS-config-einsum.xml @@ -23,9 +23,6 @@ 2 true 1 - - - 16 - - auto - \ No newline at end of file + 16 + auto +