Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/main/java/org/apache/sysds/api/DMLScript.java
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.CoordinatorModel;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDHandler;
import org.apache.sysds.runtime.instructions.gpu.context.GPUContextPool;
import org.apache.sysds.runtime.instructions.ooc.OOCEvictionManager;
import org.apache.sysds.runtime.io.IOUtilFunctions;
import org.apache.sysds.runtime.lineage.LineageCacheConfig;
import org.apache.sysds.runtime.lineage.LineageCacheConfig.LineageCachePolicy;
Expand Down Expand Up @@ -497,6 +498,8 @@ private static void execute(String dmlScriptStr, String fnameOptConfig, Map<Stri
ScriptExecutorUtils.executeRuntimeProgram(rtprog, ec, ConfigurationManager.getDMLConfig(), STATISTICS ? STATISTICS_COUNT : 0, null);
}
finally {
//cleanup OOC streams and cache
OOCEvictionManager.reset();
//cleanup scratch_space and all working dirs
cleanupHadoopExecution(ConfigurationManager.getDMLConfig());
FederatedData.clearWorkGroup();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites)
//add static HOP DAG rewrite rules
_dagRuleSet.add( new RewriteRemoveReadAfterWrite() ); //dependency: before blocksize
_dagRuleSet.add( new RewriteBlockSizeAndReblock() );
_dagRuleSet.add( new RewriteInjectOOCTee() );
if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )
_dagRuleSet.add( new RewriteRemoveUnnecessaryCasts() );
if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )
Expand All @@ -94,6 +93,7 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites)
if( OptimizerUtils.ALLOW_QUANTIZE_COMPRESS_REWRITE )
_dagRuleSet.add( new RewriteQuantizationFusedCompression() );


//add statement block rewrite rules
if( OptimizerUtils.ALLOW_BRANCH_REMOVAL )
_sbRuleSet.add( new RewriteRemoveUnnecessaryBranches() ); //dependency: constant folding
Expand Down Expand Up @@ -152,6 +152,7 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites)
_dagRuleSet.add( new RewriteConstantFolding() ); //dependency: cse
_sbRuleSet.add( new RewriteRemoveEmptyBasicBlocks() );
_sbRuleSet.add( new RewriteRemoveEmptyForLoops() );
_sbRuleSet.add( new RewriteInjectOOCTee() );
}

/**
Expand Down
212 changes: 152 additions & 60 deletions src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.ReorgOp;
import org.apache.sysds.parser.StatementBlock;

import java.util.ArrayList;
import java.util.HashMap;
Expand All @@ -49,73 +50,20 @@
* 2. <b>Apply Rewrites (Modification):</b> Iterate over the collected candidate and put
* {@code TeeOp}, and safely rewire the graph.
*/
public class RewriteInjectOOCTee extends HopRewriteRule {
public class RewriteInjectOOCTee extends StatementBlockRewriteRule {

public static boolean APPLY_ONLY_XtX_PATTERN = false;

private static final Map<String, Integer> _transientVars = new HashMap<>();
private static final Map<String, List<Hop>> _transientHops = new HashMap<>();
private static final Set<String> teeTransientVars = new HashSet<>();

private static final Set<Long> rewrittenHops = new HashSet<>();
private static final Map<Long, Hop> handledHop = new HashMap<>();

// Maintain a list of candidates to rewrite in the second pass
private final List<Hop> rewriteCandidates = new ArrayList<>();

/**
* Handle a generic (last-level) hop DAG with multiple roots.
*
* @param roots high-level operator roots
* @param state program rewrite status
* @return list of high-level operators
*/
@Override
public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) {
if (roots == null) {
return null;
}

// Clear candidates for this pass
rewriteCandidates.clear();

// PASS 1: Identify candidates without modifying the graph
for (Hop root : roots) {
root.resetVisitStatus();
findRewriteCandidates(root);
}

// PASS 2: Apply rewrites to identified candidates
for (Hop candidate : rewriteCandidates) {
applyTopDownTeeRewrite(candidate);
}

return roots;
}

/**
* Handle a predicate hop DAG with exactly one root.
*
* @param root high-level operator root
* @param state program rewrite status
* @return high-level operator
*/
@Override
public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
if (root == null) {
return null;
}

// Clear candidates for this pass
rewriteCandidates.clear();

// PASS 1: Identify candidates without modifying the graph
root.resetVisitStatus();
findRewriteCandidates(root);

// PASS 2: Apply rewrites to identified candidates
for (Hop candidate : rewriteCandidates) {
applyTopDownTeeRewrite(candidate);
}

return root;
}
private boolean forceTee = false;

