From a465d3a6250ef4352666c904e11b35765911dc4b Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Sun, 21 Sep 2025 00:20:09 +0200 Subject: [PATCH 01/13] saving work before my ssd dies --- .../apache/sysds/runtime/einsum/EOpNode.java | 19 + .../sysds/runtime/einsum/EOpNodeBinary.java | 43 ++ .../sysds/runtime/einsum/EOpNodeData.java | 9 + .../runtime/einsum/EOpNodeEinsumFuse.java | 311 +++++++++ .../sysds/runtime/einsum/EOpNodeFused.java | 8 + .../einsum/EinsumEquationValidator.java | 4 + .../instructions/cp/EinsumCPInstruction.java | 624 +++++++++++++++--- 7 files changed, 940 insertions(+), 78 deletions(-) create mode 100644 src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java create mode 100644 src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java create mode 100644 src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java create mode 100644 src/main/java/org/apache/sysds/runtime/einsum/EOpNodeEinsumFuse.java create mode 100644 src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFused.java diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java new file mode 100644 index 00000000000..c93527f8ca5 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java @@ -0,0 +1,19 @@ +package org.apache.sysds.runtime.einsum; + +public abstract class EOpNode { + public Character c1; + public Character c2; // nullable + public EOpNode(Character c1, Character c2){ + this.c1 = c1; + this.c2 = c2; + } + + @Override + public String toString() { + if(c1 == null) return "-"; + + if(c2 == null) return c1.toString(); + return c1.toString() + c2.toString(); + } +} + diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java new file mode 100644 index 00000000000..244dee347dc --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java @@ -0,0 +1,43 @@ +package org.apache.sysds.runtime.einsum; + +public class EOpNodeBinary extends EOpNode { + public enum EBinaryOperand { // upper case: char has to remain, lower case: to be summed + ////// summations: ////// + aB_a,// -> B + Ba_a, // -> B + Ba_aC, // mmult -> BC + aB_Ca, + Ba_Ca, // -> BC + aB_aC, // outer mult, possibly with transposing first -> BC + a_a,// dot -> + + ////// elementwisemult and sums, something like ij,ij->i ////// + aB_aB,// elemwise and colsum -> B + Ba_Ba, // elemwise and rowsum ->B + Ba_aB, // elemwise, either colsum or rowsum -> B + aB_Ba, + + ////// elementwise, no summations: ////// + A_A,// v-elemwise -> A + AB_AB,// M-M elemwise -> AB + AB_BA, // M-M.T elemwise -> AB + AB_A, // M-v colwise -> BA!? + BA_A, // M-v rowwise -> BA + ab_ab,//M-M sum all + ab_ba, //M-M.T sum all + ////// other ////// + A_B, // outer mult -> AB + A_scalar, // v-scalar + AB_scalar, // m-scalar + scalar_scalar + } + public EOpNode left; + public EOpNode right; + public EBinaryOperand operand; + public EOpNodeBinary(Character c1, Character c2, EOpNode left, EOpNode right, EBinaryOperand operand){ + super(c1,c2); + this.left = left; + this.right = right; + this.operand = operand; + } +} diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java new file mode 100644 index 00000000000..e7b75236eda --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java @@ -0,0 +1,9 @@ +package org.apache.sysds.runtime.einsum; + +public class EOpNodeData extends EOpNode { + public int matrixIdx; + public EOpNodeData(Character c1, Character c2, int matrixIdx){ + super(c1,c2); + this.matrixIdx = matrixIdx; + } +} diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeEinsumFuse.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeEinsumFuse.java new file mode 100644 index 00000000000..76be0b56b7d --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeEinsumFuse.java @@ -0,0 +1,311 @@ +package org.apache.sysds.runtime.einsum; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; + +public class EOpNodeEinsumFuse extends EOpNode { + public static final int AB_index=0; + public static final int BA_index=1; + public static final int B_index=2; + public static final int XB_index=3; + public static final int BX_index=4; + public static final int A_index=5; + public static final int XA_index=6; + public static final int AX_index=7; + public static final int AZ_index=8; + public enum EinsumRewriteType{ + // inputops__output 'X' = simplySumDim + AB_BA_B_XB_BX_A_XA_AX__AB, + AB_BA_B_XB_BX_A_XA_AX__B, + AB_BA_B_XB_BX_A_XA_AX__A, + AB_BA_B_XB_BX_A_XA_AX__, + + AB_BA_B_XB_BX_A_XA_AX_AZ__Z + } + public enum EinsumRewriteType_v2{ // option 2 without X dims + AB_BA_B_A__AB, + AB_BA_B_A__B, + AB_BA_B_A__A, + + AB_BA_B_A_AZ__Z + } + public final EinsumRewriteType einsumRewriteType; + public final List> operands; + + private EOpNodeEinsumFuse(Character c1, Character c2, EinsumRewriteType einsumRewriteType, List... operands) { + super(c1,c2); + this.einsumRewriteType = einsumRewriteType; + this.operands = Arrays.asList(operands); + } + + public static EOpNodeEinsumFuse match(ArrayList operands, Character outChar1, Character outChar2,/*, Set simplySummableChars,*/ ArrayList ret, HashMap charToOccurences){ + //precompute + HashSet matricesChars = new HashSet<>(); + HashMap> charsToMatrices = new HashMap<>(); + HashMap charsToNumberOfOperands = new HashMap<>(); + + for (EOpNode operand1 : operands) { + String k; +//todo remove and use input map charToOccurences + if (charsToNumberOfOperands.containsKey(operand1.c1)) { + charsToNumberOfOperands.put(operand1.c1, charsToNumberOfOperands.get(operand1.c1) + 1); + } else { + charsToNumberOfOperands.put(operand1.c1, 1); + } + + if (operand1.c2 != null) { + k = operand1.c1.toString() + operand1.c2; + matricesChars.add(k); + if (charsToNumberOfOperands.containsKey(operand1.c2)) { + charsToNumberOfOperands.put(operand1.c2, charsToNumberOfOperands.get(operand1.c2) + 1); + } else { + charsToNumberOfOperands.put(operand1.c2, 1); + } + } else { + k = operand1.c1.toString(); + } + + if (charsToMatrices.containsKey(k)) { + charsToMatrices.get(k).add(operand1); + } else { + ArrayList matrices = new ArrayList<>(); + matrices.add(operand1); + charsToMatrices.put(k, matrices); + } + } + + ArrayList AXs = new ArrayList<>(); + ArrayList XAs = new ArrayList<>(); + ArrayList BXs = new ArrayList<>(); + ArrayList XBs = new ArrayList<>(); + ArrayList AZs = new ArrayList<>(); + boolean pass = false; + + String AB = null; + String BA = null; + boolean doSumA=false; + boolean doSumB=false; + + for (String ABcandidate : matricesChars) { + char a = ABcandidate.charAt(0); + char b = ABcandidate.charAt(1); + BA = "" + b + a; + + AXs = new ArrayList<>(); + XAs = new ArrayList<>(); + BXs = new ArrayList<>(); + XBs = new ArrayList<>(); + AZs = new ArrayList<>(); + + pass=true; + + + for (String chars : charsToMatrices.keySet()) { + if (chars.equals(ABcandidate) || chars.equals(BA)) { +// ABsCounter++; + continue; + } + + if(chars.length()==1){ + if(chars.charAt(0)==a){ +// AsCounter++; + }else if(chars.charAt(0)==b){ +// BsCounter++; + } + continue; + //always ok + }else{ + if(a==chars.charAt(1) && b==chars.charAt(0)){ +// ABsCounter++; + //BA + continue; + } + if(chars.charAt(0)==a){ + if(charsToNumberOfOperands.get(chars.charAt(1))==1){ + if(chars.charAt(1)!= outChar1 && chars.charAt(1) != outChar2) { + AXs.addAll(charsToMatrices.get(chars)); +// AsCounter++; + continue; + }else{ + if(AZs.size()==0){ + AZs = charsToMatrices.get(chars); + continue; + } + pass = false; + break; + } + }else{ + //dont allow for now, in theory AZ,Z or AZ,AZ would also work, but for now do them separately + pass = false; + break; + } + } + else if(chars.charAt(0)==b){ + if(charsToNumberOfOperands.get(chars.charAt(1))==1){ + if(chars.charAt(1)!= outChar1 && chars.charAt(1) != outChar2) { + BXs.addAll(charsToMatrices.get(chars)); +// BsCounter++; + continue; + }else{ + pass = false; // no BZ, maybe experiment later + break; + } + }else{ + pass = false; + break; + } + } + else if(chars.charAt(1)==a){ + if(charsToNumberOfOperands.get(chars.charAt(0))==1){ + if(chars.charAt(0)!= outChar1 && chars.charAt(0) != outChar2) { + XAs.addAll(charsToMatrices.get(chars)); +// AsCounter++; + continue; + }else{ + pass = false; + break; + } + }else{ + pass = false; + break; + } + } + else if(chars.charAt(1)==b){ + if(charsToNumberOfOperands.get(chars.charAt(0))==1){ + if(chars.charAt(0)!= outChar1 && chars.charAt(0) != outChar2) { + XBs.addAll(charsToMatrices.get(chars)); +// BsCounter++; + continue; + }else{ + pass = false; + break; + } + }else{ + pass = false; + break; + } + } + } + } + if(pass){ + AB = ABcandidate; + String A = ""+a; + String B = ""+b; + int ABsCounter = charsToMatrices.get(ABcandidate).size()+(charsToMatrices.containsKey(BA) ? charsToMatrices.get(BA).size() : 0); + int AsCounter = (charsToMatrices.containsKey(A) ? charsToMatrices.get(A).size() : 0) +AXs.size()+XAs.size(); + int BsCounter = (charsToMatrices.containsKey(B) ? charsToMatrices.get(B).size() : 0)+BXs.size()+XBs.size(); + if(AsCounter==0 && BsCounter==0 && ABsCounter<2){ + pass=false; + continue; + } + int usedAsCount = AsCounter+ABsCounter; + int usedBsCount = BsCounter+ABsCounter; + doSumA = charToOccurences.get(a)==usedAsCount && (outChar1 == null || a!=outChar1) && (outChar2 == null || a!=outChar2); + doSumB = charToOccurences.get(b)==usedBsCount && (outChar1 == null || b!=outChar1) && (outChar2 == null || b!=outChar2); + if(AZs.size()!=0) { // invalidate AZ fusion + if (outChar1 != null) { + if (a == outChar1 || b == outChar1) { + pass=false; + continue; + } + } + if (outChar2 != null) { + if (a == outChar2 || b == outChar2) { + pass=false; + continue; + } + } + if(!doSumA || !doSumB){ + pass=false; + continue; + } + } + break; + } + } + + if(!pass){ + return null; + } + String B = AB.substring(1,2); + String A = AB.substring(0,1); + Character c1 = null; + Character c2 = null; + EinsumRewriteType t; + + if(AZs.size()!=0){ + c1=AZs.get(0).c2; + t=EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX_AZ__Z; + } + else if(doSumA){ + if(doSumB) { + t = EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX__; + } + else { + t = EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX__B; + c1 = AB.charAt(1); + } + } + else if(doSumB){ + t= EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX__A; + c1= AB.charAt(0); + } + else { + t = EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX__AB; + c1 = AB.charAt(0); + c2 = AB.charAt(1); + } + if(c1 != null){ + charToOccurences.put(c1, charToOccurences.get(c1)+1); + } + if(c2 != null){ + charToOccurences.put(c2, charToOccurences.get(c2)+1); + } + HashSet usedOperands = new HashSet<>(); + + ArrayList ABs=charsToMatrices.containsKey(AB) ? charsToMatrices.get(AB) : new ArrayList<>(); + ArrayList BAs=charsToMatrices.containsKey(BA) ? charsToMatrices.get(BA) : new ArrayList<>(); + ArrayList Bs=charsToMatrices.containsKey(B) ? charsToMatrices.get(B) : new ArrayList<>(); + ArrayList As=charsToMatrices.containsKey(A) ? charsToMatrices.get(A) : new ArrayList<>(); + + usedOperands.addAll(ABs); + usedOperands.addAll(BAs); + usedOperands.addAll(Bs); + usedOperands.addAll(As); + usedOperands.addAll(XBs); + usedOperands.addAll(BXs); + usedOperands.addAll(XAs); + usedOperands.addAll(AXs); + usedOperands.addAll(AZs); + + for(EOpNode n : operands){ + if(!usedOperands.contains(n)){ + ret.add(n); + }else{ + if(charToOccurences != null){ + charToOccurences.put(n.c1, charToOccurences.get(n.c1)-1); + if(charToOccurences.get(n.c2)!= null) + charToOccurences.put(n.c2, charToOccurences.get(n.c2)-1); + } + } + } + + var e = new EOpNodeEinsumFuse(c1, c2, t, + ABs, + BAs, + Bs, + XBs, + BXs, + As, + XAs, + AXs, + AZs + ); + ret.add(e); + return e; + } +} + diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFused.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFused.java new file mode 100644 index 00000000000..4554b9c8334 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFused.java @@ -0,0 +1,8 @@ +package org.apache.sysds.runtime.einsum; + +public class EOpNodeFused extends EOpNode { + public EOpNodeFused(Character c1, Character c2){ + super(c1,c2); + + } +} diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java b/src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java index 5643159ef9a..7fdce50d3ba 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java @@ -79,9 +79,13 @@ public static Triple a = einc.characterAppearanceIndexes.get(c); LOG.trace(c+" count= "+a.size()); } +// var simplySummableChars = einc.characterAppearanceIndexes.entrySet() +// .stream() +// .filter(e -> e.getValue().size() == 1) +// .map(Map.Entry::getKey) +// .collect(Collectors.toSet()); // compute scalar by suming-all matrices: Double scalar = null; @@ -141,7 +153,7 @@ public void processInstruction(ExecutionContext ec) { EOpNodeData n = new EOpNodeData(inputsChars.get(i).length() > 0 ? inputsChars.get(i).charAt(0) : null, inputsChars.get(i).length() > 1 ? inputsChars.get(i).charAt(1) : null, i); eOpNodes.add(n); } - Pair > plan = FORCE_CELL_TPL ? null : generatePlan(0, eOpNodes, einc.charToDimensionSize, characterToOccurences, einc.outChar1, einc.outChar2); + Pair > plan = FORCE_CELL_TPL ? null : generatePlan(0, eOpNodes, einc.charToDimensionSize, characterToOccurences, einc.outChar1, einc.outChar2, FUSED); ArrayList resMatrices = FORCE_CELL_TPL ? null : executePlan(plan.getRight(), inputs); @@ -242,64 +254,7 @@ private void contractDimensionsAndComputeDiagonals(EinsumContext einc, ArrayList } } - private enum EBinaryOperand { // upper case: char has to remain, lower case: to be summed - ////// summations: ////// - aB_a,// -> B - Ba_a, // -> B - Ba_aC, // mmult -> BC - aB_Ca, - Ba_Ca, // -> BC - aB_aC, // outer mult, possibly with transposing first -> BC - a_a,// dot -> - - ////// elementwisemult and sums, something like ij,ij->i ////// - aB_aB,// elemwise and colsum -> B - Ba_Ba, // elemwise and rowsum ->B - Ba_aB, // elemwise, either colsum or rowsum -> B -// aB_Ba, - - ////// elementwise, no summations: ////// - A_A,// v-elemwise -> A - AB_AB,// M-M elemwise -> AB - AB_BA, // M-M.T elemwise -> AB - AB_A, // M-v colwise -> BA!? - BA_A, // M-v rowwise -> BA - ab_ab,//M-M sum all - ab_ba, //M-M.T sum all - ////// other ////// - A_B, // outer mult -> AB - A_scalar, // v-scalar - AB_scalar, // m-scalar - scalar_scalar - } - private abstract class EOpNode { - public Character c1; - public Character c2; // nullable - public EOpNode(Character c1, Character c2){ - this.c1 = c1; - this.c2 = c2; - } - } - private class EOpNodeBinary extends EOpNode { - public EOpNode left; - public EOpNode right; - public EBinaryOperand operand; - public EOpNodeBinary(Character c1, Character c2, EOpNode left, EOpNode right, EBinaryOperand operand){ - super(c1,c2); - this.left = left; - this.right = right; - this.operand = operand; - } - } - private class EOpNodeData extends EOpNode { - public int matrixIdx; - public EOpNodeData(Character c1, Character c2, int matrixIdx){ - super(c1,c2); - this.matrixIdx = matrixIdx; - } - } - - private Pair /* ideally with one element */> generatePlan(int cost, ArrayList operands, HashMap charToSizeMap, HashMap charToOccurences, Character outChar1, Character outChar2) { + private Pair /* ideally with one element */> generatePlan(int cost, ArrayList operands, HashMap charToSizeMap, HashMap charToOccurences, Character outChar1, Character outChar2, boolean fused) { Integer minCost = cost; List minNodes = operands; @@ -320,6 +275,22 @@ else if (operands.size() == 1){ return Pair.of(cost, operands); } + if(fused){ + ArrayList ret = new ArrayList<>(); + EOpNodeEinsumFuse fuse = EOpNodeEinsumFuse.match(operands,outChar1,outChar2,ret, charToOccurences); + if(fuse != null){ + minNodes = operands = ret; + } + while(ret.size() > 2 && fuse!=null){ + ret = new ArrayList<>(); + fuse = EOpNodeEinsumFuse.match(operands,outChar1,outChar2,ret, charToOccurences); + if(fuse != null){ + operands= ret; + operands.add(fuse); + } + } + } + for(int i = 0; i < operands.size()-1; i++){ for (int j = i+1; j < operands.size(); j++){ boolean swap = (operands.get(i).c2 == null && operands.get(j).c2 != null) || operands.get(i).c1 == null; @@ -346,7 +317,7 @@ else if (operands.size() == 1){ } newOperands.add(newNode); - Pair> furtherPlan = generatePlan(thisCost, newOperands,charToSizeMap, charToOccurences, outChar1, outChar2); + Pair> furtherPlan = generatePlan(thisCost, newOperands,charToSizeMap, charToOccurences, outChar1, outChar2, fused); if(furtherPlan.getRight().size() < (minNodes.size()) || furtherPlan.getLeft() < minCost){ minCost = furtherPlan.getLeft(); minNodes = furtherPlan.getRight(); @@ -439,6 +410,9 @@ else if (n1.c2 == n2.c1) { } return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.Ba_aB, Pair.of(n1.c1, null)); } + if(cannotBeSummed.test(n1.c2)){ + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.aB_Ba, Pair.of(n1.c2, null)); + } return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.ab_ba, Pair.of(null, null)); } if(cannotBeSummed.test(n1.c2)){ @@ -455,7 +429,8 @@ else if (n1.c2 == n2.c1) { if(cannotBeSummed.test(n1.c1)){ return null; // AB_B } - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c1), EBinaryOperand.aB_Ca, Pair.of(n2.c1, n1.c2)); // * its just reorder of mmult + return null;//todo remove. +// return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c1), EBinaryOperand.aB_Ca, Pair.of(n2.c1, n1.c2)); // * its just reorder of mmult } else if (n1.c2 == n2.c2) { if(n1.c2 ==outChar1 || n1.c2==outChar2|| charToOccurences.get(n1.c2) > 2){ @@ -470,14 +445,23 @@ else if (n1.c2 == n2.c2) { } } - private ArrayList executePlan(List plan, ArrayList inputs){ - return executePlan(plan, inputs, false); - } - private ArrayList executePlan(List plan, ArrayList inputs, boolean codegen) { +// private ArrayList executePlan(List plan, ArrayList inputs){ +// return executePlan(plan, inputs); +// } + private ArrayList executePlan(List plan, ArrayList inputs) { ArrayList res = new ArrayList<>(plan.size()); for(EOpNode p : plan){ - if(codegen) res.add(ComputeEOpNodeCodegen(p, inputs)); - else res.add(ComputeEOpNode(p, inputs)); + /// ////////////// ^^^^^^^^^^^^^^^^^^^^^ //////////////// +// if((true) && plan.size()== 1 && plan.get(0) instanceof EOpNodeBinary bin1 && bin1.operand == EBinaryOperand.Ba_Ca && bin1.right instanceof EOpNodeBinary r && r.operand == EBinaryOperand.Ba_Ca){ +// res.add(ComputeEOpNodeCodegen(p, inputs)); +//// var other = ComputeEOpNode(p, inputs); +//// ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); +//// var other2 = other.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); +// continue; +// } +// if(codegen) res.add(ComputeEOpNodeCodegen(p, inputs)); +// else + res.add(ComputeEOpNode(p, inputs)); } return res; } @@ -485,7 +469,10 @@ else if (n1.c2 == n2.c2) { private MatrixBlock ComputeEOpNode(EOpNode eOpNode, ArrayList inputs){ if(eOpNode instanceof EOpNodeData eOpNodeData){ return inputs.get(eOpNodeData.matrixIdx); - } + }else if(eOpNode instanceof EOpNodeEinsumFuse eOpNodeEinsumFuse){ + var mbs = eOpNodeEinsumFuse.operands.stream().map(l->l.stream().map(n->ComputeEOpNode(n, inputs)).collect(Collectors.toList())).toList(); + return ComputeEOpNodeFuse(eOpNodeEinsumFuse, mbs); + } EOpNodeBinary bin = (EOpNodeBinary) eOpNode; MatrixBlock left = ComputeEOpNode(bin.left, inputs); MatrixBlock right = ComputeEOpNode(bin.right, inputs); @@ -493,6 +480,9 @@ private MatrixBlock ComputeEOpNode(EOpNode eOpNode, ArrayList input AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); MatrixBlock res; + + LOG.trace("computing binary "+bin.left+","+bin.right+"->"+bin); + switch (bin.operand){ case AB_AB -> { res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); @@ -540,6 +530,13 @@ private MatrixBlock ComputeEOpNode(EOpNode eOpNode, ArrayList input AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads); res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); } + case aB_Ba -> { + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); + left = left.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads); + res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); + } ///////// case AB_BA -> { @@ -605,14 +602,432 @@ private MatrixBlock ComputeEOpNode(EOpNode eOpNode, ArrayList input return res; } - private static MatrixBlock ComputeEOpNodeCodegen(EOpNode eOpNode, ArrayList inputs){ - return rComputeEOpNodeCodegen(eOpNode, inputs); -// throw new NotImplementedException(); - } + private MatrixBlock ComputeEOpNodeFuse(EOpNodeEinsumFuse eOpNodeEinsumFuse, List> mbs) { + //prepare matrices + //EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX__AB;\ + boolean isResultAB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX__AB; + boolean isResultA = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX__A; + boolean isResultB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX__B; + boolean isResult_ = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX__; + boolean isResultZ = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX_AZ__Z; + List ABs = mbs.get(0), BAs = mbs.get(1), Bs = mbs.get(2), XBs = mbs.get(3), BXs = mbs.get(4), As = mbs.get(5), XAs = mbs.get(6), AXs = mbs.get(7); + List AZs = isResultZ ? mbs.get(8) : null; + int bSize = einc.charToDimensionSize.get(eOpNodeEinsumFuse.operands.get(0).get(0).c2); + int aSize = einc.charToDimensionSize.get(eOpNodeEinsumFuse.operands.get(0).get(0).c1); + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); + for(MatrixBlock mb: BAs){//BA->AB + ABs.add(mb.reorgOperations(transpose, null,0,0,0)); + } + AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), _numThreads); + for(MatrixBlock mb: XBs){//XB->B + MatrixBlock res = new MatrixBlock(mb.getNumColumns(), 1, false); + Bs.add((MatrixBlock)mb.aggregateUnaryOperations(aggun, res, 0, null)); + } + for(MatrixBlock mb: XAs){//XA->A + MatrixBlock res = new MatrixBlock(mb.getNumColumns(), 1, false); + As.add((MatrixBlock)mb.aggregateUnaryOperations(aggun, res, 0, null)); + } + aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads); + for(MatrixBlock mb: BXs){//BX->B + MatrixBlock res = new MatrixBlock(mb.getNumRows(), 1, false); + As.add((MatrixBlock)mb.aggregateUnaryOperations(aggun, res, 0, null)); + } + for(MatrixBlock mb: AXs){//AX->B // todo remove all X + MatrixBlock res = new MatrixBlock(mb.getNumRows(), 1, false); + As.add((MatrixBlock)mb.aggregateUnaryOperations(aggun, res, 0, null)); + } + if(As.size()>1){ + MatrixBlock mb = As.get(0); + for(int i=1;i cnodeIn = new ArrayList<>(); + cnodeIn.add(new CNodeData(new LiteralOp(3), 0, 0, DataType.SCALAR)); + CNodeRow cnode = new CNodeRow(cnodeIn, null); + String src = +// "import org.apache.sysds.runtime.matrix.data.LibMatrixMult;\n"+ +// CNodeRow.JAVA_TEMPLATE; +// src= + "package codegen;\n" + +"import org.apache.sysds.runtime.matrix.data.LibMatrixMult;\n" + + "import org.apache.sysds.runtime.codegen.LibSpoofPrimitives;\n" + + "import org.apache.sysds.runtime.codegen.SpoofOperator.SideInput;\n" + + "import org.apache.sysds.runtime.codegen.SpoofRowwise;\n" + + "import org.apache.sysds.runtime.codegen.SpoofRowwise.RowType;\n" + + "import org.apache.sysds.runtime.data.SparseRowVector;\n" + + "import org.apache.commons.math3.util.FastMath;\n" + + "\n" + + "public final class %TMP% extends SpoofRowwise { \n" + + " public %TMP%() {\n" + + " super(RowType.%TYPE%, %CONST_DIM2%, %TB1%, %VECT_MEM%);\n" + + " }\n" + + " protected void genexec(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, long grix, int rix) { \n" + + "%BODY_dense%" + + " }\n" + + " protected void genexec(double[] avals, int[] aix, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int alen, int len, long grix, int rix) { \n" + + "%BODY_sparse%" + + " }\n" + + "}\n"; + StringBuilder body = new StringBuilder(); + int midx = -1; + boolean assignedArow = false; + String varnameArow = null; + + String className = cnode.createVarname(); + IDSequence _seqVar = new IDSequence(); + _seqVar.getNextID(); + Supplier getNewVarname = () -> "TMP" + _seqVar.getNextID(); + + boolean writeToC = (isResultAB || isResultB) && As.isEmpty() && Bs.isEmpty(); + + if(ABs.size()>1){ + body.append("// multiplying ABs: \n"); + assignedArow=true; + varnameArow = getNewVarname.get(); + for(int i=1;i1){ + body.append("// multiplying Bs: \n"); + + assignedB=true; + varnameB = getNewVarname.get(); + body.append("double[] "); + for(int i=1;i { + src = src.replace("%TYPE%", "NO_AGG"); + } + case AB_BA_B_XB_BX_A_XA_AX__B -> { + src = src.replace("%TYPE%", "COL_AGG_T"); + } + case AB_BA_B_XB_BX_A_XA_AX__A -> { + src = src.replace("%TYPE%", "ROW_AGG"); + } + case AB_BA_B_XB_BX_A_XA_AX__ -> { + src = src.replace("%TYPE%", "FULL_AGG"); + } + case AB_BA_B_XB_BX_A_XA_AX_AZ__Z -> { + src = src.replace("%TYPE%", "COL_AGG_T"); + } + } + src = src.replace("%CONST_DIM2%", "-1"); + src = src.replace("%TB1%", "false"); + src = src.replace("%VECT_MEM%", "0"); + src = src.replace("%BODY_dense%", body); + src = src.replace("%BODY_sparse%", "throw new RuntimeException(\"Sparse einsum not implemented\");"); + + if( LOG.isTraceEnabled()) LOG.trace(CodegenUtils.printWithLineNumber(src)); + Class cla = CodegenUtils.compileClass("codegen." + cnode.getClassname(), src); + SpoofOperator op = CodegenUtils.createInstance(cla); +// SpoofOperator op = CodegenUtils.createInstance(AAA1.class); +// MatrixBlock resBlock = new MatrixBlock(); + +// resBlock.reset(einc.charToDimensionSize.get(eOpNodeEinsumFuse.c1), +// eOpNodeEinsumFuse.c2 == null?1:einc.charToDimensionSize.get(eOpNodeEinsumFuse.c2)); + ArrayList inputs = new ArrayList<>(); +// inputs.add(resBlock); + + inputs.addAll(ABs); + inputs.addAll(Bs); + inputs.addAll(As); + MatrixBlock out = op.execute(inputs, new ArrayList<>(), new MatrixBlock(), _numThreads); + + return out; + } + +// private MatrixBlock ComputeEOpNodeCodegen(EOpNode eOpNode, ArrayList inputs){ +// return rComputeEOpNodeCodegen(eOpNode, inputs); +//// throw new NotImplementedException(); +// } private static CNodeData MatrixBlockToCNodeData(MatrixBlock mb, int id){ return new CNodeData("ce"+id, id, mb.getNumRows(), mb.getNumColumns(), DataType.MATRIX); } - private static MatrixBlock rComputeEOpNodeCodegen(EOpNode eOpNode, ArrayList inputs) { + /* + private MatrixBlock rComputeEOpNodeCodegen(EOpNode eOpNode, ArrayList inputs) { if (eOpNode instanceof EOpNodeData eOpNodeData){ return inputs.get(eOpNodeData.matrixIdx); // return new CNodeData("ce"+eOpNodeData.matrixIdx, eOpNodeData.matrixIdx, inputs.get(eOpNodeData.matrixIdx).getNumRows(), inputs.get(eOpNodeData.matrixIdx).getNumColumns(), DataType.MATRIX); @@ -623,7 +1038,60 @@ private static MatrixBlock rComputeEOpNodeCodegen(EOpNode eOpNode, ArrayList cla = TMP1.class; + +// Class cla = CodegenUtils.compileClass("codegen." + "TMP1", src); + long end = System.currentTimeMillis(); + long duration = end - start; // duration in milliseconds + System.out.println("Time taken: " + duration + " ms"); + + SpoofOperator op = CodegenUtils.createInstance(cla); + MatrixBlock mb = new MatrixBlock(); + mb.reset(einc.outRows, einc.outCols , false); + mb.allocateDenseBlock(); + ArrayList scalars = new ArrayList<>(); + ArrayList mbs = new ArrayList<>(3); + mbs.add(inputs.get(((EOpNodeData) r.left).matrixIdx)); + mbs.add(inputs.get(((EOpNodeData) r.right).matrixIdx)); + mbs.add(inputs.get(((EOpNodeData) bin.left).matrixIdx)); + MatrixBlock out = op.execute(mbs, scalars, mb, _numThreads); + var tmp = bin.c1; + bin.c1=bin.c2; + bin.c2=tmp; + return out; + } if(bin.operand == EBinaryOperand.AB_AB){ if (bin.right instanceof EOpNodeBinary rBinary && rBinary.operand == EBinaryOperand.AB_AB){ MatrixBlock left = rComputeEOpNodeCodegen(bin.left, inputs); @@ -668,7 +1136,7 @@ private static MatrixBlock rComputeEOpNodeCodegen(EOpNode eOpNode, ArrayList Date: Sun, 21 Sep 2025 02:23:27 +0200 Subject: [PATCH 02/13] +1 --- .../instructions/cp/EinsumCPInstruction.java | 200 +++++++++++++----- 1 file changed, 148 insertions(+), 52 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java index 7638ba38ece..246ea01b5b3 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java @@ -684,9 +684,9 @@ private MatrixBlock ComputeEOpNodeFuse(EOpNodeEinsumFuse eOpNodeEinsumFuse, List Supplier getNewVarname = () -> "TMP" + _seqVar.getNextID(); boolean writeToC = (isResultAB || isResultB) && As.isEmpty() && Bs.isEmpty(); - + body.append(" "); if(ABs.size()>1){ - body.append("// multiplying ABs: \n"); + body.append("// multiplying ABs: \n "); assignedArow=true; varnameArow = getNewVarname.get(); for(int i=1;i1){ - body.append("// multiplying Bs: \n"); + if(Bs.size()>1){ //todo use arow name if possible + body.append("// multiplying Bs: \n "); assignedB=true; varnameB = getNewVarname.get(); @@ -759,14 +759,14 @@ private MatrixBlock ComputeEOpNodeFuse(EOpNodeEinsumFuse eOpNodeEinsumFuse, List body.append("b["); midx++; body.append(midx); - body.append("].values(0),0,0,len);\n"); + body.append("].values(0),0,0,len);\n "); } else { body.append(varnameB); body.append(",b["); midx++; body.append(midx); - body.append("].values(0),0,0,len);\n"); + body.append("].values(0),0,0,len);\n "); } } } @@ -779,25 +779,63 @@ private MatrixBlock ComputeEOpNodeFuse(EOpNodeEinsumFuse eOpNodeEinsumFuse, List String resultVarname = null; if(!Bs.isEmpty()){ - body.append("// multiplying B with row: \n"); - - - resultVarname = assignedArow ? varnameArow : getNewVarname.get(); + body.append("// multiplying B with row: \n "); + if(sumB) + resultVarname = writeToC ? null : getNewVarname.get(); + else + resultVarname = writeToC ? "c" : assignedB ? varnameB : assignedArow ? varnameArow : getNewVarname.get(); - if(sumB) body.append("double "); - else if(!assignedArow) body.append("double[] "); - - body.append(resultVarname); - if(!assignedArow){ - if(sumB) body.append(" = LibSpoofPrimitives.dotProduct(a,"); - else body.append(" = LibSpoofPrimitives.vectMultWrite(a,"); + if(sumB) { + if(writeToC && isResultA) { + body.append("c[rix] = LibSpoofPrimitives.dotProduct("); + }else if(writeToC && isResult_) { + body.append("c[0] += LibSpoofPrimitives.dotProduct("); + }else { + body.append("double "); + body.append(resultVarname); + body.append(" = LibSpoofPrimitives.dotProduct("); + } + if(assignedArow) { + body.append(varnameArow); + body.append(","); + } + else { + body.append("a,"); + } }else{ - if(sumB) body.append(" = LibSpoofPrimitives.dotProduct("); - else body.append(" = LibSpoofPrimitives.vectMultWrite("); - body.append(varnameArow); - body.append(","); + if(writeToC) { + if(assignedArow || assignedB) { + body.append("LibMatrixMult.vectMultiplyWrite("); + body.append(resultVarname); + body.append(","); + }else{ + body.append("LibMatrixMult.vectMultiplyWrite(a,"); + } + } + else if(!assignedArow && !assignedB) { + body.append("double[] "); + body.append(resultVarname); + body.append(" = LibSpoofPrimitives.vectMultWrite(a,"); + }else{ + body.append("LibMatrixMult.vectMultiplyWrite("); + body.append(resultVarname); + body.append(","); + } } +// if(!assignedArow){ +// if(sumB) body.append(" = LibSpoofPrimitives.dotProduct(a,"); +// else { +// body.append("double[] "); +// body.append(resultVarname); +// body.append(" = LibSpoofPrimitives.vectMultWrite(a,"); +// } +// }else{ +// if(sumB) body.append(" = LibSpoofPrimitives.dotProduct("); +// else body.append("LibMatrixMult.vectMultiplyWrite("); +// body.append(varnameArow); +// body.append(","); +// } if(!assignedB){ body.append("b["); midx++; @@ -807,16 +845,72 @@ private MatrixBlock ComputeEOpNodeFuse(EOpNodeEinsumFuse eOpNodeEinsumFuse, List body.append(varnameB); body.append(","); } - if(!assignedArow) - body.append("ai,0,len);\n"); - else body.append("0,0,len);\n"); + if(sumB){ + if(!assignedArow) + body.append("ai,0,len);\n "); + else + body.append("0,0,len);\n "); + }else{ + //where to write + if(writeToC) { + body.append("c,"); + + }else if(assignedArow || assignedB) { + body.append(resultVarname); + body.append(","); + }// else its new [] + + //ai: + if(!assignedArow && !assignedB) + body.append("ai,0"); + else + body.append("0,0"); + + //ci + if(writeToC) { + if(isResultAB) + body.append(",ai,len);\n "); + else//B + body.append(",0,len);\n "); + }else if(assignedArow || assignedB){ + body.append(",0,len);\n "); + }else{ + body.append(",len);\n "); + } + +// else if(!assignedArow && !assignedB) { +// body.append("double[] "); +// body.append(resultVarname); +// body.append(" = LibSpoofPrimitives.vectMultWrite(a,"); +// }else{ +// body.append("LibMatrixMult.vectMultiplyWrite("); +// body.append(resultVarname); +// body.append(","); +// } +// if(!assignedArow/* && !assignedB*/){ +// body.append("ai,0,len);\n"); +// }else{ +// body.append(resultVarname); +// } + } +// if(!assignedArow) +// body.append("ai,0,len);\n"); +// else { +// if(sumB) +// body.append("0,0,len);\n"); +// else{ +// body.append(resultVarname); +// body.append(",0,0,0,len);\n"); +// +// } +// } } writeToC |= !isResultZ; if(!As.isEmpty()){ // multiply with value of A - body.append("// multiplying current result with value of A: \n"); + body.append("// multiplying current result with value of A: \n "); if(sumB) { if(resultVarname == null) { @@ -829,9 +923,9 @@ private MatrixBlock ComputeEOpNodeFuse(EOpNodeEinsumFuse eOpNodeEinsumFuse, List if(assignedArow) { body.append(varnameArow); - body.append(", 0, len);\n"); + body.append(", 0, len);\n "); } else { - body.append("a, ai, len);\n"); + body.append("a, ai, len);\n "); } } if(writeToC && isResult_){ @@ -848,12 +942,14 @@ private MatrixBlock ComputeEOpNodeFuse(EOpNodeEinsumFuse eOpNodeEinsumFuse, List } midx++; body.append(midx); - body.append("].values(0)[rix];\n"); + body.append("].values(0)[rix];\n "); } else { boolean resultVarnameNull = resultVarname == null; - if(writeToC) { + if(writeToC && isResultAB) { body.append("LibMatrixMult.vectMultiplyWrite(b["); + }else if(writeToC && isResultB){ + body.append("LibMatrixMult.vectMultiplyAdd(b["); }else{ if (resultVarnameNull) { // did vectmult previously resultVarname = getNewVarname.get(); @@ -869,39 +965,39 @@ private MatrixBlock ComputeEOpNodeFuse(EOpNodeEinsumFuse eOpNodeEinsumFuse, List if(assignedArow) { body.append(varnameArow); if (writeToC && isResultB) - body.append(",c,0,0,len);\n"); + body.append(",c,0,0,len);\n "); else if (writeToC && isResultAB) - body.append(",c,0,ci,len);\n"); + body.append(",c,0,ci,len);\n "); else - body.append(",0,len);\n"); + body.append(",0,len);\n "); }else{ if (writeToC && isResultB) - body.append("a,c,ai,0,len);\n"); + body.append("a,c,ai,0,len);\n "); else if (writeToC && isResultAB) - body.append("a,c,ai,ci,len);\n"); + body.append("a,c,ai,ci,len);\n "); else - body.append("a,ai,len);\n"); + body.append("a,ai,len);\n "); } }else { body.append(resultVarname); if (writeToC && isResultB) - body.append(",c,0,0,len);\n"); + body.append(",c,0,0,len);\n "); else if (writeToC && isResultAB) - body.append(",c,0,ci,len);\n"); + body.append(",c,0,ci,len);\n "); else - body.append(",0,len);\n"); + body.append(",0,len);\n "); } } } - body.append("// write part: \n"); + body.append("// write part: \n "); if(!writeToC) { if (isResultZ) { if (AZs.isEmpty()) throw new RuntimeException("Einsum runtime exception: Invalid rewrite type"); int zSize = einc.charToDimensionSize.get(eOpNodeEinsumFuse.c1); - body.append("// AZ part: \n"); + body.append("// AZ part: \n "); boolean resultVarnameNull = resultVarname == null; @@ -912,9 +1008,9 @@ else if (writeToC && isResultAB) body.append(" = LibSpoofPrimitives.vectSum("); if (assignedArow) { body.append(varnameArow); - body.append(", 0, len);\n"); + body.append(", 0, len);\n "); } else { - body.append("a, ai, len);\n"); + body.append("a, ai, len);\n "); } } body.append("LibSpoofPrimitives.vectMultAdd(b["); @@ -925,21 +1021,21 @@ else if (writeToC && isResultAB) body.append("c,rix*"); body.append(zSize); - body.append(",0,len);\n"); + body.append(",0,len);\n "); } else if (isResultAB) { //vectWrite(double[] a, double[] c, int ai, int ci, int len) body.append("LibSpoofPrimitives.vectWrite("); if (resultVarname == null) { if (assignedArow) { body.append(varnameArow); - body.append(",c,0,ai,len);\n"); + body.append(",c,0,ai,len);\n "); } else { //should never happen throw new RuntimeException("Einsum runtime exception: Invalid rewrite type"); } } else { body.append(resultVarname); - body.append(",c,0,ai,len);\n"); + body.append(",c,0,ai,len);\n "); } } else if (isResultA) { if (resultVarname == null) { @@ -948,21 +1044,21 @@ else if (writeToC && isResultAB) } body.append("c[rix] = "); body.append(resultVarname); - body.append(";\n"); + body.append(";\n "); } else if (isResultB) { body.append("LibMatrixMult.vectAdd(");//public static void vectAdd( double[] a, double[] c, int ai, int ci, final int len ) if (resultVarname == null) { if (assignedArow) { body.append(varnameArow); - body.append(",c,0,0,len);\n"); + body.append(",c,0,0,len);\n "); } else { //should never happen throw new RuntimeException("Einsum runtime exception: Invalid rewrite type"); } } else { body.append(resultVarname); - body.append(",c,0,0,len);\n"); + body.append(",c,0,0,len);\n "); } } else if (isResult_) { @@ -972,7 +1068,7 @@ else if (writeToC && isResultAB) } body.append("c[0] += "); body.append(resultVarname); - body.append(";\n"); + body.append(";\n "); } } @@ -1003,7 +1099,7 @@ else if (writeToC && isResultAB) if( LOG.isTraceEnabled()) LOG.trace(CodegenUtils.printWithLineNumber(src)); Class cla = CodegenUtils.compileClass("codegen." + cnode.getClassname(), src); SpoofOperator op = CodegenUtils.createInstance(cla); -// SpoofOperator op = CodegenUtils.createInstance(AAA1.class); +// SpoofOperator op = CodegenUtils.createInstance(AAA2.class); // MatrixBlock resBlock = new MatrixBlock(); // resBlock.reset(einc.charToDimensionSize.get(eOpNodeEinsumFuse.c1), From 5176cc053c407924500ab6ee8d9ae6746a32f322 Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Fri, 3 Oct 2025 19:55:22 +0200 Subject: [PATCH 03/13] fix in einsum codegen and added no-codegen fused op --- .../runtime/einsum/EOpNodeEinsumFuse.java | 234 +++++++------- .../runtime/einsum/EinsumSpoofRowwise.java | 160 ++++++++++ .../instructions/cp/EinsumCPInstruction.java | 296 ++++++++++++++---- .../runtime/matrix/data/LibMatrixMult.java | 19 ++ 4 files changed, 535 insertions(+), 174 deletions(-) create mode 100644 src/main/java/org/apache/sysds/runtime/einsum/EinsumSpoofRowwise.java diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeEinsumFuse.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeEinsumFuse.java index 76be0b56b7d..ae10c93115c 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeEinsumFuse.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeEinsumFuse.java @@ -1,5 +1,8 @@ package org.apache.sysds.runtime.einsum; +import org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction; +import org.apache.sysds.runtime.matrix.data.LibMatrixMult; + import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -17,21 +20,24 @@ public class EOpNodeEinsumFuse extends EOpNode { public static final int AX_index=7; public static final int AZ_index=8; public enum EinsumRewriteType{ - // inputops__output 'X' = simplySumDim - AB_BA_B_XB_BX_A_XA_AX__AB, - AB_BA_B_XB_BX_A_XA_AX__B, - AB_BA_B_XB_BX_A_XA_AX__A, - AB_BA_B_XB_BX_A_XA_AX__, - - AB_BA_B_XB_BX_A_XA_AX_AZ__Z - } - public enum EinsumRewriteType_v2{ // option 2 without X dims + // B -> row*row, A -> row*scalar AB_BA_B_A__AB, AB_BA_B_A__B, AB_BA_B_A__A, + AB_BA_B_A__, - AB_BA_B_A_AZ__Z + // scalar from row(AB).dot(B) multiplied by row(AZ) + AB_BA_B_A_AZ__Z, + + // AC: last step is outer matrix multiplication using vector C + AB_BA_B_A_AZ__BZ, + AB_BA_B_A_AZ__ZB, + +// // outer matrix multiplication using vector C and vector Z +// AB_BA_B_A_AZ_AC__ZC, +// AB_BA_B_A_AZ_AC__CZ, } + public final EinsumRewriteType einsumRewriteType; public final List> operands; @@ -41,29 +47,17 @@ private EOpNodeEinsumFuse(Character c1, Character c2, EinsumRewriteType einsumRe this.operands = Arrays.asList(operands); } - public static EOpNodeEinsumFuse match(ArrayList operands, Character outChar1, Character outChar2,/*, Set simplySummableChars,*/ ArrayList ret, HashMap charToOccurences){ + public static EOpNodeEinsumFuse match(ArrayList operands, Character outChar1, Character outChar2,/*, Set simplySummableChars,*/ ArrayList ret, HashMap charToOccurences, HashMap charToSize){ //precompute HashSet matricesChars = new HashSet<>(); HashMap> charsToMatrices = new HashMap<>(); - HashMap charsToNumberOfOperands = new HashMap<>(); for (EOpNode operand1 : operands) { String k; -//todo remove and use input map charToOccurences - if (charsToNumberOfOperands.containsKey(operand1.c1)) { - charsToNumberOfOperands.put(operand1.c1, charsToNumberOfOperands.get(operand1.c1) + 1); - } else { - charsToNumberOfOperands.put(operand1.c1, 1); - } if (operand1.c2 != null) { k = operand1.c1.toString() + operand1.c2; matricesChars.add(k); - if (charsToNumberOfOperands.containsKey(operand1.c2)) { - charsToNumberOfOperands.put(operand1.c2, charsToNumberOfOperands.get(operand1.c2) + 1); - } else { - charsToNumberOfOperands.put(operand1.c2, 1); - } } else { k = operand1.c1.toString(); } @@ -82,13 +76,14 @@ public static EOpNodeEinsumFuse match(ArrayList operands, Character out ArrayList BXs = new ArrayList<>(); ArrayList XBs = new ArrayList<>(); ArrayList AZs = new ArrayList<>(); +// ArrayList ACs = new ArrayList<>(); + ArrayList Zs = new ArrayList<>(); boolean pass = false; String AB = null; String BA = null; boolean doSumA=false; boolean doSumB=false; - for (String ABcandidate : matricesChars) { char a = ABcandidate.charAt(0); char b = ABcandidate.charAt(1); @@ -99,9 +94,10 @@ public static EOpNodeEinsumFuse match(ArrayList operands, Character out BXs = new ArrayList<>(); XBs = new ArrayList<>(); AZs = new ArrayList<>(); - + Character z = null; pass=true; - + int AZsCounter = 0; + HashSet AZCandidates = new HashSet<>(); for (String chars : charsToMatrices.keySet()) { if (chars.equals(ABcandidate) || chars.equals(BA)) { @@ -118,111 +114,74 @@ public static EOpNodeEinsumFuse match(ArrayList operands, Character out continue; //always ok }else{ - if(a==chars.charAt(1) && b==chars.charAt(0)){ + if(a==chars.charAt(1) && b==chars.charAt(0)){ //BA // ABsCounter++; - //BA continue; } if(chars.charAt(0)==a){ - if(charsToNumberOfOperands.get(chars.charAt(1))==1){ - if(chars.charAt(1)!= outChar1 && chars.charAt(1) != outChar2) { - AXs.addAll(charsToMatrices.get(chars)); -// AsCounter++; - continue; - }else{ - if(AZs.size()==0){ - AZs = charsToMatrices.get(chars); - continue; - } - pass = false; - break; - } - }else{ - //dont allow for now, in theory AZ,Z or AZ,AZ would also work, but for now do them separately - pass = false; - break; - } + //AZ + AZsCounter++; + AZCandidates.add(chars); } else if(chars.charAt(0)==b){ - if(charsToNumberOfOperands.get(chars.charAt(1))==1){ - if(chars.charAt(1)!= outChar1 && chars.charAt(1) != outChar2) { - BXs.addAll(charsToMatrices.get(chars)); -// BsCounter++; - continue; - }else{ - pass = false; // no BZ, maybe experiment later - break; - } - }else{ - pass = false; - break; - } + // BZ, todo, maybe transpose ab into ba + pass = false; + break; } else if(chars.charAt(1)==a){ - if(charsToNumberOfOperands.get(chars.charAt(0))==1){ - if(chars.charAt(0)!= outChar1 && chars.charAt(0) != outChar2) { - XAs.addAll(charsToMatrices.get(chars)); -// AsCounter++; - continue; - }else{ - pass = false; - break; - } - }else{ + //ZA, maybe its small enough that it can be tranposed? but then not impactful as the bigger A, the more sense to fuse AZ? pass = false; break; - } } else if(chars.charAt(1)==b){ - if(charsToNumberOfOperands.get(chars.charAt(0))==1){ - if(chars.charAt(0)!= outChar1 && chars.charAt(0) != outChar2) { - XBs.addAll(charsToMatrices.get(chars)); -// BsCounter++; - continue; - }else{ - pass = false; - break; - } - }else{ - pass = false; - break; - } + // ZB + pass = false; + break; } } } if(pass){ + AB = ABcandidate; String A = ""+a; String B = ""+b; int ABsCounter = charsToMatrices.get(ABcandidate).size()+(charsToMatrices.containsKey(BA) ? charsToMatrices.get(BA).size() : 0); - int AsCounter = (charsToMatrices.containsKey(A) ? charsToMatrices.get(A).size() : 0) +AXs.size()+XAs.size(); - int BsCounter = (charsToMatrices.containsKey(B) ? charsToMatrices.get(B).size() : 0)+BXs.size()+XBs.size(); +// int AZsCounter = AZs.size(); + int AsCounter = (charsToMatrices.containsKey(A) ? charsToMatrices.get(A).size() : 0); + int BsCounter = (charsToMatrices.containsKey(B) ? charsToMatrices.get(B).size() : 0); if(AsCounter==0 && BsCounter==0 && ABsCounter<2){ pass=false; continue; } - int usedAsCount = AsCounter+ABsCounter; int usedBsCount = BsCounter+ABsCounter; - doSumA = charToOccurences.get(a)==usedAsCount && (outChar1 == null || a!=outChar1) && (outChar2 == null || a!=outChar2); doSumB = charToOccurences.get(b)==usedBsCount && (outChar1 == null || b!=outChar1) && (outChar2 == null || b!=outChar2); - if(AZs.size()!=0) { // invalidate AZ fusion - if (outChar1 != null) { - if (a == outChar1 || b == outChar1) { - pass=false; - continue; - } - } - if (outChar2 != null) { - if (a == outChar2 || b == outChar2) { - pass=false; - continue; - } + + if(AZCandidates.size()==1){ +// if(!doSumB){ +// pass=false; +// continue; +// } + int usedAsCount = AsCounter+ABsCounter+AZsCounter; + doSumA = charToOccurences.get(a)==usedAsCount && (outChar1 == null || a!=outChar1) && (outChar2 == null || a!=outChar2); + if(!doSumA){ // cant do AZ + break;// just do AB,B,A ->AB / A + }else { + AZs = charsToMatrices.get(AZCandidates.iterator().next()); + break;//ok } - if(!doSumA || !doSumB){ - pass=false; - continue; + } else if (AZCandidates.size()>=2) { + doSumA = false; + if(doSumB){ + pass=true; + break; // can do it, it will create AB,B,A -> A, that will be consumed by some AZ later } + pass=false; + continue; + } + int usedAsCount = AsCounter+ABsCounter; + doSumA = charToOccurences.get(a)==usedAsCount && (outChar1 == null || a!=outChar1) && (outChar2 == null || a!=outChar2); + break; } } @@ -232,31 +191,62 @@ else if(chars.charAt(1)==b){ } String B = AB.substring(1,2); String A = AB.substring(0,1); + char a = A.charAt(0); + char b = B.charAt(0); Character c1 = null; Character c2 = null; - EinsumRewriteType t; + EinsumRewriteType t = null; - if(AZs.size()!=0){ - c1=AZs.get(0).c2; - t=EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX_AZ__Z; - } - else if(doSumA){ + if(!AZs.isEmpty()){ +// Character azC1 = AZs.get(0).c1; + Character azC2 = AZs.get(0).c2; +// c1 = AZs.get(0).c2; if(doSumB) { - t = EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX__; + t = EinsumRewriteType.AB_BA_B_A_AZ__Z; + c1 = azC2; + } - else { - t = EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX__B; - c1 = AB.charAt(1); + else if (EinsumCPInstruction.FUSE_OUTER_MULTIPLY) { + if(LibMatrixMult.isSkinnyRightHandSide(charToSize.get(AB.charAt(0)), charToSize.get(AB.charAt(1)), charToSize.get(azC2), charToSize.get(AB.charAt(1)),false)|| + LibMatrixMult.isSkinnyRightHandSide(charToSize.get(AB.charAt(0)), azC2, charToSize.get(AB.charAt(1)), charToSize.get(azC2),false)) { + // ideally this can be changed later by parent,depending on need + if (outChar1 == azC2 && outChar2 == b) { + t = EinsumRewriteType.AB_BA_B_A_AZ__ZB; + c1 = azC2; + c2 = b; + } else if (outChar2 == azC2 && outChar1 == b) { + t = EinsumRewriteType.AB_BA_B_A_AZ__BZ; + c1 = b; + c2 = azC2; + } else { + t = EinsumRewriteType.AB_BA_B_A_AZ__ZB; + c1 = azC2; + c2 = b; + } + + } + } + + if(charsToMatrices.containsKey(azC2.toString())) { + Zs = charsToMatrices.get(azC2.toString()); } } - else if(doSumB){ - t= EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX__A; - c1= AB.charAt(0); - } - else { - t = EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX__AB; - c1 = AB.charAt(0); - c2 = AB.charAt(1); + if(t==null) { + if (doSumA) { + if (doSumB) { + t = EinsumRewriteType.AB_BA_B_A__; + } else { + t = EinsumRewriteType.AB_BA_B_A__B; + c1 = AB.charAt(1); + } + } else if (doSumB) { + t = EinsumRewriteType.AB_BA_B_A__A; + c1 = AB.charAt(0); + } else { + t = EinsumRewriteType.AB_BA_B_A__AB; + c1 = AB.charAt(0); + c2 = AB.charAt(1); + } } if(c1 != null){ charToOccurences.put(c1, charToOccurences.get(c1)+1); @@ -280,6 +270,7 @@ else if(doSumB){ usedOperands.addAll(XAs); usedOperands.addAll(AXs); usedOperands.addAll(AZs); + usedOperands.addAll(Zs); for(EOpNode n : operands){ if(!usedOperands.contains(n)){ @@ -302,7 +293,8 @@ else if(doSumB){ As, XAs, AXs, - AZs + AZs, + Zs ); ret.add(e); return e; diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EinsumSpoofRowwise.java b/src/main/java/org/apache/sysds/runtime/einsum/EinsumSpoofRowwise.java new file mode 100644 index 00000000000..5298a8db176 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/einsum/EinsumSpoofRowwise.java @@ -0,0 +1,160 @@ +package org.apache.sysds.runtime.einsum; + +import org.apache.commons.lang3.NotImplementedException; +import org.apache.sysds.runtime.codegen.LibSpoofPrimitives; +import org.apache.sysds.runtime.codegen.SpoofRowwise; +import org.apache.sysds.runtime.matrix.data.LibMatrixMult; + +public final class EinsumSpoofRowwise extends SpoofRowwise { + private final int _ABCount; + private final int _BCount; + private final int _ACount; + private final int _ZCount; + private final int _AZCount; + private final int _ZSize; + + private final int _uptoBCumCount; + private final int _uptoZCumCount; + + private final EOpNodeEinsumFuse.EinsumRewriteType _EinsumRewriteType; + + public EinsumSpoofRowwise(EOpNodeEinsumFuse.EinsumRewriteType einsumRewriteType, RowType rowType, long constDim2, boolean tb1, int reqVectMem, int abCount, int bCount, int aCount, int zCount, int azCount, int zSize) { + super(rowType, constDim2, tb1, reqVectMem); + _ABCount = abCount; + _BCount = bCount; + _uptoBCumCount = bCount+ abCount; + _ACount = aCount; + _ZCount = zCount; + _uptoZCumCount = _uptoBCumCount + aCount; + _AZCount = azCount; + _EinsumRewriteType = einsumRewriteType; + _ZSize = zSize; + } + protected void genexec(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, long grix, int rix) { + switch (_EinsumRewriteType) { + case AB_BA_B_A__AB -> genexec_AB(a,ai,b,scalars,c,ci,len,grix,rix); + case AB_BA_B_A__B -> genexec_B(a,ai,b,scalars,c,ci,len,grix,rix); + case AB_BA_B_A__A -> genexec_A_or_(a,ai,b,scalars,c,ci,len,grix,rix); + case AB_BA_B_A__ -> genexec_A_or_(a,ai,b,scalars,c,ci,len,grix,rix); + case AB_BA_B_A_AZ__Z -> { + double[] temp = {0}; + genexec_A_or_(a,ai,b,scalars,temp,0,len,grix,rix); + LibMatrixMult.vectMultiplyAdd(temp[0], b[_uptoZCumCount].values(rix), c, _ZSize*rix,0, _ZSize); + } + case AB_BA_B_A_AZ__BZ -> { + double[] temp = new double[len]; + genexec_B(a,ai,b,scalars,temp,0,len,grix,rix); + LibSpoofPrimitives.vectOuterMultAdd(temp, b[_uptoZCumCount].values(rix), c,0, _ZSize*rix, 0, len,_ZSize); + } + case AB_BA_B_A_AZ__ZB -> { + double[] temp = new double[len]; + genexec_B(a,ai,b,scalars,temp,0,len,grix,rix); + + LibSpoofPrimitives.vectOuterMultAdd(b[_uptoZCumCount].values(rix),temp , c,_ZSize*rix,0, 0, _ZSize, len); + } +// case AB_BA_B_XB_BX_A_XA_AX_AZ_AC__CZ -> + default -> throw new NotImplementedException(); + } + } + protected void genexec_AB(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, long grix, int rix) { + int bi = 0; + double[] TMP1 = null; + if (_ABCount != 0){ + TMP1 = LibSpoofPrimitives.vectMultWrite(a,b[bi++].values(rix),ai,ai,len); + while (bi < _ABCount) { + if(_ACount == 0 && _BCount == 0 && bi == _ABCount-1) { + LibMatrixMult.vectMultiplyWrite(TMP1, b[bi++].values(rix), c, 0, ai, ci, len); + }else { + LibMatrixMult.vectMultiplyWrite(TMP1, b[bi++].values(rix), TMP1, 0, ai, 0, len); + } + } + } + + if(_BCount > 0 && TMP1 == null) { + TMP1 = LibSpoofPrimitives.vectMultWrite(a,b[bi++].values(0),ai,0,len); + } + while(bi < _uptoBCumCount) { + if (_ACount == 0 && bi == _uptoBCumCount - 1) { + LibMatrixMult.vectMultiplyWrite(TMP1, b[bi++].values(0), c, 0, 0, ci, len); + } else { + LibMatrixMult.vectMultiplyWrite(TMP1, b[bi++].values(0), TMP1, 0, 0, 0, len); + } + } + + if(_ACount == 1) { + LibMatrixMult.vectMultiplyWrite(b[bi].values(0)[rix],TMP1,c,0,ci,len); + } + } + + protected void genexec_B(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, long grix, int rix) { + int bi = 0; + double[] TMP1 = null; + if (_ABCount != 0){ + TMP1 = LibSpoofPrimitives.vectMultWrite(a,b[bi++].values(rix),ai,ai,len); + while (bi < _ABCount) { + if(_ACount == 0 && _BCount == 0 && bi == _ABCount-1) { + LibMatrixMult.vectMultiplyAdd(TMP1, b[bi++].values(rix), c, 0, ai, 0, len); + }else { + LibMatrixMult.vectMultiplyWrite(TMP1, b[bi++].values(rix), TMP1, 0, ai, 0, len); + } + } + } + + if(_BCount > 0 && TMP1 == null) { + TMP1 = LibSpoofPrimitives.vectMultWrite(a,b[bi++].values(0),ai,0,len); + } + while(bi < _uptoBCumCount) { + if (_ACount == 0 && bi == _uptoBCumCount - 1) { + LibMatrixMult.vectMultiplyAdd(TMP1, b[bi++].values(0), c, 0, 0, 0, len); + } else { + LibMatrixMult.vectMultiplyWrite(TMP1, b[bi++].values(0), TMP1, 0, 0, 0, len); + } + } + + if(_ACount == 1) { + LibMatrixMult.vectMultiplyAdd(b[bi].values(0)[rix],TMP1,c,0,0,len); + } + } + + protected void genexec_A_or_(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, long grix, int rix) { + int bi = 0; + double[] TMP1 = null; + Double TMP2 = null; + if (_ABCount == 0 && _BCount == 0){ + TMP2 = LibSpoofPrimitives.dotProduct(a,b[bi++].values(rix),ai,ai,len); + } + else if (_ABCount != 0){ + TMP1 = LibSpoofPrimitives.vectMultWrite(a,b[bi++].values(rix),ai,ai,len); + while (bi < _ABCount) { + if(_BCount == 0 && bi == _ABCount - 1) { + TMP2 = LibSpoofPrimitives.dotProduct(TMP1,b[bi++].values(rix),0,ai,len); + }else { + LibMatrixMult.vectMultiplyWrite(TMP1, b[bi++].values(rix), TMP1, 0, ai, 0, len); + } + } + } + + if(_BCount == 1 && TMP1 == null) { + TMP2 = LibSpoofPrimitives.dotProduct(a,b[bi++].values(0),ai,0,len); + } + else if(_BCount > 0 && TMP1 == null) { + TMP1 = LibSpoofPrimitives.vectMultWrite(a,b[bi++].values(0),ai,0,len); + } + while(bi < _uptoBCumCount) { + if(bi == _uptoBCumCount -1){ + TMP2 = LibSpoofPrimitives.dotProduct(TMP1,b[bi++].values(0),0,0,len); + } + else { + LibMatrixMult.vectMultiplyWrite(TMP1, b[bi++].values(0), TMP1, 0, 0, 0, len); + } + } + + if(_ACount == 1) { + TMP2 *= b[bi].values(0)[rix]; + } + if (_EinsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A__A) c[ci] = TMP2; + else c[0] += TMP2; + } + protected void genexec(double[] avals, int[] aix, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int alen, int len, long grix, int rix) { + throw new RuntimeException("Sparse einsum not implemented"); } +} \ No newline at end of file diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java index 246ea01b5b3..254c9b829fe 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java @@ -19,8 +19,6 @@ package org.apache.sysds.runtime.instructions.cp; -import codegen.AAA1; -import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Triple; import org.apache.commons.logging.Log; @@ -30,9 +28,7 @@ import org.apache.sysds.common.Types.DataType; import org.apache.sysds.hops.LiteralOp; import org.apache.sysds.hops.OptimizerUtils; -import org.apache.sysds.hops.codegen.SpoofCompiler; import org.apache.sysds.hops.codegen.cplan.CNode; -import org.apache.sysds.hops.codegen.cplan.CNodeBinary; import org.apache.sysds.hops.codegen.cplan.CNodeCell; import org.apache.sysds.hops.codegen.cplan.CNodeData; import org.apache.sysds.hops.codegen.cplan.CNodeRow; @@ -51,6 +47,7 @@ import org.apache.sysds.runtime.matrix.operators.Operator; import org.apache.sysds.runtime.matrix.operators.ReorgOperator; import org.apache.sysds.runtime.matrix.operators.SimpleOperator; +import org.jetbrains.annotations.NotNull; import java.util.*; import java.util.function.Predicate; @@ -58,8 +55,9 @@ import java.util.stream.Collectors; public class EinsumCPInstruction extends BuiltinNaryCPInstruction { - public static boolean FORCE_CELL_TPL = false; - public static boolean FUSED = true; + public static final boolean FORCE_CELL_TPL = false; + public static final boolean FUSED = true; + public static final boolean FUSE_OUTER_MULTIPLY = true; protected static final Log LOG = LogFactory.getLog(EinsumCPInstruction.class.getName()); public String eqStr; private final int _numThreads; @@ -68,11 +66,12 @@ public class EinsumCPInstruction extends BuiltinNaryCPInstruction { public EinsumCPInstruction(Operator op, String opcode, String istr, CPOperand out, CPOperand... inputs) { super(op, opcode, istr, out, inputs); -// _numThreads = OptimizerUtils.getConstrainedNumThreads(-1); - _numThreads = 3; + _numThreads = OptimizerUtils.getConstrainedNumThreads(-1)/2; +// _numThreads = 6; _in = inputs; this.eqStr = inputs[0].getName(); - Logger.getLogger(EinsumCPInstruction.class).setLevel(Level.TRACE); + Logger.getLogger(EinsumCPInstruction.class).setLevel(Level.TRACE); +// Logger.getLogger(EinsumCPInstruction.class).setLevel(Level.WARN); } @SuppressWarnings("unused") @@ -88,6 +87,9 @@ public void processInstruction(ExecutionContext ec) { if(mb instanceof CompressedMatrixBlock){ mb = ((CompressedMatrixBlock) mb).getUncompressed("Spoof instruction"); } + if(mb.getNumRows() == 1){ + EnsureMatrixBlockColumnVector(mb); + } inputs.add(mb); } } @@ -166,6 +168,8 @@ public void processInstruction(ExecutionContext ec) { ec.setMatrixOutput(output.getName(), resMatrices.get(0)); } else if(resNode.c1 == einc.outChar2 && resNode.c2 == einc.outChar1){ + if( LOG.isTraceEnabled()) LOG.trace("Transposing the final result"); + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); MatrixBlock resM = resMatrices.get(0).reorgOperations(transpose, new MatrixBlock(),0,0,0); ec.setMatrixOutput(output.getName(), resM); @@ -277,18 +281,18 @@ else if (operands.size() == 1){ if(fused){ ArrayList ret = new ArrayList<>(); - EOpNodeEinsumFuse fuse = EOpNodeEinsumFuse.match(operands,outChar1,outChar2,ret, charToOccurences); + EOpNodeEinsumFuse fuse = EOpNodeEinsumFuse.match(operands,outChar1,outChar2,ret, charToOccurences, einc.charToDimensionSize); if(fuse != null){ minNodes = operands = ret; } while(ret.size() > 2 && fuse!=null){ ret = new ArrayList<>(); - fuse = EOpNodeEinsumFuse.match(operands,outChar1,outChar2,ret, charToOccurences); + fuse = EOpNodeEinsumFuse.match(operands,outChar1,outChar2,ret, charToOccurences, einc.charToDimensionSize); if(fuse != null){ - operands= ret; - operands.add(fuse); + minNodes = operands = ret; } } + fused = false; } for(int i = 0; i < operands.size()-1; i++){ @@ -601,15 +605,132 @@ private MatrixBlock ComputeEOpNode(EOpNode eOpNode, ArrayList input } return res; } - private MatrixBlock ComputeEOpNodeFuse(EOpNodeEinsumFuse eOpNodeEinsumFuse, List> mbs) { + if( LOG.isTraceEnabled()) { + String x = eOpNodeEinsumFuse.operands.stream() + .flatMap(List::stream) + .map(o -> o.c1.toString() + (o.c2 == null ? "" : o.c2)) + .collect(Collectors.joining(",")); + String res = (eOpNodeEinsumFuse.c1 == null ? "" : eOpNodeEinsumFuse.c1.toString())+(eOpNodeEinsumFuse.c2 == null ? "" : eOpNodeEinsumFuse.c2.toString()); + LOG.trace("ComputeEOpNodeFuse " + eOpNodeEinsumFuse.einsumRewriteType.toString() +" "+ x + " -> " + res); + } + boolean isResultAB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A__AB; + boolean isResultA = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A__A; + boolean isResultB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A__B; + boolean isResult_ = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A__; + boolean isResultZ = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A_AZ__Z; + boolean isResultBZ = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A_AZ__BZ; + boolean isResultZB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A_AZ__ZB; +// boolean isResultBC = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A_AC__BC; +// boolean isResultCB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A_AZ__ZB; + List ABs = mbs.get(0), BAs = mbs.get(1), Bs = mbs.get(2), XBs = mbs.get(3), BXs = mbs.get(4), As = mbs.get(5), XAs = mbs.get(6), AXs = mbs.get(7); + List AZs = mbs.get(8); + List Zs = mbs.get(9); +// List ACs = isResultBC || isResultCB ? mbs.get(10) : null; + int bSize = einc.charToDimensionSize.get(eOpNodeEinsumFuse.operands.get(0).get(0).c2); + int aSize = einc.charToDimensionSize.get(eOpNodeEinsumFuse.operands.get(0).get(0).c1); + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); + for(MatrixBlock mb: BAs){//BA->AB + ABs.add(mb.reorgOperations(transpose, null,0,0,0)); + } + AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), _numThreads); + for(MatrixBlock mb: XBs){//XB->B + MatrixBlock res = new MatrixBlock(mb.getNumColumns(), 1, false); + Bs.add((MatrixBlock)mb.aggregateUnaryOperations(aggun, res, 0, null)); + } + for(MatrixBlock mb: XAs){//XA->A + MatrixBlock res = new MatrixBlock(mb.getNumColumns(), 1, false); + As.add((MatrixBlock)mb.aggregateUnaryOperations(aggun, res, 0, null)); + } + aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads); + for(MatrixBlock mb: BXs){//BX->B + MatrixBlock res = new MatrixBlock(mb.getNumRows(), 1, false); + As.add((MatrixBlock)mb.aggregateUnaryOperations(aggun, res, 0, null)); + } + for(MatrixBlock mb: AXs){//AX->B // todo remove all X + MatrixBlock res = new MatrixBlock(mb.getNumRows(), 1, false); + As.add((MatrixBlock)mb.aggregateUnaryOperations(aggun, res, 0, null)); + } + if(As.size()>1){ + As = MultiplyVectorsIntoOne(As, aSize); + } + if(Bs.size() > 1){ + Bs = MultiplyVectorsIntoOne(Bs, bSize); + } + if(Zs != null && Zs.size() > 1){ + Zs = MultiplyVectorsIntoOne(Zs, AZs.get(0).getNumColumns()); + } + int constDim2 = -1; + int zSize = 0; + int azCount = 0; + int zCount = 0; + switch(eOpNodeEinsumFuse.einsumRewriteType){ + case AB_BA_B_A_AZ__Z -> { + constDim2 = AZs.get(0).getNumColumns(); + zSize = AZs.get(0).getNumColumns(); + azCount = AZs.size(); + if (Zs != null) zCount = Zs.size(); + } + case AB_BA_B_A_AZ__BZ, AB_BA_B_A_AZ__ZB -> { + constDim2 = AZs.get(0).getNumColumns(); + zSize = AZs.get(0).getNumColumns(); + azCount = AZs.size(); + } + } + + SpoofRowwise.RowType rowType = switch(eOpNodeEinsumFuse.einsumRewriteType){ + case AB_BA_B_A__AB -> SpoofRowwise.RowType.NO_AGG; + case AB_BA_B_A__B -> SpoofRowwise.RowType.COL_AGG_T; + case AB_BA_B_A__A -> SpoofRowwise.RowType.ROW_AGG; + case AB_BA_B_A__ -> SpoofRowwise.RowType.FULL_AGG; + case AB_BA_B_A_AZ__Z -> SpoofRowwise.RowType.COL_AGG_CONST; + case AB_BA_B_A_AZ__BZ -> SpoofRowwise.RowType.COL_AGG_B1_T; + case AB_BA_B_A_AZ__ZB -> SpoofRowwise.RowType.COL_AGG_B1; + }; + EinsumSpoofRowwise r = new EinsumSpoofRowwise(eOpNodeEinsumFuse.einsumRewriteType, rowType, constDim2, false, 0, ABs.size()-1,Bs.size(), As.size(), zCount, azCount, zSize); + + + ArrayList inputs = new ArrayList<>(); +// inputs.add(resBlock); + + inputs.addAll(ABs); + inputs.addAll(Bs); + inputs.addAll(As); + if (isResultZ || isResultBZ || isResultZB) + inputs.addAll(AZs); + MatrixBlock out = r.execute(inputs, new ArrayList<>(), new MatrixBlock(), _numThreads); + if( isResultA || isResultB || isResultZ) + EnsureMatrixBlockColumnVector(out); + return out; + + + } + + private static @NotNull List MultiplyVectorsIntoOne(List mbs, int size) { + MatrixBlock mb = new MatrixBlock(mbs.get(0).getNumRows(), mbs.get(0).getNumColumns(), false); + mb.allocateDenseBlock(); + MatrixBlock l = mbs.get(0); + for(int i = 1; i< mbs.size(); i++) { // multiply Bs + if(i==1){ + LibMatrixMult.vectMultiplyWrite(mbs.get(0).getDenseBlock().values(0), mbs.get(1).getDenseBlock().values(0), mb.getDenseBlock().values(0),0,0,0, size); + }else{ + LibMatrixMult.vectMultiply(mbs.get(i).getDenseBlock().values(0),mb.getDenseBlock().values(0),0,0, size); + } + } + return List.of(mb); + } + + private MatrixBlock ComputeEOpNodeFuseCodegen(EOpNodeEinsumFuse eOpNodeEinsumFuse, List> mbs) { + //prepare matrices //EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX__AB;\ - boolean isResultAB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX__AB; - boolean isResultA = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX__A; - boolean isResultB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX__B; - boolean isResult_ = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX__; - boolean isResultZ = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX_AZ__Z; + //region Description + boolean isResultAB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A__AB; + boolean isResultA = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A__A; + boolean isResultB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A__B; + boolean isResult_ = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A__; + boolean isResultZ = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A_AZ__Z; List ABs = mbs.get(0), BAs = mbs.get(1), Bs = mbs.get(2), XBs = mbs.get(3), BXs = mbs.get(4), As = mbs.get(5), XAs = mbs.get(6), AXs = mbs.get(7); List AZs = isResultZ ? mbs.get(8) : null; int bSize = einc.charToDimensionSize.get(eOpNodeEinsumFuse.operands.get(0).get(0).c2); @@ -638,11 +759,7 @@ private MatrixBlock ComputeEOpNodeFuse(EOpNodeEinsumFuse eOpNodeEinsumFuse, List As.add((MatrixBlock)mb.aggregateUnaryOperations(aggun, res, 0, null)); } if(As.size()>1){ - MatrixBlock mb = As.get(0); - for(int i=1;i getNewVarname = () -> "TMP" + _seqVar.getNextID(); - boolean writeToC = (isResultAB || isResultB) && As.isEmpty() && Bs.isEmpty(); + boolean writeToC = (!isResultZ) && As.isEmpty() && Bs.isEmpty(); body.append(" "); if(ABs.size()>1){ body.append("// multiplying ABs: \n "); @@ -697,21 +814,62 @@ private MatrixBlock ComputeEOpNodeFuse(EOpNodeEinsumFuse eOpNodeEinsumFuse, List String writeToVarname=null;/// = i == 1 ? null : writeToCInThisIteration ? "c" : varnameArow; String writeToVarIdx=null; if(writeToCInThisIteration){ - writeToVarname = "c"; - if(isResultAB) + if(isResultAB) { + writeToVarname = "c"; writeToVarIdx = "ai"; - else // B + }else if (isResultB) {// +// writeToVarname = varnameArow; + writeToVarname = "c"; writeToVarIdx = "0"; + }else{ //A or _ = scalar + //does not apply +// writeToVarname =varnameArow; + } }else if(i!=1){ writeToVarname =varnameArow; writeToVarIdx = "0"; } - if(writeToVarname != null) { + if(writeToCInThisIteration){ + if(isResultB) { + if (i == 1){ // never + body.append("LibMatrixMult.vectMultiplyAdd(a");// + }else { + body.append("LibMatrixMult.vectMultiplyAdd(");// + body.append(varnameArow); + } + } + else if(isResultAB || isResultB) { + if (i == 1){ + body.append("LibMatrixMult.vectMultiplyWrite(a"); + }else { + body.append("LibMatrixMult.vectMultiplyWrite("); + body.append(varnameArow); + } + } + else if(isResultA) { + if (i == 1) + body.append("c[rix] = LibSpoofPrimitives.dotProduct(a"); + else { + body.append("c[rix] = LibSpoofPrimitives.dotProduct("); + body.append(varnameArow); + } + }else{ // _ + if (i == 1) + body.append("c[0] += LibSpoofPrimitives.dotProduct(a"); + else { + body.append("c[0] += LibSpoofPrimitives.dotProduct("); + body.append(varnameArow); + } + } + } + else if(writeToVarname != null) { if (i == 1){ body.append("LibMatrixMult.vectMultiplyWrite(a"); } - body.append("LibMatrixMult.vectMultiplyWrite("); - body.append(varnameArow); + else { + body.append("LibMatrixMult.vectMultiplyWrite("); + body.append(varnameArow); + } }else { body.append("double[] "); body.append(varnameArow); @@ -721,7 +879,7 @@ private MatrixBlock ComputeEOpNodeFuse(EOpNodeEinsumFuse eOpNodeEinsumFuse, List midx++; body.append(midx); body.append("].values(rix),"); - if(writeToVarname != null){ + if(writeToVarname != null && !(writeToCInThisIteration && (isResultA || isResult_))){ body.append(writeToVarname); body.append(","); } @@ -735,7 +893,11 @@ private MatrixBlock ComputeEOpNodeFuse(EOpNodeEinsumFuse eOpNodeEinsumFuse, List }else{ body.append(",len);\n "); } - +//if(writeToCInThisIteration&& isResultB){ +// body.append("LibMatrixMult.vectAdd("); +// body.append(varnameArow); +// body.append(",c,0,0,len);\n "); +//} } } @@ -805,12 +967,24 @@ private MatrixBlock ComputeEOpNodeFuse(EOpNodeEinsumFuse eOpNodeEinsumFuse, List } }else{ if(writeToC) { - if(assignedArow || assignedB) { - body.append("LibMatrixMult.vectMultiplyWrite("); - body.append(resultVarname); +// if(assignedB){ +// body.append("LibMatrixMult.vectMultiplyWrite("); +// body.append(varnameB); +// body.append(","); +// +// } + if(assignedArow) { + if(isResultB) + body.append("LibMatrixMult.vectMultiplyAdd("); + else + body.append("LibMatrixMult.vectMultiplyWrite("); + body.append(varnameArow); body.append(","); }else{ - body.append("LibMatrixMult.vectMultiplyWrite(a,"); + if(isResultB) + body.append("LibMatrixMult.vectMultiplyAdd(a,"); + else + body.append("LibMatrixMult.vectMultiplyWrite(a,"); } } else if(!assignedArow && !assignedB) { @@ -861,7 +1035,7 @@ else if(!assignedArow && !assignedB) { }// else its new [] //ai: - if(!assignedArow && !assignedB) + if(!assignedArow) body.append("ai,0"); else body.append("0,0"); @@ -929,7 +1103,7 @@ else if(!assignedArow && !assignedB) { } } if(writeToC && isResult_){ - body.append("c[0] = "); + body.append("c[0] += "); body.append(resultVarname); body.append(" * b["); }else if(writeToC && isResultA){ @@ -942,7 +1116,7 @@ else if(!assignedArow && !assignedB) { } midx++; body.append(midx); - body.append("].values(0)[rix];\n "); + body.append("].values(rix)[rix];\n "); } else { boolean resultVarnameNull = resultVarname == null; @@ -991,11 +1165,11 @@ else if (writeToC && isResultAB) } } body.append("// write part: \n "); - + int zSize = -1; if(!writeToC) { if (isResultZ) { if (AZs.isEmpty()) throw new RuntimeException("Einsum runtime exception: Invalid rewrite type"); - int zSize = einc.charToDimensionSize.get(eOpNodeEinsumFuse.c1); + zSize = einc.charToDimensionSize.get(eOpNodeEinsumFuse.c1); body.append("// AZ part: \n "); @@ -1018,10 +1192,17 @@ else if (writeToC && isResultAB) body.append(midx); body.append("].values(rix),"); body.append(resultVarname); - body.append("c,rix*"); + body.append(",c,rix*"); +// body.append(",c,0,0,"); body.append(zSize); +// body.append(",c,ai,0,len);\n "); +// body.append(",c,ai,0,len);\n "); - body.append(",0,len);\n "); +// body.append(");\n "); +// body.append(",0,len);\n "); + body.append(",0,"); + body.append(zSize); + body.append(");\n "); } else if (isResultAB) { //vectWrite(double[] a, double[] c, int ai, int ci, int len) body.append("LibSpoofPrimitives.vectWrite("); @@ -1071,26 +1252,31 @@ else if (writeToC && isResultAB) body.append(";\n "); } } + //endregion src = src.replace("%TMP%", className); switch(eOpNodeEinsumFuse.einsumRewriteType){ - case AB_BA_B_XB_BX_A_XA_AX__AB -> { + case AB_BA_B_A__AB -> { src = src.replace("%TYPE%", "NO_AGG"); } - case AB_BA_B_XB_BX_A_XA_AX__B -> { + case AB_BA_B_A__B -> { src = src.replace("%TYPE%", "COL_AGG_T"); } - case AB_BA_B_XB_BX_A_XA_AX__A -> { + case AB_BA_B_A__A -> { src = src.replace("%TYPE%", "ROW_AGG"); } - case AB_BA_B_XB_BX_A_XA_AX__ -> { + case AB_BA_B_A__ -> { src = src.replace("%TYPE%", "FULL_AGG"); } - case AB_BA_B_XB_BX_A_XA_AX_AZ__Z -> { - src = src.replace("%TYPE%", "COL_AGG_T"); + case AB_BA_B_A_AZ__Z -> { +// src = src.replace("%TYPE%", "COL_AGG_T"); + src = src.replace("%TYPE%", "COL_AGG_CONST"); } } - src = src.replace("%CONST_DIM2%", "-1"); + if(isResultZ) + src = src.replace("%CONST_DIM2%",""+zSize); + else + src = src.replace("%CONST_DIM2%", "-1"); src = src.replace("%TB1%", "false"); src = src.replace("%VECT_MEM%", "0"); src = src.replace("%BODY_dense%", body); @@ -1099,7 +1285,7 @@ else if (writeToC && isResultAB) if( LOG.isTraceEnabled()) LOG.trace(CodegenUtils.printWithLineNumber(src)); Class cla = CodegenUtils.compileClass("codegen." + cnode.getClassname(), src); SpoofOperator op = CodegenUtils.createInstance(cla); -// SpoofOperator op = CodegenUtils.createInstance(AAA2.class); +// SpoofOperator op = CodegenUtils.createInstance(AAA3.class); // MatrixBlock resBlock = new MatrixBlock(); // resBlock.reset(einc.charToDimensionSize.get(eOpNodeEinsumFuse.c1), @@ -1110,8 +1296,11 @@ else if (writeToC && isResultAB) inputs.addAll(ABs); inputs.addAll(Bs); inputs.addAll(As); + if (isResultZ) + inputs.addAll(AZs); MatrixBlock out = op.execute(inputs, new ArrayList<>(), new MatrixBlock(), _numThreads); - + if( isResultA || isResultB ||isResultZ) + EnsureMatrixBlockColumnVector(out); return out; } @@ -1370,6 +1559,7 @@ else if (summingChars.contains(inputsChars.get(i).charAt(1))) { indent--; sb.append("}\n"); } + //endregion String src = CNodeCell.JAVA_TEMPLATE;// src = src.replace("%TMP%", cnode.createVarname()); src = src.replace("%TYPE%", "NO_AGG"); diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java index 5753fbbadbe..52d03725fc8 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java @@ -3982,6 +3982,25 @@ public static void vectMultiplyWrite( double[] a, double[] b, double[] c, int ai aVec.mul(bVec).intoArray(c, ci); } } + + //note: public for use by codegen for consistency + public static void vectMultiplyAdd( double[] a, double[] b, double[] c, int ai, int bi, int ci, final int len ){ + final int bn = len%vLen; + + //rest, not aligned to vLen-blocks + for( int j = 0; j < bn; j++, ai++, bi++, ci++) + c[ ci ] += a[ ai ] * b[ bi ]; + + //unrolled vLen-block (for better instruction-level parallelism) + for( int j = bn; j < len; j+=vLen, ai+=vLen, bi+=vLen, ci+=vLen) + { + DoubleVector aVec = DoubleVector.fromArray(SPECIES, a, ai); + DoubleVector bVec = DoubleVector.fromArray(SPECIES, b, bi); + DoubleVector cVec = DoubleVector.fromArray(SPECIES, c, ci); + cVec = aVec.fma(bVec, cVec); + cVec.intoArray(c, ci); + } + } public static void vectMultiplyWrite( final double[] a, double[] b, double[] c, int[] bix, final int ai, final int bi, final int ci, final int len ) { final int bn = len%8; From 54bd4ffa50a419eb0032497f2c5d9c9d00845cb8 Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Fri, 3 Oct 2025 19:57:26 +0200 Subject: [PATCH 04/13] delete einsum codegen code --- .../instructions/cp/EinsumCPInstruction.java | 709 +----------------- 1 file changed, 1 insertion(+), 708 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java index 254c9b829fe..519c1e01b8e 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java @@ -455,17 +455,7 @@ else if (n1.c2 == n2.c2) { private ArrayList executePlan(List plan, ArrayList inputs) { ArrayList res = new ArrayList<>(plan.size()); for(EOpNode p : plan){ - /// ////////////// ^^^^^^^^^^^^^^^^^^^^^ //////////////// -// if((true) && plan.size()== 1 && plan.get(0) instanceof EOpNodeBinary bin1 && bin1.operand == EBinaryOperand.Ba_Ca && bin1.right instanceof EOpNodeBinary r && r.operand == EBinaryOperand.Ba_Ca){ -// res.add(ComputeEOpNodeCodegen(p, inputs)); -//// var other = ComputeEOpNode(p, inputs); -//// ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); -//// var other2 = other.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); -// continue; -// } -// if(codegen) res.add(ComputeEOpNodeCodegen(p, inputs)); -// else - res.add(ComputeEOpNode(p, inputs)); + res.add(ComputeEOpNode(p, inputs)); } return res; } @@ -721,707 +711,10 @@ private MatrixBlock ComputeEOpNodeFuse(EOpNodeEinsumFuse eOpNodeEinsumFuse, List return List.of(mb); } - private MatrixBlock ComputeEOpNodeFuseCodegen(EOpNodeEinsumFuse eOpNodeEinsumFuse, List> mbs) { - //prepare matrices - //EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX__AB;\ - //region Description - boolean isResultAB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A__AB; - boolean isResultA = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A__A; - boolean isResultB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A__B; - boolean isResult_ = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A__; - boolean isResultZ = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A_AZ__Z; - List ABs = mbs.get(0), BAs = mbs.get(1), Bs = mbs.get(2), XBs = mbs.get(3), BXs = mbs.get(4), As = mbs.get(5), XAs = mbs.get(6), AXs = mbs.get(7); - List AZs = isResultZ ? mbs.get(8) : null; - int bSize = einc.charToDimensionSize.get(eOpNodeEinsumFuse.operands.get(0).get(0).c2); - int aSize = einc.charToDimensionSize.get(eOpNodeEinsumFuse.operands.get(0).get(0).c1); - ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); - for(MatrixBlock mb: BAs){//BA->AB - ABs.add(mb.reorgOperations(transpose, null,0,0,0)); - } - AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); - AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), _numThreads); - for(MatrixBlock mb: XBs){//XB->B - MatrixBlock res = new MatrixBlock(mb.getNumColumns(), 1, false); - Bs.add((MatrixBlock)mb.aggregateUnaryOperations(aggun, res, 0, null)); - } - for(MatrixBlock mb: XAs){//XA->A - MatrixBlock res = new MatrixBlock(mb.getNumColumns(), 1, false); - As.add((MatrixBlock)mb.aggregateUnaryOperations(aggun, res, 0, null)); - } - aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads); - for(MatrixBlock mb: BXs){//BX->B - MatrixBlock res = new MatrixBlock(mb.getNumRows(), 1, false); - As.add((MatrixBlock)mb.aggregateUnaryOperations(aggun, res, 0, null)); - } - for(MatrixBlock mb: AXs){//AX->B // todo remove all X - MatrixBlock res = new MatrixBlock(mb.getNumRows(), 1, false); - As.add((MatrixBlock)mb.aggregateUnaryOperations(aggun, res, 0, null)); - } - if(As.size()>1){ - As = MultiplyVectorsIntoOne(As, aSize); - } - - - ArrayList cnodeIn = new ArrayList<>(); - cnodeIn.add(new CNodeData(new LiteralOp(3), 0, 0, DataType.SCALAR)); - CNodeRow cnode = new CNodeRow(cnodeIn, null); - String src = -// "import org.apache.sysds.runtime.matrix.data.LibMatrixMult;\n"+ -// CNodeRow.JAVA_TEMPLATE; -// src= - "package codegen;\n" - +"import org.apache.sysds.runtime.matrix.data.LibMatrixMult;\n" - + "import org.apache.sysds.runtime.codegen.LibSpoofPrimitives;\n" - + "import org.apache.sysds.runtime.codegen.SpoofOperator.SideInput;\n" - + "import org.apache.sysds.runtime.codegen.SpoofRowwise;\n" - + "import org.apache.sysds.runtime.codegen.SpoofRowwise.RowType;\n" - + "import org.apache.sysds.runtime.data.SparseRowVector;\n" - + "import org.apache.commons.math3.util.FastMath;\n" - + "\n" - + "public final class %TMP% extends SpoofRowwise { \n" - + " public %TMP%() {\n" - + " super(RowType.%TYPE%, %CONST_DIM2%, %TB1%, %VECT_MEM%);\n" - + " }\n" - + " protected void genexec(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, long grix, int rix) { \n" - + "%BODY_dense%" - + " }\n" - + " protected void genexec(double[] avals, int[] aix, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int alen, int len, long grix, int rix) { \n" - + "%BODY_sparse%" - + " }\n" - + "}\n"; - StringBuilder body = new StringBuilder(); - int midx = -1; - boolean assignedArow = false; - String varnameArow = null; - - String className = cnode.createVarname(); - IDSequence _seqVar = new IDSequence(); - _seqVar.getNextID(); - Supplier getNewVarname = () -> "TMP" + _seqVar.getNextID(); - - boolean writeToC = (!isResultZ) && As.isEmpty() && Bs.isEmpty(); - body.append(" "); - if(ABs.size()>1){ - body.append("// multiplying ABs: \n "); - assignedArow=true; - varnameArow = getNewVarname.get(); - for(int i=1;i1){ //todo use arow name if possible - body.append("// multiplying Bs: \n "); - - assignedB=true; - varnameB = getNewVarname.get(); - body.append("double[] "); - for(int i=1;i { - src = src.replace("%TYPE%", "NO_AGG"); - } - case AB_BA_B_A__B -> { - src = src.replace("%TYPE%", "COL_AGG_T"); - } - case AB_BA_B_A__A -> { - src = src.replace("%TYPE%", "ROW_AGG"); - } - case AB_BA_B_A__ -> { - src = src.replace("%TYPE%", "FULL_AGG"); - } - case AB_BA_B_A_AZ__Z -> { -// src = src.replace("%TYPE%", "COL_AGG_T"); - src = src.replace("%TYPE%", "COL_AGG_CONST"); - } - } - if(isResultZ) - src = src.replace("%CONST_DIM2%",""+zSize); - else - src = src.replace("%CONST_DIM2%", "-1"); - src = src.replace("%TB1%", "false"); - src = src.replace("%VECT_MEM%", "0"); - src = src.replace("%BODY_dense%", body); - src = src.replace("%BODY_sparse%", "throw new RuntimeException(\"Sparse einsum not implemented\");"); - - if( LOG.isTraceEnabled()) LOG.trace(CodegenUtils.printWithLineNumber(src)); - Class cla = CodegenUtils.compileClass("codegen." + cnode.getClassname(), src); - SpoofOperator op = CodegenUtils.createInstance(cla); -// SpoofOperator op = CodegenUtils.createInstance(AAA3.class); -// MatrixBlock resBlock = new MatrixBlock(); - -// resBlock.reset(einc.charToDimensionSize.get(eOpNodeEinsumFuse.c1), -// eOpNodeEinsumFuse.c2 == null?1:einc.charToDimensionSize.get(eOpNodeEinsumFuse.c2)); - ArrayList inputs = new ArrayList<>(); -// inputs.add(resBlock); - - inputs.addAll(ABs); - inputs.addAll(Bs); - inputs.addAll(As); - if (isResultZ) - inputs.addAll(AZs); - MatrixBlock out = op.execute(inputs, new ArrayList<>(), new MatrixBlock(), _numThreads); - if( isResultA || isResultB ||isResultZ) - EnsureMatrixBlockColumnVector(out); - return out; - } - -// private MatrixBlock ComputeEOpNodeCodegen(EOpNode eOpNode, ArrayList inputs){ -// return rComputeEOpNodeCodegen(eOpNode, inputs); -//// throw new NotImplementedException(); -// } private static CNodeData MatrixBlockToCNodeData(MatrixBlock mb, int id){ return new CNodeData("ce"+id, id, mb.getNumRows(), mb.getNumColumns(), DataType.MATRIX); } - /* - private MatrixBlock rComputeEOpNodeCodegen(EOpNode eOpNode, ArrayList inputs) { - if (eOpNode instanceof EOpNodeData eOpNodeData){ - return inputs.get(eOpNodeData.matrixIdx); -// return new CNodeData("ce"+eOpNodeData.matrixIdx, eOpNodeData.matrixIdx, inputs.get(eOpNodeData.matrixIdx).getNumRows(), inputs.get(eOpNodeData.matrixIdx).getNumColumns(), DataType.MATRIX); - } - - EOpNodeBinary bin = (EOpNodeBinary) eOpNode; -// CNodeData dataLeft = null; -// if (bin.left instanceof EOpNodeData eOpNodeData) dataLeft = new CNodeData("ce"+eOpNodeData.matrixIdx, eOpNodeData.matrixIdx, inputs.get(eOpNodeData.matrixIdx).getNumRows(), inputs.get(eOpNodeData.matrixIdx).getNumColumns(), DataType.MATRIX); -// CNodeData dataRight = null; -// if (bin.right instanceof EOpNodeData eOpNodeData) dataRight = new CNodeData("ce"+eOpNodeData.matrixIdx, eOpNodeData.matrixIdx, inputs.get(eOpNodeData.matrixIdx).getNumRows(), inputs.get(eOpNodeData.matrixIdx).getNumColumns(), DataType.MATRIX); - if(bin.operand == EBinaryOperand.Ba_Ca && bin.right instanceof EOpNodeBinary r &&r.operand == EBinaryOperand.Ba_Ca ) { -// String src = """ -// package codegen; -// import org.apache.sysds.runtime.codegen.LibSpoofPrimitives; -// import org.apache.sysds.runtime.codegen.SpoofOperator.SideInput; -// import org.apache.sysds.runtime.codegen.SpoofRowwise; -// import org.apache.sysds.runtime.codegen.SpoofRowwise.RowType; -// import org.apache.sysds.runtime.data.SparseRowVector; -// import org.apache.commons.math3.util.FastMath; -// -// public final class TMP1 extends SpoofRowwise { -// public TMP1() { -// super(RowType.NO_AGG, 600, false, 1); -// } -// protected void genexec(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, long grix, int rix) { -// //System.out.println("rix:"+rix+", len:" + len+", a len:"+a.length+", b len:"+b[0].values(0).length+" "+b[1].values(0).length); -// double[] TMP1 = LibSpoofPrimitives.vectMatrixMult(a,b[0].values(0),ai,0,len); -// //System.out.println("tmp1 len:" + TMP1.length); -// //System.out.println("len2:"+b[0].values(0).length/b[0].clen); -// double[] TMP2 = LibSpoofPrimitives.vectMatrixMult(TMP1,b[1].values(0),0,0,b[0].values(0).length/b[0].clen); -// //System.out.println("tmp2 len:" + TMP2.length); -// //System.out.println("ci:" + ci+", myci:"+rix*(b[1].values(0).length/b[1].clen)); -// LibSpoofPrimitives.vectWrite(TMP2, c,ci, TMP2.length); -// } -// protected void genexec(double[] avals, int[] aix, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int alen, int len, long grix, int rix) { -// LibSpoofPrimitives.vectOuterMultAdd(avals, b[0].values(rix), c, aix, ai, b[0].pos(rix), 0, alen, len, b[0].clen); -// } -// } -// """; - long start = System.currentTimeMillis(); - - - Class cla = TMP1.class; - -// Class cla = CodegenUtils.compileClass("codegen." + "TMP1", src); - long end = System.currentTimeMillis(); - long duration = end - start; // duration in milliseconds - System.out.println("Time taken: " + duration + " ms"); - - SpoofOperator op = CodegenUtils.createInstance(cla); - MatrixBlock mb = new MatrixBlock(); - mb.reset(einc.outRows, einc.outCols , false); - mb.allocateDenseBlock(); - ArrayList scalars = new ArrayList<>(); - ArrayList mbs = new ArrayList<>(3); - mbs.add(inputs.get(((EOpNodeData) r.left).matrixIdx)); - mbs.add(inputs.get(((EOpNodeData) r.right).matrixIdx)); - mbs.add(inputs.get(((EOpNodeData) bin.left).matrixIdx)); - MatrixBlock out = op.execute(mbs, scalars, mb, _numThreads); - var tmp = bin.c1; - bin.c1=bin.c2; - bin.c2=tmp; - return out; - } - if(bin.operand == EBinaryOperand.AB_AB){ - if (bin.right instanceof EOpNodeBinary rBinary && rBinary.operand == EBinaryOperand.AB_AB){ - MatrixBlock left = rComputeEOpNodeCodegen(bin.left, inputs); - - MatrixBlock right1 = rComputeEOpNodeCodegen(((EOpNodeBinary) bin.right).left, inputs); - MatrixBlock right2 = rComputeEOpNodeCodegen(((EOpNodeBinary) bin.right).right, inputs); - - CNodeData d0 = MatrixBlockToCNodeData(left, 0); - CNodeData d1 = MatrixBlockToCNodeData(right1, 1); - CNodeData d2 = MatrixBlockToCNodeData(right2, 2); -// CNodeNary nary = new CNodeNary(cnodeIn, CNodeNary.NaryType.) - CNodeBinary rightBinary = new CNodeBinary(d1, d2, CNodeBinary.BinType.VECT_MULT); - CNodeBinary cNodeBinary = new CNodeBinary(d0, rightBinary, CNodeBinary.BinType.VECT_MULT); - ArrayList cnodeIn = new ArrayList<>(); - cnodeIn.add(d0); - cnodeIn.add(d1); - cnodeIn.add(d2); - - CNodeRow cnode = new CNodeRow(cnodeIn, cNodeBinary); - - cnode.setRowType(SpoofRowwise.RowType.NO_AGG); - cnode.renameInputs(); - - - String src = cnode.codegen(false, SpoofCompiler.GeneratorAPI.JAVA); - if( LOG.isTraceEnabled()) LOG.trace(CodegenUtils.printWithLineNumber(src)); - Class cla = CodegenUtils.compileClass("codegen." + cnode.getClassname(), src); - - SpoofOperator op = CodegenUtils.createInstance(cla); - MatrixBlock mb = new MatrixBlock(); - - ArrayList scalars = new ArrayList<>(); - ArrayList mbs = new ArrayList<>(3); - mbs.add(left); - mbs.add(right1); - mbs.add(right2); - MatrixBlock out = op.execute(mbs, scalars, mb, 6); - - return out; - } - } - - throw new NotImplementedException(); - } -*/ private void releaseMatrixInputs(ExecutionContext ec){ for (CPOperand input : _in) From 824824167754cccc485f2f5c92317d750e442ebe Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Fri, 3 Oct 2025 20:26:21 +0200 Subject: [PATCH 05/13] move eopnode compute impl. to respective classes --- .../apache/sysds/runtime/einsum/EOpNode.java | 7 + .../sysds/runtime/einsum/EOpNodeBinary.java | 167 +++++++++- .../sysds/runtime/einsum/EOpNodeData.java | 10 + .../runtime/einsum/EOpNodeEinsumFuse.java | 142 ++++++++- .../sysds/runtime/einsum/EOpNodeFused.java | 8 - .../instructions/cp/EinsumCPInstruction.java | 284 +----------------- 6 files changed, 319 insertions(+), 299 deletions(-) delete mode 100644 src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFused.java diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java index c93527f8ca5..32fe42801fc 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java @@ -1,5 +1,10 @@ package org.apache.sysds.runtime.einsum; +import org.apache.commons.logging.Log; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; + +import java.util.ArrayList; + public abstract class EOpNode { public Character c1; public Character c2; // nullable @@ -15,5 +20,7 @@ public String toString() { if(c2 == null) return c1.toString(); return c1.toString() + c2.toString(); } + + public abstract MatrixBlock computeEOpNode(ArrayList inputs, int numOfThreads, Log LOG); } diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java index 244dee347dc..50bb576c0be 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java @@ -1,6 +1,29 @@ package org.apache.sysds.runtime.einsum; +import org.apache.commons.logging.Log; +import org.apache.sysds.runtime.functionobjects.Multiply; +import org.apache.sysds.runtime.functionobjects.Plus; +import org.apache.sysds.runtime.functionobjects.ReduceAll; +import org.apache.sysds.runtime.functionobjects.ReduceCol; +import org.apache.sysds.runtime.functionobjects.ReduceRow; +import org.apache.sysds.runtime.functionobjects.SwapIndex; +import org.apache.sysds.runtime.instructions.cp.DoubleObject; +import org.apache.sysds.runtime.instructions.cp.ScalarObject; +import org.apache.sysds.runtime.matrix.data.LibMatrixMult; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.AggregateOperator; +import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; +import org.apache.sysds.runtime.matrix.operators.BinaryOperator; +import org.apache.sysds.runtime.matrix.operators.ReorgOperator; +import org.apache.sysds.runtime.matrix.operators.SimpleOperator; + +import java.util.ArrayList; + +import static org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction.ensureMatrixBlockColumnVector; +import static org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction.ensureMatrixBlockRowVector; + public class EOpNodeBinary extends EOpNode { + public enum EBinaryOperand { // upper case: char has to remain, lower case: to be summed ////// summations: ////// aB_a,// -> B @@ -31,13 +54,145 @@ public enum EBinaryOperand { // upper case: char has to remain, lower case: to b AB_scalar, // m-scalar scalar_scalar } - public EOpNode left; - public EOpNode right; - public EBinaryOperand operand; + public EOpNode _left; + public EOpNode _right; + public EBinaryOperand _operand; public EOpNodeBinary(Character c1, Character c2, EOpNode left, EOpNode right, EBinaryOperand operand){ super(c1,c2); - this.left = left; - this.right = right; - this.operand = operand; + this._left = left; + this._right = right; + this._operand = operand; } + + @Override + public MatrixBlock computeEOpNode(ArrayList inputs, int numThreads, Log LOG) { + EOpNodeBinary bin = this; + MatrixBlock left = _left.computeEOpNode(inputs, numThreads, LOG); + MatrixBlock right = _right.computeEOpNode(inputs, numThreads, LOG); + + AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); + + MatrixBlock res; + + LOG.trace("computing binary "+bin._left +","+bin._right +"->"+bin); + + switch (bin._operand){ + case AB_AB -> { + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + } + case A_A -> { + ensureMatrixBlockColumnVector(left); + ensureMatrixBlockColumnVector(right); + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + } + case a_a -> { + ensureMatrixBlockColumnVector(left); + ensureMatrixBlockColumnVector(right); + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), numThreads); + res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); + } + //////////// + case Ba_Ba -> { + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), numThreads); + res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); + } + case aB_aB -> { + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), numThreads); + res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); + ensureMatrixBlockColumnVector(res); + } + case ab_ab -> { + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), numThreads); + res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); + } + case ab_ba -> { + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads); + right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), numThreads); + res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); + } + case Ba_aB -> { + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads); + right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), numThreads); + res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); + } + case aB_Ba -> { + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads); + left = left.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), numThreads); + res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); + } + + ///////// + case AB_BA -> { + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads); + right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); + } + case Ba_aC -> { + res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), numThreads); + } + case aB_Ca -> { + res = LibMatrixMult.matrixMult(right,left, new MatrixBlock(), numThreads); + } + case Ba_Ca -> { + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads); + right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); + res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), numThreads); + } + case aB_aC -> { + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads); + left = left.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); + res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), numThreads); + } + case A_scalar, AB_scalar -> { + res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left},new ScalarObject[]{new DoubleObject(right.get(0,0))}, new MatrixBlock()); + } + case BA_A -> { + ensureMatrixBlockRowVector(right); + res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); + } + case Ba_a -> { + ensureMatrixBlockRowVector(right); + res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), numThreads); + res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); + } + + case AB_A -> { + ensureMatrixBlockColumnVector(right); + res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); + } + case aB_a -> { + ensureMatrixBlockColumnVector(right); + res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), numThreads); + res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); + ensureMatrixBlockColumnVector(res); + } + + case A_B -> { + ensureMatrixBlockColumnVector(left); + ensureMatrixBlockRowVector(right); + res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); + } + case scalar_scalar -> { + return new MatrixBlock(left.get(0,0)*right.get(0,0)); + } + default -> { + throw new IllegalArgumentException("Unexpected value: " + bin._operand.toString()); + } + + } + return res; + } + } diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java index e7b75236eda..50a352be6c2 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java @@ -1,9 +1,19 @@ package org.apache.sysds.runtime.einsum; +import org.apache.commons.logging.Log; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; + +import java.util.ArrayList; + public class EOpNodeData extends EOpNode { public int matrixIdx; public EOpNodeData(Character c1, Character c2, int matrixIdx){ super(c1,c2); this.matrixIdx = matrixIdx; } + + @Override + public MatrixBlock computeEOpNode(ArrayList inputs, int numOfThreads, Log LOG) { + return inputs.get(matrixIdx); + } } diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeEinsumFuse.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeEinsumFuse.java index ae10c93115c..583b6ff8bff 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeEinsumFuse.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeEinsumFuse.java @@ -1,24 +1,30 @@ package org.apache.sysds.runtime.einsum; +import org.apache.commons.logging.Log; +import org.apache.sysds.runtime.codegen.SpoofRowwise; +import org.apache.sysds.runtime.functionobjects.Plus; +import org.apache.sysds.runtime.functionobjects.ReduceCol; +import org.apache.sysds.runtime.functionobjects.ReduceRow; +import org.apache.sysds.runtime.functionobjects.SwapIndex; import org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction; import org.apache.sysds.runtime.matrix.data.LibMatrixMult; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.AggregateOperator; +import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; +import org.apache.sysds.runtime.matrix.operators.ReorgOperator; +import org.jetbrains.annotations.NotNull; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.List; +import java.util.stream.Collectors; + +import static org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction.ensureMatrixBlockColumnVector; public class EOpNodeEinsumFuse extends EOpNode { - public static final int AB_index=0; - public static final int BA_index=1; - public static final int B_index=2; - public static final int XB_index=3; - public static final int BX_index=4; - public static final int A_index=5; - public static final int XA_index=6; - public static final int AX_index=7; - public static final int AZ_index=8; + public enum EinsumRewriteType{ // B -> row*row, A -> row*scalar AB_BA_B_A__AB, @@ -299,5 +305,123 @@ else if (EinsumCPInstruction.FUSE_OUTER_MULTIPLY) { ret.add(e); return e; } + + @Override + public MatrixBlock computeEOpNode(ArrayList inputs, int numThreads, Log LOG) { + List> mbs = operands.stream().map(l -> l.stream().map(n -> n.computeEOpNode(inputs, numThreads, LOG)).collect(Collectors.toList())).toList(); + var eOpNodeEinsumFuse = this; + + if( LOG.isTraceEnabled()) { + String x = eOpNodeEinsumFuse.operands.stream() + .flatMap(List::stream) + .map(o -> o.c1.toString() + (o.c2 == null ? "" : o.c2)) + .collect(Collectors.joining(",")); + String res = (eOpNodeEinsumFuse.c1 == null ? "" : eOpNodeEinsumFuse.c1.toString())+(eOpNodeEinsumFuse.c2 == null ? "" : eOpNodeEinsumFuse.c2.toString()); + LOG.trace("ComputeEOpNodeFuse " + eOpNodeEinsumFuse.einsumRewriteType.toString() +" "+ x + " -> " + res); + } + boolean isResultAB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A__AB; + boolean isResultA = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A__A; + boolean isResultB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A__B; + boolean isResult_ = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A__; + boolean isResultZ = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A_AZ__Z; + boolean isResultBZ = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A_AZ__BZ; + boolean isResultZB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A_AZ__ZB; +// boolean isResultBC = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A_AC__BC; +// boolean isResultCB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A_AZ__ZB; + List ABs = mbs.get(0), BAs = mbs.get(1), Bs = mbs.get(2), XBs = mbs.get(3), BXs = mbs.get(4), As = mbs.get(5), XAs = mbs.get(6), AXs = mbs.get(7); + List AZs = mbs.get(8); + List Zs = mbs.get(9); +// List ACs = isResultBC || isResultCB ? mbs.get(10) : null; + int bSize = ABs.get(0).getNumColumns(); + int aSize = ABs.get(0).getNumRows(); + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads); + for(MatrixBlock mb: BAs){//BA->AB + ABs.add(mb.reorgOperations(transpose, null,0,0,0)); + } + AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), numThreads); + for(MatrixBlock mb: XBs){//XB->B + MatrixBlock res = new MatrixBlock(mb.getNumColumns(), 1, false); + Bs.add((MatrixBlock)mb.aggregateUnaryOperations(aggun, res, 0, null)); + } + for(MatrixBlock mb: XAs){//XA->A + MatrixBlock res = new MatrixBlock(mb.getNumColumns(), 1, false); + As.add((MatrixBlock)mb.aggregateUnaryOperations(aggun, res, 0, null)); + } + aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), numThreads); + for(MatrixBlock mb: BXs){//BX->B + MatrixBlock res = new MatrixBlock(mb.getNumRows(), 1, false); + As.add((MatrixBlock)mb.aggregateUnaryOperations(aggun, res, 0, null)); + } + for(MatrixBlock mb: AXs){//AX->B // todo remove all X + MatrixBlock res = new MatrixBlock(mb.getNumRows(), 1, false); + As.add((MatrixBlock)mb.aggregateUnaryOperations(aggun, res, 0, null)); + } + if(As.size()>1){ + As = multiplyVectorsIntoOne(As, aSize); + } + if(Bs.size() > 1){ + Bs = multiplyVectorsIntoOne(Bs, bSize); + } + if(Zs != null && Zs.size() > 1){ + Zs = multiplyVectorsIntoOne(Zs, AZs.get(0).getNumColumns()); + } + int constDim2 = -1; + int zSize = 0; + int azCount = 0; + int zCount = 0; + switch(eOpNodeEinsumFuse.einsumRewriteType){ + case AB_BA_B_A_AZ__Z -> { + constDim2 = AZs.get(0).getNumColumns(); + zSize = AZs.get(0).getNumColumns(); + azCount = AZs.size(); + if (Zs != null) zCount = Zs.size(); + } + case AB_BA_B_A_AZ__BZ, AB_BA_B_A_AZ__ZB -> { + constDim2 = AZs.get(0).getNumColumns(); + zSize = AZs.get(0).getNumColumns(); + azCount = AZs.size(); + } + } + + SpoofRowwise.RowType rowType = switch(eOpNodeEinsumFuse.einsumRewriteType){ + case AB_BA_B_A__AB -> SpoofRowwise.RowType.NO_AGG; + case AB_BA_B_A__B -> SpoofRowwise.RowType.COL_AGG_T; + case AB_BA_B_A__A -> SpoofRowwise.RowType.ROW_AGG; + case AB_BA_B_A__ -> SpoofRowwise.RowType.FULL_AGG; + case AB_BA_B_A_AZ__Z -> SpoofRowwise.RowType.COL_AGG_CONST; + case AB_BA_B_A_AZ__BZ -> SpoofRowwise.RowType.COL_AGG_B1_T; + case AB_BA_B_A_AZ__ZB -> SpoofRowwise.RowType.COL_AGG_B1; + }; + EinsumSpoofRowwise r = new EinsumSpoofRowwise(eOpNodeEinsumFuse.einsumRewriteType, rowType, constDim2, false, 0, ABs.size()-1,Bs.size(), As.size(), zCount, azCount, zSize); + + + ArrayList fuseInputs = new ArrayList<>(); +// inputs.add(resBlock); + + fuseInputs.addAll(ABs); + fuseInputs.addAll(Bs); + fuseInputs.addAll(As); + if (isResultZ || isResultBZ || isResultZB) + fuseInputs.addAll(AZs); + MatrixBlock out = r.execute(fuseInputs, new ArrayList<>(), new MatrixBlock(), numThreads); + if( isResultA || isResultB || isResultZ) + ensureMatrixBlockColumnVector(out); + return out; + + } + + private static @NotNull List multiplyVectorsIntoOne(List mbs, int size) { + MatrixBlock mb = new MatrixBlock(mbs.get(0).getNumRows(), mbs.get(0).getNumColumns(), false); + mb.allocateDenseBlock(); + for(int i = 1; i< mbs.size(); i++) { // multiply Bs + if(i==1){ + LibMatrixMult.vectMultiplyWrite(mbs.get(0).getDenseBlock().values(0), mbs.get(1).getDenseBlock().values(0), mb.getDenseBlock().values(0),0,0,0, size); + }else{ + LibMatrixMult.vectMultiply(mbs.get(i).getDenseBlock().values(0),mb.getDenseBlock().values(0),0,0, size); + } + } + return List.of(mb); + } } diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFused.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFused.java deleted file mode 100644 index 4554b9c8334..00000000000 --- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFused.java +++ /dev/null @@ -1,8 +0,0 @@ -package org.apache.sysds.runtime.einsum; - -public class EOpNodeFused extends EOpNode { - public EOpNodeFused(Character c1, Character c2){ - super(c1,c2); - - } -} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java index 519c1e01b8e..aba60b6b333 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java @@ -31,28 +31,20 @@ import org.apache.sysds.hops.codegen.cplan.CNode; import org.apache.sysds.hops.codegen.cplan.CNodeCell; import org.apache.sysds.hops.codegen.cplan.CNodeData; -import org.apache.sysds.hops.codegen.cplan.CNodeRow; import org.apache.sysds.runtime.codegen.*; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; import org.apache.sysds.runtime.einsum.*; import org.apache.sysds.runtime.einsum.EOpNodeBinary.EBinaryOperand; import org.apache.sysds.runtime.functionobjects.*; -import org.apache.sysds.runtime.matrix.data.LibMatrixMult; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.AggregateOperator; import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; -import org.apache.sysds.runtime.matrix.operators.BinaryOperator; import org.apache.sysds.runtime.matrix.operators.Operator; import org.apache.sysds.runtime.matrix.operators.ReorgOperator; -import org.apache.sysds.runtime.matrix.operators.SimpleOperator; -import org.jetbrains.annotations.NotNull; import java.util.*; import java.util.function.Predicate; -import java.util.function.Supplier; -import java.util.stream.Collectors; public class EinsumCPInstruction extends BuiltinNaryCPInstruction { public static final boolean FORCE_CELL_TPL = false; @@ -67,7 +59,6 @@ public EinsumCPInstruction(Operator op, String opcode, String istr, CPOperand ou { super(op, opcode, istr, out, inputs); _numThreads = OptimizerUtils.getConstrainedNumThreads(-1)/2; -// _numThreads = 6; _in = inputs; this.eqStr = inputs[0].getName(); Logger.getLogger(EinsumCPInstruction.class).setLevel(Level.TRACE); @@ -88,7 +79,7 @@ public void processInstruction(ExecutionContext ec) { mb = ((CompressedMatrixBlock) mb).getUncompressed("Spoof instruction"); } if(mb.getNumRows() == 1){ - EnsureMatrixBlockColumnVector(mb); + ensureMatrixBlockColumnVector(mb); } inputs.add(mb); } @@ -109,7 +100,7 @@ public void processInstruction(ExecutionContext ec) { //make all vetors col vectors for(int i = 0; i < inputs.size(); i++){ - if(inputs.get(i) != null && inputsChars.get(i).length() == 1) EnsureMatrixBlockColumnVector(inputs.get(i)); + if(inputs.get(i) != null && inputsChars.get(i).length() == 1) ensureMatrixBlockColumnVector(inputs.get(i)); } if(LOG.isTraceEnabled()) for(Character c : einc.characterAppearanceIndexes.keySet()){ @@ -159,7 +150,6 @@ public void processInstruction(ExecutionContext ec) { ArrayList resMatrices = FORCE_CELL_TPL ? null : executePlan(plan.getRight(), inputs); -// ArrayList resMatrices = executePlan(plan.getRight(), inputs, true); if(!FORCE_CELL_TPL && resMatrices.size() == 1){ EOpNode resNode = plan.getRight().get(0); @@ -258,7 +248,8 @@ private void contractDimensionsAndComputeDiagonals(EinsumContext einc, ArrayList } } - private Pair /* ideally with one element */> generatePlan(int cost, ArrayList operands, HashMap charToSizeMap, HashMap charToOccurences, Character outChar1, Character outChar2, boolean fused) { + // ideally the return list contains only one final element + private Pair> generatePlan(int cost, ArrayList operands, HashMap charToSizeMap, HashMap charToOccurences, Character outChar1, Character outChar2, boolean fused) { Integer minCost = cost; List minNodes = operands; @@ -449,287 +440,28 @@ else if (n1.c2 == n2.c2) { } } -// private ArrayList executePlan(List plan, ArrayList inputs){ -// return executePlan(plan, inputs); -// } - private ArrayList executePlan(List plan, ArrayList inputs) { + private ArrayList executePlan(List plan, ArrayList inputs) { ArrayList res = new ArrayList<>(plan.size()); for(EOpNode p : plan){ - res.add(ComputeEOpNode(p, inputs)); + res.add(p.computeEOpNode(inputs, _numThreads, LOG)); } return res; } - private MatrixBlock ComputeEOpNode(EOpNode eOpNode, ArrayList inputs){ - if(eOpNode instanceof EOpNodeData eOpNodeData){ - return inputs.get(eOpNodeData.matrixIdx); - }else if(eOpNode instanceof EOpNodeEinsumFuse eOpNodeEinsumFuse){ - var mbs = eOpNodeEinsumFuse.operands.stream().map(l->l.stream().map(n->ComputeEOpNode(n, inputs)).collect(Collectors.toList())).toList(); - return ComputeEOpNodeFuse(eOpNodeEinsumFuse, mbs); - } - EOpNodeBinary bin = (EOpNodeBinary) eOpNode; - MatrixBlock left = ComputeEOpNode(bin.left, inputs); - MatrixBlock right = ComputeEOpNode(bin.right, inputs); - - AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); - - MatrixBlock res; - - LOG.trace("computing binary "+bin.left+","+bin.right+"->"+bin); - - switch (bin.operand){ - case AB_AB -> { - res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); - } - case A_A -> { - EnsureMatrixBlockColumnVector(left); - EnsureMatrixBlockColumnVector(right); - res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); - } - case a_a -> { - EnsureMatrixBlockColumnVector(left); - EnsureMatrixBlockColumnVector(right); - res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); - AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), _numThreads); - res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); - } - //////////// - case Ba_Ba -> { - res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); - AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads); - res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); - } - case aB_aB -> { - res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); - AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), _numThreads); - res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); - EnsureMatrixBlockColumnVector(res); - } - case ab_ab -> { - res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); - AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), _numThreads); - res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); - } - case ab_ba -> { - ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); - right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); - res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); - AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), _numThreads); - res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); - } - case Ba_aB -> { - ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); - right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); - res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); - AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads); - res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); - } - case aB_Ba -> { - ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); - left = left.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); - res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); - AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads); - res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); - } - - ///////// - case AB_BA -> { - ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); - right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); - res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); - } - case Ba_aC -> { - res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), _numThreads); - } - case aB_Ca -> { - res = LibMatrixMult.matrixMult(right,left, new MatrixBlock(), _numThreads); - } - case Ba_Ca -> { - ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); - right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); - res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), _numThreads); - } - case aB_aC -> { - ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); - left = left.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); - res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), _numThreads); - } - case A_scalar, AB_scalar -> { - res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left},new ScalarObject[]{new DoubleObject(right.get(0,0))}, new MatrixBlock()); - } - case BA_A -> { - EnsureMatrixBlockRowVector(right); - res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); - } - case Ba_a -> { - EnsureMatrixBlockRowVector(right); - res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); - AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads); - res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); - } - - case AB_A -> { - EnsureMatrixBlockColumnVector(right); - res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); - } - case aB_a -> { - EnsureMatrixBlockColumnVector(right); - res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); - AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), _numThreads); - res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); - EnsureMatrixBlockColumnVector(res); - } - - case A_B -> { - EnsureMatrixBlockColumnVector(left); - EnsureMatrixBlockRowVector(right); - res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); - } - case scalar_scalar -> { - return new MatrixBlock(left.get(0,0)*right.get(0,0)); - } - default -> { - throw new IllegalArgumentException("Unexpected value: " + bin.operand.toString()); - } - - } - return res; - } - private MatrixBlock ComputeEOpNodeFuse(EOpNodeEinsumFuse eOpNodeEinsumFuse, List> mbs) { - if( LOG.isTraceEnabled()) { - String x = eOpNodeEinsumFuse.operands.stream() - .flatMap(List::stream) - .map(o -> o.c1.toString() + (o.c2 == null ? "" : o.c2)) - .collect(Collectors.joining(",")); - String res = (eOpNodeEinsumFuse.c1 == null ? "" : eOpNodeEinsumFuse.c1.toString())+(eOpNodeEinsumFuse.c2 == null ? "" : eOpNodeEinsumFuse.c2.toString()); - LOG.trace("ComputeEOpNodeFuse " + eOpNodeEinsumFuse.einsumRewriteType.toString() +" "+ x + " -> " + res); - } - boolean isResultAB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A__AB; - boolean isResultA = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A__A; - boolean isResultB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A__B; - boolean isResult_ = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A__; - boolean isResultZ = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A_AZ__Z; - boolean isResultBZ = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A_AZ__BZ; - boolean isResultZB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A_AZ__ZB; -// boolean isResultBC = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A_AC__BC; -// boolean isResultCB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A_AZ__ZB; - List ABs = mbs.get(0), BAs = mbs.get(1), Bs = mbs.get(2), XBs = mbs.get(3), BXs = mbs.get(4), As = mbs.get(5), XAs = mbs.get(6), AXs = mbs.get(7); - List AZs = mbs.get(8); - List Zs = mbs.get(9); -// List ACs = isResultBC || isResultCB ? mbs.get(10) : null; - int bSize = einc.charToDimensionSize.get(eOpNodeEinsumFuse.operands.get(0).get(0).c2); - int aSize = einc.charToDimensionSize.get(eOpNodeEinsumFuse.operands.get(0).get(0).c1); - ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); - for(MatrixBlock mb: BAs){//BA->AB - ABs.add(mb.reorgOperations(transpose, null,0,0,0)); - } - AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); - AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), _numThreads); - for(MatrixBlock mb: XBs){//XB->B - MatrixBlock res = new MatrixBlock(mb.getNumColumns(), 1, false); - Bs.add((MatrixBlock)mb.aggregateUnaryOperations(aggun, res, 0, null)); - } - for(MatrixBlock mb: XAs){//XA->A - MatrixBlock res = new MatrixBlock(mb.getNumColumns(), 1, false); - As.add((MatrixBlock)mb.aggregateUnaryOperations(aggun, res, 0, null)); - } - aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads); - for(MatrixBlock mb: BXs){//BX->B - MatrixBlock res = new MatrixBlock(mb.getNumRows(), 1, false); - As.add((MatrixBlock)mb.aggregateUnaryOperations(aggun, res, 0, null)); - } - for(MatrixBlock mb: AXs){//AX->B // todo remove all X - MatrixBlock res = new MatrixBlock(mb.getNumRows(), 1, false); - As.add((MatrixBlock)mb.aggregateUnaryOperations(aggun, res, 0, null)); - } - if(As.size()>1){ - As = MultiplyVectorsIntoOne(As, aSize); - } - if(Bs.size() > 1){ - Bs = MultiplyVectorsIntoOne(Bs, bSize); - } - if(Zs != null && Zs.size() > 1){ - Zs = MultiplyVectorsIntoOne(Zs, AZs.get(0).getNumColumns()); - } - int constDim2 = -1; - int zSize = 0; - int azCount = 0; - int zCount = 0; - switch(eOpNodeEinsumFuse.einsumRewriteType){ - case AB_BA_B_A_AZ__Z -> { - constDim2 = AZs.get(0).getNumColumns(); - zSize = AZs.get(0).getNumColumns(); - azCount = AZs.size(); - if (Zs != null) zCount = Zs.size(); - } - case AB_BA_B_A_AZ__BZ, AB_BA_B_A_AZ__ZB -> { - constDim2 = AZs.get(0).getNumColumns(); - zSize = AZs.get(0).getNumColumns(); - azCount = AZs.size(); - } - } - - SpoofRowwise.RowType rowType = switch(eOpNodeEinsumFuse.einsumRewriteType){ - case AB_BA_B_A__AB -> SpoofRowwise.RowType.NO_AGG; - case AB_BA_B_A__B -> SpoofRowwise.RowType.COL_AGG_T; - case AB_BA_B_A__A -> SpoofRowwise.RowType.ROW_AGG; - case AB_BA_B_A__ -> SpoofRowwise.RowType.FULL_AGG; - case AB_BA_B_A_AZ__Z -> SpoofRowwise.RowType.COL_AGG_CONST; - case AB_BA_B_A_AZ__BZ -> SpoofRowwise.RowType.COL_AGG_B1_T; - case AB_BA_B_A_AZ__ZB -> SpoofRowwise.RowType.COL_AGG_B1; - }; - EinsumSpoofRowwise r = new EinsumSpoofRowwise(eOpNodeEinsumFuse.einsumRewriteType, rowType, constDim2, false, 0, ABs.size()-1,Bs.size(), As.size(), zCount, azCount, zSize); - - - ArrayList inputs = new ArrayList<>(); -// inputs.add(resBlock); - - inputs.addAll(ABs); - inputs.addAll(Bs); - inputs.addAll(As); - if (isResultZ || isResultBZ || isResultZB) - inputs.addAll(AZs); - MatrixBlock out = r.execute(inputs, new ArrayList<>(), new MatrixBlock(), _numThreads); - if( isResultA || isResultB || isResultZ) - EnsureMatrixBlockColumnVector(out); - return out; - - - } - - private static @NotNull List MultiplyVectorsIntoOne(List mbs, int size) { - MatrixBlock mb = new MatrixBlock(mbs.get(0).getNumRows(), mbs.get(0).getNumColumns(), false); - mb.allocateDenseBlock(); - MatrixBlock l = mbs.get(0); - for(int i = 1; i< mbs.size(); i++) { // multiply Bs - if(i==1){ - LibMatrixMult.vectMultiplyWrite(mbs.get(0).getDenseBlock().values(0), mbs.get(1).getDenseBlock().values(0), mb.getDenseBlock().values(0),0,0,0, size); - }else{ - LibMatrixMult.vectMultiply(mbs.get(i).getDenseBlock().values(0),mb.getDenseBlock().values(0),0,0, size); - } - } - return List.of(mb); - } - - - private static CNodeData MatrixBlockToCNodeData(MatrixBlock mb, int id){ - return new CNodeData("ce"+id, id, mb.getNumRows(), mb.getNumColumns(), DataType.MATRIX); - } - private void releaseMatrixInputs(ExecutionContext ec){ for (CPOperand input : _in) if(input.getDataType()==DataType.MATRIX) ec.releaseMatrixInput(input.getName()); //todo release other } - private static void EnsureMatrixBlockColumnVector(MatrixBlock mb){ + public static void ensureMatrixBlockColumnVector(MatrixBlock mb){ if(mb.getNumColumns() > 1){ mb.setNumRows(mb.getNumColumns()); mb.setNumColumns(1); mb.getDenseBlock().resetNoFill(mb.getNumRows(),1); } } - private static void EnsureMatrixBlockRowVector(MatrixBlock mb){ + public static void ensureMatrixBlockRowVector(MatrixBlock mb){ if(mb.getNumRows() > 1){ mb.setNumColumns(mb.getNumRows()); mb.setNumRows(1); From 8cb77dc6802424e487d87074a0ea29b983a49ca8 Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Thu, 30 Oct 2025 23:49:26 +0100 Subject: [PATCH 06/13] minor changes --- .../apache/sysds/runtime/einsum/EOpNode.java | 2 + .../sysds/runtime/einsum/EOpNodeBinary.java | 49 +++++++-- .../sysds/runtime/einsum/EOpNodeData.java | 5 + ...OpNodeEinsumFuse.java => EOpNodeFuse.java} | 99 ++++++++++++------- .../runtime/einsum/EinsumSpoofRowwise.java | 14 +-- .../instructions/cp/EinsumCPInstruction.java | 84 +++++++++------- 6 files changed, 167 insertions(+), 86 deletions(-) rename src/main/java/org/apache/sysds/runtime/einsum/{EOpNodeEinsumFuse.java => EOpNodeFuse.java} (81%) diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java index 32fe42801fc..bdaf0d31c85 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java @@ -22,5 +22,7 @@ public String toString() { } public abstract MatrixBlock computeEOpNode(ArrayList inputs, int numOfThreads, Log LOG); + + public abstract void reorderChildren(Character outChar1, Character outChar2); } diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java index 50bb576c0be..f0abfaf3ed2 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java @@ -1,6 +1,7 @@ package org.apache.sysds.runtime.einsum; import org.apache.commons.logging.Log; +import org.apache.sysds.runtime.codegen.LibSpoofPrimitives; import org.apache.sysds.runtime.functionobjects.Multiply; import org.apache.sysds.runtime.functionobjects.Plus; import org.apache.sysds.runtime.functionobjects.ReduceAll; @@ -74,7 +75,7 @@ public MatrixBlock computeEOpNode(ArrayList inputs, int numThreads, MatrixBlock res; - LOG.trace("computing binary "+bin._left +","+bin._right +"->"+bin); + if(LOG.isTraceEnabled()) LOG.trace("computing binary "+bin._left +","+bin._right +"->"+bin); switch (bin._operand){ case AB_AB -> { @@ -88,11 +89,10 @@ public MatrixBlock computeEOpNode(ArrayList inputs, int numThreads, case a_a -> { ensureMatrixBlockColumnVector(left); ensureMatrixBlockColumnVector(right); - res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); - AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), numThreads); - res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); + res = new MatrixBlock(0.0); + res.allocateDenseBlock(); + res.getDenseBlockValues()[0] = LibMatrixMult.dotProduct(left.getDenseBlockValues(), right.getDenseBlockValues(), 0,0 , left.getNumRows()); } - //////////// case Ba_Ba -> { res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), numThreads); @@ -131,7 +131,6 @@ public MatrixBlock computeEOpNode(ArrayList inputs, int numThreads, res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); } - ///////// case AB_BA -> { ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads); right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); @@ -149,9 +148,22 @@ public MatrixBlock computeEOpNode(ArrayList inputs, int numThreads, res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), numThreads); } case aB_aC -> { - ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads); - left = left.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); - res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), numThreads); + if(false && LibMatrixMult.isSkinnyRightHandSide(left.getNumRows(), left.getNumColumns(), right.getNumRows(), right.getNumColumns(), true)){ + res = new MatrixBlock(left.getNumColumns(), right.getNumColumns(),false); + res.allocateDenseBlock(); + double[] m1 = left.getDenseBlock().values(0); + double[] m2 = right.getDenseBlock().values(0); + double[] c = res.getDenseBlock().values(0); + int alen = left.getNumColumns(); + int blen = right.getNumColumns(); + for(int i =0;i { res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left},new ScalarObject[]{new DoubleObject(right.get(0,0))}, new MatrixBlock()); @@ -195,4 +207,23 @@ public MatrixBlock computeEOpNode(ArrayList inputs, int numThreads, return res; } + @Override + public void reorderChildren(Character outChar1, Character outChar2) { + if (this._operand==EBinaryOperand.aB_aC){ + if(this._right.c2 == outChar1) { + var tmp = _left; + _left = _right; + _right = tmp; + var tmp2 = c1; + c1 = c2; + c2 = tmp2; + } + _left.reorderChildren(_left.c2, _left.c1); + // check if change happened: + if(_left.c2 == _right.c1) { + this._operand = EBinaryOperand.Ba_aC; + } + } + } + } diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java index 50a352be6c2..43feb5cd7cb 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java @@ -16,4 +16,9 @@ public EOpNodeData(Character c1, Character c2, int matrixIdx){ public MatrixBlock computeEOpNode(ArrayList inputs, int numOfThreads, Log LOG) { return inputs.get(matrixIdx); } + + @Override + public void reorderChildren(Character outChar1, Character outChar2) { + + } } diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeEinsumFuse.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java similarity index 81% rename from src/main/java/org/apache/sysds/runtime/einsum/EOpNodeEinsumFuse.java rename to src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java index 583b6ff8bff..eca0c465e95 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeEinsumFuse.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java @@ -23,7 +23,7 @@ import static org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction.ensureMatrixBlockColumnVector; -public class EOpNodeEinsumFuse extends EOpNode { +public class EOpNodeFuse extends EOpNode { public enum EinsumRewriteType{ // B -> row*row, A -> row*scalar @@ -47,13 +47,13 @@ public enum EinsumRewriteType{ public final EinsumRewriteType einsumRewriteType; public final List> operands; - private EOpNodeEinsumFuse(Character c1, Character c2, EinsumRewriteType einsumRewriteType, List... operands) { + private EOpNodeFuse(Character c1, Character c2, EinsumRewriteType einsumRewriteType, List... operands) { super(c1,c2); this.einsumRewriteType = einsumRewriteType; this.operands = Arrays.asList(operands); } - public static EOpNodeEinsumFuse match(ArrayList operands, Character outChar1, Character outChar2,/*, Set simplySummableChars,*/ ArrayList ret, HashMap charToOccurences, HashMap charToSize){ + public static EOpNodeFuse match(ArrayList operands, Character outChar1, Character outChar2,/*, Set simplySummableChars,*/ ArrayList ret, HashMap charToOccurences, HashMap charToSize){ //precompute HashSet matricesChars = new HashSet<>(); HashMap> charsToMatrices = new HashMap<>(); @@ -131,27 +131,29 @@ public static EOpNodeEinsumFuse match(ArrayList operands, Character out } else if(chars.charAt(0)==b){ // BZ, todo, maybe transpose ab into ba - pass = false; - break; +// pass = false; +// break; } else if(chars.charAt(1)==a){ //ZA, maybe its small enough that it can be tranposed? but then not impactful as the bigger A, the more sense to fuse AZ? - pass = false; - break; +// pass = false; +// break; } else if(chars.charAt(1)==b){ // ZB - pass = false; - break; +// pass = false; +// break; } } } - if(pass){ + + if(pass){ // final checks for current AB candidate AB = ABcandidate; String A = ""+a; String B = ""+b; - int ABsCounter = charsToMatrices.get(ABcandidate).size()+(charsToMatrices.containsKey(BA) ? charsToMatrices.get(BA).size() : 0); + int BAsCounter = (charsToMatrices.containsKey(BA) ? charsToMatrices.get(BA).size() : 0); + int ABsCounter = charsToMatrices.get(ABcandidate).size()+BAsCounter; // int AZsCounter = AZs.size(); int AsCounter = (charsToMatrices.containsKey(A) ? charsToMatrices.get(A).size() : 0); int BsCounter = (charsToMatrices.containsKey(B) ? charsToMatrices.get(B).size() : 0); @@ -162,18 +164,27 @@ else if(chars.charAt(1)==b){ int usedBsCount = BsCounter+ABsCounter; doSumB = charToOccurences.get(b)==usedBsCount && (outChar1 == null || b!=outChar1) && (outChar2 == null || b!=outChar2); + boolean includeAz = true; if(AZCandidates.size()==1){ // if(!doSumB){ // pass=false; // continue; // } - int usedAsCount = AsCounter+ABsCounter+AZsCounter; - doSumA = charToOccurences.get(a)==usedAsCount && (outChar1 == null || a!=outChar1) && (outChar2 == null || a!=outChar2); - if(!doSumA){ // cant do AZ - break;// just do AB,B,A ->AB / A - }else { - AZs = charsToMatrices.get(AZCandidates.iterator().next()); - break;//ok + if(!doSumB) { + // check if outer is possible AB,...,AZ->BZ + if(!EinsumCPInstruction.FUSE_OUTER_MULTIPLY || !LibMatrixMult.isSkinnyRightHandSide(charToSize.get(AB.charAt(0)), charToSize.get(AB.charAt(1)), charToSize.get(AB.charAt(0)),charToSize.get(AZCandidates.iterator().next().charAt(1)),false)) { + includeAz=false; + } + } + if(includeAz){ + int usedAsCount = AsCounter+ABsCounter+AZsCounter; + doSumA = charToOccurences.get(a)==usedAsCount && (outChar1 == null || a!=outChar1) && (outChar2 == null || a!=outChar2); + if(!doSumA){ // cant do AZ + break;// just do AB,B,A ->AB / A + }else { + AZs = charsToMatrices.get(AZCandidates.iterator().next()); + break;//ok + } } } else if (AZCandidates.size()>=2) { doSumA = false; @@ -195,6 +206,17 @@ else if(chars.charAt(1)==b){ if(!pass){ return null; } + ArrayList ABs=charsToMatrices.containsKey(AB) ? charsToMatrices.get(AB) : new ArrayList<>(); + ArrayList BAs=charsToMatrices.containsKey(BA) ? charsToMatrices.get(BA) : new ArrayList<>(); + if (ABs.size() < BAs.size() - 1) { + String tmp = AB; + + AB=BA; + BA=tmp; + ArrayList tmp2 = ABs; + BAs=ABs; + ABs=tmp2; + } String B = AB.substring(1,2); String A = AB.substring(0,1); char a = A.charAt(0); @@ -213,8 +235,9 @@ else if(chars.charAt(1)==b){ } else if (EinsumCPInstruction.FUSE_OUTER_MULTIPLY) { - if(LibMatrixMult.isSkinnyRightHandSide(charToSize.get(AB.charAt(0)), charToSize.get(AB.charAt(1)), charToSize.get(azC2), charToSize.get(AB.charAt(1)),false)|| - LibMatrixMult.isSkinnyRightHandSide(charToSize.get(AB.charAt(0)), azC2, charToSize.get(AB.charAt(1)), charToSize.get(azC2),false)) { +// if(LibMatrixMult.isSkinnyRightHandSide(charToSize.get(AB.charAt(0)), charToSize.get(AB.charAt(1)), charToSize.get(azC2), charToSize.get(AB.charAt(1)),false)|| +// LibMatrixMult.isSkinnyRightHandSide(charToSize.get(AB.charAt(0)), azC2, charToSize.get(AB.charAt(1)), charToSize.get(azC2),false)) { + if(LibMatrixMult.isSkinnyRightHandSide(charToSize.get(AB.charAt(0)), charToSize.get(AB.charAt(1)), charToSize.get(AB.charAt(0)),charToSize.get(azC2),false)){ // ideally this can be changed later by parent,depending on need if (outChar1 == azC2 && outChar2 == b) { t = EinsumRewriteType.AB_BA_B_A_AZ__ZB; @@ -230,7 +253,13 @@ else if (EinsumCPInstruction.FUSE_OUTER_MULTIPLY) { c2 = b; } + }else{ + t=null; + AZs=new ArrayList<>(); } + }else{ + t=null; + AZs=new ArrayList<>(); } if(charsToMatrices.containsKey(azC2.toString())) { @@ -262,8 +291,7 @@ else if (EinsumCPInstruction.FUSE_OUTER_MULTIPLY) { } HashSet usedOperands = new HashSet<>(); - ArrayList ABs=charsToMatrices.containsKey(AB) ? charsToMatrices.get(AB) : new ArrayList<>(); - ArrayList BAs=charsToMatrices.containsKey(BA) ? charsToMatrices.get(BA) : new ArrayList<>(); + ArrayList Bs=charsToMatrices.containsKey(B) ? charsToMatrices.get(B) : new ArrayList<>(); ArrayList As=charsToMatrices.containsKey(A) ? charsToMatrices.get(A) : new ArrayList<>(); @@ -290,7 +318,7 @@ else if (EinsumCPInstruction.FUSE_OUTER_MULTIPLY) { } } - var e = new EOpNodeEinsumFuse(c1, c2, t, + var e = new EOpNodeFuse(c1, c2, t, ABs, BAs, Bs, @@ -316,16 +344,16 @@ public MatrixBlock computeEOpNode(ArrayList inputs, int numThreads, .flatMap(List::stream) .map(o -> o.c1.toString() + (o.c2 == null ? "" : o.c2)) .collect(Collectors.joining(",")); - String res = (eOpNodeEinsumFuse.c1 == null ? "" : eOpNodeEinsumFuse.c1.toString())+(eOpNodeEinsumFuse.c2 == null ? "" : eOpNodeEinsumFuse.c2.toString()); - LOG.trace("ComputeEOpNodeFuse " + eOpNodeEinsumFuse.einsumRewriteType.toString() +" "+ x + " -> " + res); + String res = (eOpNodeEinsumFuse.c1 == null ? "AB=" : eOpNodeEinsumFuse.c1.toString())+(eOpNodeEinsumFuse.c2 == null ? "" : eOpNodeEinsumFuse.c2.toString()); + LOG.trace("ComputeEOpNodeFuse " + operands.get(0).get(0).c1+operands.get(0).get(0).c2 +" "+eOpNodeEinsumFuse.einsumRewriteType.toString() +" "+ x + " -> " + res); } - boolean isResultAB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A__AB; - boolean isResultA = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A__A; - boolean isResultB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A__B; - boolean isResult_ = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A__; - boolean isResultZ = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A_AZ__Z; - boolean isResultBZ = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A_AZ__BZ; - boolean isResultZB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A_AZ__ZB; + boolean isResultAB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__AB; + boolean isResultA = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A; + boolean isResultB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__B; + boolean isResult_ = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__; + boolean isResultZ = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A_AZ__Z; + boolean isResultBZ = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A_AZ__BZ; + boolean isResultZB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A_AZ__ZB; // boolean isResultBC = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A_AC__BC; // boolean isResultCB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A_AZ__ZB; List ABs = mbs.get(0), BAs = mbs.get(1), Bs = mbs.get(2), XBs = mbs.get(3), BXs = mbs.get(4), As = mbs.get(5), XAs = mbs.get(6), AXs = mbs.get(7); @@ -393,7 +421,7 @@ public MatrixBlock computeEOpNode(ArrayList inputs, int numThreads, case AB_BA_B_A_AZ__BZ -> SpoofRowwise.RowType.COL_AGG_B1_T; case AB_BA_B_A_AZ__ZB -> SpoofRowwise.RowType.COL_AGG_B1; }; - EinsumSpoofRowwise r = new EinsumSpoofRowwise(eOpNodeEinsumFuse.einsumRewriteType, rowType, constDim2, false, 0, ABs.size()-1,Bs.size(), As.size(), zCount, azCount, zSize); + EinsumSpoofRowwise r = new EinsumSpoofRowwise(eOpNodeEinsumFuse.einsumRewriteType, rowType, constDim2, false, 1, ABs.size()-1,Bs.size(), As.size(), zCount, azCount, zSize); ArrayList fuseInputs = new ArrayList<>(); @@ -411,6 +439,11 @@ public MatrixBlock computeEOpNode(ArrayList inputs, int numThreads, } + @Override + public void reorderChildren(Character outChar1, Character outChar2) { + + } + private static @NotNull List multiplyVectorsIntoOne(List mbs, int size) { MatrixBlock mb = new MatrixBlock(mbs.get(0).getNumRows(), mbs.get(0).getNumColumns(), false); mb.allocateDenseBlock(); diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EinsumSpoofRowwise.java b/src/main/java/org/apache/sysds/runtime/einsum/EinsumSpoofRowwise.java index 5298a8db176..1e0e6136221 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EinsumSpoofRowwise.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EinsumSpoofRowwise.java @@ -16,9 +16,9 @@ public final class EinsumSpoofRowwise extends SpoofRowwise { private final int _uptoBCumCount; private final int _uptoZCumCount; - private final EOpNodeEinsumFuse.EinsumRewriteType _EinsumRewriteType; + private final EOpNodeFuse.EinsumRewriteType _EinsumRewriteType; - public EinsumSpoofRowwise(EOpNodeEinsumFuse.EinsumRewriteType einsumRewriteType, RowType rowType, long constDim2, boolean tb1, int reqVectMem, int abCount, int bCount, int aCount, int zCount, int azCount, int zSize) { + public EinsumSpoofRowwise(EOpNodeFuse.EinsumRewriteType einsumRewriteType, RowType rowType, long constDim2, boolean tb1, int reqVectMem, int abCount, int bCount, int aCount, int zCount, int azCount, int zSize) { super(rowType, constDim2, tb1, reqVectMem); _ABCount = abCount; _BCount = bCount; @@ -119,7 +119,7 @@ protected void genexec_B(double[] a, int ai, SideInput[] b, double[] scalars, do protected void genexec_A_or_(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, long grix, int rix) { int bi = 0; double[] TMP1 = null; - Double TMP2 = null; + double TMP2 = 0; if (_ABCount == 0 && _BCount == 0){ TMP2 = LibSpoofPrimitives.dotProduct(a,b[bi++].values(rix),ai,ai,len); } @@ -152,9 +152,11 @@ else if(_BCount > 0 && TMP1 == null) { if(_ACount == 1) { TMP2 *= b[bi].values(0)[rix]; } - if (_EinsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A__A) c[ci] = TMP2; + if (_EinsumRewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A) c[ci] = TMP2; else c[0] += TMP2; } + protected void genexec(double[] avals, int[] aix, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int alen, int len, long grix, int rix) { - throw new RuntimeException("Sparse einsum not implemented"); } -} \ No newline at end of file + throw new RuntimeException("Sparse fused einsum not implemented"); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java index aba60b6b333..f5ae1f2a538 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java @@ -132,10 +132,10 @@ public void processInstruction(ExecutionContext ec) { } HashMap characterToOccurences = new HashMap<>(); - for (Character key :einc.characterAppearanceIndexes.keySet()) { + for (Character key : einc.characterAppearanceIndexes.keySet()) { characterToOccurences.put(key, einc.characterAppearanceIndexes.get(key).size()); } - for (Character key :einc.charToDimensionSize.keySet()) { + for (Character key : einc.charToDimensionSize.keySet()) { if(!characterToOccurences.containsKey(key)) characterToOccurences.put(key, 1); } @@ -146,42 +146,67 @@ public void processInstruction(ExecutionContext ec) { EOpNodeData n = new EOpNodeData(inputsChars.get(i).length() > 0 ? inputsChars.get(i).charAt(0) : null, inputsChars.get(i).length() > 1 ? inputsChars.get(i).charAt(1) : null, i); eOpNodes.add(n); } - Pair > plan = FORCE_CELL_TPL ? null : generatePlan(0, eOpNodes, einc.charToDimensionSize, characterToOccurences, einc.outChar1, einc.outChar2, FUSED); + Pair > plan; + ArrayList remainingMatrices; + if(!FORCE_CELL_TPL) { + if(FUSED) { + ArrayList ret = new ArrayList<>(); + EOpNodeFuse fuse = EOpNodeFuse.match(eOpNodes, einc.outChar1, einc.outChar2, ret, characterToOccurences, einc.charToDimensionSize); + if(fuse != null){ + eOpNodes = ret; + } + while(ret.size() > 2 && fuse != null){ + ret = new ArrayList<>(); + fuse = EOpNodeFuse.match(eOpNodes, einc.outChar1, einc.outChar2, ret, characterToOccurences, einc.charToDimensionSize); + if(fuse != null){ + eOpNodes = ret; + } + } + + } + + plan = generatePlan(0, eOpNodes, einc.charToDimensionSize, characterToOccurences, einc.outChar1, einc.outChar2); + if(plan.getRight().size() == 1) + plan.getRight().get(0).reorderChildren(einc.outChar1, einc.outChar2); - ArrayList resMatrices = FORCE_CELL_TPL ? null : executePlan(plan.getRight(), inputs); + remainingMatrices = executePlan(plan.getRight(), inputs); + }else{ + plan = Pair.of(0, eOpNodes); + remainingMatrices = inputs; + } - if(!FORCE_CELL_TPL && resMatrices.size() == 1){ + if(!FORCE_CELL_TPL && remainingMatrices.size() == 1){ EOpNode resNode = plan.getRight().get(0); if (einc.outChar1 != null && einc.outChar2 != null){ if(resNode.c1 == einc.outChar1 && resNode.c2 == einc.outChar2){ - ec.setMatrixOutput(output.getName(), resMatrices.get(0)); + ec.setMatrixOutput(output.getName(), remainingMatrices.get(0)); } else if(resNode.c1 == einc.outChar2 && resNode.c2 == einc.outChar1){ if( LOG.isTraceEnabled()) LOG.trace("Transposing the final result"); ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads); - MatrixBlock resM = resMatrices.get(0).reorgOperations(transpose, new MatrixBlock(),0,0,0); + MatrixBlock resM = remainingMatrices.get(0).reorgOperations(transpose, new MatrixBlock(),0,0,0); ec.setMatrixOutput(output.getName(), resM); }else{ - if(LOG.isTraceEnabled()) LOG.trace("Einsum expected: "+resultString + ", got: "+resNode.c1+resNode.c2); - throw new RuntimeException("Einsum plan produced different result"); + if(LOG.isTraceEnabled()) LOG.trace("Einsum error, expected: "+resultString + ", got: "+resNode.c1+resNode.c2); + throw new RuntimeException("Einsum plan produced different result, expected: "+resultString + ", got: "+resNode.c1+resNode.c2); } }else if (einc.outChar1 != null){ if(resNode.c1 == einc.outChar1 && resNode.c2 == null){ - ec.setMatrixOutput(output.getName(), resMatrices.get(0)); + ec.setMatrixOutput(output.getName(), remainingMatrices.get(0)); }else{ if(LOG.isTraceEnabled()) LOG.trace("Einsum expected: "+resultString + ", got: "+resNode.c1+resNode.c2); throw new RuntimeException("Einsum plan produced different result"); } }else{ if(resNode.c1 == null && resNode.c2 == null){ - ec.setScalarOutput(output.getName(), new DoubleObject(resMatrices.get(0).get(0, 0)));; + ec.setScalarOutput(output.getName(), new DoubleObject(remainingMatrices.get(0).get(0, 0)));; } } }else{ // use cell template with loops for remaining - ArrayList mbs = resMatrices; + ArrayList mbs = remainingMatrices; ArrayList chars = new ArrayList<>(); for (int i = 0; i < plan.getRight().size(); i++) { @@ -249,7 +274,7 @@ private void contractDimensionsAndComputeDiagonals(EinsumContext einc, ArrayList } // ideally the return list contains only one final element - private Pair> generatePlan(int cost, ArrayList operands, HashMap charToSizeMap, HashMap charToOccurences, Character outChar1, Character outChar2, boolean fused) { + private Pair> generatePlan(int cost, ArrayList operands, HashMap charToSizeMap, HashMap charToOccurences, Character outChar1, Character outChar2) { Integer minCost = cost; List minNodes = operands; @@ -270,21 +295,6 @@ else if (operands.size() == 1){ return Pair.of(cost, operands); } - if(fused){ - ArrayList ret = new ArrayList<>(); - EOpNodeEinsumFuse fuse = EOpNodeEinsumFuse.match(operands,outChar1,outChar2,ret, charToOccurences, einc.charToDimensionSize); - if(fuse != null){ - minNodes = operands = ret; - } - while(ret.size() > 2 && fuse!=null){ - ret = new ArrayList<>(); - fuse = EOpNodeEinsumFuse.match(operands,outChar1,outChar2,ret, charToOccurences, einc.charToDimensionSize); - if(fuse != null){ - minNodes = operands = ret; - } - } - fused = false; - } for(int i = 0; i < operands.size()-1; i++){ for (int j = i+1; j < operands.size(); j++){ @@ -312,7 +322,7 @@ else if (operands.size() == 1){ } newOperands.add(newNode); - Pair> furtherPlan = generatePlan(thisCost, newOperands,charToSizeMap, charToOccurences, outChar1, outChar2, fused); + Pair> furtherPlan = generatePlan(thisCost, newOperands,charToSizeMap, charToOccurences, outChar1, outChar2); if(furtherPlan.getRight().size() < (minNodes.size()) || furtherPlan.getLeft() < minCost){ minCost = furtherPlan.getLeft(); minNodes = furtherPlan.getRight(); @@ -374,7 +384,7 @@ else if(n1.c1 ==outChar1 || n1.c1==outChar2|| charToOccurences.get(n1.c1) > 2){ return null;// AB,AC } else { - return Triple.of((charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2))+(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c2)), EBinaryOperand.aB_aC, Pair.of(n1.c2, n2.c2)); // or n2.c2, n1.c2 + return Triple.of((charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2))+(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c2)), EBinaryOperand.aB_aC, Pair.of(n1.c2, n2.c2)); // or n2.c2, n1.c2 } }else{ // n1.c2 = null -> c2.c2 = null if(n1.c1 ==outChar1 || n1.c1==outChar2 || charToOccurences.get(n1.c1) > 2){ @@ -424,8 +434,7 @@ else if (n1.c2 == n2.c1) { if(cannotBeSummed.test(n1.c1)){ return null; // AB_B } - return null;//todo remove. -// return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c1), EBinaryOperand.aB_Ca, Pair.of(n2.c1, n1.c2)); // * its just reorder of mmult + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c1), EBinaryOperand.aB_Ca, Pair.of(n2.c1, n1.c2)); // * its just reorder of mmult } else if (n1.c2 == n2.c2) { if(n1.c2 ==outChar1 || n1.c2==outChar2|| charToOccurences.get(n1.c2) > 2){ @@ -477,9 +486,9 @@ private static void indent(StringBuilder sb, int level) { private MatrixBlock computeCellSummation(ArrayList inputs, List inputsChars, String resultString, HashMap charToDimensionSizeInt, List summingChars, int outRows, int outCols){ - ArrayList cnodeIn = new ArrayList<>(); - cnodeIn.add(new CNodeData(new LiteralOp(3), 0, 0, DataType.SCALAR)); - CNodeCell cnode = new CNodeCell(cnodeIn, null); + ArrayList dummyIn = new ArrayList<>(); + dummyIn.add(new CNodeData(new LiteralOp(0), 0, 0, DataType.SCALAR)); + CNodeCell cnode = new CNodeCell(dummyIn, null); StringBuilder sb = new StringBuilder(); int indent = 2; @@ -562,7 +571,7 @@ else if (summingChars.contains(inputsChars.get(i).charAt(1))) { sb.append(itVar0); sb.append(inputsChars.get(i).charAt(1)); sb.append(")"); - } else if (resultString.length() >= 1 &&inputsChars.get(i).charAt(1) == resultString.charAt(0)) { + } else if (resultString.length() >= 1 && inputsChars.get(i).charAt(1) == resultString.charAt(0)) { sb.append("rix)"); } else if (resultString.length() == 2 && inputsChars.get(i).charAt(1) == resultString.charAt(1)) { sb.append("cix)"); @@ -584,8 +593,7 @@ else if (summingChars.contains(inputsChars.get(i).charAt(1))) { indent--; sb.append("}\n"); } - //endregion - String src = CNodeCell.JAVA_TEMPLATE;// + String src = CNodeCell.JAVA_TEMPLATE; src = src.replace("%TMP%", cnode.createVarname()); src = src.replace("%TYPE%", "NO_AGG"); src = src.replace("%SPARSE_SAFE%", "false"); From cda71b736e3f0307ae5371560898d89f5ceda69e Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Wed, 12 Nov 2025 23:25:14 +0100 Subject: [PATCH 07/13] new approach to fuse --- .../apache/sysds/runtime/einsum/EOpNode.java | 24 +- .../sysds/runtime/einsum/EOpNodeBinary.java | 59 +- .../sysds/runtime/einsum/EOpNodeData.java | 26 +- .../sysds/runtime/einsum/EOpNodeFuse.java | 201 +++---- .../sysds/runtime/einsum/EOpNodeUnary.java | 95 ++++ .../sysds/runtime/einsum/EinsumContext.java | 197 +++---- .../runtime/einsum/EinsumSpoofRowwise.java | 239 +++++--- .../instructions/cp/EinsumCPInstruction.java | 532 ++++++++++++++---- .../test/functions/einsum/EinsumTest.java | 200 +++++-- 9 files changed, 1105 insertions(+), 468 deletions(-) create mode 100644 src/main/java/org/apache/sysds/runtime/einsum/EOpNodeUnary.java diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java index bdaf0d31c85..f803dc9e21d 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + package org.apache.sysds.runtime.einsum; import org.apache.commons.logging.Log; @@ -15,12 +34,13 @@ public EOpNode(Character c1, Character c2){ @Override public String toString() { - if(c1 == null) return "-"; - + if(c1 == null) return "''"; if(c2 == null) return c1.toString(); return c1.toString() + c2.toString(); } + public abstract String[] recursivePrintString(); + public abstract MatrixBlock computeEOpNode(ArrayList inputs, int numOfThreads, Log LOG); public abstract void reorderChildren(Character outChar1, Character outChar2); diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java index f0abfaf3ed2..0bb03737a66 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + package org.apache.sysds.runtime.einsum; import org.apache.commons.logging.Log; @@ -25,7 +44,8 @@ public class EOpNodeBinary extends EOpNode { - public enum EBinaryOperand { // upper case: char has to remain, lower case: to be summed + + public enum EBinaryOperand { // upper case: char has to remain, lower case: to be summed ////// summations: ////// aB_a,// -> B Ba_a, // -> B @@ -58,14 +78,45 @@ public enum EBinaryOperand { // upper case: char has to remain, lower case: to b public EOpNode _left; public EOpNode _right; public EBinaryOperand _operand; + private boolean transposeResult; public EOpNodeBinary(Character c1, Character c2, EOpNode left, EOpNode right, EBinaryOperand operand){ super(c1,c2); this._left = left; this._right = right; this._operand = operand; } + public void setTransposeResult(boolean transposeResult){ + this.transposeResult = transposeResult; + } - @Override + public static EOpNodeBinary combineMatrixMultiply(EOpNode left, EOpNode right) { + if (left.c2 == right.c1) { return new EOpNodeBinary(left.c1, right.c2, left, right, EBinaryOperand.Ba_aC); } + if (left.c2 == right.c2) { return new EOpNodeBinary(left.c1, right.c1, left, right, EBinaryOperand.Ba_Ca); } + if (left.c1 == right.c1) { return new EOpNodeBinary(left.c2, right.c2, left, right, EBinaryOperand.aB_aC); } + if (left.c1 == right.c2) { + var res = new EOpNodeBinary(left.c2, right.c1, left, right, EBinaryOperand.aB_Ca); + res.setTransposeResult(true); + return res; + } + throw new RuntimeException("EOpNodeBinary::combineMatrixMultiply: invalid matrix operation"); + } + + @Override + public String[] recursivePrintString() { + String[] left = _left.recursivePrintString(); + String[] right = _right.recursivePrintString(); + String[] res = new String[left.length + right.length+1]; + res[0] = this.getClass().getSimpleName()+" ("+_operand.toString()+") "+this.toString(); + for (int i=0; i inputs, int numThreads, Log LOG) { EOpNodeBinary bin = this; MatrixBlock left = _left.computeEOpNode(inputs, numThreads, LOG); @@ -204,6 +255,10 @@ public MatrixBlock computeEOpNode(ArrayList inputs, int numThreads, } } + if(transposeResult){ + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads); + res = res.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); + } return res; } diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java index 43feb5cd7cb..d9b61b29514 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + package org.apache.sysds.runtime.einsum; import org.apache.commons.logging.Log; @@ -11,7 +30,12 @@ public EOpNodeData(Character c1, Character c2, int matrixIdx){ super(c1,c2); this.matrixIdx = matrixIdx; } - + @Override + public String[] recursivePrintString() { + String[] res = new String[1]; + res[0] = this.getClass().getSimpleName()+" ("+matrixIdx+") "+this.toString(); + return res; + } @Override public MatrixBlock computeEOpNode(ArrayList inputs, int numOfThreads, Log LOG) { return inputs.get(matrixIdx); diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java index eca0c465e95..536456fc744 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java @@ -1,12 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + package org.apache.sysds.runtime.einsum; import org.apache.commons.logging.Log; +import org.apache.sysds.common.Types; +import org.apache.sysds.hops.LiteralOp; +import org.apache.sysds.hops.codegen.cplan.CNode; +import org.apache.sysds.hops.codegen.cplan.CNodeData; +import org.apache.sysds.hops.codegen.cplan.CNodeRow; +import org.apache.sysds.runtime.codegen.CodegenUtils; +import org.apache.sysds.runtime.codegen.SpoofOperator; import org.apache.sysds.runtime.codegen.SpoofRowwise; +import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; import org.apache.sysds.runtime.functionobjects.Plus; import org.apache.sysds.runtime.functionobjects.ReduceCol; import org.apache.sysds.runtime.functionobjects.ReduceRow; import org.apache.sysds.runtime.functionobjects.SwapIndex; +import org.apache.sysds.runtime.instructions.cp.DoubleObject; import org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction; +import org.apache.sysds.runtime.instructions.cp.ScalarObject; import org.apache.sysds.runtime.matrix.data.LibMatrixMult; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.AggregateOperator; @@ -19,14 +48,19 @@ import java.util.HashMap; import java.util.HashSet; import java.util.List; +import java.util.function.Supplier; import java.util.stream.Collectors; import static org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction.ensureMatrixBlockColumnVector; public class EOpNodeFuse extends EOpNode { - public enum EinsumRewriteType{ - // B -> row*row, A -> row*scalar + private EOpNode scalar = null; + + + + public enum EinsumRewriteType{ + // B -> row*vec, A -> row*scalar AB_BA_B_A__AB, AB_BA_B_A__B, AB_BA_B_A__A, @@ -35,13 +69,9 @@ public enum EinsumRewriteType{ // scalar from row(AB).dot(B) multiplied by row(AZ) AB_BA_B_A_AZ__Z, - // AC: last step is outer matrix multiplication using vector C + // AZ: last step is outer matrix multiplication using vector Z AB_BA_B_A_AZ__BZ, AB_BA_B_A_AZ__ZB, - -// // outer matrix multiplication using vector C and vector Z -// AB_BA_B_A_AZ_AC__ZC, -// AB_BA_B_A_AZ_AC__CZ, } public final EinsumRewriteType einsumRewriteType; @@ -52,8 +82,37 @@ private EOpNodeFuse(Character c1, Character c2, EinsumRewriteType einsumRewriteT this.einsumRewriteType = einsumRewriteType; this.operands = Arrays.asList(operands); } - - public static EOpNodeFuse match(ArrayList operands, Character outChar1, Character outChar2,/*, Set simplySummableChars,*/ ArrayList ret, HashMap charToOccurences, HashMap charToSize){ + @Override + public String[] recursivePrintString() { + ArrayList inpStrings = new ArrayList<>(); + for (List list : operands) { + for (EOpNode node : list) { + inpStrings.add(node.recursivePrintString()); + } + } + String[] inpRes = inpStrings.stream().flatMap(Arrays::stream).toArray(String[]::new); + String[] scalarRes = this.scalar==null ? new String[]{} : this.scalar.recursivePrintString(); + String[] res = new String[1 + inpRes.length + scalarRes.length]; + + res[0] = this.getClass().getSimpleName()+" ("+einsumRewriteType.toString()+") "+this.toString(); + + for (int i=0; i operands, Character outChar1, Character outChar2,/*, Set simplySummableChars,*/ + HashMap charToSize, HashMap charToOccurences, ArrayList ret){ //precompute HashSet matricesChars = new HashSet<>(); HashMap> charsToMatrices = new HashMap<>(); @@ -77,12 +136,7 @@ public static EOpNodeFuse match(ArrayList operands, Character outChar1, } } - ArrayList AXs = new ArrayList<>(); - ArrayList XAs = new ArrayList<>(); - ArrayList BXs = new ArrayList<>(); - ArrayList XBs = new ArrayList<>(); ArrayList AZs = new ArrayList<>(); -// ArrayList ACs = new ArrayList<>(); ArrayList Zs = new ArrayList<>(); boolean pass = false; @@ -95,10 +149,6 @@ public static EOpNodeFuse match(ArrayList operands, Character outChar1, char b = ABcandidate.charAt(1); BA = "" + b + a; - AXs = new ArrayList<>(); - XAs = new ArrayList<>(); - BXs = new ArrayList<>(); - XBs = new ArrayList<>(); AZs = new ArrayList<>(); Character z = null; pass=true; @@ -107,42 +157,27 @@ public static EOpNodeFuse match(ArrayList operands, Character outChar1, for (String chars : charsToMatrices.keySet()) { if (chars.equals(ABcandidate) || chars.equals(BA)) { -// ABsCounter++; continue; } if(chars.length()==1){ - if(chars.charAt(0)==a){ -// AsCounter++; - }else if(chars.charAt(0)==b){ -// BsCounter++; - } - continue; //always ok }else{ if(a==chars.charAt(1) && b==chars.charAt(0)){ //BA -// ABsCounter++; continue; } - if(chars.charAt(0)==a){ - //AZ + if(chars.charAt(0)==a){ //AZ AZsCounter++; AZCandidates.add(chars); } else if(chars.charAt(0)==b){ - // BZ, todo, maybe transpose ab into ba -// pass = false; -// break; + // BZ } else if(chars.charAt(1)==a){ - //ZA, maybe its small enough that it can be tranposed? but then not impactful as the bigger A, the more sense to fuse AZ? -// pass = false; -// break; + //ZA } else if(chars.charAt(1)==b){ // ZB -// pass = false; -// break; } } } @@ -154,7 +189,6 @@ else if(chars.charAt(1)==b){ String B = ""+b; int BAsCounter = (charsToMatrices.containsKey(BA) ? charsToMatrices.get(BA).size() : 0); int ABsCounter = charsToMatrices.get(ABcandidate).size()+BAsCounter; -// int AZsCounter = AZs.size(); int AsCounter = (charsToMatrices.containsKey(A) ? charsToMatrices.get(A).size() : 0); int BsCounter = (charsToMatrices.containsKey(B) ? charsToMatrices.get(B).size() : 0); if(AsCounter==0 && BsCounter==0 && ABsCounter<2){ @@ -166,10 +200,6 @@ else if(chars.charAt(1)==b){ boolean includeAz = true; if(AZCandidates.size()==1){ -// if(!doSumB){ -// pass=false; -// continue; -// } if(!doSumB) { // check if outer is possible AB,...,AZ->BZ if(!EinsumCPInstruction.FUSE_OUTER_MULTIPLY || !LibMatrixMult.isSkinnyRightHandSide(charToSize.get(AB.charAt(0)), charToSize.get(AB.charAt(1)), charToSize.get(AB.charAt(0)),charToSize.get(AZCandidates.iterator().next().charAt(1)),false)) { @@ -186,16 +216,17 @@ else if(chars.charAt(1)==b){ break;//ok } } - } else if (AZCandidates.size()>=2) { - doSumA = false; - if(doSumB){ - pass=true; - break; // can do it, it will create AB,B,A -> A, that will be consumed by some AZ later - } - pass=false; - continue; - } +// else if (AZCandidates.size() >= 2) { +// doSumA = false; +// if(doSumB){ +// pass=true; +// break; // can do it, it will create AB,B,A -> A, that will be consumed by some AZ later +// } +// pass=false; +// continue; +// +// } int usedAsCount = AsCounter+ABsCounter; doSumA = charToOccurences.get(a)==usedAsCount && (outChar1 == null || a!=outChar1) && (outChar2 == null || a!=outChar2); @@ -216,6 +247,7 @@ else if(chars.charAt(1)==b){ ArrayList tmp2 = ABs; BAs=ABs; ABs=tmp2; + AZs.clear(); } String B = AB.substring(1,2); String A = AB.substring(0,1); @@ -226,19 +258,13 @@ else if(chars.charAt(1)==b){ EinsumRewriteType t = null; if(!AZs.isEmpty()){ -// Character azC1 = AZs.get(0).c1; Character azC2 = AZs.get(0).c2; -// c1 = AZs.get(0).c2; if(doSumB) { t = EinsumRewriteType.AB_BA_B_A_AZ__Z; c1 = azC2; - } else if (EinsumCPInstruction.FUSE_OUTER_MULTIPLY) { -// if(LibMatrixMult.isSkinnyRightHandSide(charToSize.get(AB.charAt(0)), charToSize.get(AB.charAt(1)), charToSize.get(azC2), charToSize.get(AB.charAt(1)),false)|| -// LibMatrixMult.isSkinnyRightHandSide(charToSize.get(AB.charAt(0)), azC2, charToSize.get(AB.charAt(1)), charToSize.get(azC2),false)) { if(LibMatrixMult.isSkinnyRightHandSide(charToSize.get(AB.charAt(0)), charToSize.get(AB.charAt(1)), charToSize.get(AB.charAt(0)),charToSize.get(azC2),false)){ - // ideally this can be changed later by parent,depending on need if (outChar1 == azC2 && outChar2 == b) { t = EinsumRewriteType.AB_BA_B_A_AZ__ZB; c1 = azC2; @@ -299,10 +325,6 @@ else if (EinsumCPInstruction.FUSE_OUTER_MULTIPLY) { usedOperands.addAll(BAs); usedOperands.addAll(Bs); usedOperands.addAll(As); - usedOperands.addAll(XBs); - usedOperands.addAll(BXs); - usedOperands.addAll(XAs); - usedOperands.addAll(AXs); usedOperands.addAll(AZs); usedOperands.addAll(Zs); @@ -322,21 +344,16 @@ else if (EinsumCPInstruction.FUSE_OUTER_MULTIPLY) { ABs, BAs, Bs, - XBs, - BXs, As, - XAs, - AXs, AZs, Zs ); - ret.add(e); return e; } @Override public MatrixBlock computeEOpNode(ArrayList inputs, int numThreads, Log LOG) { - List> mbs = operands.stream().map(l -> l.stream().map(n -> n.computeEOpNode(inputs, numThreads, LOG)).collect(Collectors.toList())).toList(); + List> mbs = operands.stream().map(l -> l.stream().map(n -> n.computeEOpNode(inputs, numThreads, LOG)).collect(Collectors.toList())).toList(); var eOpNodeEinsumFuse = this; if( LOG.isTraceEnabled()) { @@ -344,8 +361,9 @@ public MatrixBlock computeEOpNode(ArrayList inputs, int numThreads, .flatMap(List::stream) .map(o -> o.c1.toString() + (o.c2 == null ? "" : o.c2)) .collect(Collectors.joining(",")); - String res = (eOpNodeEinsumFuse.c1 == null ? "AB=" : eOpNodeEinsumFuse.c1.toString())+(eOpNodeEinsumFuse.c2 == null ? "" : eOpNodeEinsumFuse.c2.toString()); - LOG.trace("ComputeEOpNodeFuse " + operands.get(0).get(0).c1+operands.get(0).get(0).c2 +" "+eOpNodeEinsumFuse.einsumRewriteType.toString() +" "+ x + " -> " + res); + String res = eOpNodeEinsumFuse.c1 == null ? "" : (eOpNodeEinsumFuse.c1.toString() +(eOpNodeEinsumFuse.c2 == null ? "" : eOpNodeEinsumFuse.c2.toString())); + + LOG.trace("ComputeEOpNodeFuse AB=" + operands.get(0).get(0).c1+operands.get(0).get(0).c2 +" "+eOpNodeEinsumFuse.einsumRewriteType.toString() +" "+ x + "->" + res); } boolean isResultAB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__AB; boolean isResultA = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A; @@ -354,38 +372,21 @@ public MatrixBlock computeEOpNode(ArrayList inputs, int numThreads, boolean isResultZ = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A_AZ__Z; boolean isResultBZ = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A_AZ__BZ; boolean isResultZB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A_AZ__ZB; -// boolean isResultBC = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A_AC__BC; -// boolean isResultCB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeEinsumFuse.EinsumRewriteType.AB_BA_B_A_AZ__ZB; - List ABs = mbs.get(0), BAs = mbs.get(1), Bs = mbs.get(2), XBs = mbs.get(3), BXs = mbs.get(4), As = mbs.get(5), XAs = mbs.get(6), AXs = mbs.get(7); - List AZs = mbs.get(8); - List Zs = mbs.get(9); -// List ACs = isResultBC || isResultCB ? mbs.get(10) : null; + List ABs = mbs.get(0), BAs = mbs.get(1), Bs = mbs.get(2), As = mbs.get(3); + List AZs = mbs.get(4); + List Zs = mbs.get(5); int bSize = ABs.get(0).getNumColumns(); int aSize = ABs.get(0).getNumRows(); - ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads); - for(MatrixBlock mb: BAs){//BA->AB - ABs.add(mb.reorgOperations(transpose, null,0,0,0)); - } + if (!BAs.isEmpty()) { + ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads); + for(MatrixBlock mb : BAs) {//BA->AB + ABs.add(mb.reorgOperations(transpose, null, 0, 0, 0)); + } + } AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), numThreads); - for(MatrixBlock mb: XBs){//XB->B - MatrixBlock res = new MatrixBlock(mb.getNumColumns(), 1, false); - Bs.add((MatrixBlock)mb.aggregateUnaryOperations(aggun, res, 0, null)); - } - for(MatrixBlock mb: XAs){//XA->A - MatrixBlock res = new MatrixBlock(mb.getNumColumns(), 1, false); - As.add((MatrixBlock)mb.aggregateUnaryOperations(aggun, res, 0, null)); - } - aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), numThreads); - for(MatrixBlock mb: BXs){//BX->B - MatrixBlock res = new MatrixBlock(mb.getNumRows(), 1, false); - As.add((MatrixBlock)mb.aggregateUnaryOperations(aggun, res, 0, null)); - } - for(MatrixBlock mb: AXs){//AX->B // todo remove all X - MatrixBlock res = new MatrixBlock(mb.getNumRows(), 1, false); - As.add((MatrixBlock)mb.aggregateUnaryOperations(aggun, res, 0, null)); - } - if(As.size()>1){ + + if(As.size() > 1){ As = multiplyVectorsIntoOne(As, aSize); } if(Bs.size() > 1){ @@ -425,14 +426,18 @@ public MatrixBlock computeEOpNode(ArrayList inputs, int numThreads, ArrayList fuseInputs = new ArrayList<>(); -// inputs.add(resBlock); fuseInputs.addAll(ABs); fuseInputs.addAll(Bs); fuseInputs.addAll(As); if (isResultZ || isResultBZ || isResultZB) fuseInputs.addAll(AZs); - MatrixBlock out = r.execute(fuseInputs, new ArrayList<>(), new MatrixBlock(), numThreads); + ArrayList scalarObjects = new ArrayList<>(); + if(this.scalar != null){ + MatrixBlock scMb = this.scalar.computeEOpNode(inputs,numThreads,LOG); + scalarObjects.add(new DoubleObject(scMb.get(0,0))); + } + MatrixBlock out = r.execute(fuseInputs, scalarObjects, new MatrixBlock(), numThreads); if( isResultA || isResultB || isResultZ) ensureMatrixBlockColumnVector(out); return out; diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeUnary.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeUnary.java new file mode 100644 index 00000000000..b48877625cf --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeUnary.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.einsum; + +import org.apache.commons.logging.Log; +import org.apache.sysds.runtime.functionobjects.DiagIndex; +import org.apache.sysds.runtime.functionobjects.Plus; +import org.apache.sysds.runtime.functionobjects.ReduceAll; +import org.apache.sysds.runtime.functionobjects.ReduceCol; +import org.apache.sysds.runtime.functionobjects.ReduceRow; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.AggregateOperator; +import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; +import org.apache.sysds.runtime.matrix.operators.ReorgOperator; + +import java.util.ArrayList; + +public class EOpNodeUnary extends EOpNode { + private final EUnaryOperand eUnaryOperand; + public EOpNode child; + + public enum EUnaryOperand { + DIAG, SUM, SUM_ROWS, SUM_COLS + } + public EOpNodeUnary(Character c1, Character c2, EOpNode child, EUnaryOperand eUnaryOperand) { + super(c1, c2); + this.child = child; + this.eUnaryOperand = eUnaryOperand; + } + + @Override + public String[] recursivePrintString() { + String[] childResult = child.recursivePrintString(); + String[] res = new String[1+childResult.length]; + res[0] = this.getClass().getSimpleName()+" ("+eUnaryOperand.toString()+") "+this.toString(); + for (int i=0; i inputs, int numOfThreads, Log LOG) { + MatrixBlock mb = child.computeEOpNode(inputs, numOfThreads, LOG); + return switch(eUnaryOperand) { + case DIAG->{ + ReorgOperator op = new ReorgOperator(DiagIndex.getDiagIndexFnObject()); + yield mb.reorgOperations(op, new MatrixBlock(),0,0,0); + } + case SUM->{ + AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), numOfThreads); + MatrixBlock res = new MatrixBlock(1, 1, false); + mb.aggregateUnaryOperations(aggun, res, 0, null); + yield res; + } + case SUM_ROWS->{ + AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), numOfThreads); + MatrixBlock res = new MatrixBlock(mb.getNumColumns(), 1, false); + mb.aggregateUnaryOperations(aggun, res, 0, null); + yield res; + } + case SUM_COLS->{ + AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), numOfThreads); + MatrixBlock res = new MatrixBlock(mb.getNumRows(), 1, false); + mb.aggregateUnaryOperations(aggun, res, 0, null); + yield res; + } + }; + } + + @Override + public void reorderChildren(Character outChar1, Character outChar2) { + + } +} diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EinsumContext.java b/src/main/java/org/apache/sysds/runtime/einsum/EinsumContext.java index 6da39e59873..55692d0109e 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EinsumContext.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EinsumContext.java @@ -23,155 +23,98 @@ import java.util.ArrayList; import java.util.HashMap; -import java.util.HashSet; import java.util.Iterator; - public class EinsumContext { - public enum ContractDimensions { - CONTRACT_LEFT, - CONTRACT_RIGHT, - CONTRACT_BOTH, - } public Integer outRows; public Integer outCols; public Character outChar1; public Character outChar2; public HashMap charToDimensionSize; public String equationString; - public boolean[] diagonalInputs; - public HashSet summingChars; - public HashSet contractDimsSet; - public ContractDimensions[] contractDims; public ArrayList newEquationStringInputsSplit; - public HashMap> characterAppearanceIndexes; // for each character, this tells in which inputs it appears + public HashMap characterAppearanceCount; private EinsumContext(){}; public static EinsumContext getEinsumContext(String eqStr, ArrayList inputs){ EinsumContext res = new EinsumContext(); res.equationString = eqStr; - res.charToDimensionSize = new HashMap(); - HashSet summingChars = new HashSet<>(); - ContractDimensions[] contractDims = new ContractDimensions[inputs.size()]; - boolean[] diagonalInputs = new boolean[inputs.size()]; // all false by default - HashSet contractDimsSet = new HashSet<>(); - HashMap> partsCharactersToIndices = new HashMap<>(); - ArrayList newEquationStringSplit = new ArrayList<>(); + HashMap charToDimensionSize = new HashMap<>(); + HashMap characterAppearanceCount = new HashMap<>(); + ArrayList newEquationStringSplit = new ArrayList<>(); + Character outChar1 = null; + Character outChar2 = null; Iterator it = inputs.iterator(); MatrixBlock curArr = it.next(); - int arrSizeIterator = 0; - int arrayIterator = 0; - int i; - // first iteration through string: collect information on character-size and what characters are summing characters - for (i = 0; true; i++) { - char c = eqStr.charAt(i); - if(c == '-'){ - i+=2; - break; - } - if(c == ','){ - arrayIterator++; - curArr = it.next(); - arrSizeIterator = 0; - } - else{ - if (res.charToDimensionSize.containsKey(c)) { // sanity check if dims match, this is already checked at validation - if(arrSizeIterator == 0 && res.charToDimensionSize.get(c) != curArr.getNumRows()) - throw new RuntimeException("Einsum: character "+c+" has multiple conflicting sizes"); - else if(arrSizeIterator == 1 && res.charToDimensionSize.get(c) != curArr.getNumColumns()) - throw new RuntimeException("Einsum: character "+c+" has multiple conflicting sizes"); - summingChars.add(c); - } else { - if(arrSizeIterator == 0) - res.charToDimensionSize.put(c, curArr.getNumRows()); - else if(arrSizeIterator == 1) - res.charToDimensionSize.put(c, curArr.getNumColumns()); - } - - arrSizeIterator++; - } - } - - int numOfRemainingChars = eqStr.length() - i; - - if (numOfRemainingChars > 2) - throw new RuntimeException("Einsum: dim > 2 not supported"); - - arrSizeIterator = 0; - - Character outChar1 = numOfRemainingChars > 0 ? eqStr.charAt(i) : null; - Character outChar2 = numOfRemainingChars > 1 ? eqStr.charAt(i+1) : null; - res.outRows=(numOfRemainingChars > 0 ? res.charToDimensionSize.get(outChar1) : 1); - res.outCols=(numOfRemainingChars > 1 ? res.charToDimensionSize.get(outChar2) : 1); - - arrayIterator=0; - // second iteration through string: collect remaining information - for (i = 0; true; i++) { - char c = eqStr.charAt(i); - if (c == '-') { - break; - } - if (c == ',') { - arrayIterator++; - arrSizeIterator = 0; - continue; - } - String s = ""; - - if(summingChars.contains(c)) { - s+=c; - if(!partsCharactersToIndices.containsKey(c)) - partsCharactersToIndices.put(c, new ArrayList<>()); - partsCharactersToIndices.get(c).add(arrayIterator); - } - else if((outChar1 != null && c == outChar1) || (outChar2 != null && c == outChar2)) { - s+=c; - } - else { - contractDimsSet.add(c); - contractDims[arrayIterator] = ContractDimensions.CONTRACT_LEFT; - } - - if(i + 1 < eqStr.length()) { // process next character together - char c2 = eqStr.charAt(i + 1); - i++; - if (c2 == '-') { newEquationStringSplit.add(s); break;} - if (c2 == ',') { arrayIterator++; newEquationStringSplit.add(s); continue; } - - if (c2 == c){ - diagonalInputs[arrayIterator] = true; - if (contractDims[arrayIterator] == ContractDimensions.CONTRACT_LEFT) contractDims[arrayIterator] = ContractDimensions.CONTRACT_BOTH; - } - else{ - if(summingChars.contains(c2)) { - s+=c2; - if(!partsCharactersToIndices.containsKey(c2)) - partsCharactersToIndices.put(c2, new ArrayList<>()); - partsCharactersToIndices.get(c2).add(arrayIterator); - } - else if((outChar1 != null && c2 == outChar1) || (outChar2 != null && c2 == outChar2)) { - s+=c2; - } - else { - contractDimsSet.add(c2); - contractDims[arrayIterator] = contractDims[arrayIterator] == ContractDimensions.CONTRACT_LEFT ? ContractDimensions.CONTRACT_BOTH : ContractDimensions.CONTRACT_RIGHT; - } - } - } - newEquationStringSplit.add(s); - arrSizeIterator++; - } - - res.contractDims = contractDims; - res.contractDimsSet = contractDimsSet; - res.diagonalInputs = diagonalInputs; - res.summingChars = summingChars; + int i = 0; + + char c = eqStr.charAt(i); + for(i = 0; i < eqStr.length(); i++) { + StringBuilder sb = new StringBuilder(2); + for(;i < eqStr.length(); i++){ + c = eqStr.charAt(i); + if (c == ' ') continue; + if (c == ',' || c == '-' ) break; + if (!Character.isAlphabetic(c)) { + throw new RuntimeException("Einsum: only alphabetic characters are supported for dimensions: "+c); + } + sb.append(c); + if (characterAppearanceCount.containsKey(c)) characterAppearanceCount.put(c, characterAppearanceCount.get(c) + 1) ; + else characterAppearanceCount.put(c, 1); + } + String s = sb.toString(); + newEquationStringSplit.add(s); + + if(s.length() > 0){ + if (charToDimensionSize.containsKey(s.charAt(0))) + if (charToDimensionSize.get(s.charAt(0)) != curArr.getNumRows()) + throw new RuntimeException("Einsum: character "+c+" has multiple conflicting sizes"); + charToDimensionSize.put(s.charAt(0), curArr.getNumRows()); + } + if(s.length() > 1){ + if (charToDimensionSize.containsKey(s.charAt(1))) + if (charToDimensionSize.get(s.charAt(1)) != curArr.getNumColumns()) + throw new RuntimeException("Einsum: character "+c+" has multiple conflicting sizes"); + charToDimensionSize.put(s.charAt(1), curArr.getNumColumns()); + } + if(s.length() > 2) throw new RuntimeException("Einsum: only up-to 2D inputs strings allowed "); + + if( c==','){ + curArr = it.next(); + } + else if (c=='-') break; + + if (i == eqStr.length() - 1) {throw new RuntimeException("Einsum: missing '->' substring "+c);} + } + + if (i == eqStr.length() - 1 || eqStr.charAt(i+1) != '>') throw new RuntimeException("Einsum: missing '->' substring "+c); + i+=2; + + StringBuilder sb = new StringBuilder(2); + + for(;i < eqStr.length(); i++){ + c = eqStr.charAt(i); + if (c == ' ') continue; + if (!Character.isAlphabetic(c)) { + throw new RuntimeException("Einsum: only alphabetic characters are supported for dimensions: "+c); + } + sb.append(c); + } + String s = sb.toString(); + if(s.length() > 0) outChar1 = s.charAt(0); + if(s.length() > 1) outChar2 = s.charAt(1); + if(s.length() > 2) throw new RuntimeException("Einsum: only up-to 2D output allowed "); + + res.outRows=(outChar1 == null ? 1 : charToDimensionSize.get(outChar1)); + res.outCols=(outChar2 == null ? 1 : charToDimensionSize.get(outChar2)); + res.outChar1 = outChar1; res.outChar2 = outChar2; res.newEquationStringInputsSplit = newEquationStringSplit; - res.characterAppearanceIndexes = partsCharactersToIndices; + res.characterAppearanceCount = characterAppearanceCount; + res.charToDimensionSize = charToDimensionSize; return res; } } diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EinsumSpoofRowwise.java b/src/main/java/org/apache/sysds/runtime/einsum/EinsumSpoofRowwise.java index 1e0e6136221..f12667685b8 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EinsumSpoofRowwise.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EinsumSpoofRowwise.java @@ -1,5 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + package org.apache.sysds.runtime.einsum; +import jdk.incubator.vector.DoubleVector; +import jdk.incubator.vector.VectorSpecies; import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.codegen.LibSpoofPrimitives; import org.apache.sysds.runtime.codegen.SpoofRowwise; @@ -32,31 +53,58 @@ public EinsumSpoofRowwise(EOpNodeFuse.EinsumRewriteType einsumRewriteType, RowTy } protected void genexec(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, long grix, int rix) { switch (_EinsumRewriteType) { - case AB_BA_B_A__AB -> genexec_AB(a,ai,b,scalars,c,ci,len,grix,rix); - case AB_BA_B_A__B -> genexec_B(a,ai,b,scalars,c,ci,len,grix,rix); - case AB_BA_B_A__A -> genexec_A_or_(a,ai,b,scalars,c,ci,len,grix,rix); - case AB_BA_B_A__ -> genexec_A_or_(a,ai,b,scalars,c,ci,len,grix,rix); + case AB_BA_B_A__AB -> { + genexec_AB(a,ai,b,scalars,c,ci,len,grix,rix); + if (scalars.length != 0) { + LibMatrixMult.vectMultiplyWrite(scalars[0], c,c,ci,ci, len); + } + } + case AB_BA_B_A__B -> { + genexec_B(a,ai,b,scalars,c,ci,len,grix,rix); + } + case AB_BA_B_A__A -> { +// HARDCODEDgenexec_A_or_(a,ai,b,scalars,c,ci,len,grix,rix); + genexec_A_or_(a,ai,b,scalars,c,ci,len,grix,rix); + if (scalars.length != 0) { + c[rix] *= scalars[0]; + } + } + case AB_BA_B_A__ -> { + genexec_A_or_(a,ai,b,scalars,c,ci,len,grix,rix); + if (scalars.length != 0) { + c[0] *= scalars[0]; + } + } case AB_BA_B_A_AZ__Z -> { double[] temp = {0}; genexec_A_or_(a,ai,b,scalars,temp,0,len,grix,rix); + if (scalars.length != 0) { + temp[0] *= scalars[0]; + } LibMatrixMult.vectMultiplyAdd(temp[0], b[_uptoZCumCount].values(rix), c, _ZSize*rix,0, _ZSize); } case AB_BA_B_A_AZ__BZ -> { double[] temp = new double[len]; genexec_B(a,ai,b,scalars,temp,0,len,grix,rix); + if (scalars.length != 0) { + LibMatrixMult.vectMultiplyWrite(scalars[0], temp,temp,0,0,len); + } LibSpoofPrimitives.vectOuterMultAdd(temp, b[_uptoZCumCount].values(rix), c,0, _ZSize*rix, 0, len,_ZSize); } case AB_BA_B_A_AZ__ZB -> { double[] temp = new double[len]; genexec_B(a,ai,b,scalars,temp,0,len,grix,rix); - + if (scalars.length != 0) { + LibMatrixMult.vectMultiplyWrite(scalars[0], temp,temp,0,0,len); + } LibSpoofPrimitives.vectOuterMultAdd(b[_uptoZCumCount].values(rix),temp , c,_ZSize*rix,0, 0, _ZSize, len); } -// case AB_BA_B_XB_BX_A_XA_AX_AZ_AC__CZ -> default -> throw new NotImplementedException(); } + } - protected void genexec_AB(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, long grix, int rix) { + private void genexec_AB(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, long grix, + int rix) { int bi = 0; double[] TMP1 = null; if (_ABCount != 0){ @@ -70,93 +118,142 @@ protected void genexec_AB(double[] a, int ai, SideInput[] b, double[] scalars, d } } - if(_BCount > 0 && TMP1 == null) { - TMP1 = LibSpoofPrimitives.vectMultWrite(a,b[bi++].values(0),ai,0,len); - } - while(bi < _uptoBCumCount) { - if (_ACount == 0 && bi == _uptoBCumCount - 1) { - LibMatrixMult.vectMultiplyWrite(TMP1, b[bi++].values(0), c, 0, 0, ci, len); - } else { - LibMatrixMult.vectMultiplyWrite(TMP1, b[bi++].values(0), TMP1, 0, 0, 0, len); - } - } - - if(_ACount == 1) { - LibMatrixMult.vectMultiplyWrite(b[bi].values(0)[rix],TMP1,c,0,ci,len); - } + if(_BCount == 1) { + if(_ACount == 1) + if(TMP1 == null) + vectMultiplyWrite(b[bi + 1].values(0)[rix], a, b[bi].values(0), c, ai, 0, ci, len); + else + vectMultiplyWrite(b[bi + 1].values(0)[rix], TMP1, b[bi].values(0), c, 0, 0, ci, len); + else if(TMP1 == null) + LibMatrixMult.vectMultiplyWrite(a, b[bi].values(0), c, ai, 0, ci, len); + else + LibMatrixMult.vectMultiplyWrite(TMP1, b[bi].values(0), c, 0, 0, ci, len); + } else if(_ACount == 1) { + if(TMP1 == null) + LibMatrixMult.vectMultiplyWrite(b[bi].values(0)[rix],a,c, ai, ci, len); + else + LibMatrixMult.vectMultiplyWrite(b[bi].values(0)[rix],TMP1,c, 0, ci, len); + } } - protected void genexec_B(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, long grix, int rix) { - int bi = 0; - double[] TMP1 = null; - if (_ABCount != 0){ - TMP1 = LibSpoofPrimitives.vectMultWrite(a,b[bi++].values(rix),ai,ai,len); - while (bi < _ABCount) { - if(_ACount == 0 && _BCount == 0 && bi == _ABCount-1) { - LibMatrixMult.vectMultiplyAdd(TMP1, b[bi++].values(rix), c, 0, ai, 0, len); - }else { - LibMatrixMult.vectMultiplyWrite(TMP1, b[bi++].values(rix), TMP1, 0, ai, 0, len); - } - } + private void genexec_B(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, long grix, + int rix) { + int bi = 0; + double[] TMP1 = null; + if(_ABCount != 0) { + TMP1 = LibSpoofPrimitives.vectMultWrite(a, b[bi++].values(rix), ai, ai, len); + while(bi < _ABCount) { + if(_ACount == 0 && _BCount == 0 && bi == _ABCount - 1) { + LibMatrixMult.vectMultiplyAdd(TMP1, b[bi++].values(rix), c, 0, ai, 0, len); } - - if(_BCount > 0 && TMP1 == null) { - TMP1 = LibSpoofPrimitives.vectMultWrite(a,b[bi++].values(0),ai,0,len); - } - while(bi < _uptoBCumCount) { - if (_ACount == 0 && bi == _uptoBCumCount - 1) { - LibMatrixMult.vectMultiplyAdd(TMP1, b[bi++].values(0), c, 0, 0, 0, len); - } else { - LibMatrixMult.vectMultiplyWrite(TMP1, b[bi++].values(0), TMP1, 0, 0, 0, len); - } + else { + LibMatrixMult.vectMultiplyWrite(TMP1, b[bi++].values(rix), TMP1, 0, ai, 0, len); } + } + } - if(_ACount == 1) { - LibMatrixMult.vectMultiplyAdd(b[bi].values(0)[rix],TMP1,c,0,0,len); - } + if(_BCount == 1) { + if(_ACount == 1) + if(TMP1 == null) + vectMultiplyAdd(b[bi + 1].values(0)[rix], a, b[bi].values(0), c, ai, 0, 0, len); + else + vectMultiplyAdd(b[bi + 1].values(0)[rix], TMP1, b[bi].values(0), c, 0, 0, 0, len); + else if(TMP1 == null) + LibMatrixMult.vectMultiplyAdd(a, b[bi].values(0), c, ai, 0, 0, len); + else + LibMatrixMult.vectMultiplyAdd(TMP1, b[bi].values(0), c, 0, 0, 0, len); + } + else if(_ACount == 1) { + if(TMP1 == null) + LibMatrixMult.vectMultiplyAdd(b[bi].values(0)[rix], a, c, ai, 0, len); + else + LibMatrixMult.vectMultiplyAdd(b[bi].values(0)[rix], TMP1, c, 0, 0, len); } + } - protected void genexec_A_or_(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, long grix, int rix) { + private void genexec_A_or_(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, + long grix, int rix) { int bi = 0; double[] TMP1 = null; double TMP2 = 0; - if (_ABCount == 0 && _BCount == 0){ - TMP2 = LibSpoofPrimitives.dotProduct(a,b[bi++].values(rix),ai,ai,len); - } - else if (_ABCount != 0){ + if (_ABCount == 1 && _BCount == 0) + TMP2 = LibSpoofPrimitives.dotProduct(TMP1,b[bi++].values(rix),0,ai,len); + else if (_ABCount != 0) { TMP1 = LibSpoofPrimitives.vectMultWrite(a,b[bi++].values(rix),ai,ai,len); while (bi < _ABCount) { - if(_BCount == 0 && bi == _ABCount - 1) { + if(_BCount == 0 && bi == _ABCount - 1) TMP2 = LibSpoofPrimitives.dotProduct(TMP1,b[bi++].values(rix),0,ai,len); - }else { + else LibMatrixMult.vectMultiplyWrite(TMP1, b[bi++].values(rix), TMP1, 0, ai, 0, len); - } } } - if(_BCount == 1 && TMP1 == null) { - TMP2 = LibSpoofPrimitives.dotProduct(a,b[bi++].values(0),ai,0,len); - } - else if(_BCount > 0 && TMP1 == null) { - TMP1 = LibSpoofPrimitives.vectMultWrite(a,b[bi++].values(0),ai,0,len); - } - while(bi < _uptoBCumCount) { - if(bi == _uptoBCumCount -1){ - TMP2 = LibSpoofPrimitives.dotProduct(TMP1,b[bi++].values(0),0,0,len); - } - else { - LibMatrixMult.vectMultiplyWrite(TMP1, b[bi++].values(0), TMP1, 0, 0, 0, len); - } - } + if(_BCount == 1) + if(_ABCount != 0) + TMP2 = LibSpoofPrimitives.dotProduct(TMP1,b[bi++].values(0),0,0,len); + else + TMP2 = LibSpoofPrimitives.dotProduct(a,b[bi++].values(0),ai,0,len); + else if(_ABCount == 0) + TMP2 = LibSpoofPrimitives.vectSum(a, ai, len); + + if(_ACount == 1) + TMP2 *= b[bi].values(0)[rix]; - if(_ACount == 1) { - TMP2 *= b[bi].values(0)[rix]; - } if (_EinsumRewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A) c[ci] = TMP2; else c[0] += TMP2; } + private void HARDCODEDgenexec_A_or_(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, + int len, long grix, int rix) { + double[] TMP1 = LibSpoofPrimitives.vectMultWrite(a,b[0].values(rix),ai,ai,len); + double TMP2 = LibSpoofPrimitives.dotProduct(TMP1,b[1].values(0),0,ai,len); + TMP2 *= b[2].values(0)[rix]; + c[rix] = TMP2; + } + protected void genexec(double[] avals, int[] aix, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int alen, int len, long grix, int rix) { throw new RuntimeException("Sparse fused einsum not implemented"); } + + + // I am not sure if it is worth copying to LibMatrixMult so for now added it here + private static final VectorSpecies SPECIES = DoubleVector.SPECIES_PREFERRED; + private static final int vLen = SPECIES.length(); + public static void vectMultiplyWrite( final double aval, double[] a, double[] b, double[] c,int ai, int bi, int ci, final int len ) + { + final int bn = len%vLen; + + //rest, not aligned to vLen-blocks + for( int j = 0; j < bn; j++, ai++, bi++, ci++) + c[ ci ] = aval * b[ bi ] * a[ ai ]; + + //unrolled vLen-block (for better instruction-level parallelism) + DoubleVector avalVec = DoubleVector.broadcast(SPECIES, aval); + for( int j = bn; j < len; j+=vLen, ai+=vLen, bi+=vLen, ci+=vLen) + { + DoubleVector aVec = DoubleVector.fromArray(SPECIES, a, ai); + DoubleVector bVec = DoubleVector.fromArray(SPECIES, b, bi); + avalVec.mul(bVec).mul(aVec).intoArray(c, ci); + } + } + + public static void vectMultiplyAdd( final double aval, double[] a, double[] b, double[] c,int ai, int bi, int ci, final int len ) + { + final int bn = len%vLen; + + //rest, not aligned to vLen-blocks + for( int j = 0; j < bn; j++, ai++, bi++, ci++) + c[ ci ] += aval * b[ bi ] * a[ ai ]; + + //unrolled vLen-block (for better instruction-level parallelism) + DoubleVector avalVec = DoubleVector.broadcast(SPECIES, aval); + for( int j = bn; j < len; j+=vLen, ai+=vLen, bi+=vLen, ci+=vLen) + { + DoubleVector aVec = DoubleVector.fromArray(SPECIES, a, ai); + DoubleVector bVec = DoubleVector.fromArray(SPECIES, b, bi); + DoubleVector cVec = DoubleVector.fromArray(SPECIES, c, ci); + DoubleVector tmp = aVec.mul(bVec); + tmp.fma(avalVec, cVec).intoArray(c, ci); + } + } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java index f5ae1f2a538..13748ca980a 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java @@ -38,18 +38,25 @@ import org.apache.sysds.runtime.einsum.EOpNodeBinary.EBinaryOperand; import org.apache.sysds.runtime.functionobjects.*; import org.apache.sysds.runtime.matrix.data.MatrixBlock; -import org.apache.sysds.runtime.matrix.operators.AggregateOperator; -import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; import org.apache.sysds.runtime.matrix.operators.Operator; import org.apache.sysds.runtime.matrix.operators.ReorgOperator; +import org.apache.sysds.utils.Explain; import java.util.*; import java.util.function.Predicate; +import java.util.stream.Collectors; + +import static org.apache.sysds.api.DMLScript.EXPLAIN; +import static org.apache.sysds.hops.rewrite.RewriteMatrixMultChainOptimization.mmChainDP; public class EinsumCPInstruction extends BuiltinNaryCPInstruction { public static final boolean FORCE_CELL_TPL = false; public static final boolean FUSED = true; public static final boolean FUSE_OUTER_MULTIPLY = true; + + + public static final boolean PRINT_TRACE = true; + protected static final Log LOG = LogFactory.getLog(EinsumCPInstruction.class.getName()); public String eqStr; private final int _numThreads; @@ -61,8 +68,12 @@ public EinsumCPInstruction(Operator op, String opcode, String istr, CPOperand ou _numThreads = OptimizerUtils.getConstrainedNumThreads(-1)/2; _in = inputs; this.eqStr = inputs[0].getName(); - Logger.getLogger(EinsumCPInstruction.class).setLevel(Level.TRACE); -// Logger.getLogger(EinsumCPInstruction.class).setLevel(Level.WARN); + if (PRINT_TRACE) { +// System.out.println("fusing outer mult:"+FUSE_OUTER_MULTIPLY); + Logger.getLogger(EinsumCPInstruction.class).setLevel(Level.TRACE); + } + else + Logger.getLogger(EinsumCPInstruction.class).setLevel(Level.WARN); } @SuppressWarnings("unused") @@ -90,94 +101,113 @@ public void processInstruction(ExecutionContext ec) { this.einc = einc; String resultString = einc.outChar2 != null ? String.valueOf(einc.outChar1) + einc.outChar2 : einc.outChar1 != null ? String.valueOf(einc.outChar1) : ""; - if( LOG.isDebugEnabled() ) LOG.trace("outrows:"+einc.outRows+", outcols:"+einc.outCols); + if( LOG.isTraceEnabled() ) LOG.trace("output: "+resultString +" "+einc.outRows+"x"+einc.outCols); ArrayList inputsChars = einc.newEquationStringInputsSplit; if(LOG.isTraceEnabled()) LOG.trace(String.join(",",einc.newEquationStringInputsSplit)); - - contractDimensionsAndComputeDiagonals(einc, inputs); + ArrayList eOpNodes = new ArrayList<>(inputsChars.size()); + ArrayList eOpNodesScalars = new ArrayList<>(inputsChars.size()); // computed separately and not included into plan until it is already created //make all vetors col vectors for(int i = 0; i < inputs.size(); i++){ - if(inputs.get(i) != null && inputsChars.get(i).length() == 1) ensureMatrixBlockColumnVector(inputs.get(i)); - } - - if(LOG.isTraceEnabled()) for(Character c : einc.characterAppearanceIndexes.keySet()){ - ArrayList a = einc.characterAppearanceIndexes.get(c); - LOG.trace(c+" count= "+a.size()); - } -// var simplySummableChars = einc.characterAppearanceIndexes.entrySet() -// .stream() -// .filter(e -> e.getValue().size() == 1) -// .map(Map.Entry::getKey) -// .collect(Collectors.toSet()); - - // compute scalar by suming-all matrices: - Double scalar = null; - for(int i=0;i< inputs.size(); i++){ - String s = inputsChars.get(i); - if(s.equals("")){ - MatrixBlock mb = inputs.get(i); - if (scalar == null) scalar = mb.get(0,0); - else scalar*= mb.get(0,0); - inputs.set(i,null); - inputsChars.set(i,null); - } + if(inputsChars.get(i).length() == 1) ensureMatrixBlockColumnVector(inputs.get(i)); } - if (scalar != null) { - inputsChars.add(""); - inputs.add(new MatrixBlock(scalar)); - } + addSumDimensionsDiagonalsAndScalars(einc, inputsChars, eOpNodes, eOpNodesScalars); - HashMap characterToOccurences = new HashMap<>(); - for (Character key : einc.characterAppearanceIndexes.keySet()) { - characterToOccurences.put(key, einc.characterAppearanceIndexes.get(key).size()); - } - for (Character key : einc.charToDimensionSize.keySet()) { - if(!characterToOccurences.containsKey(key)) - characterToOccurences.put(key, 1); - } + HashMap characterToOccurences = einc.characterAppearanceCount; - ArrayList eOpNodes = new ArrayList<>(inputsChars.size()); for (int i = 0; i < inputsChars.size(); i++) { if (inputsChars.get(i) == null) continue; - EOpNodeData n = new EOpNodeData(inputsChars.get(i).length() > 0 ? inputsChars.get(i).charAt(0) : null, inputsChars.get(i).length() > 1 ? inputsChars.get(i).charAt(1) : null, i); + EOpNodeData n = new EOpNodeData(!inputsChars.get(i).isEmpty() ? inputsChars.get(i).charAt(0) : null, inputsChars.get(i).length() > 1 ? inputsChars.get(i).charAt(1) : null, i); eOpNodes.add(n); } - Pair > plan; + + ArrayList ret = new ArrayList<>(); + addVectorMultiplies(eOpNodes, eOpNodesScalars,characterToOccurences, einc.outChar1, einc.outChar2, ret); + eOpNodes = ret; + + List plan; ArrayList remainingMatrices; + if(!FORCE_CELL_TPL) { - if(FUSED) { - ArrayList ret = new ArrayList<>(); - EOpNodeFuse fuse = EOpNodeFuse.match(eOpNodes, einc.outChar1, einc.outChar2, ret, characterToOccurences, einc.charToDimensionSize); - if(fuse != null){ - eOpNodes = ret; - } - while(ret.size() > 2 && fuse != null){ - ret = new ArrayList<>(); - fuse = EOpNodeFuse.match(eOpNodes, einc.outChar1, einc.outChar2, ret, characterToOccurences, einc.charToDimensionSize); - if(fuse != null){ - eOpNodes = ret; - } - } + if(true){ + plan = generatePlanFusionAndMM(eOpNodes, eOpNodesScalars, einc.charToDimensionSize, characterToOccurences, einc.outChar1, einc.outChar2); + }else { // old way: try to do fusion first and then rest in binary fashion cost based + if(FUSED) { + ret = new ArrayList<>(); + EOpNodeFuse fuse = EOpNodeFuse.match(eOpNodes, einc.outChar1, einc.outChar2, + einc.charToDimensionSize, characterToOccurences, ret); + if(fuse != null) { + ret.add(fuse); + eOpNodes = ret; + } + while(ret.size() > 2 && fuse != null) { + ret = new ArrayList<>(); + fuse = EOpNodeFuse.match(eOpNodes, einc.outChar1, einc.outChar2, einc.charToDimensionSize, + characterToOccurences, ret); + if(fuse != null) { + ret.add(fuse); + eOpNodes = ret; + } + } - } + } - plan = generatePlan(0, eOpNodes, einc.charToDimensionSize, characterToOccurences, einc.outChar1, einc.outChar2); + Pair> costAndPlan = generatePlanBinaryCostBased(0, eOpNodes, einc.charToDimensionSize, characterToOccurences, + einc.outChar1, einc.outChar2); + plan = costAndPlan.getRight(); + } + if(!eOpNodesScalars.isEmpty()){ + EOpNode l = eOpNodesScalars.get(0); + for(int i = 1; i < eOpNodesScalars.size(); i++){ + l = new EOpNodeBinary(null,null,l, eOpNodesScalars.get(i), EBinaryOperand.scalar_scalar); + } - if(plan.getRight().size() == 1) - plan.getRight().get(0).reorderChildren(einc.outChar1, einc.outChar2); + if(plan.isEmpty()) plan.add(l); + else { + int minCost = Integer.MAX_VALUE; + EOpNode addToNode = null; + int minIdx = -1; + for(int i = 0; i < plan.size(); i++) { + EOpNode n = plan.get(i); + Pair costAndNode = addScalarToPlanFindMinCost(n, einc.charToDimensionSize); + if(costAndNode.getLeft() < minCost) { + minCost = costAndNode.getLeft(); + addToNode = costAndNode.getRight(); + minIdx = i; + } + } + plan.set(minIdx, mergeEOpNodeWithScalar(addToNode, l)); + } + + } + if(plan.size() == 1) + plan.get(0).reorderChildren(einc.outChar1, einc.outChar2); + + if(plan.size() == 2 && plan.get(0).c2 == null && plan.get(1).c2 == null){ + if (plan.get(0).c1 == einc.outChar1 && plan.get(1).c1 == einc.outChar2) + plan.set(0, new EOpNodeBinary(plan.get(0).c1, plan.get(1).c1, plan.get(0), plan.get(1), EBinaryOperand.A_B)); + if (plan.get(0).c1 == einc.outChar2 && plan.get(1).c1 == einc.outChar1) + plan.set(0, new EOpNodeBinary(plan.get(1).c1, plan.get(0).c1, plan.get(1), plan.get(0), EBinaryOperand.A_B)); + } + if (EXPLAIN != Explain.ExplainType.NONE ) + System.out.println("Einsum plan:"); + for(var pl : plan){ + System.out.println("- "+String.join("\n- ", pl.recursivePrintString())); + } - remainingMatrices = executePlan(plan.getRight(), inputs); + remainingMatrices = executePlan(plan, inputs); }else{ - plan = Pair.of(0, eOpNodes); + plan = eOpNodes; remainingMatrices = inputs; } + + if(!FORCE_CELL_TPL && remainingMatrices.size() == 1){ - EOpNode resNode = plan.getRight().get(0); + EOpNode resNode = plan.get(0); if (einc.outChar1 != null && einc.outChar2 != null){ if(resNode.c1 == einc.outChar1 && resNode.c2 == einc.outChar2){ ec.setMatrixOutput(output.getName(), remainingMatrices.get(0)); @@ -209,16 +239,16 @@ else if(resNode.c1 == einc.outChar2 && resNode.c2 == einc.outChar1){ ArrayList mbs = remainingMatrices; ArrayList chars = new ArrayList<>(); - for (int i = 0; i < plan.getRight().size(); i++) { + for (int i = 0; i < plan.size(); i++) { String s; - if(plan.getRight().get(i).c1 == null) s = ""; - else if(plan.getRight().get(i).c2 == null) s = plan.getRight().get(i).c1.toString(); - else s = plan.getRight().get(i).c1.toString() + plan.getRight().get(i).c2; + if(plan.get(i).c1 == null) s = ""; + else if(plan.get(i).c2 == null) s = plan.get(i).c1.toString(); + else s = plan.get(i).c1.toString() + plan.get(i).c2; chars.add(s); } ArrayList summingChars = new ArrayList<>(); - for (Character c : einc.characterAppearanceIndexes.keySet()) { + for (Character c : characterToOccurences.keySet()) { if (c != einc.outChar1 && c != einc.outChar2) summingChars.add(c); } if(LOG.isTraceEnabled()) LOG.trace("finishing with cell tpl: "+String.join(",", chars)); @@ -235,46 +265,338 @@ else if(resNode.c1 == einc.outChar2 && resNode.c2 == einc.outChar1){ } - private void contractDimensionsAndComputeDiagonals(EinsumContext einc, ArrayList inputs) { - for(int i = 0; i< einc.contractDims.length; i++){ - //AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(),Types.CorrectionLocationType.LASTCOLUMN); - AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); - - if(einc.diagonalInputs[i]){ - ReorgOperator op = new ReorgOperator(DiagIndex.getDiagIndexFnObject()); - inputs.set(i, inputs.get(i).reorgOperations(op, new MatrixBlock(),0,0,0)); - } - if (einc.contractDims[i] == null) continue; - switch (einc.contractDims[i]){ - case CONTRACT_BOTH: { - AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), _numThreads); - MatrixBlock res = new MatrixBlock(1, 1, false); - inputs.get(i).aggregateUnaryOperations(aggun, res, 0, null); - inputs.set(i, res); - break; + private EOpNode mergeEOpNodeWithScalar(EOpNode addToNode, EOpNode scalar) { + if(addToNode instanceof EOpNodeFuse fuse) { + switch (fuse.einsumRewriteType) { + case AB_BA_B_A__A, AB_BA_B_A_AZ__Z -> { + fuse.addScalarAsIntermediate(scalar); + return fuse; + } + }; + return new EOpNodeBinary(addToNode.c1,addToNode.c2,addToNode,scalar,EBinaryOperand.AB_scalar); + } + if(addToNode.c1 == null) + return new EOpNodeBinary(null,null,addToNode,scalar,EBinaryOperand.scalar_scalar); + if(addToNode.c2 == null) + return new EOpNodeBinary(addToNode.c1,null,addToNode,scalar,EBinaryOperand.A_scalar); + return new EOpNodeBinary(addToNode.c1,addToNode.c2,addToNode,scalar,EBinaryOperand.AB_scalar); + } + + private static Pair addScalarToPlanFindMinCost(EOpNode plan, HashMap charToSizeMap) { + int thisSize = 0; + if(plan.c1 != null) thisSize += charToSizeMap.get(plan.c1); + if(plan.c2 != null) thisSize += charToSizeMap.get(plan.c2); + int cost = thisSize; + + if (plan instanceof EOpNodeData || plan instanceof EOpNodeUnary) return Pair.of(thisSize, plan); + + List inputs = List.of(); + + if (plan instanceof EOpNodeBinary bin) inputs = List.of(bin._left, bin._right); + else if(plan instanceof EOpNodeFuse fuse){ + cost = switch (fuse.einsumRewriteType) { + case AB_BA_B_A__ -> 1; // thisSize + case AB_BA_B_A__AB -> thisSize; + case AB_BA_B_A__B -> thisSize; + case AB_BA_B_A__A -> 2; // intermediate is scalar, 2 because if there is some real scalar + case AB_BA_B_A_AZ__Z -> 2; // intermediate is scalar + case AB_BA_B_A_AZ__BZ -> thisSize; + case AB_BA_B_A_AZ__ZB -> thisSize; + }; + inputs = fuse.operands.stream().flatMap(List::stream).collect(Collectors.toList()); + } + + for(EOpNode inp : inputs){ + Pair min = addScalarToPlanFindMinCost(inp, charToSizeMap); + if(min.getLeft() < cost){ + cost = min.getLeft(); + plan = min.getRight(); + } + } + return Pair.of(cost, plan); + } + + private static void addVectorMultiplies(ArrayList eOpNodes, ArrayList eOpNodesScalars,HashMap charToOccurences, Character outChar1, Character outChar2,ArrayList ret) { + HashMap> vectorCharacterToIndices = new HashMap<>(); + for (int i = 0; i < eOpNodes.size(); i++) { + if (eOpNodes.get(i).c2 == null) { + if (vectorCharacterToIndices.containsKey(eOpNodes.get(i).c1)) + vectorCharacterToIndices.get(eOpNodes.get(i).c1).add(eOpNodes.get(i)); + else + vectorCharacterToIndices.put(eOpNodes.get(i).c1, new ArrayList<>(Collections.singletonList(eOpNodes.get(i)))); + } + } + HashSet usedNodes = new HashSet<>(); + for(Character c : vectorCharacterToIndices.keySet()){ + ArrayList nodes = vectorCharacterToIndices.get(c); + + if(nodes.size()==1) continue; + EOpNode left = nodes.get(0); + usedNodes.add(left); + boolean canBeSummed = c != outChar1 && c != outChar2 && charToOccurences.get(c) == nodes.size(); + + for(int i = 1; i < nodes.size(); i++){ + EOpNode right = nodes.get(i); + + if(canBeSummed && i == nodes.size()-1){ + left = new EOpNodeBinary(null, null, left, right, EBinaryOperand.a_a); + }else { + left = new EOpNodeBinary(c, null, left, right, EBinaryOperand.A_A); } - case CONTRACT_RIGHT: { - AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads); - MatrixBlock res = new MatrixBlock(inputs.get(i).getNumRows(), 1, false); - inputs.get(i).aggregateUnaryOperations(aggun, res, 0, null); - inputs.set(i, res); - break; + usedNodes.add(right); + } + if(canBeSummed) { + eOpNodesScalars.add(left); + charToOccurences.put(c, 0); + } + else { + ret.add(left); + charToOccurences.put(c, charToOccurences.get(c) - nodes.size() + 1); + } + } + for(EOpNode inp : eOpNodes){ + if(!usedNodes.contains(inp)) ret.add(inp); + } + } + + private void addSumDimensionsDiagonalsAndScalars(EinsumContext einc, ArrayList inputStrings, + ArrayList eOpNodes, ArrayList eOpNodesScalars) { + for(int i = 0; i< inputStrings.size(); i++){ + String s = inputStrings.get(i); + if (s.length() == 0){ + eOpNodesScalars.add(new EOpNodeData(null, null,i)); + inputStrings.set(i, null); + continue; + }else if (s.length() == 1){ + char c1 = s.charAt(0); + if((einc.outChar1 == null || c1 != einc.outChar1) && (einc.outChar2 == null || c1 != einc.outChar2) && einc.characterAppearanceCount.get(c1) == 1){ + EOpNode e0 = new EOpNodeData(c1, null,i); + eOpNodesScalars.add(new EOpNodeUnary(null, null, e0, EOpNodeUnary.EUnaryOperand.SUM)); + inputStrings.set(i, null); } - case CONTRACT_LEFT: { - AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), _numThreads); - MatrixBlock res = new MatrixBlock(inputs.get(i).getNumColumns(), 1, false); - inputs.get(i).aggregateUnaryOperations(aggun, res, 0, null); - inputs.set(i, res); - break; + continue; + } + + char c1 = s.charAt(0); + char c2 = s.charAt(1); + Character newC1 = null; + EOpNodeUnary.EUnaryOperand op = null; + + if(c1 == c2){ + if((einc.outChar1 == null || c1 != einc.outChar1) && (einc.outChar2 == null || c1 != einc.outChar2) && einc.characterAppearanceCount.get(c1) == 2){ + op = EOpNodeUnary.EUnaryOperand.SUM; + }else { + einc.characterAppearanceCount.put(c1, einc.characterAppearanceCount.get(c1) - 1); + op = EOpNodeUnary.EUnaryOperand.DIAG; + newC1 = c1; + } + }else if((einc.outChar1 == null || c1 != einc.outChar1) && (einc.outChar2 == null || c1 != einc.outChar2) && einc.characterAppearanceCount.get(c1) == 1){ + if ((einc.outChar1 == null || c2 != einc.outChar1) && (einc.outChar2 == null || c2 != einc.outChar2) && einc.characterAppearanceCount.get(c2) == 1){ + op = EOpNodeUnary.EUnaryOperand.SUM; + }else{ + newC1 = c2; + op = EOpNodeUnary.EUnaryOperand.SUM_ROWS; } - default: - break; + }else if((einc.outChar1 == null || c2 != einc.outChar1) && (einc.outChar2 == null || c2 != einc.outChar2) && einc.characterAppearanceCount.get(c2) == 1){ + newC1 = c1; + op = EOpNodeUnary.EUnaryOperand.SUM_COLS; } + + if(op == null) continue; + + EOpNodeData e0 = new EOpNodeData(c1, c2, i); + EOpNodeUnary res = new EOpNodeUnary(newC1, null, e0, op); + + if(op == EOpNodeUnary.EUnaryOperand.SUM) eOpNodesScalars.add(res); + else eOpNodes.add(res); + + inputStrings.set(i, null); + } + } + + private static List generatePlanFusionAndMM(ArrayList eOpNodes, + ArrayList eOpNodesScalars, HashMap charToSizeMap, HashMap charToOccurences, Character outChar1, Character outChar2) { + ArrayList ret; + int lastNumOfOperands = -1; + while(lastNumOfOperands != eOpNodes.size() && eOpNodes.size() > 1){ + lastNumOfOperands = eOpNodes.size(); + + EOpNodeFuse fuse = null; + + do { + ret = new ArrayList<>(); + fuse = EOpNodeFuse.match(eOpNodes, outChar1, outChar2, charToSizeMap, charToOccurences, ret); + if(fuse != null) { + if(fuse.c1 == null) eOpNodesScalars.add(fuse); + else ret.add(fuse); + eOpNodes = ret; + } + } while(eOpNodes.size() > 1 && fuse != null); + + ret = new ArrayList<>(); + addVectorMultiplies(eOpNodes, eOpNodesScalars,charToOccurences, outChar1, outChar2, ret); + eOpNodes = ret; + + ret = new ArrayList<>(); + ArrayList> matrixMultiplies = findMatrixMultiplicationChains(eOpNodes, outChar1, outChar2, charToOccurences, charToSizeMap, ret); + + for(List list : matrixMultiplies) { + EOpNodeBinary bin = optimizeMMChain(list, charToSizeMap); + ret.add(bin); + } + eOpNodes = ret; + + } + + return eOpNodes; + } + + private static EOpNodeBinary optimizeMMChain(List mmChain, HashMap charToSizeMap) { + ArrayList> dimensions = new ArrayList<>(); + + for(int i = 0; i < mmChain.size()-1; i++){ + EOpNode n1 = mmChain.get(i); + EOpNode n2 = mmChain.get(i+1); + if(n1.c2 == n2.c1 || n1.c2 == n2.c2) dimensions.add(Pair.of(charToSizeMap.get(n1.c1), charToSizeMap.get(n1.c2))); + else dimensions.add(Pair.of(charToSizeMap.get(n1.c2), charToSizeMap.get(n1.c1))); // transpose this one + } + EOpNode prelast = mmChain.get(mmChain.size()-2); + EOpNode last = mmChain.get(mmChain.size()-1); + if (last.c1 == prelast.c2 || last.c1 == prelast.c1) dimensions.add(Pair.of(charToSizeMap.get(last.c1), charToSizeMap.get(last.c2))); + else dimensions.add(Pair.of(charToSizeMap.get(last.c2), charToSizeMap.get(last.c1))); + + + double[] dimsArray = new double[mmChain.size() + 1]; + getDimsArray( dimensions, dimsArray ); + + int size = mmChain.size(); + int[][] splitMatrix = mmChainDP(dimsArray, mmChain.size()); + + EOpNodeBinary res = (EOpNodeBinary) getBinaryFromSplit(splitMatrix,0,size-1, mmChain); + return res; + } + + private static EOpNode getBinaryFromSplit(int[][] splitMatrix, int i, int j, List mmChain) { + if (i==j) return mmChain.get(i); + int split = splitMatrix[i][j]; + + EOpNode left = getBinaryFromSplit(splitMatrix,i,split,mmChain); + EOpNode right = getBinaryFromSplit(splitMatrix,split+1,j,mmChain); + return EOpNodeBinary.combineMatrixMultiply(left, right); + } + + private static void getDimsArray( ArrayList> chain, double[] dimsArray ) + { + for( int i = 0; i < chain.size(); i++ ) { + if (i == 0) { + dimsArray[i] = chain.get(i).getLeft(); + if (dimsArray[i] <= 0) { + throw new RuntimeException( + "EinsumCPInstruction::optimizeMMChain() : Invalid Matrix Dimension: "+ dimsArray[i]); + } + } + else if (chain.get(i - 1).getRight() != chain.get(i).getLeft()) { + throw new RuntimeException( + "EinsumCPInstruction::optimizeMMChain() : Matrix Dimension Mismatch: " + + chain.get(i - 1).getRight()+" != "+chain.get(i).getLeft()); + } + + dimsArray[i + 1] = chain.get(i).getRight(); + if( dimsArray[i + 1] <= 0 ) { + throw new RuntimeException( + "EinsumCPInstruction::optimizeMMChain() : Invalid Matrix Dimension: " + dimsArray[i + 1]); + } + } + } + private static ArrayList> findMatrixMultiplicationChains(ArrayList inpOperands, Character outChar1, Character outChar2, HashMap charToOccurences, HashMap charToSizeMap, ArrayList ret) { + HashSet charactersThatCanBeContracted = new HashSet<>(); + HashMap> characterToNodes = new HashMap<>(); + ArrayList operandsTodo = new ArrayList<>(); + for(EOpNode op : inpOperands) { + if(op.c2 == null || op.c1 == null) continue; + + if (characterToNodes.containsKey(op.c1)) characterToNodes.get(op.c1).add(op); + else characterToNodes.put(op.c1, new ArrayList<>(Collections.singletonList(op))); + if (characterToNodes.containsKey(op.c2)) characterToNodes.get(op.c2).add(op); + else characterToNodes.put(op.c2, new ArrayList<>(Collections.singletonList(op))); + + boolean todo = false; + if (charToOccurences.get(op.c1) == 2 && op.c1 != outChar1 && op.c1 != outChar2) { + charactersThatCanBeContracted.add(op.c1); + todo = true; + } + if (charToOccurences.get(op.c2) == 2 && op.c2 != outChar1 && op.c2 != outChar2) { + charactersThatCanBeContracted.add(op.c2); + todo = true; + } + if (todo) operandsTodo.add(op); + } + ArrayList> res = new ArrayList<>(); + + HashSet doneNodes = new HashSet<>(); + + for(int i = 0; i < operandsTodo.size(); i++){ + EOpNode iterateNode = operandsTodo.get(i); + +// if (iterateNode == null) continue; // was added previously somewhere + if (doneNodes.contains(iterateNode)) continue;// was added previously somewhere + doneNodes.add(iterateNode); + + LinkedList multiplies = new LinkedList<>(); + multiplies.add(iterateNode); + + EOpNode nextNode = iterateNode; + Character nextC = iterateNode.c2; + // add to right using c2 + while(charactersThatCanBeContracted.contains(nextC)) { + EOpNode one = characterToNodes.get(nextC).get(0); + EOpNode two = characterToNodes.get(nextC).get(1); + if (nextNode == one){ + multiplies.addLast(two); + nextNode = two; + }else{ + multiplies.addLast(one); + nextNode = one; + } + if(nextNode.c1 == nextC) nextC = nextNode.c2; + else nextC = nextNode.c1; + doneNodes.add(nextNode); + } + + // add to left using c1 + nextNode = iterateNode; + nextC = iterateNode.c1; + while(charactersThatCanBeContracted.contains(nextC)) { + EOpNode one = characterToNodes.get(nextC).get(0); + EOpNode two = characterToNodes.get(nextC).get(1); + if (nextNode == one){ + multiplies.addFirst(two); + nextNode = two; + }else{ + multiplies.addFirst(one); + nextNode = one; + } + if(nextNode.c1 == nextC) nextC = nextNode.c2; + else nextC = nextNode.c1; + doneNodes.add(nextNode); + } + + res.add(multiplies); } + + + for(EOpNode op : inpOperands) { + if (doneNodes.contains(op)) continue; + ret.add(op); + } + + + + return res; } - // ideally the return list contains only one final element - private Pair> generatePlan(int cost, ArrayList operands, HashMap charToSizeMap, HashMap charToOccurences, Character outChar1, Character outChar2) { + // old way + private Pair> generatePlanBinaryCostBased(int cost, ArrayList operands, HashMap charToSizeMap, HashMap charToOccurences, Character outChar1, Character outChar2) { Integer minCost = cost; List minNodes = operands; @@ -322,7 +644,7 @@ else if (operands.size() == 1){ } newOperands.add(newNode); - Pair> furtherPlan = generatePlan(thisCost, newOperands,charToSizeMap, charToOccurences, outChar1, outChar2); + Pair> furtherPlan = generatePlanBinaryCostBased(thisCost, newOperands,charToSizeMap, charToOccurences, outChar1, outChar2); if(furtherPlan.getRight().size() < (minNodes.size()) || furtherPlan.getLeft() < minCost){ minCost = furtherPlan.getLeft(); minNodes = furtherPlan.getRight(); diff --git a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java index dbf9047968f..98d7628b82c 100644 --- a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java +++ b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java @@ -37,62 +37,119 @@ import java.io.IOException; import java.nio.file.Files; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.List; +import java.util.Map; @RunWith(Parameterized.class) public class EinsumTest extends AutomatedTestBase { - final private static List TEST_CONFIGS = List.of( - new Config("ij,jk->ik", List.of(shape(50, 600), shape(600, 10))), // mm - new Config("ji,jk->ik", List.of(shape(600, 5), shape(600, 10))), - new Config("ji,kj->ik", List.of(shape(600, 5), shape(10, 600))), - new Config("ij,kj->ik", List.of(shape(5, 600), shape(10, 600))), - - new Config("ji,jk->i", List.of(shape(600, 5), shape(600, 10))), - new Config("ij,jk->i", List.of(shape(5, 600), shape(600, 10))), - - new Config("ji,jk->k", List.of(shape(600, 5), shape(600, 10))), - new Config("ij,jk->k", List.of(shape(5, 600), shape(600, 10))), - - new Config("ji,jk->j", List.of(shape(600, 5), shape(600, 10))), - - new Config("ji,ji->ji", List.of(shape(600, 5), shape(600, 5))), // elemwise mult - new Config("ji,ji,ji->ji", List.of(shape(600, 5),shape(600, 5), shape(600, 5)), - List.of(0.0001, 0.0005, 0.001)), - new Config("ji,ij->ji", List.of(shape(600, 5), shape(5, 600))), // elemwise mult - - - new Config("ij,i->ij", List.of(shape(100, 50), shape(100))), // col mult - new Config("ji,i->ij", List.of(shape(50, 100), shape(100))), // row mult - new Config("ij,i->i", List.of(shape(100, 50), shape(100))), - new Config("ij,i->j", List.of(shape(100, 50), shape(100))), - - new Config("i,i->", List.of(shape(50), shape(50))), - new Config("i,j->", List.of(shape(50), shape(80))), - new Config("i,j->ij", List.of(shape(50), shape(80))), // outer vect mult - new Config("i,j->ji", List.of(shape(50), shape(80))), // outer vect mult - - new Config("ij->", List.of(shape(100, 50))), // sum - new Config("ij->i", List.of(shape(100, 50))), // sum(1) - new Config("ij->j", List.of(shape(100, 50))), // sum(0) - new Config("ij->ji", List.of(shape(100, 50))), // T - - new Config("ab,cd->ba", List.of(shape( 600, 10), shape(6, 5))), - new Config("ab,cd,g->ba", List.of(shape( 600, 10), shape(6, 5), shape(3))), - - new Config("ab,bc,cd,de->ae", List.of(shape(5, 600), shape(600, 10), shape(10, 5), shape(5, 4))), // chain of mm - - new Config("ji,jz,zx->ix", List.of(shape(600, 5), shape( 600, 10), shape(10, 2))), - new Config("fx,fg,fz,xg->z", List.of(shape(600, 5), shape( 600, 10), shape(600, 6), shape(5, 10))), - new Config("fx,fg,fz,xg,zx,zg->g", // each idx 3 times (cell tpl) - List.of(shape(5, 60), shape(5, 30), shape(5, 10), shape(60, 30), shape(10, 60), shape(10, 30))), - - new Config("i->", List.of(shape(100))), - new Config("i->i", List.of(shape(100))) - ); - + final private static List TEST_CONFIGS = List.of( + + new Config("ij,jk->ik", List.of(shape(50, 60), shape(60, 50))), // mm + new Config("ji,jk->ik", List.of(shape(600, 5), shape(600, 10))), +// new Config("ji,kj->ik", List.of(shape(600, 5), shape(10, 600))), +// new Config("ij,kj->ik", List.of(shape(5, 600), shape(10, 600))), +//// new Config("ab,cb,zc->az", List.of(shape(500, 900), shape(1000, 900), shape(400, 1000))), +// +// new Config("ji,jk->i", List.of(shape(600, 5), shape(600, 10))), +// new Config("ij,jk->i", List.of(shape(5, 600), shape(600, 10))), +// +// new Config("ji,jk->k", List.of(shape(600, 5), shape(600, 10))), +// new Config("ij,jk->k", List.of(shape(5, 600), shape(600, 10))), +// +// new Config("ji,jk->j", List.of(shape(600, 5), shape(600, 10))), +// +// new Config("ji,ji->ji", List.of(shape(600, 5), shape(600, 5))), // elemwise mult +// new Config("ji,ji,ji->ji", List.of(shape(600, 5),shape(600, 5), shape(600, 5)), +// List.of(0.0001, 0.0005, 0.001)), +// new Config("ji,ij->ji", List.of(shape(600, 5), shape(5, 600))), // elemwise mult +// +// +// new Config("ij,i->ij", List.of(shape(10, 5), shape(10))), // col mult +// new Config("ji,i->ij", List.of(shape(5, 10), shape(10))), // row mult +// new Config("ij,i->i", List.of(shape(10, 5), shape(10))), +// new Config("ij,i->j", List.of(shape(10, 5), shape(10))), +// + new Config("i,i->", List.of(shape(5), shape(5))), +// new Config("i,j->", List.of(shape(5), shape(80))), +// new Config("i,j->ij", List.of(shape(5), shape(80))), // outer vect mult +// new Config("i,j->ji", List.of(shape(5), shape(80))), // outer vect mult +// +// new Config("ij->", List.of(shape(10, 5))), // sum +// new Config("ij->i", List.of(shape(10, 5))), // sum(1) +// new Config("ij->j", List.of(shape(10, 5))), // sum(0) +// new Config("ij->ji", List.of(shape(10, 5))), // T +// +// new Config("ab,cd->ba", List.of(shape( 600, 10), shape(6, 5))), +// new Config("ab,cd,g->ba", List.of(shape( 600, 10), shape(6, 5), shape(3))), +// + new Config("ab,bc,cd,de->ae", List.of(shape(5, 600), shape(600, 10), shape(10, 5), shape(5, 4))), // chain of mm +// +// new Config("ji,jz,zx->ix", List.of(shape(600, 5), shape( 600, 10), shape(10, 2))), +// new Config("fx,fg,fz,xg->z", List.of(shape(600, 5), shape( 600, 10), shape(600, 6), shape(5, 10))), + new Config("fx,fg,fz,xg,zx,zg->g", // each idx 3 times (cell tpl) + List.of(shape(5, 60), shape(5, 30), shape(5, 10), shape(60, 30), shape(10, 60), shape(10, 30))), +// +// new Config("i->", List.of(shape(10))), +// new Config("i->i", List.of(shape(10))), +// +//// test fused +// new Config("ij,ij,ji,i,j->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))), +// new Config("ij,ij,ji,j->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(5))), +// new Config("ij,ij,ji,i->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10))), +// new Config("ij,i,j->ij", List.of(shape(10, 5), shape(10),shape(5))), +// new Config("ij,i,i->ij", List.of(shape(10, 5), shape(10),shape(10)), List.of(0.01,0.02,0.1)), +// new Config("ij,j,j->ij", List.of(shape(10, 5), shape(5),shape(5))), +// new Config("ij,i,j->i", List.of(shape(10, 5), shape(10),shape(5))), + new Config("ij,i,i->i", List.of(shape(10, 5), shape(10),shape(10))), +// new Config("ij,j,j->i", List.of(shape(10, 5), shape(5),shape(5))), + new Config("ij,i,j->j", List.of(shape(10, 5), shape(10),shape(5))), +// new Config("ij,i,i->j", List.of(shape(10, 5), shape(10),shape(10))), +// new Config("ij,j,j->j", List.of(shape(10, 5), shape(5),shape(5))), + new Config("ij,i,j->", List.of(shape(10, 5), shape(10),shape(5))), +// new Config("ij,i,i->", List.of(shape(10, 5), shape(10),shape(10))), +// new Config("ij,j,j->", List.of(shape(10, 5), shape(5),shape(5))), + + // test fuesed: + new Config("ij,ij,ji,i,j->i", List.of(shape(7, 5), shape(7, 5),shape(5, 7),shape(7),shape(5))), + new Config("ij,i,i,j,j->i", List.of(shape(7, 50), shape(7),shape(7),shape(50),shape(50))), + new Config("ij,i,i,j,j,z->i", List.of(shape(7, 50), shape(7),shape(7),shape(50),shape(50),shape(2)),List.of(1.0,1.0,1.0,1.0,1.0,1.0) ), // include scalar to tmpl + new Config("ij,ij,ij,i,j->j", List.of(shape(5, 5), shape(5, 5), shape(5, 5),shape(5),shape(5))), + new Config("ij,ij,ij,i,j,iz->z", List.of(shape(5, 5), shape(5, 5), shape(5, 5),shape(5),shape(5),shape(5, 60))), +// new Config("ij,ij,ij,i,j,iz->z", List.of(shape(5, 5), shape(5, 5), shape(5, 5),shape(5),shape(5),shape(5, 600))), + + + new Config("ij,i,j,iz->z", List.of(shape(10, 5),shape(10),shape(5),shape(10, 51))), + new Config("ij,i,j,iz->z", List.of(shape(100, 10),shape(100),shape(10),shape(100, 10))), + new Config("ij,i,j->j", List.of(shape(100, 5),shape(100),shape(5))), + new Config("ij,ij,ji->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10))), + new Config("ij,ij,ji,i,j->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))), + new Config("ij,ij,ji->i", List.of(shape(10, 5), shape(10, 5),shape(5, 10))), + new Config("ij,ij,ji->j", List.of(shape(10, 5), shape(10, 5),shape(5, 10))), + new Config("ij,ij,ji->", List.of(shape(10, 5), shape(10, 5),shape(5, 10))), + new Config("ij,ij,ji,j->j", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(5))), + new Config("ij,ij,ji,i->j", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10))), + new Config("ij,ij,ji,i,j->j", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))), + new Config("ij,ij,ji,j->", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(5))), + new Config("ij,ij,ji,i->", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10))), + new Config("ij,ij,ji,i,j->", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))), + new Config("ij,ij,ji,j->i", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(5))), + new Config("ij,ij,ji,i->i", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10))), + new Config("ij,ij,ji,i,j->i", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))), + new Config("ij,ij,ji,j,i,ab,ba,ab,a,b->jb", Map.of('i',10, 'j',5, 'a', 11, 'b', 6)), + //skinny right: + new Config("ij,ij,ji,j,i,iz->jz", Map.of('i',600, 'j',10,'z', 6)), // with outer mm + // no skinny right: + new Config("ij,ij,ji,j,i,iz->jz", Map.of('i',100, 'j',100,'z', 100)), // with outer mm + new Config("ij,ij,ji,j,i,iz->zj", Map.of('i',600, 'j',10,'z', 6)) // with outer mm + ,new Config("ij,ij,ij,jk->ik", List.of(shape(100, 50), shape(100, 50),shape(100, 50),shape(50, 10))) + +// ,new Config("ij,ij,ji->ij", List.of(shape(100, 50), shape(100, 50),shape(50, 100)), List.of(0.1,1.0,1.0)) + + ); private final int id; private final String einsumStr; //private final List shapes; @@ -153,12 +210,12 @@ private static StringBuilder createDmlFile(Config config, boolean outputScalar) sb.append("A"); sb.append(i); - if (dims.length == 1) { // A1 = seq(1,1000) * 0.0001 + if (dims.length == 1) { // A1 = seq(1,100) * 0.0001 sb.append(" = seq(1,"); sb.append(dims[0]); sb.append(") * "); sb.append(factor); - } else { // A0 = matrix(seq(1,50000), 1000, 50) * 0.0001 + } else { // A0 = matrix(seq(1,5000), 100, 5) * 0.0001 sb.append(" = matrix(seq(1, "); sb.append(dims[0]*dims[1]); sb.append("), "); @@ -172,7 +229,7 @@ private static StringBuilder createDmlFile(Config config, boolean outputScalar) sb.append("\n"); } sb.append("\n"); - +// sb.append("for (i in 1:5) {\n"); sb.append("R = einsum(\""); sb.append(config.einsumStr); sb.append("\", "); @@ -185,6 +242,7 @@ private static StringBuilder createDmlFile(Config config, boolean outputScalar) sb.append("A"); sb.append(config.shapes.size() - 1); sb.append(")"); +// sb.append("\n}\n"); sb.append("\n\n"); sb.append("write(R, $1)\n"); @@ -202,17 +260,17 @@ private static StringBuilder createRFile(Config config, boolean outputScalar) { for (int i = 0; i < config.shapes.size(); i++) { int[] dims = config.shapes.get(i); - + double factor = config.factors != null ? config.factors.get(i) : 0.0001; sb.append("A"); sb.append(i); - if (dims.length == 1) { // A1 = seq(1,1000) * 0.0001 + if (dims.length == 1) { // A1 = seq(1,100) * 0.0001 sb.append(" = seq(1,"); sb.append(dims[0]); sb.append(") * "); sb.append(factor); - } else { // A0 = matrix(seq(1,50000), 1000, 50, byrow=TRUE) * 0.0001 + } else { // A0 = matrix(seq(1,5000), 100, 5, byrow=TRUE) * 0.0001 sb.append(" = matrix(seq(1, "); sb.append(dims[0]*dims[1]); sb.append("), "); @@ -271,19 +329,36 @@ public void cleanUp() { } private static class Config { - public List factors; + public List factors; String einsumStr; List shapes; - Config(String einsum, List shapes) { - this.einsumStr = einsum; - this.shapes = shapes; - this.factors = null; - } + Config(String einsum, List shapes) { + this(einsum,shapes,null); + } + Config(String einsum, Map charToSize){ + this(einsum, charToSize, null); + } + + Config(String einsum, Map charToSize, List factors) { + this.einsumStr = einsum; + String leftPart = einsum.split("->")[0]; + List shapes = new ArrayList<>(); + for(String op : Arrays.stream(leftPart.split(",")).map(x->x.trim()).toList()){ + if (op.length() == 1) { + shapes.add(new int[]{charToSize.get(op.charAt(0))}); + }else{ + shapes.add(new int[]{charToSize.get(op.charAt(0)),charToSize.get(op.charAt(1))}); + } + + } + this.shapes = shapes; + this.factors = factors; + } Config(String einsum, List shapes, List factors) { this.einsumStr = einsum; this.shapes = shapes; - this.factors = factors; + this.factors = factors; } } @@ -327,6 +402,7 @@ private void testCodegenIntegration( String testname) OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = false; runTest(true, false, null, -1); +// if(true) throw new RuntimeException("aa"); runRScript(true); if(outputScalar){ From c0d061bb3845fe71442b4647dfcacda4caca9c12 Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Thu, 13 Nov 2025 00:43:25 +0100 Subject: [PATCH 08/13] bugfixes and move code to other places --- .../RewriteMatrixMultChainOptimization.java | 2 +- .../sysds/runtime/einsum/EOpNodeBinary.java | 112 +++++++++++++++ .../runtime/einsum/EinsumSpoofRowwise.java | 4 + .../instructions/cp/EinsumCPInstruction.java | 133 ++---------------- .../test/functions/einsum/EinsumTest.java | 108 +++++++------- 5 files changed, 184 insertions(+), 175 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimization.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimization.java index fdd2f8343fe..960560c254e 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimization.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimization.java @@ -210,7 +210,7 @@ protected void optimizeMMChain(Hop hop, List mmChain, List mmOperators * Thomas H. Cormen, Charles E. Leiserson, Ronald L. Rivest, Clifford Stein * Introduction to Algorithms, Third Edition, MIT Press, page 395. */ - private static int[][] mmChainDP(double[] dimArray, int size) + public static int[][] mmChainDP(double[] dimArray, int size) { double[][] dpMatrix = new double[size][size]; //min cost table int[][] split = new int[size][size]; //min cost index table diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java index 0bb03737a66..866d346fdd0 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java @@ -19,6 +19,8 @@ package org.apache.sysds.runtime.einsum; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.commons.lang3.tuple.Triple; import org.apache.commons.logging.Log; import org.apache.sysds.runtime.codegen.LibSpoofPrimitives; import org.apache.sysds.runtime.functionobjects.Multiply; @@ -38,6 +40,8 @@ import org.apache.sysds.runtime.matrix.operators.SimpleOperator; import java.util.ArrayList; +import java.util.HashMap; +import java.util.function.Predicate; import static org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction.ensureMatrixBlockColumnVector; import static org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction.ensureMatrixBlockRowVector; @@ -281,4 +285,112 @@ public void reorderChildren(Character outChar1, Character outChar2) { } } + // used in old method + public static Triple> TryCombineAndCost(EOpNode n1 , EOpNode n2, HashMap charToSizeMap, HashMap charToOccurences, Character outChar1, Character outChar2){ + Predicate cannotBeSummed = (c) -> + c == outChar1 || c == outChar2 || charToOccurences.get(c) > 2; + + if(n1.c1 == null) { + // n2.c1 also has to be null + return Triple.of(1, EBinaryOperand.scalar_scalar, Pair.of(null, null)); + } + + if(n2.c1 == null) { + if(n1.c2 == null) + return Triple.of(charToSizeMap.get(n1.c1), EBinaryOperand.A_scalar, Pair.of(n1.c1, null)); + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.AB_scalar, Pair.of(n1.c1, n1.c2)); + } + + if(n1.c1 == n2.c1){ + if(n1.c2 != null){ + if ( n1.c2 == n2.c2){ + if( cannotBeSummed.test(n1.c1)){ + if(cannotBeSummed.test(n1.c2)){ + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.AB_AB, Pair.of(n1.c1, n1.c2)); + } + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.Ba_Ba, Pair.of(n1.c1, null)); + } + + if(cannotBeSummed.test(n1.c2)){ + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.aB_aB, Pair.of(n1.c2, null)); + } + + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.ab_ab, Pair.of(null, null)); + + } + + else if(n2.c2 == null){ + if(cannotBeSummed.test(n1.c1)){ + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*2, EBinaryOperand.AB_A, Pair.of(n1.c1, n1.c2)); + } + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*2, EBinaryOperand.aB_a, Pair.of(n1.c2, null)); // in theory (null, n1.c2) + } + else if(n1.c1 ==outChar1 || n1.c1==outChar2|| charToOccurences.get(n1.c1) > 2){ + return null;// AB,AC + } + else { + return Triple.of((charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2))+(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c2)), EBinaryOperand.aB_aC, Pair.of(n1.c2, n2.c2)); // or n2.c2, n1.c2 + } + }else{ // n1.c2 = null -> c2.c2 = null + if(n1.c1 ==outChar1 || n1.c1==outChar2 || charToOccurences.get(n1.c1) > 2){ + return Triple.of(charToSizeMap.get(n1.c1), EBinaryOperand.A_A, Pair.of(n1.c1, null)); + } + return Triple.of(charToSizeMap.get(n1.c1), EBinaryOperand.a_a, Pair.of(null, null)); + } + + + }else{ // n1.c1 != n2.c1 + if(n1.c2 == null) { + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1), EBinaryOperand.A_B, Pair.of(n1.c1, n2.c1)); + } + else if(n2.c2 == null) { // ab,c + if (n1.c2 == n2.c1) { + if(cannotBeSummed.test(n1.c2)){ + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1), EBinaryOperand.BA_A, Pair.of(n1.c1, n1.c2)); + } + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1), EBinaryOperand.Ba_a, Pair.of(n1.c1, null)); + } + return null; // AB,C + } + else if (n1.c2 == n2.c1) { + if(n1.c1 == n2.c2){ // ab,ba + if(cannotBeSummed.test(n1.c1)){ + if(cannotBeSummed.test(n1.c2)){ + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.AB_BA, Pair.of(n1.c1, n1.c2)); + } + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.Ba_aB, Pair.of(n1.c1, null)); + } + if(cannotBeSummed.test(n1.c2)){ + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.aB_Ba, Pair.of(n1.c2, null)); + } + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.ab_ba, Pair.of(null, null)); + } + if(cannotBeSummed.test(n1.c2)){ + return null; // AB_B + }else{ + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c2), EBinaryOperand.Ba_aC, Pair.of(n1.c1, n2.c2)); + // if(n1.c1 ==outChar1 || n1.c1==outChar2|| charToOccurences.get(n1.c1) > 2){ + // return null; // AB_B + // } + // return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.Ba_a, Pair.of(n1.c1, null)); + } + } + if(n1.c1 == n2.c2) { + if(cannotBeSummed.test(n1.c1)){ + return null; // AB_B + } + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c1), EBinaryOperand.aB_Ca, Pair.of(n2.c1, n1.c2)); // * its just reorder of mmult + } + else if (n1.c2 == n2.c2) { + if(n1.c2 ==outChar1 || n1.c2==outChar2|| charToOccurences.get(n1.c2) > 2){ + return null; // BA_CA + }else{ + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2) +(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c1)), EBinaryOperand.Ba_Ca, Pair.of(n1.c1, n2.c1)); // or n2.c1, n1.c1 + } + } + else { // something like ab,cd + return null; + } + } + } } diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EinsumSpoofRowwise.java b/src/main/java/org/apache/sysds/runtime/einsum/EinsumSpoofRowwise.java index f12667685b8..d0f63832423 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EinsumSpoofRowwise.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EinsumSpoofRowwise.java @@ -108,6 +108,10 @@ private void genexec_AB(double[] a, int ai, SideInput[] b, double[] scalars, dou int bi = 0; double[] TMP1 = null; if (_ABCount != 0){ + if(_ABCount == 1 & _ACount == 0 && _BCount == 0){ + LibMatrixMult.vectMultiplyWrite(a, b[0].values(rix), c, ai, ai, ci, len); + return; + } TMP1 = LibSpoofPrimitives.vectMultWrite(a,b[bi++].values(rix),ai,ai,len); while (bi < _ABCount) { if(_ACount == 0 && _BCount == 0 && bi == _ABCount-1) { diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java index 13748ca980a..6f6f6fc1114 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java @@ -43,7 +43,6 @@ import org.apache.sysds.utils.Explain; import java.util.*; -import java.util.function.Predicate; import java.util.stream.Collectors; import static org.apache.sysds.api.DMLScript.EXPLAIN; @@ -51,7 +50,7 @@ public class EinsumCPInstruction extends BuiltinNaryCPInstruction { public static final boolean FORCE_CELL_TPL = false; - public static final boolean FUSED = true; +// public static final boolean FUSED = true; public static final boolean FUSE_OUTER_MULTIPLY = true; @@ -132,10 +131,10 @@ public void processInstruction(ExecutionContext ec) { ArrayList remainingMatrices; if(!FORCE_CELL_TPL) { - if(true){ + if(true){ // new way: search for fusions and matrix-multiplications chain in a loop plan = generatePlanFusionAndMM(eOpNodes, eOpNodesScalars, einc.charToDimensionSize, characterToOccurences, einc.outChar1, einc.outChar2); }else { // old way: try to do fusion first and then rest in binary fashion cost based - if(FUSED) { + if(true /*FUSED*/) { ret = new ArrayList<>(); EOpNodeFuse fuse = EOpNodeFuse.match(eOpNodes, einc.outChar1, einc.outChar2, einc.charToDimensionSize, characterToOccurences, ret); @@ -152,10 +151,9 @@ public void processInstruction(ExecutionContext ec) { eOpNodes = ret; } } - } - Pair> costAndPlan = generatePlanBinaryCostBased(0, eOpNodes, einc.charToDimensionSize, characterToOccurences, + Pair> costAndPlan = generateBinaryPlanCostBased(0, eOpNodes, einc.charToDimensionSize, characterToOccurences, einc.outChar1, einc.outChar2); plan = costAndPlan.getRight(); } @@ -191,6 +189,7 @@ public void processInstruction(ExecutionContext ec) { plan.set(0, new EOpNodeBinary(plan.get(0).c1, plan.get(1).c1, plan.get(0), plan.get(1), EBinaryOperand.A_B)); if (plan.get(0).c1 == einc.outChar2 && plan.get(1).c1 == einc.outChar1) plan.set(0, new EOpNodeBinary(plan.get(1).c1, plan.get(0).c1, plan.get(1), plan.get(0), EBinaryOperand.A_B)); + plan.remove(1); } if (EXPLAIN != Explain.ExplainType.NONE ) System.out.println("Einsum plan:"); @@ -224,6 +223,7 @@ else if(resNode.c1 == einc.outChar2 && resNode.c2 == einc.outChar1){ } }else if (einc.outChar1 != null){ if(resNode.c1 == einc.outChar1 && resNode.c2 == null){ + ensureMatrixBlockColumnVector(remainingMatrices.get(0)); ec.setMatrixOutput(output.getName(), remainingMatrices.get(0)); }else{ if(LOG.isTraceEnabled()) LOG.trace("Einsum expected: "+resultString + ", got: "+resNode.c1+resNode.c2); @@ -255,6 +255,9 @@ else if(resNode.c1 == einc.outChar2 && resNode.c2 == einc.outChar1){ MatrixBlock res = computeCellSummation(mbs, chars, resultString, einc.charToDimensionSize, summingChars, einc.outRows, einc.outCols); + if (einc.outChar2 == null) + ensureMatrixBlockColumnVector(res); + if (einc.outRows == 1 && einc.outCols == 1) ec.setScalarOutput(output.getName(), new DoubleObject(res.get(0, 0))); else ec.setMatrixOutput(output.getName(), res); @@ -422,7 +425,6 @@ private static List generatePlanFusionAndMM(ArrayList eOpNodes lastNumOfOperands = eOpNodes.size(); EOpNodeFuse fuse = null; - do { ret = new ArrayList<>(); fuse = EOpNodeFuse.match(eOpNodes, outChar1, outChar2, charToSizeMap, charToOccurences, ret); @@ -445,7 +447,6 @@ private static List generatePlanFusionAndMM(ArrayList eOpNodes ret.add(bin); } eOpNodes = ret; - } return eOpNodes; @@ -596,7 +597,7 @@ private static ArrayList> findMatrixMultiplicationChains(ArrayList } // old way - private Pair> generatePlanBinaryCostBased(int cost, ArrayList operands, HashMap charToSizeMap, HashMap charToOccurences, Character outChar1, Character outChar2) { + private Pair> generateBinaryPlanCostBased(int cost, ArrayList operands, HashMap charToSizeMap, HashMap charToOccurences, Character outChar1, Character outChar2) { Integer minCost = cost; List minNodes = operands; @@ -604,7 +605,7 @@ private Pair> generatePlanBinaryCostBased(int cost, Array boolean swap = (operands.get(0).c2 == null && operands.get(1).c2 != null) || operands.get(0).c1 == null; EOpNode n1 = operands.get(!swap ? 0 : 1); EOpNode n2 = operands.get(!swap ? 1 : 0); - Triple> t = TryCombineAndCost(n1, n2, charToSizeMap, charToOccurences, outChar1, outChar2); + Triple> t = EOpNodeBinary.TryCombineAndCost(n1, n2, charToSizeMap, charToOccurences, outChar1, outChar2); if (t != null) { EOpNodeBinary newNode = new EOpNodeBinary(t.getRight().getLeft(), t.getRight().getRight(), n1, n2, t.getMiddle()); int thisCost = cost + t.getLeft(); @@ -625,7 +626,7 @@ else if (operands.size() == 1){ EOpNode n2 = operands.get(!swap ? j : i); - Triple> t = TryCombineAndCost(n1, n2, charToSizeMap, charToOccurences, outChar1, outChar2); + Triple> t = EOpNodeBinary.TryCombineAndCost(n1, n2, charToSizeMap, charToOccurences, outChar1, outChar2); if (t != null){ EOpNodeBinary newNode = new EOpNodeBinary(t.getRight().getLeft(), t.getRight().getRight(), n1, n2, t.getMiddle()); int thisCost = cost + t.getLeft(); @@ -644,7 +645,7 @@ else if (operands.size() == 1){ } newOperands.add(newNode); - Pair> furtherPlan = generatePlanBinaryCostBased(thisCost, newOperands,charToSizeMap, charToOccurences, outChar1, outChar2); + Pair> furtherPlan = generateBinaryPlanCostBased(thisCost, newOperands,charToSizeMap, charToOccurences, outChar1, outChar2); if(furtherPlan.getRight().size() < (minNodes.size()) || furtherPlan.getLeft() < minCost){ minCost = furtherPlan.getLeft(); minNodes = furtherPlan.getRight(); @@ -663,114 +664,6 @@ else if (operands.size() == 1){ return Pair.of(minCost, minNodes); } - private static Triple> TryCombineAndCost(EOpNode n1 , EOpNode n2, HashMap charToSizeMap, HashMap charToOccurences, Character outChar1, Character outChar2){ - Predicate cannotBeSummed = (c) -> - c == outChar1 || c == outChar2 || charToOccurences.get(c) > 2; - - if(n1.c1 == null) { - // n2.c1 also has to be null - return Triple.of(1, EBinaryOperand.scalar_scalar, Pair.of(null, null)); - } - - if(n2.c1 == null) { - if(n1.c2 == null) - return Triple.of(charToSizeMap.get(n1.c1), EBinaryOperand.A_scalar, Pair.of(n1.c1, null)); - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.AB_scalar, Pair.of(n1.c1, n1.c2)); - } - - if(n1.c1 == n2.c1){ - if(n1.c2 != null){ - if ( n1.c2 == n2.c2){ - if( cannotBeSummed.test(n1.c1)){ - if(cannotBeSummed.test(n1.c2)){ - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.AB_AB, Pair.of(n1.c1, n1.c2)); - } - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.Ba_Ba, Pair.of(n1.c1, null)); - } - - if(cannotBeSummed.test(n1.c2)){ - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.aB_aB, Pair.of(n1.c2, null)); - } - - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.ab_ab, Pair.of(null, null)); - - } - - else if(n2.c2 == null){ - if(cannotBeSummed.test(n1.c1)){ - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*2, EBinaryOperand.AB_A, Pair.of(n1.c1, n1.c2)); - } - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*2, EBinaryOperand.aB_a, Pair.of(n1.c2, null)); // in theory (null, n1.c2) - } - else if(n1.c1 ==outChar1 || n1.c1==outChar2|| charToOccurences.get(n1.c1) > 2){ - return null;// AB,AC - } - else { - return Triple.of((charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2))+(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c2)), EBinaryOperand.aB_aC, Pair.of(n1.c2, n2.c2)); // or n2.c2, n1.c2 - } - }else{ // n1.c2 = null -> c2.c2 = null - if(n1.c1 ==outChar1 || n1.c1==outChar2 || charToOccurences.get(n1.c1) > 2){ - return Triple.of(charToSizeMap.get(n1.c1), EBinaryOperand.A_A, Pair.of(n1.c1, null)); - } - return Triple.of(charToSizeMap.get(n1.c1), EBinaryOperand.a_a, Pair.of(null, null)); - } - - - }else{ // n1.c1 != n2.c1 - if(n1.c2 == null) { - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1), EBinaryOperand.A_B, Pair.of(n1.c1, n2.c1)); - } - else if(n2.c2 == null) { // ab,c - if (n1.c2 == n2.c1) { - if(cannotBeSummed.test(n1.c2)){ - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1), EBinaryOperand.BA_A, Pair.of(n1.c1, n1.c2)); - } - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1), EBinaryOperand.Ba_a, Pair.of(n1.c1, null)); - } - return null; // AB,C - } - else if (n1.c2 == n2.c1) { - if(n1.c1 == n2.c2){ // ab,ba - if(cannotBeSummed.test(n1.c1)){ - if(cannotBeSummed.test(n1.c2)){ - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.AB_BA, Pair.of(n1.c1, n1.c2)); - } - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.Ba_aB, Pair.of(n1.c1, null)); - } - if(cannotBeSummed.test(n1.c2)){ - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.aB_Ba, Pair.of(n1.c2, null)); - } - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.ab_ba, Pair.of(null, null)); - } - if(cannotBeSummed.test(n1.c2)){ - return null; // AB_B - }else{ - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c2), EBinaryOperand.Ba_aC, Pair.of(n1.c1, n2.c2)); -// if(n1.c1 ==outChar1 || n1.c1==outChar2|| charToOccurences.get(n1.c1) > 2){ -// return null; // AB_B -// } -// return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.Ba_a, Pair.of(n1.c1, null)); - } - } - if(n1.c1 == n2.c2) { - if(cannotBeSummed.test(n1.c1)){ - return null; // AB_B - } - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c1), EBinaryOperand.aB_Ca, Pair.of(n2.c1, n1.c2)); // * its just reorder of mmult - } - else if (n1.c2 == n2.c2) { - if(n1.c2 ==outChar1 || n1.c2==outChar2|| charToOccurences.get(n1.c2) > 2){ - return null; // BA_CA - }else{ - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2) +(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c1)), EBinaryOperand.Ba_Ca, Pair.of(n1.c1, n2.c1)); // or n2.c1, n1.c1 - } - } - else { // we have something like ab,cd - return null; - } - } - } - private ArrayList executePlan(List plan, ArrayList inputs) { ArrayList res = new ArrayList<>(plan.size()); for(EOpNode p : plan){ diff --git a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java index 98d7628b82c..76944cfb991 100644 --- a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java +++ b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java @@ -48,62 +48,62 @@ public class EinsumTest extends AutomatedTestBase { final private static List TEST_CONFIGS = List.of( - new Config("ij,jk->ik", List.of(shape(50, 60), shape(60, 50))), // mm - new Config("ji,jk->ik", List.of(shape(600, 5), shape(600, 10))), -// new Config("ji,kj->ik", List.of(shape(600, 5), shape(10, 600))), -// new Config("ij,kj->ik", List.of(shape(5, 600), shape(10, 600))), -//// new Config("ab,cb,zc->az", List.of(shape(500, 900), shape(1000, 900), shape(400, 1000))), -// -// new Config("ji,jk->i", List.of(shape(600, 5), shape(600, 10))), -// new Config("ij,jk->i", List.of(shape(5, 600), shape(600, 10))), -// -// new Config("ji,jk->k", List.of(shape(600, 5), shape(600, 10))), -// new Config("ij,jk->k", List.of(shape(5, 600), shape(600, 10))), -// -// new Config("ji,jk->j", List.of(shape(600, 5), shape(600, 10))), -// -// new Config("ji,ji->ji", List.of(shape(600, 5), shape(600, 5))), // elemwise mult -// new Config("ji,ji,ji->ji", List.of(shape(600, 5),shape(600, 5), shape(600, 5)), -// List.of(0.0001, 0.0005, 0.001)), -// new Config("ji,ij->ji", List.of(shape(600, 5), shape(5, 600))), // elemwise mult -// -// -// new Config("ij,i->ij", List.of(shape(10, 5), shape(10))), // col mult -// new Config("ji,i->ij", List.of(shape(5, 10), shape(10))), // row mult -// new Config("ij,i->i", List.of(shape(10, 5), shape(10))), -// new Config("ij,i->j", List.of(shape(10, 5), shape(10))), + new Config("ij,jk->ik", List.of(shape(5, 6), shape(6, 5))), // mm + new Config("ji,jk->ik", List.of(shape(6, 5), shape(6, 10))), + new Config("ji,kj->ik", List.of(shape(6, 5), shape(10, 6))), + new Config("ij,kj->ik", List.of(shape(5, 6), shape(10, 6))), +// new Config("ab,cb,zc->az", List.of(shape(500, 900), shape(1000, 900), shape(400, 1000))), + + new Config("ji,jk->i", List.of(shape(60, 5), shape(60, 10))), + new Config("ij,jk->i", List.of(shape(5, 60), shape(60, 10))), + + new Config("ji,jk->k", List.of(shape(60, 5), shape(60, 10))), + new Config("ij,jk->k", List.of(shape(5, 60), shape(60, 10))), + + new Config("ji,jk->j", List.of(shape(60, 5), shape(60, 10))), + + new Config("ji,ji->ji", List.of(shape(60, 5), shape(60, 5))), // elemwise mult + new Config("ji,ji,ji->ji", List.of(shape(60, 5),shape(60, 5), shape(60, 5)), + List.of(0.0001, 0.0005, 0.001)), + new Config("ji,ij->ji", List.of(shape(60, 5), shape(5, 60))), // elemwise mult + + + new Config("ij,i->ij", List.of(shape(10, 5), shape(10))), // col mult + new Config("ji,i->ij", List.of(shape(5, 10), shape(10))), // row mult + new Config("ij,i->i", List.of(shape(10, 5), shape(10))), + new Config("ij,i->j", List.of(shape(10, 5), shape(10))), // new Config("i,i->", List.of(shape(5), shape(5))), -// new Config("i,j->", List.of(shape(5), shape(80))), -// new Config("i,j->ij", List.of(shape(5), shape(80))), // outer vect mult -// new Config("i,j->ji", List.of(shape(5), shape(80))), // outer vect mult -// -// new Config("ij->", List.of(shape(10, 5))), // sum -// new Config("ij->i", List.of(shape(10, 5))), // sum(1) -// new Config("ij->j", List.of(shape(10, 5))), // sum(0) -// new Config("ij->ji", List.of(shape(10, 5))), // T -// -// new Config("ab,cd->ba", List.of(shape( 600, 10), shape(6, 5))), -// new Config("ab,cd,g->ba", List.of(shape( 600, 10), shape(6, 5), shape(3))), + new Config("i,j->", List.of(shape(5), shape(80))), + new Config("i,j->ij", List.of(shape(5), shape(80))), // outer vect mult + new Config("i,j->ji", List.of(shape(5), shape(80))), // outer vect mult + + new Config("ij->", List.of(shape(10, 5))), // sum + new Config("ij->i", List.of(shape(10, 5))), // sum(1) + new Config("ij->j", List.of(shape(10, 5))), // sum(0) + new Config("ij->ji", List.of(shape(10, 5))), // T + + new Config("ab,cd->ba", List.of(shape( 60, 10), shape(6, 5))), + new Config("ab,cd,g->ba", List.of(shape( 60, 10), shape(6, 5), shape(3))), // - new Config("ab,bc,cd,de->ae", List.of(shape(5, 600), shape(600, 10), shape(10, 5), shape(5, 4))), // chain of mm + new Config("ab,bc,cd,de->ae", List.of(shape(5, 60), shape(60, 10), shape(10, 5), shape(5, 4))), // chain of mm // -// new Config("ji,jz,zx->ix", List.of(shape(600, 5), shape( 600, 10), shape(10, 2))), -// new Config("fx,fg,fz,xg->z", List.of(shape(600, 5), shape( 600, 10), shape(600, 6), shape(5, 10))), +// new Config("ji,jz,zx->ix", List.of(shape(60, 5), shape( 60, 10), shape(10, 2))), +// new Config("fx,fg,fz,xg->z", List.of(shape(60, 5), shape( 60, 10), shape(60, 6), shape(5, 10))), new Config("fx,fg,fz,xg,zx,zg->g", // each idx 3 times (cell tpl) List.of(shape(5, 60), shape(5, 30), shape(5, 10), shape(60, 30), shape(10, 60), shape(10, 30))), // -// new Config("i->", List.of(shape(10))), -// new Config("i->i", List.of(shape(10))), -// -//// test fused -// new Config("ij,ij,ji,i,j->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))), -// new Config("ij,ij,ji,j->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(5))), -// new Config("ij,ij,ji,i->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10))), -// new Config("ij,i,j->ij", List.of(shape(10, 5), shape(10),shape(5))), -// new Config("ij,i,i->ij", List.of(shape(10, 5), shape(10),shape(10)), List.of(0.01,0.02,0.1)), -// new Config("ij,j,j->ij", List.of(shape(10, 5), shape(5),shape(5))), -// new Config("ij,i,j->i", List.of(shape(10, 5), shape(10),shape(5))), + new Config("i->", List.of(shape(10))), + new Config("i->i", List.of(shape(10))), + +// test fused + new Config("ij,ij,ji,i,j->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))), + new Config("ij,ij,ji,j->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(5))), + new Config("ij,ij,ji,i->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10))), + new Config("ij,i,j->ij", List.of(shape(10, 5), shape(10),shape(5))), + new Config("ij,i,i->ij", List.of(shape(10, 5), shape(10),shape(10)), List.of(0.01,0.02,0.1)), + new Config("ij,j,j->ij", List.of(shape(10, 5), shape(5),shape(5))), + new Config("ij,i,j->i", List.of(shape(10, 5), shape(10),shape(5))), new Config("ij,i,i->i", List.of(shape(10, 5), shape(10),shape(10))), // new Config("ij,j,j->i", List.of(shape(10, 5), shape(5),shape(5))), new Config("ij,i,j->j", List.of(shape(10, 5), shape(10),shape(5))), @@ -119,11 +119,11 @@ public class EinsumTest extends AutomatedTestBase new Config("ij,i,i,j,j,z->i", List.of(shape(7, 50), shape(7),shape(7),shape(50),shape(50),shape(2)),List.of(1.0,1.0,1.0,1.0,1.0,1.0) ), // include scalar to tmpl new Config("ij,ij,ij,i,j->j", List.of(shape(5, 5), shape(5, 5), shape(5, 5),shape(5),shape(5))), new Config("ij,ij,ij,i,j,iz->z", List.of(shape(5, 5), shape(5, 5), shape(5, 5),shape(5),shape(5),shape(5, 60))), -// new Config("ij,ij,ij,i,j,iz->z", List.of(shape(5, 5), shape(5, 5), shape(5, 5),shape(5),shape(5),shape(5, 600))), +// new Config("ij,ij,ij,i,j,iz->z", List.of(shape(5, 5), shape(5, 5), shape(5, 5),shape(5),shape(5),shape(5, 60))), new Config("ij,i,j,iz->z", List.of(shape(10, 5),shape(10),shape(5),shape(10, 51))), - new Config("ij,i,j,iz->z", List.of(shape(100, 10),shape(100),shape(10),shape(100, 10))), + new Config("ij,i,j,iz->z", List.of(shape(20, 10),shape(20),shape(10),shape(20, 10))), new Config("ij,i,j->j", List.of(shape(100, 5),shape(100),shape(5))), new Config("ij,ij,ji->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10))), new Config("ij,ij,ji,i,j->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))), @@ -143,9 +143,9 @@ public class EinsumTest extends AutomatedTestBase //skinny right: new Config("ij,ij,ji,j,i,iz->jz", Map.of('i',600, 'j',10,'z', 6)), // with outer mm // no skinny right: - new Config("ij,ij,ji,j,i,iz->jz", Map.of('i',100, 'j',100,'z', 100)), // with outer mm - new Config("ij,ij,ji,j,i,iz->zj", Map.of('i',600, 'j',10,'z', 6)) // with outer mm - ,new Config("ij,ij,ij,jk->ik", List.of(shape(100, 50), shape(100, 50),shape(100, 50),shape(50, 10))) + new Config("ij,ij,ji,j,i,iz->jz", Map.of('i',10, 'j',10,'z', 10)), // with outer mm + new Config("ij,ij,ji,j,i,iz->zj", Map.of('i',60, 'j',10,'z', 6)) // with outer mm + ,new Config("ij,ij,ij,jk->ik", List.of(shape(10, 5), shape(10, 5),shape(10, 5),shape(5, 10))) // ,new Config("ij,ij,ji->ij", List.of(shape(100, 50), shape(100, 50),shape(50, 100)), List.of(0.1,1.0,1.0)) From b20b927b03488bd683a42c9136a631311309c38d Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Thu, 13 Nov 2025 18:32:13 +0100 Subject: [PATCH 09/13] clean tests and improve explain prints --- .../sysds/runtime/einsum/EOpNodeBinary.java | 4 +- .../instructions/cp/EinsumCPInstruction.java | 9 +- .../test/functions/einsum/EinsumTest.java | 136 +++++++----------- 3 files changed, 57 insertions(+), 92 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java index 866d346fdd0..18d22ae795c 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java @@ -112,10 +112,10 @@ public String[] recursivePrintString() { String[] res = new String[left.length + right.length+1]; res[0] = this.getClass().getSimpleName()+" ("+_operand.toString()+") "+this.toString(); for (int i=0; i TEST_CONFIGS = List.of( - new Config("ij,jk->ik", List.of(shape(5, 6), shape(6, 5))), // mm new Config("ji,jk->ik", List.of(shape(6, 5), shape(6, 10))), new Config("ji,kj->ik", List.of(shape(6, 5), shape(10, 6))), new Config("ij,kj->ik", List.of(shape(5, 6), shape(10, 6))), -// new Config("ab,cb,zc->az", List.of(shape(500, 900), shape(1000, 900), shape(400, 1000))), - - new Config("ji,jk->i", List.of(shape(60, 5), shape(60, 10))), - new Config("ij,jk->i", List.of(shape(5, 60), shape(60, 10))), - - new Config("ji,jk->k", List.of(shape(60, 5), shape(60, 10))), - new Config("ij,jk->k", List.of(shape(5, 60), shape(60, 10))), + new Config("ab,bc,cd,de->ae", List.of(shape(5, 6), shape(6, 5),shape(5, 6), shape(6, 5))), // mm chain - new Config("ji,jk->j", List.of(shape(60, 5), shape(60, 10))), + new Config("ji,jk->i", List.of(shape(6, 5), shape(6, 4))), + new Config("ij,jk->i", List.of(shape(5, 6), shape(6, 4))), + new Config("ji,jk->k", List.of(shape(6, 5), shape(6, 4))), + new Config("ij,jk->k", List.of(shape(5, 6), shape(6, 4))), + new Config("ji,jk->j", List.of(shape(6, 5), shape(6, 4))), new Config("ji,ji->ji", List.of(shape(60, 5), shape(60, 5))), // elemwise mult - new Config("ji,ji,ji->ji", List.of(shape(60, 5),shape(60, 5), shape(60, 5)), - List.of(0.0001, 0.0005, 0.001)), new Config("ji,ij->ji", List.of(shape(60, 5), shape(5, 60))), // elemwise mult - new Config("ij,i->ij", List.of(shape(10, 5), shape(10))), // col mult new Config("ji,i->ij", List.of(shape(5, 10), shape(10))), // row mult new Config("ij,i->i", List.of(shape(10, 5), shape(10))), new Config("ij,i->j", List.of(shape(10, 5), shape(10))), -// - new Config("i,i->", List.of(shape(5), shape(5))), - new Config("i,j->", List.of(shape(5), shape(80))), + + new Config("i,i->", List.of(shape(5), shape(5))), // dot + new Config("i,j->", List.of(shape(5), shape(80))), // sum new Config("i,j->ij", List.of(shape(5), shape(80))), // outer vect mult new Config("i,j->ji", List.of(shape(5), shape(80))), // outer vect mult new Config("ij->", List.of(shape(10, 5))), // sum + new Config("i->", List.of(shape(10))), // sum new Config("ij->i", List.of(shape(10, 5))), // sum(1) new Config("ij->j", List.of(shape(10, 5))), // sum(0) - new Config("ij->ji", List.of(shape(10, 5))), // T - - new Config("ab,cd->ba", List.of(shape( 60, 10), shape(6, 5))), - new Config("ab,cd,g->ba", List.of(shape( 60, 10), shape(6, 5), shape(3))), -// - new Config("ab,bc,cd,de->ae", List.of(shape(5, 60), shape(60, 10), shape(10, 5), shape(5, 4))), // chain of mm -// -// new Config("ji,jz,zx->ix", List.of(shape(60, 5), shape( 60, 10), shape(10, 2))), -// new Config("fx,fg,fz,xg->z", List.of(shape(60, 5), shape( 60, 10), shape(60, 6), shape(5, 10))), - new Config("fx,fg,fz,xg,zx,zg->g", // each idx 3 times (cell tpl) - List.of(shape(5, 60), shape(5, 30), shape(5, 10), shape(60, 30), shape(10, 60), shape(10, 30))), -// - new Config("i->", List.of(shape(10))), - new Config("i->i", List.of(shape(10))), - -// test fused - new Config("ij,ij,ji,i,j->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))), - new Config("ij,ij,ji,j->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(5))), - new Config("ij,ij,ji,i->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10))), - new Config("ij,i,j->ij", List.of(shape(10, 5), shape(10),shape(5))), - new Config("ij,i,i->ij", List.of(shape(10, 5), shape(10),shape(10)), List.of(0.01,0.02,0.1)), - new Config("ij,j,j->ij", List.of(shape(10, 5), shape(5),shape(5))), - new Config("ij,i,j->i", List.of(shape(10, 5), shape(10),shape(5))), - new Config("ij,i,i->i", List.of(shape(10, 5), shape(10),shape(10))), -// new Config("ij,j,j->i", List.of(shape(10, 5), shape(5),shape(5))), - new Config("ij,i,j->j", List.of(shape(10, 5), shape(10),shape(5))), -// new Config("ij,i,i->j", List.of(shape(10, 5), shape(10),shape(10))), -// new Config("ij,j,j->j", List.of(shape(10, 5), shape(5),shape(5))), - new Config("ij,i,j->", List.of(shape(10, 5), shape(10),shape(5))), -// new Config("ij,i,i->", List.of(shape(10, 5), shape(10),shape(10))), -// new Config("ij,j,j->", List.of(shape(10, 5), shape(5),shape(5))), - - // test fuesed: - new Config("ij,ij,ji,i,j->i", List.of(shape(7, 5), shape(7, 5),shape(5, 7),shape(7),shape(5))), - new Config("ij,i,i,j,j->i", List.of(shape(7, 50), shape(7),shape(7),shape(50),shape(50))), - new Config("ij,i,i,j,j,z->i", List.of(shape(7, 50), shape(7),shape(7),shape(50),shape(50),shape(2)),List.of(1.0,1.0,1.0,1.0,1.0,1.0) ), // include scalar to tmpl - new Config("ij,ij,ij,i,j->j", List.of(shape(5, 5), shape(5, 5), shape(5, 5),shape(5),shape(5))), - new Config("ij,ij,ij,i,j,iz->z", List.of(shape(5, 5), shape(5, 5), shape(5, 5),shape(5),shape(5),shape(5, 60))), -// new Config("ij,ij,ij,i,j,iz->z", List.of(shape(5, 5), shape(5, 5), shape(5, 5),shape(5),shape(5),shape(5, 60))), - - - new Config("ij,i,j,iz->z", List.of(shape(10, 5),shape(10),shape(5),shape(10, 51))), - new Config("ij,i,j,iz->z", List.of(shape(20, 10),shape(20),shape(10),shape(20, 10))), - new Config("ij,i,j->j", List.of(shape(100, 5),shape(100),shape(5))), - new Config("ij,ij,ji->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10))), - new Config("ij,ij,ji,i,j->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))), - new Config("ij,ij,ji->i", List.of(shape(10, 5), shape(10, 5),shape(5, 10))), - new Config("ij,ij,ji->j", List.of(shape(10, 5), shape(10, 5),shape(5, 10))), - new Config("ij,ij,ji->", List.of(shape(10, 5), shape(10, 5),shape(5, 10))), - new Config("ij,ij,ji,j->j", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(5))), - new Config("ij,ij,ji,i->j", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10))), - new Config("ij,ij,ji,i,j->j", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))), - new Config("ij,ij,ji,j->", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(5))), - new Config("ij,ij,ji,i->", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10))), - new Config("ij,ij,ji,i,j->", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))), - new Config("ij,ij,ji,j->i", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(5))), - new Config("ij,ij,ji,i->i", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10))), - new Config("ij,ij,ji,i,j->i", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))), - new Config("ij,ij,ji,j,i,ab,ba,ab,a,b->jb", Map.of('i',10, 'j',5, 'a', 11, 'b', 6)), - //skinny right: - new Config("ij,ij,ji,j,i,iz->jz", Map.of('i',600, 'j',10,'z', 6)), // with outer mm - // no skinny right: - new Config("ij,ij,ji,j,i,iz->jz", Map.of('i',10, 'j',10,'z', 10)), // with outer mm - new Config("ij,ij,ji,j,i,iz->zj", Map.of('i',60, 'j',10,'z', 6)) // with outer mm - ,new Config("ij,ij,ij,jk->ik", List.of(shape(10, 5), shape(10, 5),shape(10, 5),shape(5, 10))) - -// ,new Config("ij,ij,ji->ij", List.of(shape(100, 50), shape(100, 50),shape(50, 100)), List.of(0.1,1.0,1.0)) + new Config("ij->ji", List.of(shape(10, 5))), // T + new Config("ij->ij", List.of(shape(10, 5))), + new Config("i->i", List.of(shape(10))), + new Config("ii->i", List.of(shape(10, 10))), // Diag + new Config("ii,i->i", List.of(shape(10, 10),shape(10))), // Diag*vec + + new Config("ab,cd->ba", List.of(shape( 6, 10), shape(6, 5))), // sum cd to scalar and multiply ab + new Config("fx,fg,fz,xg,zx,zg->g", // each idx 3 times (cell tpl fallback) + List.of(shape(5, 6), shape(5, 3), shape(5, 10), shape(6, 3), shape(10, 6), shape(10, 3))), + + // test fused: + new Config("ij,ij,ji->ij", List.of(shape(10, 5), shape(10, 5), shape(5, 10))), + new Config("ij,ij,ji,i,j->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))), + new Config("ij,ij,ji,i->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10))), + new Config("ij,ij,ji,j->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(5))), + + new Config("ij,ij,ji->i", List.of(shape(10, 5), shape(10, 5), shape(5, 10))), + new Config("ij,ij,ji,i,j->i", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))), + new Config("ij,ij,ji,i->i", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10))), + new Config("ij,ij,ji,j->i", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(5))), + + new Config("ij,ij,ji->j", List.of(shape(10, 5), shape(10, 5), shape(5, 10))), + new Config("ij,ij,ji,i,j->j", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))), + new Config("ij,ij,ji,i->j", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10))), + new Config("ij,ij,ji,j->j", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(5))), + + new Config("ij,ij,ji->", List.of(shape(10, 5), shape(10, 5), shape(5, 10))), + new Config("ij,ij,ji,i,j->", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))), + new Config("ij,ij,ji,i->", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10))), + new Config("ij,ij,ji,j->", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(5))), + + new Config("ij,ij,ij,i,j,iz->z", List.of(shape(5, 5), shape(5, 5), shape(5, 5),shape(5),shape(5),shape(5, 6))), + new Config("ij,ij,ij,i,j,iz,z->z", List.of(shape(5, 5), shape(5, 5), shape(5, 5),shape(5),shape(5),shape(5, 6),shape(6))), + + new Config("ij,i,j,iz->z", List.of(shape(10, 5),shape(10),shape(5),shape(10, 51))), + new Config("ij,i,j,iz->z", List.of(shape(20, 10),shape(20),shape(10),shape(20, 10))), + + new Config("ij,ij,ji,j,i, ab,ba,ab,a,b->jb", Map.of('i',10, 'j',5, 'a', 11, 'b', 6)), + new Config("ij,ij,ji,j,i,iz->jz", Map.of('i',600, 'j',10,'z', 6)), // //skinny right with outer mm + new Config("ij,ij,ji,j,i,iz->jz", Map.of('i',10, 'j',10,'z', 10)), // // no skinny right + new Config("ij,ij,ji,j,i,iz->zj", Map.of('i',60, 'j',10,'z', 6)), + new Config("ij,ij,ij,jk->ik", List.of(shape(10, 5), shape(10, 5),shape(10, 5),shape(5, 10))) ); private final int id; private final String einsumStr; @@ -229,7 +199,6 @@ private static StringBuilder createDmlFile(Config config, boolean outputScalar) sb.append("\n"); } sb.append("\n"); -// sb.append("for (i in 1:5) {\n"); sb.append("R = einsum(\""); sb.append(config.einsumStr); sb.append("\", "); @@ -242,7 +211,6 @@ private static StringBuilder createDmlFile(Config config, boolean outputScalar) sb.append("A"); sb.append(config.shapes.size() - 1); sb.append(")"); -// sb.append("\n}\n"); sb.append("\n\n"); sb.append("write(R, $1)\n"); From c19b2c0e72787fac98899a87555eb0327a274397 Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Tue, 18 Nov 2025 20:33:13 +0100 Subject: [PATCH 10/13] include dimsize in EOpNode, implement the reordering of binary in the final plan --- .../apache/sysds/runtime/einsum/EOpNode.java | 11 +- .../sysds/runtime/einsum/EOpNodeBinary.java | 184 +++++++++++++----- .../sysds/runtime/einsum/EOpNodeData.java | 8 +- .../sysds/runtime/einsum/EOpNodeFuse.java | 22 ++- .../sysds/runtime/einsum/EOpNodeUnary.java | 16 +- .../instructions/cp/EinsumCPInstruction.java | 74 +++---- .../test/functions/einsum/EinsumTest.java | 10 +- 7 files changed, 217 insertions(+), 108 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java index f803dc9e21d..b402ec634bf 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java @@ -21,15 +21,20 @@ import org.apache.commons.logging.Log; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import scala.Int; import java.util.ArrayList; public abstract class EOpNode { public Character c1; - public Character c2; // nullable - public EOpNode(Character c1, Character c2){ + public Character c2; + public Integer dim1; + public Integer dim2; + public EOpNode(Character c1, Character c2, Integer dim1, Integer dim2) { this.c1 = c1; this.c2 = c2; + this.dim1 = dim1; + this.dim2 = dim2; } @Override @@ -43,6 +48,6 @@ public String toString() { public abstract MatrixBlock computeEOpNode(ArrayList inputs, int numOfThreads, Log LOG); - public abstract void reorderChildren(Character outChar1, Character outChar2); + public abstract EOpNode reorderChildrenAndOptimize(EOpNode parent, Character outChar1, Character outChar2); } diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java index 18d22ae795c..4f49df04a12 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java @@ -30,6 +30,7 @@ import org.apache.sysds.runtime.functionobjects.ReduceRow; import org.apache.sysds.runtime.functionobjects.SwapIndex; import org.apache.sysds.runtime.instructions.cp.DoubleObject; +import org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction; import org.apache.sysds.runtime.instructions.cp.ScalarObject; import org.apache.sysds.runtime.matrix.data.LibMatrixMult; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -48,22 +49,22 @@ public class EOpNodeBinary extends EOpNode { - public enum EBinaryOperand { // upper case: char has to remain, lower case: to be summed - ////// summations: ////// - aB_a,// -> B - Ba_a, // -> B - Ba_aC, // mmult -> BC - aB_Ca, + ////// mm: ////// + Ba_aC, // -> BC + aB_Ca, // -> CB Ba_Ca, // -> BC - aB_aC, // outer mult, possibly with transposing first -> BC - a_a,// dot -> + aB_aC, // -> BC - ////// elementwisemult and sums, something like ij,ij->i ////// + ////// elementwisemult and sums ////// aB_aB,// elemwise and colsum -> B Ba_Ba, // elemwise and rowsum ->B Ba_aB, // elemwise, either colsum or rowsum -> B aB_Ba, + ab_ab,//M-M sum all + ab_ba, //M-M.T sum all + aB_a,// -> B + Ba_a, // -> B ////// elementwise, no summations: ////// A_A,// v-elemwise -> A @@ -71,34 +72,99 @@ public enum EBinaryOperand { // upper case: char has to remain, lower case: to b AB_BA, // M-M.T elemwise -> AB AB_A, // M-v colwise -> BA!? BA_A, // M-v rowwise -> BA - ab_ab,//M-M sum all - ab_ba, //M-M.T sum all + ////// other ////// + a_a,// dot -> A_B, // outer mult -> AB A_scalar, // v-scalar AB_scalar, // m-scalar scalar_scalar } - public EOpNode _left; - public EOpNode _right; - public EBinaryOperand _operand; + public EOpNode left; + public EOpNode right; + public EBinaryOperand operand; private boolean transposeResult; - public EOpNodeBinary(Character c1, Character c2, EOpNode left, EOpNode right, EBinaryOperand operand){ - super(c1,c2); - this._left = left; - this._right = right; - this._operand = operand; - } + public EOpNodeBinary(EOpNode left, EOpNode right, EBinaryOperand operand){ + super(null,null,null, null); + Character c1, c2; + Integer dim1, dim2; + switch(operand){ + case Ba_aC -> { + c1=left.c1; + c2=right.c2; + dim1=left.dim1; + dim2=right.dim2; + } + case aB_Ca -> { + c1=left.c2; + c2=right.c1; + dim1=left.dim2; + dim2=right.dim1; + } + case Ba_Ca -> { + c1=left.c1; + c2=right.c1; + dim1=left.dim1; + dim2=right.dim1; + } + case aB_aC -> { + c1=left.c2; + c2=right.c2; + dim1=left.dim2; + dim2=right.dim2; + } + case aB_aB, aB_Ba, aB_a -> { + c1=left.c2; + c2=null; + dim1=left.dim2; + dim2=null; + } + case Ba_Ba, Ba_aB, Ba_a, A_A, A_scalar -> { + c1=left.c1; + c2=null; + dim1=left.dim1; + dim2=null; + } + case ab_ab, ab_ba, a_a, scalar_scalar -> { + c1=null; + c2=null; + dim1=null; + dim2=null; + } + case AB_AB, AB_BA, AB_A, BA_A, AB_scalar ->{ + c1=left.c1; + c2=left.c2; + dim1=left.dim1; + dim2=left.dim2; + } + case A_B -> { + c1=left.c1; + c2=right.c1; + dim1=left.dim1; + dim2=right.dim1; + } + default -> throw new IllegalStateException("EOpNodeBinary Unexpected type: " + operand); + } + // super(c1, c2, dim1, dim2); // unavailable in JDK < 22 + this.c1 = c1; + this.c2 = c2; + this.dim1 = dim1; + this.dim2 = dim2; + this.left = left; + this.right = right; + this.operand = operand; + } + public void setTransposeResult(boolean transposeResult){ this.transposeResult = transposeResult; } public static EOpNodeBinary combineMatrixMultiply(EOpNode left, EOpNode right) { - if (left.c2 == right.c1) { return new EOpNodeBinary(left.c1, right.c2, left, right, EBinaryOperand.Ba_aC); } - if (left.c2 == right.c2) { return new EOpNodeBinary(left.c1, right.c1, left, right, EBinaryOperand.Ba_Ca); } - if (left.c1 == right.c1) { return new EOpNodeBinary(left.c2, right.c2, left, right, EBinaryOperand.aB_aC); } + if (left.c2 == right.c1) { return new EOpNodeBinary(left, right, EBinaryOperand.Ba_aC); } + if (left.c2 == right.c2) { return new EOpNodeBinary(left, right, EBinaryOperand.Ba_Ca); } + if (left.c1 == right.c1) { return new EOpNodeBinary(left, right, EBinaryOperand.aB_aC); } if (left.c1 == right.c2) { - var res = new EOpNodeBinary(left.c2, right.c1, left, right, EBinaryOperand.aB_Ca); + var res = new EOpNodeBinary(left, right, EBinaryOperand.aB_Ca); res.setTransposeResult(true); return res; } @@ -107,10 +173,10 @@ public static EOpNodeBinary combineMatrixMultiply(EOpNode left, EOpNode right) { @Override public String[] recursivePrintString() { - String[] left = _left.recursivePrintString(); - String[] right = _right.recursivePrintString(); + String[] left = this.left.recursivePrintString(); + String[] right = this.right.recursivePrintString(); String[] res = new String[left.length + right.length+1]; - res[0] = this.getClass().getSimpleName()+" ("+_operand.toString()+") "+this.toString(); + res[0] = this.getClass().getSimpleName()+" ("+ operand.toString()+") "+this.toString(); for (int i=0; i inputs, int numThreads, Log LOG) { EOpNodeBinary bin = this; - MatrixBlock left = _left.computeEOpNode(inputs, numThreads, LOG); - MatrixBlock right = _right.computeEOpNode(inputs, numThreads, LOG); + MatrixBlock left = this.left.computeEOpNode(inputs, numThreads, LOG); + MatrixBlock right = this.right.computeEOpNode(inputs, numThreads, LOG); AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); MatrixBlock res; - if(LOG.isTraceEnabled()) LOG.trace("computing binary "+bin._left +","+bin._right +"->"+bin); + if(LOG.isTraceEnabled()) LOG.trace("computing binary "+bin.left +","+bin.right +"->"+bin); - switch (bin._operand){ + switch (bin.operand){ case AB_AB -> { res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); } @@ -255,7 +321,7 @@ public MatrixBlock computeEOpNode(ArrayList inputs, int numThreads, return new MatrixBlock(left.get(0,0)*right.get(0,0)); } default -> { - throw new IllegalArgumentException("Unexpected value: " + bin._operand.toString()); + throw new IllegalArgumentException("Unexpected value: " + bin.operand.toString()); } } @@ -267,25 +333,47 @@ public MatrixBlock computeEOpNode(ArrayList inputs, int numThreads, } @Override - public void reorderChildren(Character outChar1, Character outChar2) { - if (this._operand==EBinaryOperand.aB_aC){ - if(this._right.c2 == outChar1) { - var tmp = _left; - _left = _right; - _right = tmp; - var tmp2 = c1; - c1 = c2; - c2 = tmp2; + public EOpNode reorderChildrenAndOptimize(EOpNode parent, Character outChar1, Character outChar2) { + if (this.operand ==EBinaryOperand.aB_aC){ + if(this.right.c2 == outChar1) { // result is CB so Swap aB and aC + var tmpLeft = left; left = right; right = tmpLeft; + var tmpC1 = c1; c1 = c2; c2 = tmpC1; + var tmpDim1 = dim1; dim1 = dim2; dim2 = tmpDim1; } - _left.reorderChildren(_left.c2, _left.c1); - // check if change happened: - if(_left.c2 == _right.c1) { - this._operand = EBinaryOperand.Ba_aC; + if(EinsumCPInstruction.FUSE_OUTER_MULTIPLY && left instanceof EOpNodeFuse fuse && fuse.einsumRewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__AB && + LibMatrixMult.isSkinnyRightHandSide(left.dim1, left.dim2, right.dim1, right.dim2, true)) { + fuse.operands.get(4).add(right); + fuse.einsumRewriteType = EOpNodeFuse.EinsumRewriteType.AB_BA_B_A_AZ__BZ; + fuse.c1 = fuse.c2; + fuse.c2 = right.c2; + return fuse; + } + + left = left.reorderChildrenAndOptimize(this, left.c2, left.c1); // maybe can be reordered + if(left.c2 == right.c1) { // check if change happened: + this.operand = EBinaryOperand.Ba_aC; } - } + right = right.reorderChildrenAndOptimize(this, right.c1, right.c2); + }else if (this.operand ==EBinaryOperand.Ba_Ca){ + if(this.right.c1 == outChar1) { // result is CB so Swap Ba and Ca + var tmpLeft = left; left = right; right = tmpLeft; + var tmpC1 = c1; c1 = c2; c2 = tmpC1; + var tmpDim1 = dim1; dim1 = dim2; dim2 = tmpDim1; + } + + right = right.reorderChildrenAndOptimize(this, right.c2, right.c1); // maybe can be reordered + if(left.c2 == right.c1) { // check if change happened: + this.operand = EBinaryOperand.Ba_aC; + } + left = left.reorderChildrenAndOptimize(this, left.c1, left.c2); + }else { + left = left.reorderChildrenAndOptimize(this, left.c1, left.c2); // just recurse + right = right.reorderChildrenAndOptimize(this, right.c1, right.c2); + } + return this; } - // used in old method + // used in the old approach public static Triple> TryCombineAndCost(EOpNode n1 , EOpNode n2, HashMap charToSizeMap, HashMap charToOccurences, Character outChar1, Character outChar2){ Predicate cannotBeSummed = (c) -> c == outChar1 || c == outChar2 || charToOccurences.get(c) > 2; @@ -388,7 +476,7 @@ else if (n1.c2 == n2.c2) { return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2) +(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c1)), EBinaryOperand.Ba_Ca, Pair.of(n1.c1, n2.c1)); // or n2.c1, n1.c1 } } - else { // something like ab,cd + else { // something like AB,CD return null; } } diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java index d9b61b29514..d352586a21e 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java @@ -26,8 +26,8 @@ public class EOpNodeData extends EOpNode { public int matrixIdx; - public EOpNodeData(Character c1, Character c2, int matrixIdx){ - super(c1,c2); + public EOpNodeData(Character c1, Character c2, Integer dim1, Integer dim2, int matrixIdx){ + super(c1,c2,dim1,dim2); this.matrixIdx = matrixIdx; } @Override @@ -42,7 +42,7 @@ public MatrixBlock computeEOpNode(ArrayList inputs, int numOfThread } @Override - public void reorderChildren(Character outChar1, Character outChar2) { - + public EOpNode reorderChildrenAndOptimize(EOpNode parent, Character outChar1, Character outChar2) { + return this; } } diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java index 536456fc744..fa65f8bf7fa 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java @@ -74,11 +74,11 @@ public enum EinsumRewriteType{ AB_BA_B_A_AZ__ZB, } - public final EinsumRewriteType einsumRewriteType; + public EinsumRewriteType einsumRewriteType; public final List> operands; - private EOpNodeFuse(Character c1, Character c2, EinsumRewriteType einsumRewriteType, List... operands) { - super(c1,c2); + private EOpNodeFuse(Character c1, Character c2, Integer dim1, Integer dim2, EinsumRewriteType einsumRewriteType, List... operands) { + super(c1,c2, dim1, dim2); this.einsumRewriteType = einsumRewriteType; this.operands = Arrays.asList(operands); } @@ -202,7 +202,7 @@ else if(chars.charAt(1)==b){ if(AZCandidates.size()==1){ if(!doSumB) { // check if outer is possible AB,...,AZ->BZ - if(!EinsumCPInstruction.FUSE_OUTER_MULTIPLY || !LibMatrixMult.isSkinnyRightHandSide(charToSize.get(AB.charAt(0)), charToSize.get(AB.charAt(1)), charToSize.get(AB.charAt(0)),charToSize.get(AZCandidates.iterator().next().charAt(1)),false)) { + if(!EinsumCPInstruction.FUSE_OUTER_MULTIPLY || !LibMatrixMult.isSkinnyRightHandSide(charToSize.get(AB.charAt(0)), charToSize.get(AB.charAt(1)), charToSize.get(AB.charAt(0)),charToSize.get(AZCandidates.iterator().next().charAt(1)),true)) { includeAz=false; } } @@ -253,8 +253,8 @@ else if(chars.charAt(1)==b){ String A = AB.substring(0,1); char a = A.charAt(0); char b = B.charAt(0); - Character c1 = null; - Character c2 = null; + Character c1 = null, c2 = null; + Integer dim1 = null, dim2 = null; EinsumRewriteType t = null; if(!AZs.isEmpty()){ @@ -311,9 +311,11 @@ else if (EinsumCPInstruction.FUSE_OUTER_MULTIPLY) { } if(c1 != null){ charToOccurences.put(c1, charToOccurences.get(c1)+1); + dim1 = charToSize.get(c1); } if(c2 != null){ charToOccurences.put(c2, charToOccurences.get(c2)+1); + dim2 = charToSize.get(c2); } HashSet usedOperands = new HashSet<>(); @@ -340,7 +342,7 @@ else if (EinsumCPInstruction.FUSE_OUTER_MULTIPLY) { } } - var e = new EOpNodeFuse(c1, c2, t, + var e = new EOpNodeFuse(c1, c2, dim1, dim2, t, ABs, BAs, Bs, @@ -445,8 +447,10 @@ public MatrixBlock computeEOpNode(ArrayList inputs, int numThreads, } @Override - public void reorderChildren(Character outChar1, Character outChar2) { - + public EOpNode reorderChildrenAndOptimize(EOpNode parent, Character outChar1, Character outChar2) { + for(List list : operands) + for(int i = 0; i < list.size(); i++) list.set(i,list.get(i).reorderChildrenAndOptimize(this, list.get(i).c1, list.get(i).c2)); + return this; } private static @NotNull List multiplyVectorsIntoOne(List mbs, int size) { diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeUnary.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeUnary.java index b48877625cf..7f61bd6fb62 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeUnary.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeUnary.java @@ -37,10 +37,10 @@ public class EOpNodeUnary extends EOpNode { public EOpNode child; public enum EUnaryOperand { - DIAG, SUM, SUM_ROWS, SUM_COLS + DIAG, SUM, SUM_COLS, SUM_ROWS } - public EOpNodeUnary(Character c1, Character c2, EOpNode child, EUnaryOperand eUnaryOperand) { - super(c1, c2); + public EOpNodeUnary(Character c1, Character c2, Integer dim1, Integer dim2, EOpNode child, EUnaryOperand eUnaryOperand) { + super(c1, c2, dim1, dim2); this.child = child; this.eUnaryOperand = eUnaryOperand; } @@ -71,14 +71,14 @@ public MatrixBlock computeEOpNode(ArrayList inputs, int numOfThread mb.aggregateUnaryOperations(aggun, res, 0, null); yield res; } - case SUM_ROWS->{ + case SUM_COLS ->{ AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), numOfThreads); MatrixBlock res = new MatrixBlock(mb.getNumColumns(), 1, false); mb.aggregateUnaryOperations(aggun, res, 0, null); yield res; } - case SUM_COLS->{ + case SUM_ROWS ->{ AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), numOfThreads); MatrixBlock res = new MatrixBlock(mb.getNumRows(), 1, false); @@ -89,7 +89,7 @@ public MatrixBlock computeEOpNode(ArrayList inputs, int numOfThread } @Override - public void reorderChildren(Character outChar1, Character outChar2) { - - } + public EOpNode reorderChildrenAndOptimize(EOpNode parent, Character outChar1, Character outChar2) { + return this; + } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java index a94e5c7e4fc..05a2e8519b2 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java @@ -109,13 +109,17 @@ public void processInstruction(ExecutionContext ec) { if(inputsChars.get(i).length() == 1) ensureMatrixBlockColumnVector(inputs.get(i)); } - addSumDimensionsDiagonalsAndScalars(einc, inputsChars, eOpNodes, eOpNodesScalars); + addSumDimensionsDiagonalsAndScalars(einc, inputsChars, eOpNodes, eOpNodesScalars, einc.charToDimensionSize); HashMap characterToOccurences = einc.characterAppearanceCount; for (int i = 0; i < inputsChars.size(); i++) { if (inputsChars.get(i) == null) continue; - EOpNodeData n = new EOpNodeData(!inputsChars.get(i).isEmpty() ? inputsChars.get(i).charAt(0) : null, inputsChars.get(i).length() > 1 ? inputsChars.get(i).charAt(1) : null, i); + Character c1 = inputsChars.get(i).isEmpty() ? null : inputsChars.get(i).charAt(0); + Character c2 = inputsChars.get(i).length() > 1 ? inputsChars.get(i).charAt(1) : null; + Integer dim1 = c1 == null ? null : einc.charToDimensionSize.get(c1); + Integer dim2 = c1 == null ? null : einc.charToDimensionSize.get(c2); + EOpNodeData n = new EOpNodeData(c1,c2,dim1,dim2, i); eOpNodes.add(n); } @@ -156,7 +160,7 @@ public void processInstruction(ExecutionContext ec) { if(!eOpNodesScalars.isEmpty()){ EOpNode l = eOpNodesScalars.get(0); for(int i = 1; i < eOpNodesScalars.size(); i++){ - l = new EOpNodeBinary(null,null,l, eOpNodesScalars.get(i), EBinaryOperand.scalar_scalar); + l = new EOpNodeBinary(l, eOpNodesScalars.get(i), EBinaryOperand.scalar_scalar); } if(plan.isEmpty()) plan.add(l); @@ -177,22 +181,25 @@ public void processInstruction(ExecutionContext ec) { } } - if(plan.size() == 1) - plan.get(0).reorderChildren(einc.outChar1, einc.outChar2); if(plan.size() == 2 && plan.get(0).c2 == null && plan.get(1).c2 == null){ if (plan.get(0).c1 == einc.outChar1 && plan.get(1).c1 == einc.outChar2) - plan.set(0, new EOpNodeBinary(plan.get(0).c1, plan.get(1).c1, plan.get(0), plan.get(1), EBinaryOperand.A_B)); + plan.set(0, new EOpNodeBinary(plan.get(0), plan.get(1), EBinaryOperand.A_B)); if (plan.get(0).c1 == einc.outChar2 && plan.get(1).c1 == einc.outChar1) - plan.set(0, new EOpNodeBinary(plan.get(1).c1, plan.get(0).c1, plan.get(1), plan.get(0), EBinaryOperand.A_B)); + plan.set(0, new EOpNodeBinary(plan.get(1), plan.get(0), EBinaryOperand.A_B)); plan.remove(1); } - if (EXPLAIN != Explain.ExplainType.NONE ) + + if(plan.size() == 1) + plan.set(0,plan.get(0).reorderChildrenAndOptimize(null, einc.outChar1, einc.outChar2)); + + if (EXPLAIN != Explain.ExplainType.NONE ) { System.out.println("Einsum plan:"); for(int i = 0; i < plan.size(); i++) { - System.out.println((i+1)+"."); - System.out.println("- "+String.join("\n- ", plan.get(i).recursivePrintString())); + System.out.println((i + 1) + "."); + System.out.println("- " + String.join("\n- ", plan.get(i).recursivePrintString())); } + } remainingMatrices = executePlan(plan, inputs); }else{ @@ -273,13 +280,13 @@ private EOpNode mergeEOpNodeWithScalar(EOpNode addToNode, EOpNode scalar) { return fuse; } }; - return new EOpNodeBinary(addToNode.c1,addToNode.c2,addToNode,scalar,EBinaryOperand.AB_scalar); + return new EOpNodeBinary(addToNode,scalar,EBinaryOperand.AB_scalar); } if(addToNode.c1 == null) - return new EOpNodeBinary(null,null,addToNode,scalar,EBinaryOperand.scalar_scalar); + return new EOpNodeBinary(addToNode,scalar,EBinaryOperand.scalar_scalar); if(addToNode.c2 == null) - return new EOpNodeBinary(addToNode.c1,null,addToNode,scalar,EBinaryOperand.A_scalar); - return new EOpNodeBinary(addToNode.c1,addToNode.c2,addToNode,scalar,EBinaryOperand.AB_scalar); + return new EOpNodeBinary(addToNode,scalar,EBinaryOperand.A_scalar); + return new EOpNodeBinary(addToNode,scalar,EBinaryOperand.AB_scalar); } private static Pair addScalarToPlanFindMinCost(EOpNode plan, HashMap charToSizeMap) { @@ -292,7 +299,7 @@ private static Pair addScalarToPlanFindMinCost(EOpNode plan, H List inputs = List.of(); - if (plan instanceof EOpNodeBinary bin) inputs = List.of(bin._left, bin._right); + if (plan instanceof EOpNodeBinary bin) inputs = List.of(bin.left, bin.right); else if(plan instanceof EOpNodeFuse fuse){ cost = switch (fuse.einsumRewriteType) { case AB_BA_B_A__ -> 1; // thisSize @@ -339,9 +346,9 @@ private static void addVectorMultiplies(ArrayList eOpNodes, ArrayList eOpNodes, ArrayList inputStrings, - ArrayList eOpNodes, ArrayList eOpNodesScalars) { + ArrayList eOpNodes, ArrayList eOpNodesScalars, + HashMap charToDimensionSize) { for(int i = 0; i< inputStrings.size(); i++){ String s = inputStrings.get(i); if (s.length() == 0){ - eOpNodesScalars.add(new EOpNodeData(null, null,i)); + eOpNodesScalars.add(new EOpNodeData(null, null, null, null,i)); inputStrings.set(i, null); continue; }else if (s.length() == 1){ char c1 = s.charAt(0); if((einc.outChar1 == null || c1 != einc.outChar1) && (einc.outChar2 == null || c1 != einc.outChar2) && einc.characterAppearanceCount.get(c1) == 1){ - EOpNode e0 = new EOpNodeData(c1, null,i); - eOpNodesScalars.add(new EOpNodeUnary(null, null, e0, EOpNodeUnary.EUnaryOperand.SUM)); + EOpNode e0 = new EOpNodeData(c1, null, charToDimensionSize.get(c1), null, i); + eOpNodesScalars.add(new EOpNodeUnary(null, null, null, null, e0, EOpNodeUnary.EUnaryOperand.SUM)); inputStrings.set(i, null); } continue; @@ -395,17 +403,17 @@ private void addSumDimensionsDiagonalsAndScalars(EinsumContext einc, ArrayList generatePlanFusionAndMM(ArrayList eOpNodes eOpNodes = ret; ret = new ArrayList<>(); - ArrayList> matrixMultiplies = findMatrixMultiplicationChains(eOpNodes, outChar1, outChar2, charToOccurences, charToSizeMap, ret); + ArrayList> matrixMultiplies = findMatrixMultiplicationChains(eOpNodes, outChar1, outChar2, charToOccurences, + ret); for(List list : matrixMultiplies) { EOpNodeBinary bin = optimizeMMChain(list, charToSizeMap); @@ -506,7 +515,8 @@ else if (chain.get(i - 1).getRight() != chain.get(i).getLeft()) { } } } - private static ArrayList> findMatrixMultiplicationChains(ArrayList inpOperands, Character outChar1, Character outChar2, HashMap charToOccurences, HashMap charToSizeMap, ArrayList ret) { + private static ArrayList> findMatrixMultiplicationChains(ArrayList inpOperands, Character outChar1, Character outChar2, HashMap charToOccurences, + ArrayList ret) { HashSet charactersThatCanBeContracted = new HashSet<>(); HashMap> characterToNodes = new HashMap<>(); ArrayList operandsTodo = new ArrayList<>(); @@ -536,7 +546,6 @@ private static ArrayList> findMatrixMultiplicationChains(ArrayList for(int i = 0; i < operandsTodo.size(); i++){ EOpNode iterateNode = operandsTodo.get(i); -// if (iterateNode == null) continue; // was added previously somewhere if (doneNodes.contains(iterateNode)) continue;// was added previously somewhere doneNodes.add(iterateNode); @@ -582,14 +591,11 @@ private static ArrayList> findMatrixMultiplicationChains(ArrayList res.add(multiplies); } - for(EOpNode op : inpOperands) { if (doneNodes.contains(op)) continue; ret.add(op); } - - return res; } @@ -604,7 +610,7 @@ private Pair> generateBinaryPlanCostBased(int cost, Array EOpNode n2 = operands.get(!swap ? 1 : 0); Triple> t = EOpNodeBinary.TryCombineAndCost(n1, n2, charToSizeMap, charToOccurences, outChar1, outChar2); if (t != null) { - EOpNodeBinary newNode = new EOpNodeBinary(t.getRight().getLeft(), t.getRight().getRight(), n1, n2, t.getMiddle()); + EOpNodeBinary newNode = new EOpNodeBinary(n1, n2, t.getMiddle()); int thisCost = cost + t.getLeft(); return Pair.of(thisCost, Arrays.asList(newNode)); } @@ -625,7 +631,7 @@ else if (operands.size() == 1){ Triple> t = EOpNodeBinary.TryCombineAndCost(n1, n2, charToSizeMap, charToOccurences, outChar1, outChar2); if (t != null){ - EOpNodeBinary newNode = new EOpNodeBinary(t.getRight().getLeft(), t.getRight().getRight(), n1, n2, t.getMiddle()); + EOpNodeBinary newNode = new EOpNodeBinary(n1, n2, t.getMiddle()); int thisCost = cost + t.getLeft(); if(n1.c1 != null) charToOccurences.put(n1.c1, charToOccurences.get(n1.c1)-1); diff --git a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java index cc28bbae0f3..95dfd39aaeb 100644 --- a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java +++ b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java @@ -51,6 +51,11 @@ public class EinsumTest extends AutomatedTestBase new Config("ji,jk->ik", List.of(shape(6, 5), shape(6, 10))), new Config("ji,kj->ik", List.of(shape(6, 5), shape(10, 6))), new Config("ij,kj->ik", List.of(shape(5, 6), shape(10, 6))), + new Config("ij,jk->ki", List.of(shape(5, 6), shape(6, 5))), // mm t + new Config("ji,jk->ki", List.of(shape(6, 5), shape(6, 10))), + new Config("ji,kj->ki", List.of(shape(6, 5), shape(10, 6))), + new Config("ij,kj->ki", List.of(shape(5, 6), shape(10, 6))), + new Config("ij,kp,pj->ki", List.of(shape(5,6), shape(5,4), shape(4, 6))), // reordering new Config("ab,bc,cd,de->ae", List.of(shape(5, 6), shape(6, 5),shape(5, 6), shape(6, 5))), // mm chain new Config("ji,jk->i", List.of(shape(6, 5), shape(6, 4))), @@ -117,8 +122,9 @@ public class EinsumTest extends AutomatedTestBase new Config("ij,ij,ji,j,i, ab,ba,ab,a,b->jb", Map.of('i',10, 'j',5, 'a', 11, 'b', 6)), new Config("ij,ij,ji,j,i,iz->jz", Map.of('i',600, 'j',10,'z', 6)), // //skinny right with outer mm new Config("ij,ij,ji,j,i,iz->jz", Map.of('i',10, 'j',10,'z', 10)), // // no skinny right - new Config("ij,ij,ji,j,i,iz->zj", Map.of('i',60, 'j',10,'z', 6)), - new Config("ij,ij,ij,jk->ik", List.of(shape(10, 5), shape(10, 5),shape(10, 5),shape(5, 10))) + new Config("ij,ij,ji,j,i,iz->zj", Map.of('i',60, 'j',10,'z', 6)), // //skinny right with outer mm + new Config("ij,ij,ij,jk->ik", List.of(shape(10, 5), shape(10, 5),shape(10, 5),shape(5, 10))), + new Config("ab,ab,a,ag,gz->bz", List.of(shape(10, 5), shape(10, 5),shape(10),shape(10,200),shape(200,7))) ); private final int id; private final String einsumStr; From 288f80d728f98f0830972fcfaf07ac8e7b3fec58 Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Mon, 24 Nov 2025 00:22:51 +0100 Subject: [PATCH 11/13] bugfix and optimize outer product decision --- .../java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java | 4 ++-- .../java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java index 4f49df04a12..dbcfcdae6a5 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java @@ -269,7 +269,7 @@ public MatrixBlock computeEOpNode(ArrayList inputs, int numThreads, res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), numThreads); } case aB_aC -> { - if(false && LibMatrixMult.isSkinnyRightHandSide(left.getNumRows(), left.getNumColumns(), right.getNumRows(), right.getNumColumns(), true)){ + if(false && LibMatrixMult.isSkinnyRightHandSide(left.getNumRows(), left.getNumColumns(), right.getNumRows(), right.getNumColumns(), false)){ res = new MatrixBlock(left.getNumColumns(), right.getNumColumns(),false); res.allocateDenseBlock(); double[] m1 = left.getDenseBlock().values(0); @@ -341,7 +341,7 @@ public EOpNode reorderChildrenAndOptimize(EOpNode parent, Character outChar1, Ch var tmpDim1 = dim1; dim1 = dim2; dim2 = tmpDim1; } if(EinsumCPInstruction.FUSE_OUTER_MULTIPLY && left instanceof EOpNodeFuse fuse && fuse.einsumRewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__AB && - LibMatrixMult.isSkinnyRightHandSide(left.dim1, left.dim2, right.dim1, right.dim2, true)) { + left.dim1 * left.dim2 * 8 > LibMatrixMult.L3_CACHESIZE && LibMatrixMult.isSkinnyRightHandSide(left.dim1, left.dim2, right.dim1, right.dim2, false)) { fuse.operands.get(4).add(right); fuse.einsumRewriteType = EOpNodeFuse.EinsumRewriteType.AB_BA_B_A_AZ__BZ; fuse.c1 = fuse.c2; diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java index fa65f8bf7fa..5cfdb3b78ea 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java @@ -202,7 +202,7 @@ else if(chars.charAt(1)==b){ if(AZCandidates.size()==1){ if(!doSumB) { // check if outer is possible AB,...,AZ->BZ - if(!EinsumCPInstruction.FUSE_OUTER_MULTIPLY || !LibMatrixMult.isSkinnyRightHandSide(charToSize.get(AB.charAt(0)), charToSize.get(AB.charAt(1)), charToSize.get(AB.charAt(0)),charToSize.get(AZCandidates.iterator().next().charAt(1)),true)) { + if(!EinsumCPInstruction.FUSE_OUTER_MULTIPLY || !LibMatrixMult.isSkinnyRightHandSide(charToSize.get(AB.charAt(0)), charToSize.get(AB.charAt(1)), charToSize.get(AB.charAt(0)),charToSize.get(AZCandidates.iterator().next().charAt(1)),false)) { includeAz=false; } } @@ -264,7 +264,7 @@ else if(chars.charAt(1)==b){ c1 = azC2; } else if (EinsumCPInstruction.FUSE_OUTER_MULTIPLY) { - if(LibMatrixMult.isSkinnyRightHandSide(charToSize.get(AB.charAt(0)), charToSize.get(AB.charAt(1)), charToSize.get(AB.charAt(0)),charToSize.get(azC2),false)){ + if(charToSize.get(AB.charAt(0)) * charToSize.get(AB.charAt(1)) * 8 > LibMatrixMult.L3_CACHESIZE && LibMatrixMult.isSkinnyRightHandSide(charToSize.get(AB.charAt(0)), charToSize.get(AB.charAt(1)), charToSize.get(AB.charAt(0)),charToSize.get(azC2),false)){ if (outChar1 == azC2 && outChar2 == b) { t = EinsumRewriteType.AB_BA_B_A_AZ__ZB; c1 = azC2; @@ -280,10 +280,12 @@ else if (EinsumCPInstruction.FUSE_OUTER_MULTIPLY) { } }else{ + doSumA=false; t=null; AZs=new ArrayList<>(); } }else{ + doSumA=false; t=null; AZs=new ArrayList<>(); } From 42a2e2df9da3203d13c2bb84e365562630ce017b Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Fri, 28 Nov 2025 16:00:07 +0100 Subject: [PATCH 12/13] saving work done --- .../sysds/runtime/einsum/EOpNodeBinary.java | 102 +-- .../sysds/runtime/einsum/EOpNodeFuse.java | 689 +++++++++--------- .../runtime/einsum/EinsumSpoofRowwise.java | 343 ++++----- .../instructions/cp/EinsumCPInstruction.java | 70 +- 4 files changed, 602 insertions(+), 602 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java index dbcfcdae6a5..eba00ae442b 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java @@ -25,9 +25,6 @@ import org.apache.sysds.runtime.codegen.LibSpoofPrimitives; import org.apache.sysds.runtime.functionobjects.Multiply; import org.apache.sysds.runtime.functionobjects.Plus; -import org.apache.sysds.runtime.functionobjects.ReduceAll; -import org.apache.sysds.runtime.functionobjects.ReduceCol; -import org.apache.sysds.runtime.functionobjects.ReduceRow; import org.apache.sysds.runtime.functionobjects.SwapIndex; import org.apache.sysds.runtime.instructions.cp.DoubleObject; import org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction; @@ -35,13 +32,13 @@ import org.apache.sysds.runtime.matrix.data.LibMatrixMult; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.AggregateOperator; -import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; import org.apache.sysds.runtime.matrix.operators.ReorgOperator; import org.apache.sysds.runtime.matrix.operators.SimpleOperator; import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.function.Predicate; import static org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction.ensureMatrixBlockColumnVector; @@ -58,20 +55,20 @@ public enum EBinaryOperand { // upper case: char has to remain, lower case: to b ////// elementwisemult and sums ////// aB_aB,// elemwise and colsum -> B - Ba_Ba, // elemwise and rowsum ->B - Ba_aB, // elemwise, either colsum or rowsum -> B + Ab_Ab, // elemwise and rowsum ->A + Ab_bA, // elemwise, either colsum or rowsum -> A aB_Ba, ab_ab,//M-M sum all ab_ba, //M-M.T sum all aB_a,// -> B - Ba_a, // -> B + Ab_b, // -> A ////// elementwise, no summations: ////// A_A,// v-elemwise -> A AB_AB,// M-M elemwise -> AB AB_BA, // M-M.T elemwise -> AB AB_A, // M-v colwise -> BA!? - BA_A, // M-v rowwise -> BA + AB_B, // M-v rowwise -> AB ////// other ////// a_a,// dot -> @@ -119,7 +116,7 @@ public EOpNodeBinary(EOpNode left, EOpNode right, EBinaryOperand operand){ dim1=left.dim2; dim2=null; } - case Ba_Ba, Ba_aB, Ba_a, A_A, A_scalar -> { + case Ab_Ab, Ab_bA, Ab_b, A_A, A_scalar -> { c1=left.c1; c2=null; dim1=left.dim1; @@ -131,7 +128,7 @@ public EOpNodeBinary(EOpNode left, EOpNode right, EBinaryOperand operand){ dim1=null; dim2=null; } - case AB_AB, AB_BA, AB_A, BA_A, AB_scalar ->{ + case AB_AB, AB_BA, AB_A, AB_B, AB_scalar ->{ c1=left.c1; c2=left.c2; dim1=left.dim1; @@ -214,44 +211,24 @@ public MatrixBlock computeEOpNode(ArrayList inputs, int numThreads, res.allocateDenseBlock(); res.getDenseBlockValues()[0] = LibMatrixMult.dotProduct(left.getDenseBlockValues(), right.getDenseBlockValues(), 0,0 , left.getNumRows()); } - case Ba_Ba -> { - res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); - AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), numThreads); - res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); - } + case Ab_Ab -> { + res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A, List.of(left, right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),null,numThreads); + } case aB_aB -> { - res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); - AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), numThreads); - res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); - ensureMatrixBlockColumnVector(res); + res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_A__B, List.of(left, right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),null,numThreads); } case ab_ab -> { - res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); - AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), numThreads); - res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); - } + res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__, List.of(left, right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),null,numThreads); + } case ab_ba -> { - ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads); - right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); - res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); - AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), numThreads); - res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); - } - case Ba_aB -> { - ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads); - right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); - res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); - AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), numThreads); - res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); - } + res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__, List.of(left), List.of(right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),null,numThreads); + } + case Ab_bA -> { + res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A, List.of(left), List.of(right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),null,numThreads); + } case aB_Ba -> { - ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads); - left = left.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); - res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); - AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), numThreads); - res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); - } - + res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_A__B, List.of(left), List.of(right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),null,numThreads); + } case AB_BA -> { ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads); right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); @@ -289,29 +266,20 @@ public MatrixBlock computeEOpNode(ArrayList inputs, int numThreads, case A_scalar, AB_scalar -> { res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left},new ScalarObject[]{new DoubleObject(right.get(0,0))}, new MatrixBlock()); } - case BA_A -> { - ensureMatrixBlockRowVector(right); - res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); - } - case Ba_a -> { + case AB_B -> { ensureMatrixBlockRowVector(right); res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); - AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), numThreads); - res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); } - + case Ab_b -> { + res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A, List.of(left), new ArrayList<>(), List.of(right), new ArrayList<>(), new ArrayList<>(),null,numThreads); + } case AB_A -> { ensureMatrixBlockColumnVector(right); res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); } case aB_a -> { - ensureMatrixBlockColumnVector(right); - res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); - AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), numThreads); - res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null); - ensureMatrixBlockColumnVector(res); - } - + res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_A__B, List.of(left), new ArrayList<>(), new ArrayList<>(), List.of(right), new ArrayList<>(),null,numThreads); + } case A_B -> { ensureMatrixBlockColumnVector(left); ensureMatrixBlockRowVector(right); @@ -329,6 +297,7 @@ public MatrixBlock computeEOpNode(ArrayList inputs, int numThreads, ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads); res = res.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0); } + if(c2 == null) ensureMatrixBlockColumnVector(res); return res; } @@ -340,10 +309,11 @@ public EOpNode reorderChildrenAndOptimize(EOpNode parent, Character outChar1, Ch var tmpC1 = c1; c1 = c2; c2 = tmpC1; var tmpDim1 = dim1; dim1 = dim2; dim2 = tmpDim1; } - if(EinsumCPInstruction.FUSE_OUTER_MULTIPLY && left instanceof EOpNodeFuse fuse && fuse.einsumRewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__AB && - left.dim1 * left.dim2 * 8 > LibMatrixMult.L3_CACHESIZE && LibMatrixMult.isSkinnyRightHandSide(left.dim1, left.dim2, right.dim1, right.dim2, false)) { - fuse.operands.get(4).add(right); - fuse.einsumRewriteType = EOpNodeFuse.EinsumRewriteType.AB_BA_B_A_AZ__BZ; + if(EinsumCPInstruction.FUSE_OUTER_MULTIPLY && left instanceof EOpNodeFuse fuse && fuse.einsumRewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__AB + && (!EinsumCPInstruction.FUSE_OUTER_MULTIPLY_EXCEEDS_L2_CACHE_CHECK || ((fuse.dim1 * fuse.dim2 *(fuse.ABs.size()+fuse.BAs.size())) + (right.dim1*right.dim2)) * 8 > 6 * 1024 * 1024) + && LibMatrixMult.isSkinnyRightHandSide(left.dim1, left.dim2, right.dim1, right.dim2, false)) { + fuse.AZs.add(right); + fuse.einsumRewriteType = EOpNodeFuse.EinsumRewriteType.AB_BA_A_AZ__BZ; fuse.c1 = fuse.c2; fuse.c2 = right.c2; return fuse; @@ -396,7 +366,7 @@ public static Triple> TryCom if(cannotBeSummed.test(n1.c2)){ return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.AB_AB, Pair.of(n1.c1, n1.c2)); } - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.Ba_Ba, Pair.of(n1.c1, null)); + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.Ab_Ab, Pair.of(n1.c1, null)); } if(cannotBeSummed.test(n1.c2)){ @@ -434,9 +404,9 @@ else if(n1.c1 ==outChar1 || n1.c1==outChar2|| charToOccurences.get(n1.c1) > 2){ else if(n2.c2 == null) { // ab,c if (n1.c2 == n2.c1) { if(cannotBeSummed.test(n1.c2)){ - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1), EBinaryOperand.BA_A, Pair.of(n1.c1, n1.c2)); + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1), EBinaryOperand.AB_B, Pair.of(n1.c1, n1.c2)); } - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1), EBinaryOperand.Ba_a, Pair.of(n1.c1, null)); + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1), EBinaryOperand.Ab_b, Pair.of(n1.c1, null)); } return null; // AB,C } @@ -446,7 +416,7 @@ else if (n1.c2 == n2.c1) { if(cannotBeSummed.test(n1.c2)){ return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.AB_BA, Pair.of(n1.c1, n1.c2)); } - return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.Ba_aB, Pair.of(n1.c1, null)); + return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.Ab_bA, Pair.of(n1.c1, null)); } if(cannotBeSummed.test(n1.c2)){ return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.aB_Ba, Pair.of(n1.c2, null)); diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java index 5cfdb3b78ea..b7237301d59 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java @@ -19,50 +19,38 @@ package org.apache.sysds.runtime.einsum; +import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.logging.Log; -import org.apache.sysds.common.Types; -import org.apache.sysds.hops.LiteralOp; -import org.apache.sysds.hops.codegen.cplan.CNode; -import org.apache.sysds.hops.codegen.cplan.CNodeData; -import org.apache.sysds.hops.codegen.cplan.CNodeRow; -import org.apache.sysds.runtime.codegen.CodegenUtils; -import org.apache.sysds.runtime.codegen.SpoofOperator; import org.apache.sysds.runtime.codegen.SpoofRowwise; -import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; -import org.apache.sysds.runtime.functionobjects.Plus; -import org.apache.sysds.runtime.functionobjects.ReduceCol; -import org.apache.sysds.runtime.functionobjects.ReduceRow; +import org.apache.sysds.runtime.functionobjects.Multiply; import org.apache.sysds.runtime.functionobjects.SwapIndex; import org.apache.sysds.runtime.instructions.cp.DoubleObject; import org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction; import org.apache.sysds.runtime.instructions.cp.ScalarObject; import org.apache.sysds.runtime.matrix.data.LibMatrixMult; import org.apache.sysds.runtime.matrix.data.MatrixBlock; -import org.apache.sysds.runtime.matrix.operators.AggregateOperator; -import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; +import org.apache.sysds.runtime.matrix.operators.BinaryOperator; import org.apache.sysds.runtime.matrix.operators.ReorgOperator; import org.jetbrains.annotations.NotNull; import java.util.ArrayList; import java.util.Arrays; +import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.List; -import java.util.function.Supplier; -import java.util.stream.Collectors; import static org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction.ensureMatrixBlockColumnVector; +import static org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction.ensureMatrixBlockRowVector; public class EOpNodeFuse extends EOpNode { private EOpNode scalar = null; - - public enum EinsumRewriteType{ // B -> row*vec, A -> row*scalar AB_BA_B_A__AB, - AB_BA_B_A__B, + AB_BA_A__B, AB_BA_B_A__A, AB_BA_B_A__, @@ -70,25 +58,80 @@ public enum EinsumRewriteType{ AB_BA_B_A_AZ__Z, // AZ: last step is outer matrix multiplication using vector Z - AB_BA_B_A_AZ__BZ, - AB_BA_B_A_AZ__ZB, + AB_BA_A_AZ__BZ, AB_BA_A_AZ__ZB, } public EinsumRewriteType einsumRewriteType; - public final List> operands; - - private EOpNodeFuse(Character c1, Character c2, Integer dim1, Integer dim2, EinsumRewriteType einsumRewriteType, List... operands) { + public List ABs; + public List BAs; + public List Bs; + public List As; + public List AZs; +// public List Zs; +// public final List> operands; + public List getAllOps(){ + List all = new ArrayList<>(); + all.addAll(ABs); + all.addAll(BAs); + all.addAll(Bs); + all.addAll(As); + all.addAll(AZs); + return all; + }; + private EOpNodeFuse(Character c1, Character c2, Integer dim1, Integer dim2, EinsumRewriteType einsumRewriteType, List ABs, List BAs, List Bs, List As, List AZs) { super(c1,c2, dim1, dim2); this.einsumRewriteType = einsumRewriteType; - this.operands = Arrays.asList(operands); + this.ABs = ABs; + this.BAs = BAs; + this.Bs = Bs; + this.As = As; + this.AZs = AZs; +// this.Zs = Zs; } + public EOpNodeFuse(EinsumRewriteType einsumRewriteType, List ABs, List BAs, List Bs, List As, List AZs) { + super(null,null,null, null); + switch(einsumRewriteType) { + case AB_BA_B_A__A->{ + c1 = ABs.get(0).c1; + dim1 = ABs.get(0).dim1; + }case AB_BA_A__B -> { + c1 = ABs.get(0).c2; + dim1 = ABs.get(0).dim2; + }case AB_BA_B_A__ -> { + }case AB_BA_B_A__AB -> { + c1 = ABs.get(0).c1; + dim1 = ABs.get(0).dim1; + c2 = ABs.get(0).c2; + dim2 = ABs.get(0).dim2; + }case AB_BA_B_A_AZ__Z -> { + c1 = AZs.get(0).c1; + dim1 = AZs.get(0).dim2; + }case AB_BA_A_AZ__BZ ->{ + c1 = ABs.get(0).c2; + dim1 = ABs.get(0).dim2; + c2 = AZs.get(0).c2; + dim2 = AZs.get(0).dim2; + }case AB_BA_A_AZ__ZB ->{ + c2 = ABs.get(0).c2; + dim2 = ABs.get(0).dim2; + c1 = AZs.get(0).c2; + dim1 = AZs.get(0).dim2; + } + } + this.einsumRewriteType = einsumRewriteType; + this.ABs = ABs; + this.BAs = BAs; + this.Bs = Bs; + this.As = As; + this.AZs = AZs; +// this.Zs = Zs; +// this.operands = Arrays.asList(operands); + } @Override public String[] recursivePrintString() { ArrayList inpStrings = new ArrayList<>(); - for (List list : operands) { - for (EOpNode node : list) { - inpStrings.add(node.recursivePrintString()); - } + for (EOpNode node : getAllOps()) { + inpStrings.add(node.recursivePrintString()); } String[] inpRes = inpStrings.stream().flatMap(Arrays::stream).toArray(String[]::new); String[] scalarRes = this.scalar==null ? new String[]{} : this.scalar.recursivePrintString(); @@ -111,347 +154,313 @@ public void addScalarAsIntermediate(EOpNode scalar) { throw new RuntimeException("EOpNodeFuse.addScalarAsIntermediate: scalar is undefined for type "+einsumRewriteType.toString()); } - public static EOpNodeFuse match(ArrayList operands, Character outChar1, Character outChar2,/*, Set simplySummableChars,*/ - HashMap charToSize, HashMap charToOccurences, ArrayList ret){ - //precompute - HashSet matricesChars = new HashSet<>(); - HashMap> charsToMatrices = new HashMap<>(); + public static List findFuseOps(ArrayList operands, Character outChar1, Character outChar2,/*, Set simplySummableChars,*/ + HashMap charToSize, HashMap charToOccurences, ArrayList ret) { + ArrayList result = new ArrayList<>(); + HashSet matricesChars = new HashSet<>(); + HashMap> matricesCharsStartingWithChar = new HashMap<>(); + HashMap> charsToMatrices = new HashMap<>(); + + for(EOpNode operand1 : operands) { + String k; + + if(operand1.c2 != null) { + k = operand1.c1.toString() + operand1.c2; + matricesChars.add(k); + if(matricesCharsStartingWithChar.containsKey(operand1.c1)) { + matricesCharsStartingWithChar.get(operand1.c1).add(k); + } + else { + HashSet set = new HashSet<>(); + set.add(k); + matricesCharsStartingWithChar.put(operand1.c1, set); + } + } + else { + k = operand1.c1.toString(); + } - for (EOpNode operand1 : operands) { - String k; + if(charsToMatrices.containsKey(k)) { + charsToMatrices.get(k).add(operand1); + } + else { + ArrayList matrices = new ArrayList<>(); + matrices.add(operand1); + charsToMatrices.put(k, matrices); + } + } + ArrayList> matricesCharsSorted = new ArrayList<>(matricesChars.stream() + .map(x -> Pair.of(charsToMatrices.get(x).get(0).dim1 * charsToMatrices.get(x).get(0).dim2, x)).toList()); + matricesCharsSorted.sort(Comparator.comparing(Pair::getLeft)); + ArrayList AZs = new ArrayList<>(); + ArrayList Zs = new ArrayList<>(); + + HashSet usedMatricesChars = new HashSet<>(); + HashSet usedOperands = new HashSet<>(); + + for(String ABCandidate : matricesCharsSorted.stream().map(Pair::getRight).toList()) { + if(usedMatricesChars.contains(ABCandidate)) continue; + + char a = ABCandidate.charAt(0); + char b = ABCandidate.charAt(1); + String AB = ABCandidate; + String BA = "" + b + a; + + int BAsCounter = (charsToMatrices.containsKey(BA) ? charsToMatrices.get(BA).size() : 0); + int ABsCounter = charsToMatrices.get(AB).size(); + + if(BAsCounter > ABsCounter + 1) { + BA = "" + a + b; + AB = "" + b + a; + char tmp = a; + a = b; + b = tmp; + int tmp2 = ABsCounter; + ABsCounter = BAsCounter; + BAsCounter = tmp2; + } + String A = "" + a; + String B = "" + b; + ArrayList Bs = !charsToMatrices.containsKey(B) || usedMatricesChars.contains(B) ? new ArrayList<>() : charsToMatrices.get(B); + ArrayList As = !charsToMatrices.containsKey(A) || usedMatricesChars.contains(A) ? new ArrayList<>() : charsToMatrices.get(A); + int AsCounter = As.size(); + int BsCounter = Bs.size(); + + if(AsCounter == 0 && BsCounter == 0 && (ABsCounter + BAsCounter) < 2) { // no elementwise multiplication possible + continue; + } - if (operand1.c2 != null) { - k = operand1.c1.toString() + operand1.c2; - matricesChars.add(k); - } else { - k = operand1.c1.toString(); - } + int usedBsCount = BsCounter + ABsCounter + BAsCounter; + + boolean doSumA = false; + boolean doSumB = charToOccurences.get(b) == usedBsCount && (outChar1 == null || b != outChar1) && (outChar2 == null || b != outChar2); +// boolean doSumZ = false; // there could be multiple AZ-s if Z is summed but for now it is limited to one + HashSet AZCandidates = matricesCharsStartingWithChar.get(a); + boolean includeAZ = AZCandidates.size() == 2; // 2 because it also contains AB + + String AZ = null; + Character z = null; + if(includeAZ) { + var it = AZCandidates.iterator(); AZ = it.next(); + if(AZ.charAt(1) == b) AZ = it.next(); // AB was chosen instead of AZ + AZs = charsToMatrices.get(AZ); + z = AZ.charAt(1); +// String Z = "" + z; +// Zs = charsToMatrices.get(Z); + if(usedMatricesChars.contains(AZ)) { includeAZ = false; } + int AZsCounter = AZs.size(); + doSumA = charToOccurences.get(a) == AsCounter + ABsCounter + BAsCounter + AZsCounter && (outChar1 == null || a != outChar1) && (outChar2 == null || a != outChar2); +// doSumZ = charToOccurences.get(z) == AZsCounter + Zs.size(); + if(!doSumA) { + includeAZ = false; + } + else if(!doSumB) { // check if outer is possible AB,...,AZ->BZ + if(!EinsumCPInstruction.FUSE_OUTER_MULTIPLY + || (EinsumCPInstruction.FUSE_OUTER_MULTIPLY_EXCEEDS_L2_CACHE_CHECK && ((charToSize.get(a) * charToSize.get(b) *(ABsCounter + BAsCounter)) + (charToSize.get(a)*charToSize.get(z)*(AZsCounter))) * 8 < 6 * 1024 * 1024) + || !LibMatrixMult.isSkinnyRightHandSide(charToSize.get(AB.charAt(0)), charToSize.get(AB.charAt(1)), + charToSize.get(AB.charAt(0)), charToSize.get(AZCandidates.iterator().next().charAt(1)), + false)) { + includeAZ = false; + } + } + // else AB,...,AZ-> Z possible + } - if (charsToMatrices.containsKey(k)) { - charsToMatrices.get(k).add(operand1); - } else { - ArrayList matrices = new ArrayList<>(); - matrices.add(operand1); - charsToMatrices.put(k, matrices); - } - } + if(!includeAZ) { + doSumA = charToOccurences.get(a) == AsCounter + ABsCounter + BAsCounter && (outChar1 == null || a != outChar1) && (outChar2 == null || a != outChar2); + } - ArrayList AZs = new ArrayList<>(); - ArrayList Zs = new ArrayList<>(); - boolean pass = false; - - String AB = null; - String BA = null; - boolean doSumA=false; - boolean doSumB=false; - for (String ABcandidate : matricesChars) { - char a = ABcandidate.charAt(0); - char b = ABcandidate.charAt(1); - BA = "" + b + a; - - AZs = new ArrayList<>(); - Character z = null; - pass=true; - int AZsCounter = 0; - HashSet AZCandidates = new HashSet<>(); - - for (String chars : charsToMatrices.keySet()) { - if (chars.equals(ABcandidate) || chars.equals(BA)) { - continue; - } - - if(chars.length()==1){ - //always ok - }else{ - if(a==chars.charAt(1) && b==chars.charAt(0)){ //BA - continue; - } - if(chars.charAt(0)==a){ //AZ - AZsCounter++; - AZCandidates.add(chars); - } - else if(chars.charAt(0)==b){ - // BZ - } - else if(chars.charAt(1)==a){ - //ZA - } - else if(chars.charAt(1)==b){ - // ZB - } - } - } + ArrayList ABs = charsToMatrices.containsKey(AB) ? charsToMatrices.get(AB) : new ArrayList<>(); + ArrayList BAs = charsToMatrices.containsKey(BA) ? charsToMatrices.get(BA) : new ArrayList<>(); + + Character c1 = null, c2 = null; + Integer dim1 = null, dim2 = null; + EinsumRewriteType type = null; + + if(includeAZ) { + if(doSumB) { + type = EinsumRewriteType.AB_BA_B_A_AZ__Z; + c1 = z; + } + else if((outChar1 != null && outChar2 != null) && outChar1 == z && outChar2 == b) { + type = EinsumRewriteType.AB_BA_A_AZ__ZB; + c1 = z; c2 = b; + } + else if((outChar1 != null && outChar2 != null) && outChar1 == b && outChar2 == z) { + type = EinsumRewriteType.AB_BA_A_AZ__BZ; + c1 = b; c2 = z; + } + else { + type = EinsumRewriteType.AB_BA_A_AZ__ZB; + c1 = z; c2 = b; + } + } + else { + AZs= new ArrayList<>(); + if(doSumA) { + if(doSumB) { + type = EinsumRewriteType.AB_BA_B_A__; + } + else { + type = EinsumRewriteType.AB_BA_A__B; + c1 = AB.charAt(1); + } + } + else if(doSumB) { + type = EinsumRewriteType.AB_BA_B_A__A; + c1 = AB.charAt(0); + } + else { + type = EinsumRewriteType.AB_BA_B_A__AB; + c1 = AB.charAt(0); c2 = AB.charAt(1); + } + } - if(pass){ // final checks for current AB candidate - - AB = ABcandidate; - String A = ""+a; - String B = ""+b; - int BAsCounter = (charsToMatrices.containsKey(BA) ? charsToMatrices.get(BA).size() : 0); - int ABsCounter = charsToMatrices.get(ABcandidate).size()+BAsCounter; - int AsCounter = (charsToMatrices.containsKey(A) ? charsToMatrices.get(A).size() : 0); - int BsCounter = (charsToMatrices.containsKey(B) ? charsToMatrices.get(B).size() : 0); - if(AsCounter==0 && BsCounter==0 && ABsCounter<2){ - pass=false; - continue; - } - int usedBsCount = BsCounter+ABsCounter; - doSumB = charToOccurences.get(b)==usedBsCount && (outChar1 == null || b!=outChar1) && (outChar2 == null || b!=outChar2); - - boolean includeAz = true; - if(AZCandidates.size()==1){ - if(!doSumB) { - // check if outer is possible AB,...,AZ->BZ - if(!EinsumCPInstruction.FUSE_OUTER_MULTIPLY || !LibMatrixMult.isSkinnyRightHandSide(charToSize.get(AB.charAt(0)), charToSize.get(AB.charAt(1)), charToSize.get(AB.charAt(0)),charToSize.get(AZCandidates.iterator().next().charAt(1)),false)) { - includeAz=false; - } - } - if(includeAz){ - int usedAsCount = AsCounter+ABsCounter+AZsCounter; - doSumA = charToOccurences.get(a)==usedAsCount && (outChar1 == null || a!=outChar1) && (outChar2 == null || a!=outChar2); - if(!doSumA){ // cant do AZ - break;// just do AB,B,A ->AB / A - }else { - AZs = charsToMatrices.get(AZCandidates.iterator().next()); - break;//ok - } - } - } -// else if (AZCandidates.size() >= 2) { -// doSumA = false; -// if(doSumB){ -// pass=true; -// break; // can do it, it will create AB,B,A -> A, that will be consumed by some AZ later -// } -// pass=false; -// continue; -// -// } - int usedAsCount = AsCounter+ABsCounter; - doSumA = charToOccurences.get(a)==usedAsCount && (outChar1 == null || a!=outChar1) && (outChar2 == null || a!=outChar2); - - break; - } - } + if(c1 != null) { + charToOccurences.put(c1, charToOccurences.get(c1) + 1); + dim1 = charToSize.get(c1); + } + if(c2 != null) { + charToOccurences.put(c2, charToOccurences.get(c2) + 1); + dim2 = charToSize.get(c2); + } + boolean includeB = type != EinsumRewriteType.AB_BA_A__B && type != EinsumRewriteType.AB_BA_A_AZ__BZ && type != EinsumRewriteType.AB_BA_A_AZ__ZB; - if(!pass){ - return null; - } - ArrayList ABs=charsToMatrices.containsKey(AB) ? charsToMatrices.get(AB) : new ArrayList<>(); - ArrayList BAs=charsToMatrices.containsKey(BA) ? charsToMatrices.get(BA) : new ArrayList<>(); - if (ABs.size() < BAs.size() - 1) { - String tmp = AB; - - AB=BA; - BA=tmp; - ArrayList tmp2 = ABs; - BAs=ABs; - ABs=tmp2; - AZs.clear(); - } - String B = AB.substring(1,2); - String A = AB.substring(0,1); - char a = A.charAt(0); - char b = B.charAt(0); - Character c1 = null, c2 = null; - Integer dim1 = null, dim2 = null; - EinsumRewriteType t = null; - - if(!AZs.isEmpty()){ - Character azC2 = AZs.get(0).c2; - if(doSumB) { - t = EinsumRewriteType.AB_BA_B_A_AZ__Z; - c1 = azC2; - } - else if (EinsumCPInstruction.FUSE_OUTER_MULTIPLY) { - if(charToSize.get(AB.charAt(0)) * charToSize.get(AB.charAt(1)) * 8 > LibMatrixMult.L3_CACHESIZE && LibMatrixMult.isSkinnyRightHandSide(charToSize.get(AB.charAt(0)), charToSize.get(AB.charAt(1)), charToSize.get(AB.charAt(0)),charToSize.get(azC2),false)){ - if (outChar1 == azC2 && outChar2 == b) { - t = EinsumRewriteType.AB_BA_B_A_AZ__ZB; - c1 = azC2; - c2 = b; - } else if (outChar2 == azC2 && outChar1 == b) { - t = EinsumRewriteType.AB_BA_B_A_AZ__BZ; - c1 = b; - c2 = azC2; - } else { - t = EinsumRewriteType.AB_BA_B_A_AZ__ZB; - c1 = azC2; - c2 = b; - } - - }else{ - doSumA=false; - t=null; - AZs=new ArrayList<>(); - } - }else{ - doSumA=false; - t=null; - AZs=new ArrayList<>(); - } + usedOperands.addAll(ABs); + usedOperands.addAll(BAs); + usedOperands.addAll(As); + if (includeB) usedOperands.addAll(Bs); + if (includeAZ) usedOperands.addAll(AZs); - if(charsToMatrices.containsKey(azC2.toString())) { - Zs = charsToMatrices.get(azC2.toString()); - } - } - if(t==null) { - if (doSumA) { - if (doSumB) { - t = EinsumRewriteType.AB_BA_B_A__; - } else { - t = EinsumRewriteType.AB_BA_B_A__B; - c1 = AB.charAt(1); - } - } else if (doSumB) { - t = EinsumRewriteType.AB_BA_B_A__A; - c1 = AB.charAt(0); - } else { - t = EinsumRewriteType.AB_BA_B_A__AB; - c1 = AB.charAt(0); - c2 = AB.charAt(1); - } - } - if(c1 != null){ - charToOccurences.put(c1, charToOccurences.get(c1)+1); - dim1 = charToSize.get(c1); - } - if(c2 != null){ - charToOccurences.put(c2, charToOccurences.get(c2)+1); - dim2 = charToSize.get(c2); - } - HashSet usedOperands = new HashSet<>(); +// if(type == EinsumRewriteType.AB_BA_B_A_AZ__Z && AZs.size() > 1){ // multiply all AZs if multiple +// EOpNodeFuse fuseAZs = new EOpNodeFuse(EinsumRewriteType.AB_BA_B_A__AB, new ArrayList<>(AZs), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), new ArrayList<>()); +// AZs = new ArrayList<>(); +// AZs.add(fuseAZs); +// } + usedMatricesChars.add(AB); + usedMatricesChars.add(BA); + usedMatricesChars.add(A); + if (includeB) usedMatricesChars.add(B); + if (includeAZ) usedMatricesChars.add(AZ); - ArrayList Bs=charsToMatrices.containsKey(B) ? charsToMatrices.get(B) : new ArrayList<>(); - ArrayList As=charsToMatrices.containsKey(A) ? charsToMatrices.get(A) : new ArrayList<>(); + var e = new EOpNodeFuse(c1, c2, dim1, dim2, type, ABs, BAs, includeB ? Bs : new ArrayList<>(), As, AZs); - usedOperands.addAll(ABs); - usedOperands.addAll(BAs); - usedOperands.addAll(Bs); - usedOperands.addAll(As); - usedOperands.addAll(AZs); - usedOperands.addAll(Zs); + result.add(e); + } - for(EOpNode n : operands){ + for(EOpNode n : operands) { if(!usedOperands.contains(n)){ ret.add(n); }else{ - if(charToOccurences != null){ - charToOccurences.put(n.c1, charToOccurences.get(n.c1)-1); - if(charToOccurences.get(n.c2)!= null) - charToOccurences.put(n.c2, charToOccurences.get(n.c2)-1); - } - } + charToOccurences.put(n.c1, charToOccurences.get(n.c1) - 1); + if(charToOccurences.get(n.c2)!= null) + charToOccurences.put(n.c2, charToOccurences.get(n.c2)-1); + } } - var e = new EOpNodeFuse(c1, c2, dim1, dim2, t, - ABs, - BAs, - Bs, - As, - AZs, - Zs - ); - return e; + return result; } - - @Override - public MatrixBlock computeEOpNode(ArrayList inputs, int numThreads, Log LOG) { - List> mbs = operands.stream().map(l -> l.stream().map(n -> n.computeEOpNode(inputs, numThreads, LOG)).collect(Collectors.toList())).toList(); - var eOpNodeEinsumFuse = this; - - if( LOG.isTraceEnabled()) { - String x = eOpNodeEinsumFuse.operands.stream() - .flatMap(List::stream) - .map(o -> o.c1.toString() + (o.c2 == null ? "" : o.c2)) - .collect(Collectors.joining(",")); - String res = eOpNodeEinsumFuse.c1 == null ? "" : (eOpNodeEinsumFuse.c1.toString() +(eOpNodeEinsumFuse.c2 == null ? "" : eOpNodeEinsumFuse.c2.toString())); - - LOG.trace("ComputeEOpNodeFuse AB=" + operands.get(0).get(0).c1+operands.get(0).get(0).c2 +" "+eOpNodeEinsumFuse.einsumRewriteType.toString() +" "+ x + "->" + res); - } - boolean isResultAB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__AB; - boolean isResultA = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A; - boolean isResultB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__B; - boolean isResult_ = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__; - boolean isResultZ = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A_AZ__Z; - boolean isResultBZ = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A_AZ__BZ; - boolean isResultZB = eOpNodeEinsumFuse.einsumRewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A_AZ__ZB; - List ABs = mbs.get(0), BAs = mbs.get(1), Bs = mbs.get(2), As = mbs.get(3); - List AZs = mbs.get(4); - List Zs = mbs.get(5); - int bSize = ABs.get(0).getNumColumns(); - int aSize = ABs.get(0).getNumRows(); - if (!BAs.isEmpty()) { + public static MatrixBlock compute(EinsumRewriteType rewriteType, List ABsInput, List mbBAs, List mbBs, List mbAs, List mbAZs, Double scalar, int numThreads){ + boolean isResultAB =rewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__AB; + boolean isResultA = rewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A; + boolean isResultB = rewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_A__B; + boolean isResult_ = rewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__; + boolean isResultZ = rewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A_AZ__Z; + boolean isResultBZ =rewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_A_AZ__BZ; + boolean isResultZB =rewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_A_AZ__ZB; + + ArrayList mbABs = new ArrayList<>(ABsInput); + int bSize = mbABs.get(0).getNumColumns(); + int aSize = mbABs.get(0).getNumRows(); + if (!mbBAs.isEmpty()) { ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads); - for(MatrixBlock mb : BAs) {//BA->AB - ABs.add(mb.reorgOperations(transpose, null, 0, 0, 0)); + for(MatrixBlock mb : mbBAs) //BA->AB + mbABs.add(mb.reorgOperations(transpose, null, 0, 0, 0)); + } + + if(mbAs.size() > 1) mbAs = multiplyVectorsIntoOne(mbAs, aSize); + if(mbBs.size() > 1) mbBs = multiplyVectorsIntoOne(mbBs, bSize); + + int constDim2 = -1; + int zSize = 0; + int azCount = 0; + // int zCount = 0; + switch(rewriteType){ + case AB_BA_B_A_AZ__Z -> { + constDim2 = mbAZs.get(0).getNumColumns(); + zSize = mbAZs.get(0).getNumColumns(); + azCount = mbAZs.size(); + // if (mbZs != null) zCount = mbZs.size(); + } + case AB_BA_A_AZ__BZ, AB_BA_A_AZ__ZB -> { + constDim2 = mbAZs.get(0).getNumColumns(); + zSize = mbAZs.get(0).getNumColumns(); + azCount = mbAZs.size(); } } - AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); - AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), numThreads); - if(As.size() > 1){ - As = multiplyVectorsIntoOne(As, aSize); - } - if(Bs.size() > 1){ - Bs = multiplyVectorsIntoOne(Bs, bSize); - } - if(Zs != null && Zs.size() > 1){ - Zs = multiplyVectorsIntoOne(Zs, AZs.get(0).getNumColumns()); - } - int constDim2 = -1; - int zSize = 0; - int azCount = 0; - int zCount = 0; - switch(eOpNodeEinsumFuse.einsumRewriteType){ - case AB_BA_B_A_AZ__Z -> { - constDim2 = AZs.get(0).getNumColumns(); - zSize = AZs.get(0).getNumColumns(); - azCount = AZs.size(); - if (Zs != null) zCount = Zs.size(); - } - case AB_BA_B_A_AZ__BZ, AB_BA_B_A_AZ__ZB -> { - constDim2 = AZs.get(0).getNumColumns(); - zSize = AZs.get(0).getNumColumns(); - azCount = AZs.size(); - } - } + SpoofRowwise.RowType rowType = switch(rewriteType){ + case AB_BA_B_A__AB -> SpoofRowwise.RowType.NO_AGG; + case AB_BA_A__B -> SpoofRowwise.RowType.COL_AGG_T; + case AB_BA_B_A__A -> SpoofRowwise.RowType.ROW_AGG; + case AB_BA_B_A__ -> SpoofRowwise.RowType.FULL_AGG; + case AB_BA_B_A_AZ__Z -> SpoofRowwise.RowType.COL_AGG_CONST; + case AB_BA_A_AZ__BZ -> SpoofRowwise.RowType.COL_AGG_B1_T; + case AB_BA_A_AZ__ZB -> SpoofRowwise.RowType.COL_AGG_B1; + }; + EinsumSpoofRowwise r = new EinsumSpoofRowwise(rewriteType, rowType, constDim2, + mbABs.size()-1, !mbBs.isEmpty() && (!isResultBZ && !isResultZB && !isResultB), mbAs.size(), azCount, zSize); + + ArrayList fuseInputs = new ArrayList<>(); + fuseInputs.addAll(mbABs); + if(!isResultBZ && !isResultZB && !isResultB) + fuseInputs.addAll(mbBs); + fuseInputs.addAll(mbAs); + if (isResultZ || isResultBZ || isResultZB) + fuseInputs.addAll(mbAZs); - SpoofRowwise.RowType rowType = switch(eOpNodeEinsumFuse.einsumRewriteType){ - case AB_BA_B_A__AB -> SpoofRowwise.RowType.NO_AGG; - case AB_BA_B_A__B -> SpoofRowwise.RowType.COL_AGG_T; - case AB_BA_B_A__A -> SpoofRowwise.RowType.ROW_AGG; - case AB_BA_B_A__ -> SpoofRowwise.RowType.FULL_AGG; - case AB_BA_B_A_AZ__Z -> SpoofRowwise.RowType.COL_AGG_CONST; - case AB_BA_B_A_AZ__BZ -> SpoofRowwise.RowType.COL_AGG_B1_T; - case AB_BA_B_A_AZ__ZB -> SpoofRowwise.RowType.COL_AGG_B1; - }; - EinsumSpoofRowwise r = new EinsumSpoofRowwise(eOpNodeEinsumFuse.einsumRewriteType, rowType, constDim2, false, 1, ABs.size()-1,Bs.size(), As.size(), zCount, azCount, zSize); - - - ArrayList fuseInputs = new ArrayList<>(); - - fuseInputs.addAll(ABs); - fuseInputs.addAll(Bs); - fuseInputs.addAll(As); - if (isResultZ || isResultBZ || isResultZB) - fuseInputs.addAll(AZs); ArrayList scalarObjects = new ArrayList<>(); - if(this.scalar != null){ - MatrixBlock scMb = this.scalar.computeEOpNode(inputs,numThreads,LOG); - scalarObjects.add(new DoubleObject(scMb.get(0,0))); + if(scalar != null){ + scalarObjects.add(new DoubleObject(scalar)); } - MatrixBlock out = r.execute(fuseInputs, scalarObjects, new MatrixBlock(), numThreads); - if( isResultA || isResultB || isResultZ) - ensureMatrixBlockColumnVector(out); - return out; + MatrixBlock out = r.execute(fuseInputs, scalarObjects, new MatrixBlock(), numThreads); + if(isResultB && !mbBs.isEmpty()){ + LibMatrixMult.vectMultiply(mbBs.get(0).getDenseBlockValues(), out.getDenseBlockValues(), 0,0, mbABs.get(0).getNumColumns()); + } + if(isResultBZ && !mbBs.isEmpty()){ + ensureMatrixBlockColumnVector(mbBs.get(0)); + out = out.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), mbBs.get(0)); + } + if(isResultZB && !mbBs.isEmpty()){ + ensureMatrixBlockRowVector(mbBs.get(0)); + out = out.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), mbBs.get(0)); + } + + if( isResultA || isResultB || isResultZ) + ensureMatrixBlockColumnVector(out); + + return out; + } + @Override + public MatrixBlock computeEOpNode(ArrayList inputs, int numThreads, Log LOG) { + ArrayList mbABs = new ArrayList<>(ABs.stream().map(n -> n.computeEOpNode(inputs, numThreads, LOG)).toList()); + List mbBAs = BAs.stream().map(n -> n.computeEOpNode(inputs, numThreads, LOG)).toList(); + List mbBs = Bs.stream().map(n -> n.computeEOpNode(inputs, numThreads, LOG)).toList(); + List mbAs = As.stream().map(n -> n.computeEOpNode(inputs, numThreads, LOG)).toList(); + List mbAZs = AZs.stream().map(n -> n.computeEOpNode(inputs, numThreads, LOG)).toList(); + Double scalar = this.scalar == null ? null : this.scalar.computeEOpNode(inputs, numThreads, LOG).get(0,0); + return EOpNodeFuse.compute(this.einsumRewriteType, mbABs, mbBAs, mbBs, mbAs, mbAZs, scalar, numThreads); } @Override public EOpNode reorderChildrenAndOptimize(EOpNode parent, Character outChar1, Character outChar2) { - for(List list : operands) - for(int i = 0; i < list.size(); i++) list.set(i,list.get(i).reorderChildrenAndOptimize(this, list.get(i).c1, list.get(i).c2)); + for(int i = 0; i < ABs.size(); i++) ABs.set(i,ABs.get(i).reorderChildrenAndOptimize(this, ABs.get(i).c1, ABs.get(i).c2)); + for(int i = 0; i < BAs.size(); i++) BAs.set(i,BAs.get(i).reorderChildrenAndOptimize(this, BAs.get(i).c1, BAs.get(i).c2)); + for(int i = 0; i < As.size(); i++) As.set(i,As.get(i).reorderChildrenAndOptimize(this, As.get(i).c1, As.get(i).c2)); + for(int i = 0; i < Bs.size(); i++) Bs.set(i,Bs.get(i).reorderChildrenAndOptimize(this, Bs.get(i).c1, Bs.get(i).c2)); + for(int i = 0; i < AZs.size(); i++) AZs.set(i,AZs.get(i).reorderChildrenAndOptimize(this, AZs.get(i).c1, AZs.get(i).c2)); return this; } diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EinsumSpoofRowwise.java b/src/main/java/org/apache/sysds/runtime/einsum/EinsumSpoofRowwise.java index d0f63832423..8b3c5544e69 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EinsumSpoofRowwise.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EinsumSpoofRowwise.java @@ -27,102 +27,125 @@ import org.apache.sysds.runtime.matrix.data.LibMatrixMult; public final class EinsumSpoofRowwise extends SpoofRowwise { - private final int _ABCount; - private final int _BCount; - private final int _ACount; - private final int _ZCount; - private final int _AZCount; - private final int _ZSize; + private final int _ABCount; + private final boolean _Bsupplied; + private final int _ACount; + private final int _AZCount; + private final int _ZSize; + private final int _AZStartIndex; + private final EOpNodeFuse.EinsumRewriteType _EinsumRewriteType; - private final int _uptoBCumCount; - private final int _uptoZCumCount; - - private final EOpNodeFuse.EinsumRewriteType _EinsumRewriteType; + public EinsumSpoofRowwise(EOpNodeFuse.EinsumRewriteType einsumRewriteType, RowType rowType, long constDim2, + int abCount, boolean bSupplied, int aCount, int azCount, int zSize) { + super(rowType, constDim2, false, 1); + _ABCount = abCount; + _Bsupplied = bSupplied; + _ACount = aCount; + _AZStartIndex = abCount + (_Bsupplied ? 1 : 0) + aCount; + _AZCount = azCount; + _EinsumRewriteType = einsumRewriteType; + _ZSize = zSize; + } - public EinsumSpoofRowwise(EOpNodeFuse.EinsumRewriteType einsumRewriteType, RowType rowType, long constDim2, boolean tb1, int reqVectMem, int abCount, int bCount, int aCount, int zCount, int azCount, int zSize) { - super(rowType, constDim2, tb1, reqVectMem); - _ABCount = abCount; - _BCount = bCount; - _uptoBCumCount = bCount+ abCount; - _ACount = aCount; - _ZCount = zCount; - _uptoZCumCount = _uptoBCumCount + aCount; - _AZCount = azCount; - _EinsumRewriteType = einsumRewriteType; - _ZSize = zSize; - } - protected void genexec(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, long grix, int rix) { - switch (_EinsumRewriteType) { - case AB_BA_B_A__AB -> { - genexec_AB(a,ai,b,scalars,c,ci,len,grix,rix); - if (scalars.length != 0) { - LibMatrixMult.vectMultiplyWrite(scalars[0], c,c,ci,ci, len); - } + protected void genexec(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, long grix, + int rix) { + switch(_EinsumRewriteType) { + case AB_BA_B_A__AB -> { + genexec_AB(a, ai, b, null, c, ci, len, grix, rix); + if(scalars.length != 0) { LibMatrixMult.vectMultiplyWrite(scalars[0], c, c, ci, ci, len); } } - case AB_BA_B_A__B -> { - genexec_B(a,ai,b,scalars,c,ci,len,grix,rix); + case AB_BA_A__B -> { + genexec_B(a, ai, b, null, c, ci, len, grix, rix); } - case AB_BA_B_A__A -> { -// HARDCODEDgenexec_A_or_(a,ai,b,scalars,c,ci,len,grix,rix); - genexec_A_or_(a,ai,b,scalars,c,ci,len,grix,rix); - if (scalars.length != 0) { - c[rix] *= scalars[0]; - } + case AB_BA_B_A__A -> { + // HARDCODEDgenexec_A_or_(a,ai,b,scalars,c,ci,len,grix,rix); + genexec_A_or_(a, ai, b, null, c, ci, len, grix, rix); + if(scalars.length != 0) { c[rix] *= scalars[0]; } } - case AB_BA_B_A__ -> { - genexec_A_or_(a,ai,b,scalars,c,ci,len,grix,rix); - if (scalars.length != 0) { - c[0] *= scalars[0]; + case AB_BA_B_A__ -> { + genexec_A_or_(a, ai, b, null, c, ci, len, grix, rix); + if(scalars.length != 0) { c[0] *= scalars[0]; } + } + case AB_BA_B_A_AZ__Z -> { + double[] temp = {0}; + genexec_A_or_(a, ai, b, null, temp, 0, len, grix, rix); + if(scalars.length != 0) { temp[0] *= scalars[0]; } + if(_AZCount > 1) { + double[] temp2 = new double[_ZSize]; + int bi = _AZStartIndex; + LibMatrixMult.vectMultiplyWrite(b[bi++].values(0), b[bi++].values(0), temp2, _ZSize * rix, + _ZSize * rix, 0, _ZSize); + while(bi < _AZStartIndex + _AZCount) { + LibMatrixMult.vectMultiplyWrite(temp2, b[bi++].values(0), temp2, 0, _ZSize * rix, 0, _ZSize); + } + LibMatrixMult.vectMultiplyAdd(temp[0], temp2, c, 0, 0, _ZSize); } + else + LibMatrixMult.vectMultiplyAdd(temp[0], b[_AZStartIndex].values(rix), c, _ZSize * rix, 0, _ZSize); } - case AB_BA_B_A_AZ__Z -> { - double[] temp = {0}; - genexec_A_or_(a,ai,b,scalars,temp,0,len,grix,rix); - if (scalars.length != 0) { - temp[0] *= scalars[0]; + case AB_BA_A_AZ__BZ -> { + double[] temp = new double[len]; + genexec_B(a, ai, b, null, temp, 0, len, grix, rix); + if(scalars.length != 0) { + LibMatrixMult.vectMultiplyWrite(scalars[0], temp, temp, 0, 0, len); + } + if(_AZCount > 1) { + double[] temp2 = new double[_ZSize]; + int bi = _AZStartIndex; + LibMatrixMult.vectMultiplyWrite(b[bi++].values(0), b[bi++].values(0), temp2, _ZSize * rix, + _ZSize * rix, 0, _ZSize); + while(bi < _AZStartIndex + _AZCount) { + LibMatrixMult.vectMultiplyWrite(temp2, b[bi++].values(0), temp2, 0, _ZSize * rix, 0, _ZSize); + } + LibSpoofPrimitives.vectOuterMultAdd(temp, temp2, c, 0, 0, 0, len, _ZSize); } - LibMatrixMult.vectMultiplyAdd(temp[0], b[_uptoZCumCount].values(rix), c, _ZSize*rix,0, _ZSize); - } - case AB_BA_B_A_AZ__BZ -> { - double[] temp = new double[len]; - genexec_B(a,ai,b,scalars,temp,0,len,grix,rix); - if (scalars.length != 0) { - LibMatrixMult.vectMultiplyWrite(scalars[0], temp,temp,0,0,len); + else + LibSpoofPrimitives.vectOuterMultAdd(temp, b[_AZStartIndex].values(rix), c, 0, _ZSize * rix, 0, len, _ZSize); + } + case AB_BA_A_AZ__ZB -> { + double[] temp = new double[len]; + genexec_B(a, ai, b, null, temp, 0, len, grix, rix); + if(scalars.length != 0) { + LibMatrixMult.vectMultiplyWrite(scalars[0], temp, temp, 0, 0, len); } - LibSpoofPrimitives.vectOuterMultAdd(temp, b[_uptoZCumCount].values(rix), c,0, _ZSize*rix, 0, len,_ZSize); - } - case AB_BA_B_A_AZ__ZB -> { - double[] temp = new double[len]; - genexec_B(a,ai,b,scalars,temp,0,len,grix,rix); - if (scalars.length != 0) { - LibMatrixMult.vectMultiplyWrite(scalars[0], temp,temp,0,0,len); + if(_AZCount > 1) { + double[] temp2 = new double[_ZSize]; + int bi = _AZStartIndex; + LibMatrixMult.vectMultiplyWrite(b[bi++].values(0), b[bi++].values(0), temp2, _ZSize * rix, + _ZSize * rix, 0, _ZSize); + while(bi < _AZStartIndex + _AZCount) { + LibMatrixMult.vectMultiplyWrite(temp2, b[bi++].values(0), temp2, 0, _ZSize * rix, 0, _ZSize); + } + LibSpoofPrimitives.vectOuterMultAdd(temp2, temp, c, 0, 0, 0, _ZSize, len); } - LibSpoofPrimitives.vectOuterMultAdd(b[_uptoZCumCount].values(rix),temp , c,_ZSize*rix,0, 0, _ZSize, len); - } - default -> throw new NotImplementedException(); - } + else + LibSpoofPrimitives.vectOuterMultAdd(b[_AZStartIndex].values(rix), temp, c, _ZSize * rix, 0, 0, _ZSize, len); + } + default -> throw new NotImplementedException(); + } + } - } - private void genexec_AB(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, long grix, + private void genexec_AB(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, long grix, int rix) { - int bi = 0; - double[] TMP1 = null; - if (_ABCount != 0){ - if(_ABCount == 1 & _ACount == 0 && _BCount == 0){ + int bi = 0; + double[] TMP1 = null; + if(_ABCount != 0) { + if(_ABCount == 1 & _ACount == 0 && !_Bsupplied) { LibMatrixMult.vectMultiplyWrite(a, b[0].values(rix), c, ai, ai, ci, len); return; } - TMP1 = LibSpoofPrimitives.vectMultWrite(a,b[bi++].values(rix),ai,ai,len); - while (bi < _ABCount) { - if(_ACount == 0 && _BCount == 0 && bi == _ABCount-1) { - LibMatrixMult.vectMultiplyWrite(TMP1, b[bi++].values(rix), c, 0, ai, ci, len); - }else { - LibMatrixMult.vectMultiplyWrite(TMP1, b[bi++].values(rix), TMP1, 0, ai, 0, len); - } - } - } + TMP1 = LibSpoofPrimitives.vectMultWrite(a, b[bi++].values(rix), ai, ai, len); + while(bi < _ABCount) { + if(_ACount == 0 && !_Bsupplied && bi == _ABCount - 1) { + LibMatrixMult.vectMultiplyWrite(TMP1, b[bi++].values(rix), c, 0, ai, ci, len); + } + else { + LibMatrixMult.vectMultiplyWrite(TMP1, b[bi++].values(rix), TMP1, 0, ai, 0, len); + } + } + } - if(_BCount == 1) { + if(_Bsupplied) { if(_ACount == 1) if(TMP1 == null) vectMultiplyWrite(b[bi + 1].values(0)[rix], a, b[bi].values(0), c, ai, 0, ci, len); @@ -132,127 +155,113 @@ else if(TMP1 == null) LibMatrixMult.vectMultiplyWrite(a, b[bi].values(0), c, ai, 0, ci, len); else LibMatrixMult.vectMultiplyWrite(TMP1, b[bi].values(0), c, 0, 0, ci, len); - } else if(_ACount == 1) { + } + else if(_ACount == 1) { if(TMP1 == null) - LibMatrixMult.vectMultiplyWrite(b[bi].values(0)[rix],a,c, ai, ci, len); + LibMatrixMult.vectMultiplyWrite(b[bi].values(0)[rix], a, c, ai, ci, len); else - LibMatrixMult.vectMultiplyWrite(b[bi].values(0)[rix],TMP1,c, 0, ci, len); + LibMatrixMult.vectMultiplyWrite(b[bi].values(0)[rix], TMP1, c, 0, ci, len); } - } + } - private void genexec_B(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, long grix, - int rix) { - int bi = 0; - double[] TMP1 = null; - if(_ABCount != 0) { - TMP1 = LibSpoofPrimitives.vectMultWrite(a, b[bi++].values(rix), ai, ai, len); - while(bi < _ABCount) { - if(_ACount == 0 && _BCount == 0 && bi == _ABCount - 1) { - LibMatrixMult.vectMultiplyAdd(TMP1, b[bi++].values(rix), c, 0, ai, 0, len); - } - else { - LibMatrixMult.vectMultiplyWrite(TMP1, b[bi++].values(rix), TMP1, 0, ai, 0, len); - } - } - } + private void genexec_B(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, long grix, + int rix) { + int bi = 0; + double[] TMP1 = null; + if(_ABCount == 1 && _ACount == 0) + LibMatrixMult.vectMultiplyAdd(a, b[bi++].values(rix), c, ai, ai, 0, len); + else if(_ABCount != 0) { + TMP1 = LibSpoofPrimitives.vectMultWrite(a, b[bi++].values(rix), ai, ai, len); + while(bi < _ABCount) { + if(_ACount == 0 && bi == _ABCount - 1) { + LibMatrixMult.vectMultiplyAdd(TMP1, b[bi++].values(rix), c, 0, ai, 0, len); + } + else { + LibMatrixMult.vectMultiplyWrite(TMP1, b[bi++].values(rix), TMP1, 0, ai, 0, len); + } + } + } - if(_BCount == 1) { - if(_ACount == 1) - if(TMP1 == null) - vectMultiplyAdd(b[bi + 1].values(0)[rix], a, b[bi].values(0), c, ai, 0, 0, len); - else - vectMultiplyAdd(b[bi + 1].values(0)[rix], TMP1, b[bi].values(0), c, 0, 0, 0, len); - else if(TMP1 == null) - LibMatrixMult.vectMultiplyAdd(a, b[bi].values(0), c, ai, 0, 0, len); - else - LibMatrixMult.vectMultiplyAdd(TMP1, b[bi].values(0), c, 0, 0, 0, len); - } - else if(_ACount == 1) { - if(TMP1 == null) - LibMatrixMult.vectMultiplyAdd(b[bi].values(0)[rix], a, c, ai, 0, len); - else - LibMatrixMult.vectMultiplyAdd(b[bi].values(0)[rix], TMP1, c, 0, 0, len); - } - } + if(_ACount == 1) { + if(TMP1 == null) + LibMatrixMult.vectMultiplyAdd(b[bi].values(0)[rix], a, c, ai, 0, len); + else + LibMatrixMult.vectMultiplyAdd(b[bi].values(0)[rix], TMP1, c, 0, 0, len); + } + } - private void genexec_A_or_(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, - long grix, int rix) { - int bi = 0; - double[] TMP1 = null; - double TMP2 = 0; - if (_ABCount == 1 && _BCount == 0) - TMP2 = LibSpoofPrimitives.dotProduct(TMP1,b[bi++].values(rix),0,ai,len); - else if (_ABCount != 0) { - TMP1 = LibSpoofPrimitives.vectMultWrite(a,b[bi++].values(rix),ai,ai,len); - while (bi < _ABCount) { - if(_BCount == 0 && bi == _ABCount - 1) - TMP2 = LibSpoofPrimitives.dotProduct(TMP1,b[bi++].values(rix),0,ai,len); - else - LibMatrixMult.vectMultiplyWrite(TMP1, b[bi++].values(rix), TMP1, 0, ai, 0, len); - } - } + private void genexec_A_or_(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, long grix, int rix) { + int bi = 0; + double[] TMP1 = null; + double TMP2 = 0; + if(_ABCount == 1 && !_Bsupplied) + TMP2 = LibSpoofPrimitives.dotProduct(a, b[bi++].values(rix), ai, ai, len); + else if(_ABCount != 0) { + TMP1 = LibSpoofPrimitives.vectMultWrite(a, b[bi++].values(rix), ai, ai, len); + while(bi < _ABCount) { + if(!_Bsupplied && bi == _ABCount - 1) + TMP2 = LibSpoofPrimitives.dotProduct(TMP1, b[bi++].values(rix), 0, ai, len); + else + LibMatrixMult.vectMultiplyWrite(TMP1, b[bi++].values(rix), TMP1, 0, ai, 0, len); + } + } - if(_BCount == 1) - if(_ABCount != 0) - TMP2 = LibSpoofPrimitives.dotProduct(TMP1,b[bi++].values(0),0,0,len); - else - TMP2 = LibSpoofPrimitives.dotProduct(a,b[bi++].values(0),ai,0,len); - else if(_ABCount == 0) - TMP2 = LibSpoofPrimitives.vectSum(a, ai, len); + if(_Bsupplied) + if(_ABCount != 0) TMP2 = LibSpoofPrimitives.dotProduct(TMP1, b[bi++].values(0), 0, 0, len); + else TMP2 = LibSpoofPrimitives.dotProduct(a, b[bi++].values(0), ai, 0, len); + else if(_ABCount == 0) TMP2 = LibSpoofPrimitives.vectSum(a, ai, len); - if(_ACount == 1) - TMP2 *= b[bi].values(0)[rix]; + if(_ACount == 1) TMP2 *= b[bi].values(0)[rix]; - if (_EinsumRewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A) c[ci] = TMP2; - else c[0] += TMP2; - } + if(_EinsumRewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A) c[ci] = TMP2; + else c[0] += TMP2; + } private void HARDCODEDgenexec_A_or_(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, - int len, long grix, int rix) { - double[] TMP1 = LibSpoofPrimitives.vectMultWrite(a,b[0].values(rix),ai,ai,len); - double TMP2 = LibSpoofPrimitives.dotProduct(TMP1,b[1].values(0),0,ai,len); - TMP2 *= b[2].values(0)[rix]; - c[rix] = TMP2; + int len, long grix, int rix) { + double[] TMP1 = LibSpoofPrimitives.vectMultWrite(a, b[0].values(rix), ai, ai, len); + double TMP2 = LibSpoofPrimitives.dotProduct(TMP1, b[1].values(0), 0, ai, len); + TMP2 *= b[2].values(0)[rix]; + c[rix] = TMP2; } - protected void genexec(double[] avals, int[] aix, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int alen, int len, long grix, int rix) { + protected void genexec(double[] avals, int[] aix, int ai, SideInput[] b, double[] scalars, double[] c, int ci, + int alen, int len, long grix, int rix) { throw new RuntimeException("Sparse fused einsum not implemented"); - } - + } // I am not sure if it is worth copying to LibMatrixMult so for now added it here private static final VectorSpecies SPECIES = DoubleVector.SPECIES_PREFERRED; private static final int vLen = SPECIES.length(); - public static void vectMultiplyWrite( final double aval, double[] a, double[] b, double[] c,int ai, int bi, int ci, final int len ) - { - final int bn = len%vLen; + + public static void vectMultiplyWrite(final double aval, double[] a, double[] b, double[] c, int ai, int bi, int ci, + final int len) { + final int bn = len % vLen; //rest, not aligned to vLen-blocks - for( int j = 0; j < bn; j++, ai++, bi++, ci++) - c[ ci ] = aval * b[ bi ] * a[ ai ]; + for(int j = 0; j < bn; j++, ai++, bi++, ci++) + c[ci] = aval * b[bi] * a[ai]; //unrolled vLen-block (for better instruction-level parallelism) DoubleVector avalVec = DoubleVector.broadcast(SPECIES, aval); - for( int j = bn; j < len; j+=vLen, ai+=vLen, bi+=vLen, ci+=vLen) - { + for(int j = bn; j < len; j += vLen, ai += vLen, bi += vLen, ci += vLen) { DoubleVector aVec = DoubleVector.fromArray(SPECIES, a, ai); DoubleVector bVec = DoubleVector.fromArray(SPECIES, b, bi); avalVec.mul(bVec).mul(aVec).intoArray(c, ci); } } - public static void vectMultiplyAdd( final double aval, double[] a, double[] b, double[] c,int ai, int bi, int ci, final int len ) - { - final int bn = len%vLen; + public static void vectMultiplyAdd(final double aval, double[] a, double[] b, double[] c, int ai, int bi, int ci, + final int len) { + final int bn = len % vLen; //rest, not aligned to vLen-blocks - for( int j = 0; j < bn; j++, ai++, bi++, ci++) - c[ ci ] += aval * b[ bi ] * a[ ai ]; + for(int j = 0; j < bn; j++, ai++, bi++, ci++) + c[ci] += aval * b[bi] * a[ai]; //unrolled vLen-block (for better instruction-level parallelism) DoubleVector avalVec = DoubleVector.broadcast(SPECIES, aval); - for( int j = bn; j < len; j+=vLen, ai+=vLen, bi+=vLen, ci+=vLen) - { + for(int j = bn; j < len; j += vLen, ai += vLen, bi += vLen, ci += vLen) { DoubleVector aVec = DoubleVector.fromArray(SPECIES, a, ai); DoubleVector bVec = DoubleVector.fromArray(SPECIES, b, bi); DoubleVector cVec = DoubleVector.fromArray(SPECIES, c, ci); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java index 05a2e8519b2..2cae4bab272 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java @@ -43,18 +43,18 @@ import org.apache.sysds.utils.Explain; import java.util.*; -import java.util.stream.Collectors; import static org.apache.sysds.api.DMLScript.EXPLAIN; import static org.apache.sysds.hops.rewrite.RewriteMatrixMultChainOptimization.mmChainDP; public class EinsumCPInstruction extends BuiltinNaryCPInstruction { public static final boolean FORCE_CELL_TPL = false; -// public static final boolean FUSED = true; - public static final boolean FUSE_OUTER_MULTIPLY = true; + public static final boolean FUSE_OUTER_MULTIPLY = true; + public static final boolean FUSE_OUTER_MULTIPLY_EXCEEDS_L2_CACHE_CHECK = true; - public static final boolean PRINT_TRACE = true; + + public static final boolean PRINT_TRACE = false; protected static final Log LOG = LogFactory.getLog(EinsumCPInstruction.class.getName()); public String eqStr; @@ -134,24 +134,26 @@ public void processInstruction(ExecutionContext ec) { if(true){ // new way: search for fusions and matrix-multiplications chain in a loop plan = generatePlanFusionAndMM(eOpNodes, eOpNodesScalars, einc.charToDimensionSize, characterToOccurences, einc.outChar1, einc.outChar2); }else { // old way: try to do fusion first and then rest in binary fashion cost based - if(true /*FUSED*/) { + List fuseOps; + do { ret = new ArrayList<>(); - EOpNodeFuse fuse = EOpNodeFuse.match(eOpNodes, einc.outChar1, einc.outChar2, - einc.charToDimensionSize, characterToOccurences, ret); - if(fuse != null) { - ret.add(fuse); - eOpNodes = ret; - } - while(ret.size() > 2 && fuse != null) { - ret = new ArrayList<>(); - fuse = EOpNodeFuse.match(eOpNodes, einc.outChar1, einc.outChar2, einc.charToDimensionSize, - characterToOccurences, ret); - if(fuse != null) { - ret.add(fuse); - eOpNodes = ret; + fuseOps = EOpNodeFuse.findFuseOps(eOpNodes, einc.outChar1, einc.outChar2, einc.charToDimensionSize, characterToOccurences, ret); + + if(!fuseOps.isEmpty()) { + for (EOpNodeFuse fuseOp : fuseOps) { + if (fuseOp.c1 == null) { + eOpNodesScalars.add(fuseOp); + continue; + } + ret.add(fuseOp); +// if (fuseOp.c2 != null) { +// characterToOccurences.put(fuseOp.c2, characterToOccurences.get(fuseOp.c2)+1); +// } +// characterToOccurences.put(fuseOp.c1, characterToOccurences.get(fuseOp.c1)+1); } + eOpNodes = ret; } - } + } while(eOpNodes.size() > 1 && !fuseOps.isEmpty()); Pair> costAndPlan = generateBinaryPlanCostBased(0, eOpNodes, einc.charToDimensionSize, characterToOccurences, einc.outChar1, einc.outChar2); @@ -304,13 +306,13 @@ else if(plan instanceof EOpNodeFuse fuse){ cost = switch (fuse.einsumRewriteType) { case AB_BA_B_A__ -> 1; // thisSize case AB_BA_B_A__AB -> thisSize; - case AB_BA_B_A__B -> thisSize; + case AB_BA_A__B -> thisSize; case AB_BA_B_A__A -> 2; // intermediate is scalar, 2 because if there is some real scalar case AB_BA_B_A_AZ__Z -> 2; // intermediate is scalar - case AB_BA_B_A_AZ__BZ -> thisSize; - case AB_BA_B_A_AZ__ZB -> thisSize; + case AB_BA_A_AZ__BZ -> thisSize; + case AB_BA_A_AZ__ZB -> thisSize; }; - inputs = fuse.operands.stream().flatMap(List::stream).collect(Collectors.toList()); + inputs = fuse.getAllOps(); } for(EOpNode inp : inputs){ @@ -429,16 +431,26 @@ private static List generatePlanFusionAndMM(ArrayList eOpNodes while(lastNumOfOperands != eOpNodes.size() && eOpNodes.size() > 1){ lastNumOfOperands = eOpNodes.size(); - EOpNodeFuse fuse = null; + List fuseOps; do { ret = new ArrayList<>(); - fuse = EOpNodeFuse.match(eOpNodes, outChar1, outChar2, charToSizeMap, charToOccurences, ret); - if(fuse != null) { - if(fuse.c1 == null) eOpNodesScalars.add(fuse); - else ret.add(fuse); + fuseOps = EOpNodeFuse.findFuseOps(eOpNodes, outChar1, outChar2, charToSizeMap, charToOccurences, ret); + + if(!fuseOps.isEmpty()) { + for (EOpNodeFuse fuseOp : fuseOps) { + if (fuseOp.c1 == null) { + eOpNodesScalars.add(fuseOp); + continue; + } + ret.add(fuseOp); +// if (fuseOp.c2 != null) { +// charToOccurences.put(fuseOp.c2, charToOccurences.get(fuseOp.c2)+1); +// } +// charToOccurences.put(fuseOp.c1, charToOccurences.get(fuseOp.c1)+1); + } eOpNodes = ret; } - } while(eOpNodes.size() > 1 && fuse != null); + } while(eOpNodes.size() > 1 && !fuseOps.isEmpty()); ret = new ArrayList<>(); addVectorMultiplies(eOpNodes, eOpNodesScalars,charToOccurences, outChar1, outChar2, ret); From 9c47ddea2ad42f103334fa0c147390b1574f95b0 Mon Sep 17 00:00:00 2001 From: Hubert Krawczyk Date: Sun, 21 Dec 2025 17:05:26 +0100 Subject: [PATCH 13/13] remove comments --- .../java/org/apache/sysds/hops/NaryOp.java | 2 +- .../apache/sysds/runtime/einsum/EOpNode.java | 22 +- .../sysds/runtime/einsum/EOpNodeBinary.java | 52 +++-- .../sysds/runtime/einsum/EOpNodeData.java | 12 +- .../sysds/runtime/einsum/EOpNodeFuse.java | 168 +++++++-------- .../sysds/runtime/einsum/EOpNodeUnary.java | 30 ++- .../instructions/cp/EinsumCPInstruction.java | 78 ++++--- .../test/functions/einsum/EinsumTest.java | 199 +++++++++--------- ...codegen.xml => SystemDS-config-einsum.xml} | 9 +- 9 files changed, 298 insertions(+), 274 deletions(-) rename src/test/scripts/functions/einsum/{SystemDS-config-codegen.xml => SystemDS-config-einsum.xml} (83%) diff --git a/src/main/java/org/apache/sysds/hops/NaryOp.java b/src/main/java/org/apache/sysds/hops/NaryOp.java index 6962beadcbc..d752316a526 100644 --- a/src/main/java/org/apache/sysds/hops/NaryOp.java +++ b/src/main/java/org/apache/sysds/hops/NaryOp.java @@ -165,7 +165,7 @@ else if ( areDimsBelowThreshold() ) setRequiresRecompileIfNecessary(); //ensure cp exec type for single-node operations - if ( _op == OpOpN.PRINTF || _op == OpOpN.EVAL || _op == OpOpN.LIST + if ( _op == OpOpN.PRINTF || _op == OpOpN.EVAL || _op == OpOpN.LIST || _op == OpOpN.EINSUM //TODO: cbind/rbind of lists only support in CP right now || (_op == OpOpN.CBIND && getInput().get(0).getDataType().isList()) || (_op == OpOpN.RBIND && getInput().get(0).getDataType().isList()) diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java index b402ec634bf..29c3187c3e1 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNode.java @@ -24,6 +24,8 @@ import scala.Int; import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; public abstract class EOpNode { public Character c1; @@ -37,14 +39,28 @@ public EOpNode(Character c1, Character c2, Integer dim1, Integer dim2) { this.dim2 = dim2; } - @Override - public String toString() { + public String getOutputString() { if(c1 == null) return "''"; if(c2 == null) return c1.toString(); return c1.toString() + c2.toString(); } + public abstract List getChildren(); - public abstract String[] recursivePrintString(); + public String[] recursivePrintString(){ + ArrayList inpStrings = new ArrayList<>(); + for (EOpNode node : getChildren()) { + inpStrings.add(node.recursivePrintString()); + } + String[] inpRes = inpStrings.stream().flatMap(Arrays::stream).toArray(String[]::new); + String[] res = new String[1 + inpRes.length]; + + res[0] = this.toString(); + + for (int i=0; i inputs, int numOfThreads, Log LOG); diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java index eba00ae442b..d2917121690 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeBinary.java @@ -46,14 +46,14 @@ public class EOpNodeBinary extends EOpNode { - public enum EBinaryOperand { // upper case: char has to remain, lower case: to be summed + public enum EBinaryOperand { // upper case: char remains, lower case: summed (reduced) dimension ////// mm: ////// Ba_aC, // -> BC aB_Ca, // -> CB Ba_Ca, // -> BC aB_aC, // -> BC - ////// elementwisemult and sums ////// + ////// element-wise multiplications and sums ////// aB_aB,// elemwise and colsum -> B Ab_Ab, // elemwise and rowsum ->A Ab_bA, // elemwise, either colsum or rowsum -> A @@ -169,18 +169,12 @@ public static EOpNodeBinary combineMatrixMultiply(EOpNode left, EOpNode right) { } @Override - public String[] recursivePrintString() { - String[] left = this.left.recursivePrintString(); - String[] right = this.right.recursivePrintString(); - String[] res = new String[left.length + right.length+1]; - res[0] = this.getClass().getSimpleName()+" ("+ operand.toString()+") "+this.toString(); - for (int i=0; i getChildren() { + return List.of(this.left, this.right); + } + @Override + public String toString() { + return this.getClass().getSimpleName()+" ("+ operand.toString()+") "+getOutputString(); } @Override @@ -193,8 +187,6 @@ public MatrixBlock computeEOpNode(ArrayList inputs, int numThreads, MatrixBlock res; - if(LOG.isTraceEnabled()) LOG.trace("computing binary "+bin.left +","+bin.right +"->"+bin); - switch (bin.operand){ case AB_AB -> { res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock()); @@ -212,22 +204,28 @@ public MatrixBlock computeEOpNode(ArrayList inputs, int numThreads, res.getDenseBlockValues()[0] = LibMatrixMult.dotProduct(left.getDenseBlockValues(), right.getDenseBlockValues(), 0,0 , left.getNumRows()); } case Ab_Ab -> { - res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A, List.of(left, right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),null,numThreads); + res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A, List.of(left, right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), + null, numThreads); } case aB_aB -> { - res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_A__B, List.of(left, right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),null,numThreads); + res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_A__B, List.of(left, right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), + null, numThreads); } case ab_ab -> { - res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__, List.of(left, right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),null,numThreads); + res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__, List.of(left, right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), + null, numThreads); } case ab_ba -> { - res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__, List.of(left), List.of(right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),null,numThreads); + res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__, List.of(left), List.of(right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), + null, numThreads); } case Ab_bA -> { - res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A, List.of(left), List.of(right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),null,numThreads); + res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A, List.of(left), List.of(right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), + null, numThreads); } case aB_Ba -> { - res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_A__B, List.of(left), List.of(right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),null,numThreads); + res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_A__B, List.of(left), List.of(right), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), + null, numThreads); } case AB_BA -> { ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads); @@ -271,14 +269,16 @@ public MatrixBlock computeEOpNode(ArrayList inputs, int numThreads, res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); } case Ab_b -> { - res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A, List.of(left), new ArrayList<>(), List.of(right), new ArrayList<>(), new ArrayList<>(),null,numThreads); + res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A, List.of(left), new ArrayList<>(), List.of(right), new ArrayList<>(), new ArrayList<>(), + null, numThreads); } case AB_A -> { ensureMatrixBlockColumnVector(right); res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right); } case aB_a -> { - res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_A__B, List.of(left), new ArrayList<>(), new ArrayList<>(), List.of(right), new ArrayList<>(),null,numThreads); + res = EOpNodeFuse.compute(EOpNodeFuse.EinsumRewriteType.AB_BA_A__B, List.of(left), new ArrayList<>(), new ArrayList<>(), List.of(right), new ArrayList<>(), + null, numThreads); } case A_B -> { ensureMatrixBlockColumnVector(left); @@ -427,10 +427,6 @@ else if (n1.c2 == n2.c1) { return null; // AB_B }else{ return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c2), EBinaryOperand.Ba_aC, Pair.of(n1.c1, n2.c2)); - // if(n1.c1 ==outChar1 || n1.c1==outChar2|| charToOccurences.get(n1.c1) > 2){ - // return null; // AB_B - // } - // return Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), EBinaryOperand.Ba_a, Pair.of(n1.c1, null)); } } if(n1.c1 == n2.c2) { diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java index d352586a21e..fd710d19d1b 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeData.java @@ -23,6 +23,7 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock; import java.util.ArrayList; +import java.util.List; public class EOpNodeData extends EOpNode { public int matrixIdx; @@ -30,11 +31,14 @@ public EOpNodeData(Character c1, Character c2, Integer dim1, Integer dim2, int m super(c1,c2,dim1,dim2); this.matrixIdx = matrixIdx; } + + @Override + public List getChildren() { + return List.of(); + } @Override - public String[] recursivePrintString() { - String[] res = new String[1]; - res[0] = this.getClass().getSimpleName()+" ("+matrixIdx+") "+this.toString(); - return res; + public String toString() { + return this.getClass().getSimpleName()+" ("+matrixIdx+") "+getOutputString(); } @Override public MatrixBlock computeEOpNode(ArrayList inputs, int numOfThreads, Log LOG) { diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java index b7237301d59..1107462caf9 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeFuse.java @@ -39,6 +39,7 @@ import java.util.HashMap; import java.util.HashSet; import java.util.List; +import java.util.function.Function; import static org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction.ensureMatrixBlockColumnVector; import static org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction.ensureMatrixBlockRowVector; @@ -67,15 +68,15 @@ public enum EinsumRewriteType{ public List Bs; public List As; public List AZs; -// public List Zs; -// public final List> operands; - public List getAllOps(){ + @Override + public List getChildren(){ List all = new ArrayList<>(); all.addAll(ABs); all.addAll(BAs); all.addAll(Bs); all.addAll(As); all.addAll(AZs); + if (scalar != null) all.add(scalar); return all; }; private EOpNodeFuse(Character c1, Character c2, Integer dim1, Integer dim2, EinsumRewriteType einsumRewriteType, List ABs, List BAs, List Bs, List As, List AZs) { @@ -86,9 +87,8 @@ private EOpNodeFuse(Character c1, Character c2, Integer dim1, Integer dim2, Eins this.Bs = Bs; this.As = As; this.AZs = AZs; -// this.Zs = Zs; } - public EOpNodeFuse(EinsumRewriteType einsumRewriteType, List ABs, List BAs, List Bs, List As, List AZs) { + public EOpNodeFuse(EinsumRewriteType einsumRewriteType, List ABs, List BAs, List Bs, List As, List AZs, List, List>> AXsAndXs) { super(null,null,null, null); switch(einsumRewriteType) { case AB_BA_B_A__A->{ @@ -124,29 +124,13 @@ public EOpNodeFuse(EinsumRewriteType einsumRewriteType, List ABs, List< this.Bs = Bs; this.As = As; this.AZs = AZs; -// this.Zs = Zs; -// this.operands = Arrays.asList(operands); } - @Override - public String[] recursivePrintString() { - ArrayList inpStrings = new ArrayList<>(); - for (EOpNode node : getAllOps()) { - inpStrings.add(node.recursivePrintString()); - } - String[] inpRes = inpStrings.stream().flatMap(Arrays::stream).toArray(String[]::new); - String[] scalarRes = this.scalar==null ? new String[]{} : this.scalar.recursivePrintString(); - String[] res = new String[1 + inpRes.length + scalarRes.length]; - - res[0] = this.getClass().getSimpleName()+" ("+einsumRewriteType.toString()+") "+this.toString(); - for (int i=0; i findFuseOps(ArrayList operands, Character outChar1, Character outChar2,/*, Set simplySummableChars,*/ + public static List findFuseOps(ArrayList operands, Character outChar1, Character outChar2, HashMap charToSize, HashMap charToOccurences, ArrayList ret) { ArrayList result = new ArrayList<>(); HashSet matricesChars = new HashSet<>(); @@ -193,7 +177,6 @@ public static List findFuseOps(ArrayList operands, Charact .map(x -> Pair.of(charsToMatrices.get(x).get(0).dim1 * charsToMatrices.get(x).get(0).dim2, x)).toList()); matricesCharsSorted.sort(Comparator.comparing(Pair::getLeft)); ArrayList AZs = new ArrayList<>(); - ArrayList Zs = new ArrayList<>(); HashSet usedMatricesChars = new HashSet<>(); HashSet usedOperands = new HashSet<>(); @@ -206,73 +189,84 @@ public static List findFuseOps(ArrayList operands, Charact String AB = ABCandidate; String BA = "" + b + a; - int BAsCounter = (charsToMatrices.containsKey(BA) ? charsToMatrices.get(BA).size() : 0); - int ABsCounter = charsToMatrices.get(AB).size(); + int BAsCount = (charsToMatrices.containsKey(BA) ? charsToMatrices.get(BA).size() : 0); + int ABsCount = charsToMatrices.get(AB).size(); - if(BAsCounter > ABsCounter + 1) { + if(BAsCount > ABsCount + 1) { BA = "" + a + b; AB = "" + b + a; char tmp = a; a = b; b = tmp; - int tmp2 = ABsCounter; - ABsCounter = BAsCounter; - BAsCounter = tmp2; + int tmp2 = ABsCount; + ABsCount = BAsCount; + BAsCount = tmp2; } String A = "" + a; String B = "" + b; - ArrayList Bs = !charsToMatrices.containsKey(B) || usedMatricesChars.contains(B) ? new ArrayList<>() : charsToMatrices.get(B); - ArrayList As = !charsToMatrices.containsKey(A) || usedMatricesChars.contains(A) ? new ArrayList<>() : charsToMatrices.get(A); - int AsCounter = As.size(); - int BsCounter = Bs.size(); + int AsCount = (charsToMatrices.containsKey(A) && !usedMatricesChars.contains(A) ? charsToMatrices.get(A).size() : 0); + int BsCount = (charsToMatrices.containsKey(B) && !usedMatricesChars.contains(B) ? charsToMatrices.get(B).size() : 0); - if(AsCounter == 0 && BsCounter == 0 && (ABsCounter + BAsCounter) < 2) { // no elementwise multiplication possible + if(AsCount == 0 && BsCount == 0 && (ABsCount + BAsCount) < 2) { // no elementwise multiplication possible continue; } - int usedBsCount = BsCounter + ABsCounter + BAsCounter; + int usedBsCount = BsCount + ABsCount + BAsCount; boolean doSumA = false; boolean doSumB = charToOccurences.get(b) == usedBsCount && (outChar1 == null || b != outChar1) && (outChar2 == null || b != outChar2); -// boolean doSumZ = false; // there could be multiple AZ-s if Z is summed but for now it is limited to one HashSet AZCandidates = matricesCharsStartingWithChar.get(a); - boolean includeAZ = AZCandidates.size() == 2; // 2 because it also contains AB String AZ = null; Character z = null; + boolean includeAZ = AZCandidates.size() == 2; + if(includeAZ) { - var it = AZCandidates.iterator(); AZ = it.next(); - if(AZ.charAt(1) == b) AZ = it.next(); // AB was chosen instead of AZ - AZs = charsToMatrices.get(AZ); - z = AZ.charAt(1); -// String Z = "" + z; -// Zs = charsToMatrices.get(Z); - if(usedMatricesChars.contains(AZ)) { includeAZ = false; } - int AZsCounter = AZs.size(); - doSumA = charToOccurences.get(a) == AsCounter + ABsCounter + BAsCounter + AZsCounter && (outChar1 == null || a != outChar1) && (outChar2 == null || a != outChar2); -// doSumZ = charToOccurences.get(z) == AZsCounter + Zs.size(); - if(!doSumA) { - includeAZ = false; - } - else if(!doSumB) { // check if outer is possible AB,...,AZ->BZ - if(!EinsumCPInstruction.FUSE_OUTER_MULTIPLY - || (EinsumCPInstruction.FUSE_OUTER_MULTIPLY_EXCEEDS_L2_CACHE_CHECK && ((charToSize.get(a) * charToSize.get(b) *(ABsCounter + BAsCounter)) + (charToSize.get(a)*charToSize.get(z)*(AZsCounter))) * 8 < 6 * 1024 * 1024) - || !LibMatrixMult.isSkinnyRightHandSide(charToSize.get(AB.charAt(0)), charToSize.get(AB.charAt(1)), - charToSize.get(AB.charAt(0)), charToSize.get(AZCandidates.iterator().next().charAt(1)), - false)) { + for(var AZCandidate : AZCandidates) { + if(AB.equals(AZCandidate)) {continue;} + AZs = charsToMatrices.get(AZCandidate); + z = AZCandidate.charAt(1); + String Z = "" + z; + AZ = "" + a + z; + int AZsCount= AZs.size(); + int ZsCount= charsToMatrices.containsKey(Z) ? charsToMatrices.get(Z).size() : 0; + doSumA = AZsCount + ABsCount + BAsCount + AsCount == charToOccurences.get(a) && (outChar1 == null || a != outChar1) && (outChar2 == null || a != outChar2); + boolean doSumZ = AZsCount + ZsCount == charToOccurences.get(z) && (outChar1 == null || z != outChar1) && (outChar2 == null || z != outChar2); + if(!doSumA){ includeAZ = false; + } else if(!doSumB && doSumZ){ // swap the order, to have only one fusion AB,...,AZ->Z + b = z; + z = AB.charAt(1); + AB = "" + a + b; + BA = "" + b + a; + A = "" + a; + B = "" + b; + AZ = "" + a + z; + AZs = charsToMatrices.get(AZ); + doSumB = true; + } else if(!doSumB && !doSumZ){ // outer between B and Z + if(!EinsumCPInstruction.FUSE_OUTER_MULTIPLY + || (EinsumCPInstruction.FUSE_OUTER_MULTIPLY_EXCEEDS_L2_CACHE_CHECK && ((charToSize.get(a) * charToSize.get(b) *(ABsCount + BAsCount)) + (charToSize.get(a)*charToSize.get(z)*(AZsCount))) * 8 < 6 * 1024 * 1024) + || !LibMatrixMult.isSkinnyRightHandSide(charToSize.get(AB.charAt(0)), charToSize.get(AB.charAt(1)), charToSize.get(AB.charAt(0)), charToSize.get(AZCandidates.iterator().next().charAt(1)),false)) { + includeAZ = false; + } + } else if(doSumB && doSumZ){ + // it will be two separate templates and then mutliply a vectors + } else if (doSumB && !doSumZ) { + // ->Z template OK } + break; } - // else AB,...,AZ-> Z possible } if(!includeAZ) { - doSumA = charToOccurences.get(a) == AsCounter + ABsCounter + BAsCounter && (outChar1 == null || a != outChar1) && (outChar2 == null || a != outChar2); + doSumA = charToOccurences.get(a) == AsCount + ABsCount + BAsCount && (outChar1 == null || a != outChar1) && (outChar2 == null || a != outChar2); } ArrayList ABs = charsToMatrices.containsKey(AB) ? charsToMatrices.get(AB) : new ArrayList<>(); ArrayList BAs = charsToMatrices.containsKey(BA) ? charsToMatrices.get(BA) : new ArrayList<>(); - + ArrayList As = charsToMatrices.containsKey(A) && !usedMatricesChars.contains(A) ? charsToMatrices.get(A) : new ArrayList<>(); + ArrayList Bs = charsToMatrices.containsKey(B) && !usedMatricesChars.contains(B) ? charsToMatrices.get(B) : new ArrayList<>(); Character c1 = null, c2 = null; Integer dim1 = null, dim2 = null; EinsumRewriteType type = null; @@ -332,12 +326,6 @@ else if(doSumB) { if (includeB) usedOperands.addAll(Bs); if (includeAZ) usedOperands.addAll(AZs); -// if(type == EinsumRewriteType.AB_BA_B_A_AZ__Z && AZs.size() > 1){ // multiply all AZs if multiple -// EOpNodeFuse fuseAZs = new EOpNodeFuse(EinsumRewriteType.AB_BA_B_A__AB, new ArrayList<>(AZs), new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), new ArrayList<>()); -// AZs = new ArrayList<>(); -// AZs.add(fuseAZs); -// } - usedMatricesChars.add(AB); usedMatricesChars.add(BA); usedMatricesChars.add(A); @@ -352,7 +340,7 @@ else if(doSumB) { for(EOpNode n : operands) { if(!usedOperands.contains(n)){ ret.add(n); - }else{ + } else { charToOccurences.put(n.c1, charToOccurences.get(n.c1) - 1); if(charToOccurences.get(n.c2)!= null) charToOccurences.put(n.c2, charToOccurences.get(n.c2)-1); @@ -361,7 +349,8 @@ else if(doSumB) { return result; } - public static MatrixBlock compute(EinsumRewriteType rewriteType, List ABsInput, List mbBAs, List mbBs, List mbAs, List mbAZs, Double scalar, int numThreads){ + public static MatrixBlock compute(EinsumRewriteType rewriteType, List ABsInput, List mbBAs, List mbBs, List mbAs, List mbAZs, + Double scalar, int numThreads){ boolean isResultAB =rewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__AB; boolean isResultA = rewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_B_A__A; boolean isResultB = rewriteType == EOpNodeFuse.EinsumRewriteType.AB_BA_A__B; @@ -385,15 +374,8 @@ public static MatrixBlock compute(EinsumRewriteType rewriteType, List { - constDim2 = mbAZs.get(0).getNumColumns(); - zSize = mbAZs.get(0).getNumColumns(); - azCount = mbAZs.size(); - // if (mbZs != null) zCount = mbZs.size(); - } - case AB_BA_A_AZ__BZ, AB_BA_A_AZ__ZB -> { + case AB_BA_B_A_AZ__Z, AB_BA_A_AZ__BZ, AB_BA_A_AZ__ZB -> { constDim2 = mbAZs.get(0).getNumColumns(); zSize = mbAZs.get(0).getNumColumns(); azCount = mbAZs.size(); @@ -445,22 +427,23 @@ public static MatrixBlock compute(EinsumRewriteType rewriteType, List inputs, int numThreads, Log LOG) { - ArrayList mbABs = new ArrayList<>(ABs.stream().map(n -> n.computeEOpNode(inputs, numThreads, LOG)).toList()); - List mbBAs = BAs.stream().map(n -> n.computeEOpNode(inputs, numThreads, LOG)).toList(); - List mbBs = Bs.stream().map(n -> n.computeEOpNode(inputs, numThreads, LOG)).toList(); - List mbAs = As.stream().map(n -> n.computeEOpNode(inputs, numThreads, LOG)).toList(); - List mbAZs = AZs.stream().map(n -> n.computeEOpNode(inputs, numThreads, LOG)).toList(); + final Function eOpNodeToMatrixBlock = n -> n.computeEOpNode(inputs, numThreads, LOG); + ArrayList mbABs = new ArrayList<>(ABs.stream().map(eOpNodeToMatrixBlock).toList()); + List mbBAs = BAs.stream().map(eOpNodeToMatrixBlock).toList(); + List mbBs = Bs.stream().map(eOpNodeToMatrixBlock).toList(); + List mbAs = As.stream().map(eOpNodeToMatrixBlock).toList(); + List mbAZs = AZs.stream().map(eOpNodeToMatrixBlock).toList(); Double scalar = this.scalar == null ? null : this.scalar.computeEOpNode(inputs, numThreads, LOG).get(0,0); - return EOpNodeFuse.compute(this.einsumRewriteType, mbABs, mbBAs, mbBs, mbAs, mbAZs, scalar, numThreads); + return EOpNodeFuse.compute(this.einsumRewriteType, mbABs, mbBAs, mbBs, mbAs, mbAZs , scalar, numThreads); } @Override public EOpNode reorderChildrenAndOptimize(EOpNode parent, Character outChar1, Character outChar2) { - for(int i = 0; i < ABs.size(); i++) ABs.set(i,ABs.get(i).reorderChildrenAndOptimize(this, ABs.get(i).c1, ABs.get(i).c2)); - for(int i = 0; i < BAs.size(); i++) BAs.set(i,BAs.get(i).reorderChildrenAndOptimize(this, BAs.get(i).c1, BAs.get(i).c2)); - for(int i = 0; i < As.size(); i++) As.set(i,As.get(i).reorderChildrenAndOptimize(this, As.get(i).c1, As.get(i).c2)); - for(int i = 0; i < Bs.size(); i++) Bs.set(i,Bs.get(i).reorderChildrenAndOptimize(this, Bs.get(i).c1, Bs.get(i).c2)); - for(int i = 0; i < AZs.size(); i++) AZs.set(i,AZs.get(i).reorderChildrenAndOptimize(this, AZs.get(i).c1, AZs.get(i).c2)); + ABs.replaceAll(n -> n.reorderChildrenAndOptimize(this, n.c1, n.c2)); + BAs.replaceAll(n -> n.reorderChildrenAndOptimize(this, n.c1, n.c2)); + As.replaceAll(n -> n.reorderChildrenAndOptimize(this, n.c1, n.c2)); + Bs.replaceAll(n -> n.reorderChildrenAndOptimize(this, n.c1, n.c2)); + AZs.replaceAll(n -> n.reorderChildrenAndOptimize(this, n.c1, n.c2)); return this; } @@ -468,11 +451,10 @@ public EOpNode reorderChildrenAndOptimize(EOpNode parent, Character outChar1, Ch MatrixBlock mb = new MatrixBlock(mbs.get(0).getNumRows(), mbs.get(0).getNumColumns(), false); mb.allocateDenseBlock(); for(int i = 1; i< mbs.size(); i++) { // multiply Bs - if(i==1){ + if(i==1) LibMatrixMult.vectMultiplyWrite(mbs.get(0).getDenseBlock().values(0), mbs.get(1).getDenseBlock().values(0), mb.getDenseBlock().values(0),0,0,0, size); - }else{ + else LibMatrixMult.vectMultiply(mbs.get(i).getDenseBlock().values(0),mb.getDenseBlock().values(0),0,0, size); - } } return List.of(mb); } diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeUnary.java b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeUnary.java index 7f61bd6fb62..a94bbc95656 100644 --- a/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeUnary.java +++ b/src/main/java/org/apache/sysds/runtime/einsum/EOpNodeUnary.java @@ -20,10 +20,16 @@ package org.apache.sysds.runtime.einsum; import org.apache.commons.logging.Log; +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.common.Types; +import org.apache.sysds.hops.AggUnaryOp; +import org.apache.sysds.hops.rewrite.HopRewriteUtils; import org.apache.sysds.runtime.functionobjects.DiagIndex; +import org.apache.sysds.runtime.functionobjects.KahanPlus; import org.apache.sysds.runtime.functionobjects.Plus; import org.apache.sysds.runtime.functionobjects.ReduceAll; import org.apache.sysds.runtime.functionobjects.ReduceCol; +import org.apache.sysds.runtime.functionobjects.ReduceDiag; import org.apache.sysds.runtime.functionobjects.ReduceRow; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.AggregateOperator; @@ -31,13 +37,14 @@ import org.apache.sysds.runtime.matrix.operators.ReorgOperator; import java.util.ArrayList; +import java.util.List; public class EOpNodeUnary extends EOpNode { private final EUnaryOperand eUnaryOperand; public EOpNode child; public enum EUnaryOperand { - DIAG, SUM, SUM_COLS, SUM_ROWS + DIAG, TRACE, SUM, SUM_COLS, SUM_ROWS } public EOpNodeUnary(Character c1, Character c2, Integer dim1, Integer dim2, EOpNode child, EUnaryOperand eUnaryOperand) { super(c1, c2, dim1, dim2); @@ -46,14 +53,12 @@ public EOpNodeUnary(Character c1, Character c2, Integer dim1, Integer dim2, EOpN } @Override - public String[] recursivePrintString() { - String[] childResult = child.recursivePrintString(); - String[] res = new String[1+childResult.length]; - res[0] = this.getClass().getSimpleName()+" ("+eUnaryOperand.toString()+") "+this.toString(); - for (int i=0; i getChildren() { + return List.of(child); + } + @Override + public String toString() { + return this.getClass().getSimpleName()+" ("+eUnaryOperand.toString()+") "+this.getOutputString(); } @Override @@ -64,6 +69,13 @@ public MatrixBlock computeEOpNode(ArrayList inputs, int numOfThread ReorgOperator op = new ReorgOperator(DiagIndex.getDiagIndexFnObject()); yield mb.reorgOperations(op, new MatrixBlock(),0,0,0); } + case TRACE -> { + AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), Types.CorrectionLocationType.LASTCOLUMN); + AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceDiag.getReduceDiagFnObject(), numOfThreads); + MatrixBlock res = new MatrixBlock(10, 10, false); + mb.aggregateUnaryOperations(aggun, res,0,null); + yield res; + } case SUM->{ AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), numOfThreads); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java index 2cae4bab272..b8af61d35b4 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java @@ -48,13 +48,12 @@ import static org.apache.sysds.hops.rewrite.RewriteMatrixMultChainOptimization.mmChainDP; public class EinsumCPInstruction extends BuiltinNaryCPInstruction { - public static final boolean FORCE_CELL_TPL = false; + public static final boolean FORCE_CELL_TPL = false; // naive looped solution public static final boolean FUSE_OUTER_MULTIPLY = true; public static final boolean FUSE_OUTER_MULTIPLY_EXCEEDS_L2_CACHE_CHECK = true; - - public static final boolean PRINT_TRACE = false; + public static final boolean PRINT_TRACE = true; protected static final Log LOG = LogFactory.getLog(EinsumCPInstruction.class.getName()); public String eqStr; @@ -67,12 +66,8 @@ public EinsumCPInstruction(Operator op, String opcode, String istr, CPOperand ou _numThreads = OptimizerUtils.getConstrainedNumThreads(-1)/2; _in = inputs; this.eqStr = inputs[0].getName(); - if (PRINT_TRACE) { -// System.out.println("fusing outer mult:"+FUSE_OUTER_MULTIPLY); + if (PRINT_TRACE) Logger.getLogger(EinsumCPInstruction.class).setLevel(Level.TRACE); - } - else - Logger.getLogger(EinsumCPInstruction.class).setLevel(Level.WARN); } @Override @@ -131,8 +126,8 @@ public void processInstruction(ExecutionContext ec) { ArrayList remainingMatrices; if(!FORCE_CELL_TPL) { - if(true){ // new way: search for fusions and matrix-multiplications chain in a loop - plan = generatePlanFusionAndMM(eOpNodes, eOpNodesScalars, einc.charToDimensionSize, characterToOccurences, einc.outChar1, einc.outChar2); + if(true){ + plan = generateGreedyPlan(eOpNodes, eOpNodesScalars, einc.charToDimensionSize, characterToOccurences, einc.outChar1, einc.outChar2); }else { // old way: try to do fusion first and then rest in binary fashion cost based List fuseOps; do { @@ -146,15 +141,10 @@ public void processInstruction(ExecutionContext ec) { continue; } ret.add(fuseOp); -// if (fuseOp.c2 != null) { -// characterToOccurences.put(fuseOp.c2, characterToOccurences.get(fuseOp.c2)+1); -// } -// characterToOccurences.put(fuseOp.c1, characterToOccurences.get(fuseOp.c1)+1); } eOpNodes = ret; } } while(eOpNodes.size() > 1 && !fuseOps.isEmpty()); - Pair> costAndPlan = generateBinaryPlanCostBased(0, eOpNodes, einc.charToDimensionSize, characterToOccurences, einc.outChar1, einc.outChar2); plan = costAndPlan.getRight(); @@ -304,7 +294,7 @@ private static Pair addScalarToPlanFindMinCost(EOpNode plan, H if (plan instanceof EOpNodeBinary bin) inputs = List.of(bin.left, bin.right); else if(plan instanceof EOpNodeFuse fuse){ cost = switch (fuse.einsumRewriteType) { - case AB_BA_B_A__ -> 1; // thisSize + case AB_BA_B_A__ -> 1; case AB_BA_B_A__AB -> thisSize; case AB_BA_A__B -> thisSize; case AB_BA_B_A__A -> 2; // intermediate is scalar, 2 because if there is some real scalar @@ -312,7 +302,7 @@ else if(plan instanceof EOpNodeFuse fuse){ case AB_BA_A_AZ__BZ -> thisSize; case AB_BA_A_AZ__ZB -> thisSize; }; - inputs = fuse.getAllOps(); + inputs = fuse.getChildren(); } for(EOpNode inp : inputs){ @@ -373,7 +363,7 @@ private void addSumDimensionsDiagonalsAndScalars(EinsumContext einc, ArrayList charToDimensionSize) { for(int i = 0; i< inputStrings.size(); i++){ String s = inputStrings.get(i); - if (s.length() == 0){ + if (s.isEmpty()){ eOpNodesScalars.add(new EOpNodeData(null, null, null, null,i)); inputStrings.set(i, null); continue; @@ -394,7 +384,7 @@ private void addSumDimensionsDiagonalsAndScalars(EinsumContext einc, ArrayList generatePlanFusionAndMM(ArrayList eOpNodes, + private static List generateGreedyPlan(ArrayList eOpNodes, ArrayList eOpNodesScalars, HashMap charToSizeMap, HashMap charToOccurences, Character outChar1, Character outChar2) { ArrayList ret; int lastNumOfOperands = -1; @@ -443,10 +433,6 @@ private static List generatePlanFusionAndMM(ArrayList eOpNodes continue; } ret.add(fuseOp); -// if (fuseOp.c2 != null) { -// charToOccurences.put(fuseOp.c2, charToOccurences.get(fuseOp.c2)+1); -// } -// charToOccurences.put(fuseOp.c1, charToOccurences.get(fuseOp.c1)+1); } eOpNodes = ret; } @@ -470,7 +456,34 @@ private static List generatePlanFusionAndMM(ArrayList eOpNodes return eOpNodes; } - private static EOpNodeBinary optimizeMMChain(List mmChain, HashMap charToSizeMap) { + private static void reverseMMChainIfBeneficial(ArrayList mmChain){ // possibly check the cost instead of number of transposes + char c1 = mmChain.get(0).c1; + char c2 = mmChain.get(0).c2; + int noTransposes = 0; + for (int i=1; i (mmChain.size() / 2 )+1) { + Collections.reverse(mmChain); + } + } + private static EOpNodeBinary optimizeMMChain(List mmChainL, HashMap charToSizeMap) { + ArrayList mmChain = new ArrayList<>(mmChainL); + reverseMMChainIfBeneficial(mmChain); ArrayList> dimensions = new ArrayList<>(); for(int i = 0; i < mmChain.size()-1; i++){ @@ -491,8 +504,7 @@ private static EOpNodeBinary optimizeMMChain(List mmChain, HashMap mmChain) { @@ -558,7 +570,7 @@ private static ArrayList> findMatrixMultiplicationChains(ArrayList for(int i = 0; i < operandsTodo.size(); i++){ EOpNode iterateNode = operandsTodo.get(i); - if (doneNodes.contains(iterateNode)) continue;// was added previously somewhere + if (doneNodes.contains(iterateNode)) continue; // was added previously doneNodes.add(iterateNode); LinkedList multiplies = new LinkedList<>(); @@ -611,7 +623,7 @@ private static ArrayList> findMatrixMultiplicationChains(ArrayList return res; } - // old way + // old way, DFS finds all paths private Pair> generateBinaryPlanCostBased(int cost, ArrayList operands, HashMap charToSizeMap, HashMap charToOccurences, Character outChar1, Character outChar2) { Integer minCost = cost; List minNodes = operands; @@ -629,18 +641,15 @@ private Pair> generateBinaryPlanCostBased(int cost, Array return Pair.of(cost, operands); } else if (operands.size() == 1){ - // check for transpose return Pair.of(cost, operands); } - for(int i = 0; i < operands.size()-1; i++){ for (int j = i+1; j < operands.size(); j++){ boolean swap = (operands.get(i).c2 == null && operands.get(j).c2 != null) || operands.get(i).c1 == null; EOpNode n1 = operands.get(!swap ? i : j); EOpNode n2 = operands.get(!swap ? j : i); - Triple> t = EOpNodeBinary.TryCombineAndCost(n1, n2, charToSizeMap, charToOccurences, outChar1, outChar2); if (t != null){ EOpNodeBinary newNode = new EOpNodeBinary(n1, n2, t.getMiddle()); @@ -650,7 +659,6 @@ else if (operands.size() == 1){ if(n1.c2 != null) charToOccurences.put(n1.c2, charToOccurences.get(n1.c2)-1); if(n2.c1 != null) charToOccurences.put(n2.c1, charToOccurences.get(n2.c1)-1); if(n2.c2 != null) charToOccurences.put(n2.c2, charToOccurences.get(n2.c2)-1); - if(newNode.c1 != null) charToOccurences.put(newNode.c1, charToOccurences.get(newNode.c1)+1); if(newNode.c2 != null) charToOccurences.put(newNode.c2, charToOccurences.get(newNode.c2)+1); @@ -660,8 +668,8 @@ else if (operands.size() == 1){ } newOperands.add(newNode); - Pair> furtherPlan = generateBinaryPlanCostBased(thisCost, newOperands,charToSizeMap, charToOccurences, outChar1, outChar2); - if(furtherPlan.getRight().size() < (minNodes.size()) || furtherPlan.getLeft() < minCost){ + Pair> furtherPlan = generateBinaryPlanCostBased(thisCost, newOperands, charToSizeMap, charToOccurences, outChar1, outChar2); + if(furtherPlan.getRight().size() < minNodes.size() || furtherPlan.getLeft() < minCost){ minCost = furtherPlan.getLeft(); minNodes = furtherPlan.getRight(); } diff --git a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java index 95dfd39aaeb..97e1eb83bf9 100644 --- a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java +++ b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java @@ -47,88 +47,99 @@ public class EinsumTest extends AutomatedTestBase { final private static List TEST_CONFIGS = List.of( - new Config("ij,jk->ik", List.of(shape(5, 6), shape(6, 5))), // mm - new Config("ji,jk->ik", List.of(shape(6, 5), shape(6, 10))), - new Config("ji,kj->ik", List.of(shape(6, 5), shape(10, 6))), - new Config("ij,kj->ik", List.of(shape(5, 6), shape(10, 6))), - new Config("ij,jk->ki", List.of(shape(5, 6), shape(6, 5))), // mm t - new Config("ji,jk->ki", List.of(shape(6, 5), shape(6, 10))), - new Config("ji,kj->ki", List.of(shape(6, 5), shape(10, 6))), - new Config("ij,kj->ki", List.of(shape(5, 6), shape(10, 6))), - new Config("ij,kp,pj->ki", List.of(shape(5,6), shape(5,4), shape(4, 6))), // reordering - new Config("ab,bc,cd,de->ae", List.of(shape(5, 6), shape(6, 5),shape(5, 6), shape(6, 5))), // mm chain - - new Config("ji,jk->i", List.of(shape(6, 5), shape(6, 4))), - new Config("ij,jk->i", List.of(shape(5, 6), shape(6, 4))), - new Config("ji,jk->k", List.of(shape(6, 5), shape(6, 4))), - new Config("ij,jk->k", List.of(shape(5, 6), shape(6, 4))), - new Config("ji,jk->j", List.of(shape(6, 5), shape(6, 4))), - - new Config("ji,ji->ji", List.of(shape(60, 5), shape(60, 5))), // elemwise mult - new Config("ji,ij->ji", List.of(shape(60, 5), shape(5, 60))), // elemwise mult - - new Config("ij,i->ij", List.of(shape(10, 5), shape(10))), // col mult - new Config("ji,i->ij", List.of(shape(5, 10), shape(10))), // row mult - new Config("ij,i->i", List.of(shape(10, 5), shape(10))), - new Config("ij,i->j", List.of(shape(10, 5), shape(10))), - - new Config("i,i->", List.of(shape(5), shape(5))), // dot - new Config("i,j->", List.of(shape(5), shape(80))), // sum - new Config("i,j->ij", List.of(shape(5), shape(80))), // outer vect mult - new Config("i,j->ji", List.of(shape(5), shape(80))), // outer vect mult - - new Config("ij->", List.of(shape(10, 5))), // sum - new Config("i->", List.of(shape(10))), // sum - new Config("ij->i", List.of(shape(10, 5))), // sum(1) - new Config("ij->j", List.of(shape(10, 5))), // sum(0) - new Config("ij->ji", List.of(shape(10, 5))), // T - new Config("ij->ij", List.of(shape(10, 5))), - new Config("i->i", List.of(shape(10))), - new Config("ii->i", List.of(shape(10, 10))), // Diag - new Config("ii,i->i", List.of(shape(10, 10),shape(10))), // Diag*vec - - new Config("ab,cd->ba", List.of(shape( 6, 10), shape(6, 5))), // sum cd to scalar and multiply ab - - new Config("fx,fg,fz,xg,zx,zg->g", // each idx 3 times (cell tpl fallback) - List.of(shape(5, 6), shape(5, 3), shape(5, 10), shape(6, 3), shape(10, 6), shape(10, 3))), - - // test fused: - new Config("ij,ij,ji->ij", List.of(shape(10, 5), shape(10, 5), shape(5, 10))), - new Config("ij,ij,ji,i,j->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))), - new Config("ij,ij,ji,i->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10))), - new Config("ij,ij,ji,j->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(5))), - - new Config("ij,ij,ji->i", List.of(shape(10, 5), shape(10, 5), shape(5, 10))), - new Config("ij,ij,ji,i,j->i", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))), - new Config("ij,ij,ji,i->i", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10))), - new Config("ij,ij,ji,j->i", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(5))), - - new Config("ij,ij,ji->j", List.of(shape(10, 5), shape(10, 5), shape(5, 10))), - new Config("ij,ij,ji,i,j->j", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))), - new Config("ij,ij,ji,i->j", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10))), - new Config("ij,ij,ji,j->j", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(5))), - - new Config("ij,ij,ji->", List.of(shape(10, 5), shape(10, 5), shape(5, 10))), - new Config("ij,ij,ji,i,j->", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))), - new Config("ij,ij,ji,i->", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10))), - new Config("ij,ij,ji,j->", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(5))), - - new Config("ij,ij,ij,i,j,iz->z", List.of(shape(5, 5), shape(5, 5), shape(5, 5),shape(5),shape(5),shape(5, 6))), - new Config("ij,ij,ij,i,j,iz,z->z", List.of(shape(5, 5), shape(5, 5), shape(5, 5),shape(5),shape(5),shape(5, 6),shape(6))), - - new Config("ij,i,j,iz->z", List.of(shape(10, 5),shape(10),shape(5),shape(10, 51))), - new Config("ij,i,j,iz->z", List.of(shape(20, 10),shape(20),shape(10),shape(20, 10))), - - new Config("ij,ij,ji,j,i, ab,ba,ab,a,b->jb", Map.of('i',10, 'j',5, 'a', 11, 'b', 6)), - new Config("ij,ij,ji,j,i,iz->jz", Map.of('i',600, 'j',10,'z', 6)), // //skinny right with outer mm - new Config("ij,ij,ji,j,i,iz->jz", Map.of('i',10, 'j',10,'z', 10)), // // no skinny right - new Config("ij,ij,ji,j,i,iz->zj", Map.of('i',60, 'j',10,'z', 6)), // //skinny right with outer mm - new Config("ij,ij,ij,jk->ik", List.of(shape(10, 5), shape(10, 5),shape(10, 5),shape(5, 10))), - new Config("ab,ab,a,ag,gz->bz", List.of(shape(10, 5), shape(10, 5),shape(10),shape(10,200),shape(200,7))) + new Config("ij,jk->ik", List.of(shape(5, 6), shape(6, 5))), // mm + new Config("ji,jk->ik", List.of(shape(6, 5), shape(6, 10))), + new Config("ji,kj->ik", List.of(shape(6, 5), shape(10, 6))), + new Config("ij,kj->ik", List.of(shape(5, 6), shape(10, 6))), + new Config("ij,jk->ki", List.of(shape(5, 6), shape(6, 5))), // mm t + new Config("ji,jk->ki", List.of(shape(6, 5), shape(6, 10))), + new Config("ji,kj->ki", List.of(shape(6, 5), shape(10, 6))), + new Config("ij,kj->ki", List.of(shape(5, 6), shape(10, 6))), + new Config("ij,kp,pj->ki", List.of(shape(5,6), shape(5,4), shape(4, 6))), // reordering + new Config("ab,bc,cd,de->ae", List.of(shape(5, 6), shape(6, 5),shape(5, 6), shape(6, 5))), // mm chain + new Config("de,cd,bc,ab->ae", List.of(shape(6, 5), shape(5, 6),shape(6, 5), shape(5, 6))), // mm chain + new Config("ab,cb,de,cd->ae", List.of(shape(5, 6), shape(5,6), shape(6, 5),shape(5, 6))), // mm chain + + new Config("ji,jk->i", List.of(shape(6, 5), shape(6, 4))), + new Config("ij,jk->i", List.of(shape(5, 6), shape(6, 4))), + new Config("ji,jk->k", List.of(shape(6, 5), shape(6, 4))), + new Config("ij,jk->k", List.of(shape(5, 6), shape(6, 4))), + new Config("ji,jk->j", List.of(shape(6, 5), shape(6, 4))), + + new Config("ji,ji->ji", List.of(shape(60, 5), shape(60, 5))), // elemwise mult + new Config("ji,ji->j", List.of(shape(60, 5), shape(60, 5))), + new Config("ji,ji->i", List.of(shape(60, 5), shape(60, 5))), + new Config("ji,ij->ji", List.of(shape(60, 5), shape(5, 60))), // elemwise mult + new Config("ji,ij->i", List.of(shape(60, 5), shape(5, 60))), + new Config("ji,ij->j", List.of(shape(60, 5), shape(5, 60))), + + new Config("ij,i->ij", List.of(shape(10, 5), shape(10))), // col mult + new Config("ji,i->ij", List.of(shape(5, 10), shape(10))), // row mult + new Config("ij,i->i", List.of(shape(10, 5), shape(10))), + new Config("ij,i->j", List.of(shape(10, 5), shape(10))), + + new Config("i,i->", List.of(shape(5), shape(5))), // dot + new Config("i,j->", List.of(shape(5), shape(80))), // sum + new Config("i,j->ij", List.of(shape(5), shape(80))), // outer vect mult + new Config("i,j->ji", List.of(shape(5), shape(80))), // outer vect mult + + new Config("ij->", List.of(shape(10, 5))), // sum + new Config("i->", List.of(shape(10))), // sum + new Config("ij->i", List.of(shape(10, 5))), // sum(1) + new Config("ij->j", List.of(shape(10, 5))), // sum(0) + new Config("ij->ji", List.of(shape(10, 5))), // T + new Config("ij->ij", List.of(shape(10, 5))), + new Config("i->i", List.of(shape(10))), + new Config("ii->i", List.of(shape(10, 10))), // Diag + new Config("ii->", List.of(shape(10, 10))), // Trace + new Config("ii,i->i", List.of(shape(10, 10),shape(10))), // Diag*vec + + new Config("ab,cd->ba", List.of(shape( 6, 10), shape(6, 5))), // sum cd to scalar and multiply ab + + new Config("fx,fg,fz,xg,zx,zg->g", // each idx 3 times (the cell tpl fallback) + List.of(shape(5, 6), shape(5, 3), shape(5, 10), shape(6, 3), shape(10, 6), shape(10, 3))), + + // test fused: + new Config("ij,ij,ji->ij", List.of(shape(10, 5), shape(10, 5), shape(5, 10))), + new Config("ij,ij,ji,i,j->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))), + new Config("ij,ij,ji,i->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10))), + new Config("ij,ij,ji,j->ij", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(5))), + + new Config("ij,ij,ji->i", List.of(shape(10, 5), shape(10, 5), shape(5, 10))), + new Config("ij,ij,ji,i,j->i", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))), + new Config("ij,ij,ji,i->i", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10))), + new Config("ij,ij,ji,j->i", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(5))), + + new Config("ij,ij,ji->j", List.of(shape(10, 5), shape(10, 5), shape(5, 10))), + new Config("ij,ij,ji,i,j->j", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))), + new Config("ij,ij,ji,i->j", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10))), + new Config("ij,ij,ji,j->j", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(5))), + + new Config("ij,ij,ji->", List.of(shape(10, 5), shape(10, 5), shape(5, 10))), + new Config("ij,ij,ji,i,j->", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10),shape(5))), + new Config("ij,ij,ji,i->", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(10))), + new Config("ij,ij,ji,j->", List.of(shape(10, 5), shape(10, 5),shape(5, 10),shape(5))), + + new Config("ij,ij,ij,i,j,iz->z", List.of(shape(5, 5), shape(5, 5), shape(5, 5),shape(5),shape(5),shape(5, 6))), + new Config("ij,ij,ij,i,j,iz,z->z", List.of(shape(5, 5), shape(5, 5), shape(5, 5),shape(5),shape(5),shape(5, 6),shape(6))), + + new Config("ij,i,j,iz->z", List.of(shape(10, 5),shape(10),shape(5),shape(10, 51))), + new Config("ij,i,j,iz,iz->z", List.of(shape(10, 5),shape(10),shape(5),shape(10, 51),shape(10, 51))), + new Config("ij,i,j,iz,iz->z", List.of(shape(10, 5),shape(10),shape(5),shape(10, 4),shape(10, 4))), // order swapped because sizeof(iz) < sizeof(ij), but should still produce the same tmpl + new Config("ij,i,j,iz->z", List.of(shape(20, 10),shape(20),shape(10),shape(20, 10))), + + new Config("ij,ij,ji,j,i, ab,ba,ab,a,b->jb", Map.of('i',10, 'j',5, 'a', 11, 'b', 6)), + new Config("ij,ij,ji,j,i,iz->jz", Map.of('i',10, 'j',10,'z', 10)), // // no skinny right + // takes 2 mins each due to R package being super slow so commented out: +// new Config("ij,ij,i,j,iz->jz", Map.of('i',100000, 'j',50,'z', 40), List.of(0.0000001,0.000002, 0.000003,0.000004,0.000004)), // skinny right with outer mm +// new Config("ij,ij,i,iz->jz", Map.of('i',100000, 'j',50,'z', 40), List.of(0.0000001,0.000002, 0.000003,0.000004)), // skinny right with outer mm +// new Config("ij,j,i,iz->zj", Map.of('i',100000, 'j',50,'z', 40), List.of(0.0000001,0.000002, 0.000003,0.000004)), // skinny right with outer mm + new Config("ab,ab,a,ag,gz->bz", List.of(shape(10, 5), shape(10, 5),shape(10),shape(10,200),shape(200,7))) + ,new Config("ab,ab,a,ag,gz->bz", List.of(shape(10, 5), shape(10, 5),shape(10),shape(10,20),shape(20,7))) + ,new Config("ab,ab,bc,bc->bc", List.of(shape(10, 5), shape(10, 5),shape(5,20),shape(5,20))) ); private final int id; private final String einsumStr; - //private final List shapes; private final File dmlFile; private final File rFile; private final boolean outputScalar; @@ -136,7 +147,6 @@ public class EinsumTest extends AutomatedTestBase public EinsumTest(String einsumStr, List shapes, File dmlFile, File rFile, boolean outputScalar, int id){ this.id = id; this.einsumStr = einsumStr; - //this.shapes = shapes; this.dmlFile = dmlFile; this.rFile = rFile; this.outputScalar = outputScalar; @@ -149,7 +159,6 @@ public static Collection data() throws IOException { int counter = 1; for (Config config : TEST_CONFIGS) { - //List files = new ArrayList<>(); String fullDMLScriptName = "SystemDS_einsum_test" + counter; File dmlFile = File.createTempFile(fullDMLScriptName, ".dml"); @@ -186,12 +195,12 @@ private static StringBuilder createDmlFile(Config config, boolean outputScalar) sb.append("A"); sb.append(i); - if (dims.length == 1) { // A1 = seq(1,100) * 0.0001 + if (dims.length == 1) { // e.g. A1 = seq(1,100) * 0.0001 sb.append(" = seq(1,"); sb.append(dims[0]); sb.append(") * "); sb.append(factor); - } else { // A0 = matrix(seq(1,5000), 100, 5) * 0.0001 + } else { // e.g. A0 = matrix(seq(1,5000), 100, 5) * 0.0001 sb.append(" = matrix(seq(1, "); sb.append(dims[0]*dims[1]); sb.append("), "); @@ -239,12 +248,12 @@ private static StringBuilder createRFile(Config config, boolean outputScalar) { sb.append("A"); sb.append(i); - if (dims.length == 1) { // A1 = seq(1,100) * 0.0001 + if (dims.length == 1) { // e.g. A1 = seq(1,100) * 0.0001 sb.append(" = seq(1,"); sb.append(dims[0]); sb.append(") * "); sb.append(factor); - } else { // A0 = matrix(seq(1,5000), 100, 5, byrow=TRUE) * 0.0001 + } else { // e.g. A0 = matrix(seq(1,5000), 100, 5, byrow=TRUE) * 0.0001 sb.append(" = matrix(seq(1, "); sb.append(dims[0]*dims[1]); sb.append("), "); @@ -284,7 +293,7 @@ private static StringBuilder createRFile(Config config, boolean outputScalar) { @Test public void testEinsumWithFiles() { System.out.println("Testing einsum: " + this.einsumStr); - testCodegenIntegration(TEST_NAME_EINSUM+this.id); + test(TEST_NAME_EINSUM+this.id); } @After public void cleanUp() { @@ -344,7 +353,7 @@ private static int[] shape(int... dims) { private static final String TEST_NAME_EINSUM = "einsum"; private static final String TEST_DIR = "functions/einsum/"; private static final String TEST_CLASS_DIR = TEST_DIR + EinsumTest.class.getSimpleName() + "/"; - private final static String TEST_CONF = "SystemDS-config-codegen.xml"; + private final static String TEST_CONF = "SystemDS-config-einsum.xml"; private final static File TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, TEST_CONF); private static double eps = Math.pow(10, -10); @@ -356,7 +365,7 @@ public void setUp() { addTestConfiguration( TEST_NAME_EINSUM+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_EINSUM+i, new String[] { String.valueOf(i) }) ); } - private void testCodegenIntegration( String testname) + private void test(String testname) { boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; ExecMode platformOld = setExecMode(ExecType.CP); @@ -376,19 +385,19 @@ private void testCodegenIntegration( String testname) OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = false; runTest(true, false, null, -1); -// if(true) throw new RuntimeException("aa"); runRScript(true); + HashMap dmlfile; + HashMap rfile; if(outputScalar){ - HashMap dmlfile = readDMLScalarFromOutputDir("S"); - HashMap rfile = readRScalarFromExpectedDir("S"); - TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + dmlfile = readDMLScalarFromOutputDir("S"); + rfile = readRScalarFromExpectedDir("S"); }else { //compare matrices - HashMap dmlfile = readDMLMatrixFromOutputDir("S"); - HashMap rfile = readRMatrixFromExpectedDir("S"); - TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + dmlfile = readDMLMatrixFromOutputDir("S"); + rfile = readRMatrixFromExpectedDir("S"); } + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); } finally { resetExecMode(platformOld); diff --git a/src/test/scripts/functions/einsum/SystemDS-config-codegen.xml b/src/test/scripts/functions/einsum/SystemDS-config-einsum.xml similarity index 83% rename from src/test/scripts/functions/einsum/SystemDS-config-codegen.xml rename to src/test/scripts/functions/einsum/SystemDS-config-einsum.xml index 626b31ebd76..f6640593c42 100644 --- a/src/test/scripts/functions/einsum/SystemDS-config-codegen.xml +++ b/src/test/scripts/functions/einsum/SystemDS-config-einsum.xml @@ -23,9 +23,6 @@ 2 true 1 - - - 16 - - auto - \ No newline at end of file + 16 + auto +