/**
* First pass: Find candidates for rewrite without modifying the graph.
Expand All @@ -137,6 +85,35 @@ private void findRewriteCandidates(Hop hop) {
findRewriteCandidates(input);
}

boolean isRewriteCandidate = DMLScript.USE_OOC
&& hop.getDataType().isMatrix()
&& !HopRewriteUtils.isData(hop, OpOpData.TEE)
&& hop.getParent().size() > 1
&& (!APPLY_ONLY_XtX_PATTERN || isSelfTranposePattern(hop));

if (HopRewriteUtils.isData(hop, OpOpData.TRANSIENTREAD) && hop.getDataType().isMatrix()) {
_transientVars.compute(hop.getName(), (key, ctr) -> {
int incr = (isRewriteCandidate || forceTee) ? 2 : 1;

int ret = ctr == null ? 0 : ctr;
ret += incr;

if (ret > 1)
teeTransientVars.add(hop.getName());

return ret;
});

_transientHops.compute(hop.getName(), (key, hops) -> {
if (hops == null)
return new ArrayList<>(List.of(hop));
hops.add(hop);
return hops;
});

return; // We do not tee transient reads but rather inject before TWrite or PRead as caching stream
}

// Check if this hop is a candidate for OOC Tee injection
if (DMLScript.USE_OOC
&& hop.getDataType().isMatrix()
Expand All @@ -160,11 +137,17 @@ private void applyTopDownTeeRewrite(Hop sharedInput) {
return;
}

int consumerCount = sharedInput.getParent().size();
if (LOG.isDebugEnabled()) {
LOG.debug("Inject tee for hop " + sharedInput.getHopID() + " ("
+ sharedInput.getName() + "), consumers=" + consumerCount);
}

// Take a defensive copy of consumers before modifying the graph
ArrayList<Hop> consumers = new ArrayList<>(sharedInput.getParent());

// Create the new TeeOp with the original hop as input
DataOp teeOp = new DataOp("tee_out_" + sharedInput.getName(),
DataOp teeOp = new DataOp("tee_out_" + sharedInput.getName(),
sharedInput.getDataType(), sharedInput.getValueType(), Types.OpOpData.TEE, null,
sharedInput.getDim1(), sharedInput.getDim2(), sharedInput.getNnz(), sharedInput.getBlocksize());
HopRewriteUtils.addChildReference(teeOp, sharedInput);
Expand All @@ -177,6 +160,11 @@ private void applyTopDownTeeRewrite(Hop sharedInput) {
// Record that we've handled this hop
handledHop.put(sharedInput.getHopID(), teeOp);
rewrittenHops.add(sharedInput.getHopID());

if (LOG.isDebugEnabled()) {
LOG.debug("Created tee hop " + teeOp.getHopID() + " -> "
+ teeOp.getName());
}
}

@SuppressWarnings("unused")
Expand All @@ -196,4 +184,108 @@ else if (HopRewriteUtils.isMatrixMultiply(parent)) {
}
return hasTransposeConsumer && hasMatrixMultiplyConsumer;
}

@Override
public boolean createsSplitDag() {
return false;
}

@Override
public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) {
if (!DMLScript.USE_OOC)
return List.of(sb);

rewriteSB(sb, state);

for (String tVar : teeTransientVars) {
List<Hop> tHops = _transientHops.get(tVar);

if (tHops == null)
continue;

for (Hop affectedHops : tHops) {
applyTopDownTeeRewrite(affectedHops);
}

tHops.clear();
}

removeRedundantTeeChains(sb);

return List.of(sb);
}

@Override
public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus state) {
if (!DMLScript.USE_OOC)
return sbs;

for (StatementBlock sb : sbs)
rewriteSB(sb, state);

for (String tVar : teeTransientVars) {
List<Hop> tHops = _transientHops.get(tVar);

if (tHops == null)
continue;

for (Hop affectedHops : tHops) {
applyTopDownTeeRewrite(affectedHops);
}
}

for (StatementBlock sb : sbs)
removeRedundantTeeChains(sb);

return sbs;
}

private void rewriteSB(StatementBlock sb, ProgramRewriteStatus state) {
rewriteCandidates.clear();

if (sb.getHops() != null) {
for(Hop hop : sb.getHops()) {
hop.resetVisitStatus();
findRewriteCandidates(hop);
}
}

for (Hop candidate : rewriteCandidates) {
applyTopDownTeeRewrite(candidate);
}
}

private void removeRedundantTeeChains(StatementBlock sb) {
if (sb == null || sb.getHops() == null)
return;

Hop.resetVisitStatus(sb.getHops());
for (Hop hop : sb.getHops())
removeRedundantTeeChains(hop);
Hop.resetVisitStatus(sb.getHops());
}

private void removeRedundantTeeChains(Hop hop) {
if (hop.isVisited())
return;

ArrayList<Hop> inputs = new ArrayList<>(hop.getInput());
for (Hop in : inputs)
removeRedundantTeeChains(in);

if (HopRewriteUtils.isData(hop, OpOpData.TEE) && hop.getInput().size() == 1) {
Hop teeInput = hop.getInput().get(0);
if (HopRewriteUtils.isData(teeInput, OpOpData.TEE)) {
if (LOG.isDebugEnabled()) {
LOG.debug("Remove redundant tee hop " + hop.getHopID()
+ " (" + hop.getName() + ") -> " + teeInput.getHopID()
+ " (" + teeInput.getName() + ")");
}
HopRewriteUtils.rewireAllParentChildReferences(hop, teeInput);
HopRewriteUtils.removeAllChildReferences(hop);
}
}

hop.setVisited();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -471,12 +471,12 @@ public boolean hasBroadcastHandle() {
return _bcHandle != null && _bcHandle.hasBackReference();
}

public OOCStream<IndexedMatrixValue> getStreamHandle() {
public synchronized OOCStream<IndexedMatrixValue> getStreamHandle() {
if( !hasStreamHandle() ) {
final SubscribableTaskQueue<IndexedMatrixValue> _mStream = new SubscribableTaskQueue<>();
_streamHandle = _mStream;
DataCharacteristics dc = getDataCharacteristics();
MatrixBlock src = (MatrixBlock)acquireReadAndRelease();
_streamHandle = _mStream;
LongStream.range(0, dc.getNumBlocks())
.mapToObj(i -> UtilFunctions.createIndexedMatrixBlock(src, dc, i))
.forEach( blk -> {
Expand All @@ -489,7 +489,14 @@ public OOCStream<IndexedMatrixValue> getStreamHandle() {
_mStream.closeInput();
}

return _streamHandle.getReadStream();
OOCStream<IndexedMatrixValue> stream = _streamHandle.getReadStream();
if (!stream.hasStreamCache())
_streamHandle = null; // To ensure read once
return stream;
}

public OOCStreamable<IndexedMatrixValue> getStreamable() {
return _streamHandle;
}

/**
Expand All @@ -499,7 +506,7 @@ public OOCStream<IndexedMatrixValue> getStreamHandle() {
* @return true if existing, false otherwise
*/
public boolean hasStreamHandle() {
return _streamHandle != null && !_streamHandle.isProcessed();
return _streamHandle != null;
}

@SuppressWarnings({ "rawtypes", "unchecked" })
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public class LocalTaskQueue<T>

protected LinkedList<T> _data = null;
protected boolean _closedInput = false;
private DMLRuntimeException _failure = null;
protected DMLRuntimeException _failure = null;
private static final Log LOG = LogFactory.getLog(LocalTaskQueue.class.getName());

public LocalTaskQueue()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.ooc.CachingStream;
import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction;
import org.apache.sysds.runtime.io.FileFormatProperties;
import org.apache.sysds.runtime.io.FileFormatPropertiesCSV;
import org.apache.sysds.runtime.io.FileFormatPropertiesHDF5;
Expand Down Expand Up @@ -1026,6 +1028,9 @@ private void processCopyInstruction(ExecutionContext ec) {
if ( dd == null )
throw new DMLRuntimeException("Unexpected error: could not find a data object for variable name:" + getInput1().getName() + ", while processing instruction " +this.toString());

if (DMLScript.USE_OOC && dd instanceof MatrixObject)
TeeOOCInstruction.incrRef(((MatrixObject)dd).getStreamable(), 1);

// remove existing variable bound to target name
Data input2_data = ec.removeVariable(getInput2().getName());

Expand Down Expand Up @@ -1117,6 +1122,8 @@ private void processSetFileNameInstruction(ExecutionContext ec){
public static void processRmvarInstruction( ExecutionContext ec, String varname ) {
// remove variable from symbol table
Data dat = ec.removeVariable(varname);
if (DMLScript.USE_OOC && dat instanceof MatrixObject)
TeeOOCInstruction.incrRef(((MatrixObject) dat).getStreamable(), -1);
//cleanup matrix data on fs/hdfs (if necessary)
if( dat != null )
ec.cleanupDataObject(dat);
Expand Down
Loading
Loading