diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java b/src/main/java/org/apache/sysds/api/DMLScript.java index 65805b5c2ed..748a0c43ac0 100644 --- a/src/main/java/org/apache/sysds/api/DMLScript.java +++ b/src/main/java/org/apache/sysds/api/DMLScript.java @@ -71,6 +71,7 @@ import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.CoordinatorModel; import org.apache.sysds.runtime.controlprogram.parfor.util.IDHandler; import org.apache.sysds.runtime.instructions.gpu.context.GPUContextPool; +import org.apache.sysds.runtime.instructions.ooc.OOCEvictionManager; import org.apache.sysds.runtime.io.IOUtilFunctions; import org.apache.sysds.runtime.lineage.LineageCacheConfig; import org.apache.sysds.runtime.lineage.LineageCacheConfig.LineageCachePolicy; @@ -497,6 +498,8 @@ private static void execute(String dmlScriptStr, String fnameOptConfig, MapApply Rewrites (Modification): Iterate over the collected candidate and put * {@code TeeOp}, and safely rewire the graph. */ -public class RewriteInjectOOCTee extends HopRewriteRule { +public class RewriteInjectOOCTee extends StatementBlockRewriteRule { public static boolean APPLY_ONLY_XtX_PATTERN = false; + + private static final Map _transientVars = new HashMap<>(); + private static final Map> _transientHops = new HashMap<>(); + private static final Set teeTransientVars = new HashSet<>(); private static final Set rewrittenHops = new HashSet<>(); private static final Map handledHop = new HashMap<>(); // Maintain a list of candidates to rewrite in the second pass private final List rewriteCandidates = new ArrayList<>(); - - /** - * Handle a generic (last-level) hop DAG with multiple roots. - * - * @param roots high-level operator roots - * @param state program rewrite status - * @return list of high-level operators - */ - @Override - public ArrayList rewriteHopDAGs(ArrayList roots, ProgramRewriteStatus state) { - if (roots == null) { - return null; - } - - // Clear candidates for this pass - rewriteCandidates.clear(); - - // PASS 1: Identify candidates without modifying the graph - for (Hop root : roots) { - root.resetVisitStatus(); - findRewriteCandidates(root); - } - - // PASS 2: Apply rewrites to identified candidates - for (Hop candidate : rewriteCandidates) { - applyTopDownTeeRewrite(candidate); - } - - return roots; - } - - /** - * Handle a predicate hop DAG with exactly one root. - * - * @param root high-level operator root - * @param state program rewrite status - * @return high-level operator - */ - @Override - public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) { - if (root == null) { - return null; - } - - // Clear candidates for this pass - rewriteCandidates.clear(); - - // PASS 1: Identify candidates without modifying the graph - root.resetVisitStatus(); - findRewriteCandidates(root); - - // PASS 2: Apply rewrites to identified candidates - for (Hop candidate : rewriteCandidates) { - applyTopDownTeeRewrite(candidate); - } - - return root; - } + private boolean forceTee = false; /** * First pass: Find candidates for rewrite without modifying the graph. @@ -137,6 +85,35 @@ private void findRewriteCandidates(Hop hop) { findRewriteCandidates(input); } + boolean isRewriteCandidate = DMLScript.USE_OOC + && hop.getDataType().isMatrix() + && !HopRewriteUtils.isData(hop, OpOpData.TEE) + && hop.getParent().size() > 1 + && (!APPLY_ONLY_XtX_PATTERN || isSelfTranposePattern(hop)); + + if (HopRewriteUtils.isData(hop, OpOpData.TRANSIENTREAD) && hop.getDataType().isMatrix()) { + _transientVars.compute(hop.getName(), (key, ctr) -> { + int incr = (isRewriteCandidate || forceTee) ? 2 : 1; + + int ret = ctr == null ? 0 : ctr; + ret += incr; + + if (ret > 1) + teeTransientVars.add(hop.getName()); + + return ret; + }); + + _transientHops.compute(hop.getName(), (key, hops) -> { + if (hops == null) + return new ArrayList<>(List.of(hop)); + hops.add(hop); + return hops; + }); + + return; // We do not tee transient reads but rather inject before TWrite or PRead as caching stream + } + // Check if this hop is a candidate for OOC Tee injection if (DMLScript.USE_OOC && hop.getDataType().isMatrix() @@ -160,11 +137,17 @@ private void applyTopDownTeeRewrite(Hop sharedInput) { return; } + int consumerCount = sharedInput.getParent().size(); + if (LOG.isDebugEnabled()) { + LOG.debug("Inject tee for hop " + sharedInput.getHopID() + " (" + + sharedInput.getName() + "), consumers=" + consumerCount); + } + // Take a defensive copy of consumers before modifying the graph ArrayList consumers = new ArrayList<>(sharedInput.getParent()); // Create the new TeeOp with the original hop as input - DataOp teeOp = new DataOp("tee_out_" + sharedInput.getName(), + DataOp teeOp = new DataOp("tee_out_" + sharedInput.getName(), sharedInput.getDataType(), sharedInput.getValueType(), Types.OpOpData.TEE, null, sharedInput.getDim1(), sharedInput.getDim2(), sharedInput.getNnz(), sharedInput.getBlocksize()); HopRewriteUtils.addChildReference(teeOp, sharedInput); @@ -177,6 +160,11 @@ private void applyTopDownTeeRewrite(Hop sharedInput) { // Record that we've handled this hop handledHop.put(sharedInput.getHopID(), teeOp); rewrittenHops.add(sharedInput.getHopID()); + + if (LOG.isDebugEnabled()) { + LOG.debug("Created tee hop " + teeOp.getHopID() + " -> " + + teeOp.getName()); + } } @SuppressWarnings("unused") @@ -196,4 +184,108 @@ else if (HopRewriteUtils.isMatrixMultiply(parent)) { } return hasTransposeConsumer && hasMatrixMultiplyConsumer; } + + @Override + public boolean createsSplitDag() { + return false; + } + + @Override + public List rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) { + if (!DMLScript.USE_OOC) + return List.of(sb); + + rewriteSB(sb, state); + + for (String tVar : teeTransientVars) { + List tHops = _transientHops.get(tVar); + + if (tHops == null) + continue; + + for (Hop affectedHops : tHops) { + applyTopDownTeeRewrite(affectedHops); + } + + tHops.clear(); + } + + removeRedundantTeeChains(sb); + + return List.of(sb); + } + + @Override + public List rewriteStatementBlocks(List sbs, ProgramRewriteStatus state) { + if (!DMLScript.USE_OOC) + return sbs; + + for (StatementBlock sb : sbs) + rewriteSB(sb, state); + + for (String tVar : teeTransientVars) { + List tHops = _transientHops.get(tVar); + + if (tHops == null) + continue; + + for (Hop affectedHops : tHops) { + applyTopDownTeeRewrite(affectedHops); + } + } + + for (StatementBlock sb : sbs) + removeRedundantTeeChains(sb); + + return sbs; + } + + private void rewriteSB(StatementBlock sb, ProgramRewriteStatus state) { + rewriteCandidates.clear(); + + if (sb.getHops() != null) { + for(Hop hop : sb.getHops()) { + hop.resetVisitStatus(); + findRewriteCandidates(hop); + } + } + + for (Hop candidate : rewriteCandidates) { + applyTopDownTeeRewrite(candidate); + } + } + + private void removeRedundantTeeChains(StatementBlock sb) { + if (sb == null || sb.getHops() == null) + return; + + Hop.resetVisitStatus(sb.getHops()); + for (Hop hop : sb.getHops()) + removeRedundantTeeChains(hop); + Hop.resetVisitStatus(sb.getHops()); + } + + private void removeRedundantTeeChains(Hop hop) { + if (hop.isVisited()) + return; + + ArrayList inputs = new ArrayList<>(hop.getInput()); + for (Hop in : inputs) + removeRedundantTeeChains(in); + + if (HopRewriteUtils.isData(hop, OpOpData.TEE) && hop.getInput().size() == 1) { + Hop teeInput = hop.getInput().get(0); + if (HopRewriteUtils.isData(teeInput, OpOpData.TEE)) { + if (LOG.isDebugEnabled()) { + LOG.debug("Remove redundant tee hop " + hop.getHopID() + + " (" + hop.getName() + ") -> " + teeInput.getHopID() + + " (" + teeInput.getName() + ")"); + } + HopRewriteUtils.rewireAllParentChildReferences(hop, teeInput); + HopRewriteUtils.removeAllChildReferences(hop); + } + } + + hop.setVisited(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java index 34a8aa18631..d826af89c0e 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java @@ -471,12 +471,12 @@ public boolean hasBroadcastHandle() { return _bcHandle != null && _bcHandle.hasBackReference(); } - public OOCStream getStreamHandle() { + public synchronized OOCStream getStreamHandle() { if( !hasStreamHandle() ) { final SubscribableTaskQueue _mStream = new SubscribableTaskQueue<>(); - _streamHandle = _mStream; DataCharacteristics dc = getDataCharacteristics(); MatrixBlock src = (MatrixBlock)acquireReadAndRelease(); + _streamHandle = _mStream; LongStream.range(0, dc.getNumBlocks()) .mapToObj(i -> UtilFunctions.createIndexedMatrixBlock(src, dc, i)) .forEach( blk -> { @@ -489,7 +489,14 @@ public OOCStream getStreamHandle() { _mStream.closeInput(); } - return _streamHandle.getReadStream(); + OOCStream stream = _streamHandle.getReadStream(); + if (!stream.hasStreamCache()) + _streamHandle = null; // To ensure read once + return stream; + } + + public OOCStreamable getStreamable() { + return _streamHandle; } /** @@ -499,7 +506,7 @@ public OOCStream getStreamHandle() { * @return true if existing, false otherwise */ public boolean hasStreamHandle() { - return _streamHandle != null && !_streamHandle.isProcessed(); + return _streamHandle != null; } @SuppressWarnings({ "rawtypes", "unchecked" }) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java index 783981e0f12..50143cd0ad7 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java @@ -45,7 +45,7 @@ public class LocalTaskQueue protected LinkedList _data = null; protected boolean _closedInput = false; - private DMLRuntimeException _failure = null; + protected DMLRuntimeException _failure = null; private static final Log LOG = LogFactory.getLog(LocalTaskQueue.class.getName()); public LocalTaskQueue() diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java index 5dd8e55e821..83421bf5d82 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java @@ -46,6 +46,8 @@ import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.instructions.Instruction; import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.ooc.CachingStream; +import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction; import org.apache.sysds.runtime.io.FileFormatProperties; import org.apache.sysds.runtime.io.FileFormatPropertiesCSV; import org.apache.sysds.runtime.io.FileFormatPropertiesHDF5; @@ -1026,6 +1028,9 @@ private void processCopyInstruction(ExecutionContext ec) { if ( dd == null ) throw new DMLRuntimeException("Unexpected error: could not find a data object for variable name:" + getInput1().getName() + ", while processing instruction " +this.toString()); + if (DMLScript.USE_OOC && dd instanceof MatrixObject) + TeeOOCInstruction.incrRef(((MatrixObject)dd).getStreamable(), 1); + // remove existing variable bound to target name Data input2_data = ec.removeVariable(getInput2().getName()); @@ -1117,6 +1122,8 @@ private void processSetFileNameInstruction(ExecutionContext ec){ public static void processRmvarInstruction( ExecutionContext ec, String varname ) { // remove variable from symbol table Data dat = ec.removeVariable(varname); + if (DMLScript.USE_OOC && dat instanceof MatrixObject) + TeeOOCInstruction.incrRef(((MatrixObject) dat).getStreamable(), -1); //cleanup matrix data on fs/hdfs (if necessary) if( dat != null ) ec.cleanupDataObject(dat); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java index d7c80e4de3c..cdc23911516 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java @@ -24,6 +24,7 @@ import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.matrix.data.MatrixIndexes; +import shaded.parquet.it.unimi.dsi.fastutil.ints.IntArrayList; import java.util.HashMap; import java.util.Map; @@ -39,6 +40,7 @@ public class CachingStream implements OOCStreamable { // original live stream private final OOCStream _source; + private final IntArrayList _consumptionCounts = new IntArrayList(); // stream identifier private final long _streamId; @@ -54,6 +56,10 @@ public class CachingStream implements OOCStreamable { private DMLRuntimeException _failure; + private boolean deletable = false; + private int maxConsumptionCount = 0; + private int cachePins = 0; + public CachingStream(OOCStream source) { this(source, _streamSeq.getNextID()); } @@ -61,23 +67,43 @@ public CachingStream(OOCStream source) { public CachingStream(OOCStream source, long streamId) { _source = source; _streamId = streamId; - source.setSubscriber(() -> { + source.setSubscriber(tmp -> { try { - boolean closed = fetchFromStream(); - Runnable[] mSubscribers = _subscribers; + final IndexedMatrixValue task = tmp.get(); + int blk; + Runnable[] mSubscribers; + + synchronized (this) { + if(task != LocalTaskQueue.NO_MORE_TASKS) { + if (!_cacheInProgress) + throw new DMLRuntimeException("Stream is closed"); + OOCEvictionManager.put(_streamId, _numBlocks, task); + if (_index != null) + _index.put(task.getIndexes(), _numBlocks); + blk = _numBlocks; + _numBlocks++; + _consumptionCounts.add(0); + notifyAll(); + } + else { + _cacheInProgress = false; // caching is complete + notifyAll(); + blk = -1; + } + + mSubscribers = _subscribers; + } if(mSubscribers != null) { for(Runnable mSubscriber : mSubscribers) mSubscriber.run(); - if (closed) { + if (blk == -1) { synchronized (this) { _subscribers = null; } } } - } catch (InterruptedException e) { - throw new DMLRuntimeException(e); } catch (DMLRuntimeException e) { // Propagate failure to subscribers _failure = e; @@ -98,25 +124,28 @@ public CachingStream(OOCStream source, long streamId) { }); } - private synchronized boolean fetchFromStream() throws InterruptedException { - if(!_cacheInProgress) - throw new DMLRuntimeException("Stream is closed"); + public synchronized void scheduleDeletion() { + deletable = true; + if (_cacheInProgress && maxConsumptionCount == 0) + throw new DMLRuntimeException("Cannot have a caching stream with no listeners"); + for (int i = 0; i < _consumptionCounts.size(); i++) { + tryDeleteBlock(i); + } + } - IndexedMatrixValue task = _source.dequeue(); + public String toString() { + return "CachingStream@" + _streamId; + } - if(task != LocalTaskQueue.NO_MORE_TASKS) { - OOCEvictionManager.put(_streamId, _numBlocks, task); - if (_index != null) - _index.put(task.getIndexes(), _numBlocks); - _numBlocks++; - notifyAll(); - return false; - } - else { - _cacheInProgress = false; // caching is complete - notifyAll(); - return true; - } + private synchronized void tryDeleteBlock(int i) { + if (cachePins > 0) + return; // Block deletion is prevented + + int count = _consumptionCounts.getInt(i); + if (count > maxConsumptionCount) + throw new DMLRuntimeException("Cannot have more than " + maxConsumptionCount + " consumptions."); + if (count == maxConsumptionCount) + OOCEvictionManager.forget(_streamId, i); } public synchronized IndexedMatrixValue get(int idx) throws InterruptedException { @@ -129,6 +158,16 @@ else if (idx < _numBlocks) { if (_index != null) // Ensure index is up to date _index.putIfAbsent(out.getIndexes(), idx); + int newCount = _consumptionCounts.getInt(idx)+1; + + if (newCount > maxConsumptionCount) + throw new DMLRuntimeException("Consumer overflow! Expected: " + maxConsumptionCount); + + _consumptionCounts.set(idx, newCount); + + if (deletable) + tryDeleteBlock(idx); + return out; } else if (!_cacheInProgress) return (IndexedMatrixValue)LocalTaskQueue.NO_MORE_TASKS; @@ -137,8 +176,31 @@ else if (idx < _numBlocks) { } } + public synchronized int findCachedIndex(MatrixIndexes idx) { + return _index.get(idx); + } + public synchronized IndexedMatrixValue findCached(MatrixIndexes idx) { - return OOCEvictionManager.get(_streamId, _index.get(idx)); + int mIdx = _index.get(idx); + int newCount = _consumptionCounts.getInt(mIdx)+1; + if (newCount > maxConsumptionCount) + throw new DMLRuntimeException("Consumer overflow in " + _streamId + "_" + mIdx + ". Expected: " + maxConsumptionCount); + _consumptionCounts.set(mIdx, newCount); + + IndexedMatrixValue imv = OOCEvictionManager.get(_streamId, mIdx); + + if (deletable) + tryDeleteBlock(mIdx); + + return imv; + } + + /** + * Finds a cached item without counting it as a consumption. + */ + public synchronized IndexedMatrixValue peekCached(MatrixIndexes idx) { + int mIdx = _index.get(idx); + return OOCEvictionManager.get(_streamId, mIdx); } public synchronized void activateIndexing() { @@ -161,12 +223,18 @@ public boolean isProcessed() { return false; } - @Override - public void setSubscriber(Runnable subscriber) { + public void setSubscriber(Runnable subscriber, boolean incrConsumers) { + if (deletable) + throw new DMLRuntimeException("Cannot register a new subscriber on " + this + " because has been flagged for deletion"); + int mNumBlocks; + boolean cacheInProgress; synchronized (this) { mNumBlocks = _numBlocks; - if (_cacheInProgress) { + cacheInProgress = _cacheInProgress; + if (incrConsumers) + maxConsumptionCount++; + if (cacheInProgress) { int newLen = _subscribers == null ? 1 : _subscribers.length + 1; Runnable[] newSubscribers = new Runnable[newLen]; @@ -181,7 +249,44 @@ public void setSubscriber(Runnable subscriber) { for (int i = 0; i < mNumBlocks; i++) subscriber.run(); - if (!_cacheInProgress) + if (!cacheInProgress) subscriber.run(); // To fetch the NO_MORE_TASK element } + + /** + * Artificially increase subscriber count. + * Only use if certain blocks are accessed more than once. + */ + public synchronized void incrSubscriberCount(int count) { + maxConsumptionCount += count; + } + + /** + * Artificially increase the processing count of a block. + */ + public synchronized void incrProcessingCount(int i, int count) { + _consumptionCounts.set(i, _consumptionCounts.getInt(i)+count); + + if (deletable) + tryDeleteBlock(i); + } + + /** + * Force pins blocks in the cache to not be subject to block deletion. + */ + public synchronized void pinStream() { + cachePins++; + } + + /** + * Unpins the stream, allowing blocks to be deleted from cache. + */ + public synchronized void unpinStream() { + cachePins--; + + if (cachePins == 0) { + for (int i = 0; i < _consumptionCounts.size(); i++) + tryDeleteBlock(i); + } + } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/IndexingOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/IndexingOOCInstruction.java index 1d555da8d6c..175d81d6e06 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/IndexingOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/IndexingOOCInstruction.java @@ -115,6 +115,34 @@ public boolean isAligned() { return (_indexRange.rowStart % _blocksize) == 0 && (_indexRange.colStart % _blocksize) == 0; } + public int getNumConsumptions(MatrixIndexes index) { + long blockRow = index.getRowIndex() - 1; + long blockCol = index.getColumnIndex() - 1; + + if(!_blockRange.isWithin(blockRow, blockCol)) + return 0; + + long blockRowStart = blockRow * _blocksize; + long blockRowEnd = blockRowStart + _blocksize - 1; + long blockColStart = blockCol * _blocksize; + long blockColEnd = blockColStart + _blocksize - 1; + + long overlapRowStart = Math.max(_indexRange.rowStart, blockRowStart); + long overlapRowEnd = Math.min(_indexRange.rowEnd, blockRowEnd); + long overlapColStart = Math.max(_indexRange.colStart, blockColStart); + long overlapColEnd = Math.min(_indexRange.colEnd, blockColEnd); + + if(overlapRowStart > overlapRowEnd || overlapColStart > overlapColEnd) + return 0; + + int outRowStart = (int) ((overlapRowStart - _indexRange.rowStart) / _blocksize); + int outRowEnd = (int) ((overlapRowEnd - _indexRange.rowStart) / _blocksize); + int outColStart = (int) ((overlapColStart - _indexRange.colStart) / _blocksize); + int outColEnd = (int) ((overlapColEnd - _indexRange.colStart) / _blocksize); + + return (outRowEnd - outRowStart + 1) * (outColEnd - outColStart + 1); + } + public boolean putNext(MatrixIndexes index, T data, BiConsumer> emitter) { long blockRow = index.getRowIndex() - 1; long blockCol = index.getColumnIndex() - 1; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixIndexingOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixIndexingOOCInstruction.java index a04a77677cd..33c6675051e 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixIndexingOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixIndexingOOCInstruction.java @@ -33,6 +33,7 @@ import org.apache.sysds.runtime.util.IndexRange; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; @@ -43,10 +44,10 @@ public MatrixIndexingOOCInstruction(CPOperand in, CPOperand rl, CPOperand ru, CP super(in, rl, ru, cl, cu, out, opcode, istr); } - protected MatrixIndexingOOCInstruction(CPOperand lhsInput, CPOperand rhsInput, CPOperand rl, CPOperand ru, - CPOperand cl, CPOperand cu, CPOperand out, String opcode, String istr) { - super(lhsInput, rhsInput, rl, ru, cl, cu, out, opcode, istr); - } +// protected MatrixIndexingOOCInstruction(CPOperand lhsInput, CPOperand rhsInput, CPOperand rl, CPOperand ru, +// CPOperand cl, CPOperand cu, CPOperand out, String opcode, String istr) { +// super(lhsInput, rhsInput, rl, ru, cl, cu, out, opcode, istr); +// } @Override public void processInstruction(ExecutionContext ec) { @@ -96,8 +97,9 @@ public void processInstruction(ExecutionContext ec) { final int outBlockCols = (int) Math.ceil((double) (ix.colSpan() + 1) / blocksize); final int totalBlocks = outBlockRows * outBlockCols; final AtomicInteger producedBlocks = new AtomicInteger(0); + CompletableFuture future = new CompletableFuture<>(); - CompletableFuture future = filterOOC(qIn, tmp -> { + filterOOC(qIn, tmp -> { MatrixIndexes inIdx = tmp.getIndexes(); long blockRow = inIdx.getRowIndex() - 1; long blockCol = inIdx.getColumnIndex() - 1; @@ -124,12 +126,12 @@ public void processInstruction(ExecutionContext ec) { long outBlockCol = blockCol - firstBlockCol + 1; qOut.enqueue(new IndexedMatrixValue(new MatrixIndexes(outBlockRow, outBlockCol), outBlock)); - if(producedBlocks.incrementAndGet() >= totalBlocks) { - CompletableFuture f = futureRef.get(); - if(f != null) - f.cancel(true); - } + if(producedBlocks.incrementAndGet() >= totalBlocks) + future.complete(null); }, tmp -> { + if (future.isDone()) // Then we may skip blocks and avoid submitting tasks + return false; + long blockRow = tmp.getIndexes().getRowIndex() - 1; long blockCol = tmp.getIndexes().getColumnIndex() - 1; return blockRow >= firstBlockRow && blockRow <= lastBlockRow && blockCol >= firstBlockCol && @@ -139,20 +141,23 @@ public void processInstruction(ExecutionContext ec) { return; } - final BlockAligner aligner = new BlockAligner<>(ix, blocksize); + final BlockAligner aligner = new BlockAligner<>(ix, blocksize); + final ConcurrentHashMap consumptionCounts = new ConcurrentHashMap<>(); // We may need to construct our own intermediate stream to properly manage the cached items boolean hasIntermediateStream = !qIn.hasStreamCache(); final CachingStream cachedStream = hasIntermediateStream ? new CachingStream(new SubscribableTaskQueue<>()) : qOut.getStreamCache(); cachedStream.activateIndexing(); + cachedStream.incrSubscriberCount(1); // We may require re-consumption of blocks (up to 4 times) + final CompletableFuture future = new CompletableFuture<>(); - CompletableFuture future = filterOOC(qIn.getReadStream(), tmp -> { + filterOOC(qIn.getReadStream(), tmp -> { if (hasIntermediateStream) { // We write to an intermediate stream to ensure that these matrix blocks are properly cached cachedStream.getWriteStream().enqueue(tmp); } - boolean completed = aligner.putNext(tmp.getIndexes(), new IndexedBlockMeta(tmp), (idx, sector) -> { + boolean completed = aligner.putNext(tmp.getIndexes(), tmp.getIndexes(), (idx, sector) -> { int targetBlockRow = (int) (idx.getRowIndex() - 1); int targetBlockCol = (int) (idx.getColumnIndex() - 1); @@ -176,18 +181,18 @@ public void processInstruction(ExecutionContext ec) { for(int r = 0; r < rowSegments; r++) { for(int c = 0; c < colSegments; c++) { - IndexedBlockMeta ibm = sector.get(r, c); - if(ibm == null) + MatrixIndexes mIdx = sector.get(r, c); + if(mIdx == null) continue; - IndexedMatrixValue mv = cachedStream.findCached(ibm.idx); + IndexedMatrixValue mv = cachedStream.peekCached(mIdx); MatrixBlock srcBlock = (MatrixBlock) mv.getValue(); if(target == null) target = new MatrixBlock(nRows, nCols, srcBlock.isInSparseFormat()); - long srcBlockRowStart = (ibm.idx.getRowIndex() - 1) * blocksize; - long srcBlockColStart = (ibm.idx.getColumnIndex() - 1) * blocksize; + long srcBlockRowStart = (mIdx.getRowIndex() - 1) * blocksize; + long srcBlockColStart = (mIdx.getColumnIndex() - 1) * blocksize; long sliceRowStartGlobal = Math.max(targetRowStartGlobal, srcBlockRowStart); long sliceRowEndGlobal = Math.min(targetRowEndGlobal, srcBlockRowStart + srcBlock.getNumRows() - 1); @@ -205,21 +210,31 @@ public void processInstruction(ExecutionContext ec) { MatrixBlock sliced = srcBlock.slice(sliceRowStart, sliceRowEnd, sliceColStart, sliceColEnd); sliced.putInto(target, targetRowOffset, targetColOffset, true); + final int maxConsumptions = aligner.getNumConsumptions(mIdx); + + Integer con = consumptionCounts.compute(mIdx, (k, v) -> { + if (v == null) + v = 0; + v = v+1; + if (v == maxConsumptions) + return null; + return v; + }); + + if (con == null) + cachedStream.incrProcessingCount(cachedStream.findCachedIndex(mIdx), 1); } } qOut.enqueue(new IndexedMatrixValue(idx, target)); }); - if(completed) { - // All blocks have been processed; we can cancel the future - // Currently, this does not affect processing (predicates prevent task submission anyway). - // However, a cancelled future may allow early file read aborts once implemented. - CompletableFuture f = futureRef.get(); - if(f != null) - f.cancel(true); - } + if(completed) + future.complete(null); }, tmp -> { + if (future.isDone()) // Then we may skip blocks and avoid submitting tasks + return false; + // Pre-filter incoming blocks to avoid unnecessary task submission long blockRow = tmp.getIndexes().getRowIndex() - 1; long blockCol = tmp.getIndexes().getColumnIndex() - 1; @@ -228,8 +243,15 @@ public void processInstruction(ExecutionContext ec) { }, () -> { aligner.close(); qOut.closeInput(); + }, tmp -> { + // If elements are not processed in an existing caching stream, we increment the process counter to allow block deletion + if (!hasIntermediateStream) + cachedStream.incrProcessingCount(cachedStream.findCachedIndex(tmp.getIndexes()), 1); }); futureRef.set(future); + + if (hasIntermediateStream) + cachedStream.scheduleDeletion(); // We can immediately delete blocks after consumption } //left indexing else if(opcode.equalsIgnoreCase(Opcodes.LEFT_INDEX.toString())) { @@ -239,16 +261,4 @@ else if(opcode.equalsIgnoreCase(Opcodes.LEFT_INDEX.toString())) { throw new DMLRuntimeException( "Invalid opcode (" + opcode + ") encountered in MatrixIndexingOOCInstruction."); } - - private static class IndexedBlockMeta { - public final MatrixIndexes idx; - ////public final long nrows; - //public final long ncols; - - public IndexedBlockMeta(IndexedMatrixValue mv) { - this.idx = mv.getIndexes(); - //this.nrows = mv.getValue().getNumRows(); - //this.ncols = mv.getValue().getNumColumns(); - } - } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCEvictionManager.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCEvictionManager.java index f5ae7573b0a..dace1ab9e53 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCEvictionManager.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCEvictionManager.java @@ -84,7 +84,7 @@ public class OOCEvictionManager { // Configuration: OOC buffer limit as percentage of heap - private static final double OOC_BUFFER_PERCENTAGE = 0.15 * 0.01 * 2; // 15% of heap + private static final double OOC_BUFFER_PERCENTAGE = 0.15; // 15% of heap private static final double PARTITION_EVICTION_SIZE = 64 * 1024 * 1024; // 64 MB @@ -170,6 +170,40 @@ private static class BlockEntry { LocalFileUtils.createLocalFileIfNotExist(_spillDir); } + public static void reset() { + TeeOOCInstruction.reset(); + if (!_cache.isEmpty()) { + System.err.println("There are dangling elements in the OOC Eviction cache: " + _cache.size()); + } + _size.set(0); + _cache.clear(); + _spillLocations.clear(); + _partitions.clear(); + _partitionCounter.set(0); + _streamPartitions.clear(); + } + + /** + * Removes a block from the cache without setting its data to null. + */ + public static void forget(long streamId, int blockId) { + BlockEntry e; + synchronized (_cacheLock) { + e = _cache.remove(streamId + "_" + blockId); + } + + if (e != null) { + e.lock.lock(); + try { + if (e.state == BlockState.HOT) + _size.addAndGet(-e.size); + } finally { + e.lock.unlock(); + } + System.out.println("Removed block " + streamId + "_" + blockId + " from cache (idx: " + (e.value != null ? e.value.getIndexes() : "?") + ")"); + } + } + /** * Store a block in the OOC cache (serialize once) */ diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java index ca13cfdb2c3..1b6862361ed 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java @@ -34,6 +34,7 @@ import org.apache.sysds.runtime.util.OOCJoin; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -131,11 +132,15 @@ protected OOCStream createWritableStream() { return new SubscribableTaskQueue<>(); } - protected CompletableFuture filterOOC(OOCStream qIn, Consumer processor, Function predicate, Runnable finalizer) { + protected CompletableFuture filterOOC(OOCStream qIn, Consumer processor, Function predicate, Runnable finalizer) { + return filterOOC(qIn, processor, predicate, finalizer, null); + } + + protected CompletableFuture filterOOC(OOCStream qIn, Consumer processor, Function predicate, Runnable finalizer, Consumer onNotProcessed) { if (_inQueues == null || _outQueues == null) throw new NotImplementedException("filterOOC requires manual specification of all input and output streams for error propagation"); - return submitOOCTasks(qIn, processor, finalizer, predicate); + return submitOOCTasks(qIn, processor, finalizer, predicate, onNotProcessed != null ? (i, tmp) -> onNotProcessed.accept(tmp) : null); } protected CompletableFuture mapOOC(OOCStream qIn, OOCStream qOut, Function mapper) { @@ -163,10 +168,16 @@ protected CompletableFuture broadcastJoinOOC(OOCStream> availableLeftInput = new ConcurrentHashMap<>(); Map availableBroadcastInput = new ConcurrentHashMap<>(); - return submitOOCTasks(List.of(qIn, broadcast), (i, tmp) -> { + CompletableFuture future = submitOOCTasks(List.of(qIn, broadcast), (i, tmp) -> { P key = on.apply(tmp); if (i == 0) { // qIn stream @@ -184,11 +195,22 @@ protected CompletableFuture broadcastJoinOOC(OOCStream CompletableFuture broadcastJoinOOC(OOCStream { + availableBroadcastInput.forEach((k, v) -> { + rightCache.incrProcessingCount(rightCache.findCachedIndex(v.idx), 1); + }); + availableBroadcastInput.clear(); + qOut.closeInput(); + }); + + if (explicitLeftCaching) + leftCache.scheduleDeletion(); + if (explicitRightCaching) + rightCache.scheduleDeletion(); + + return future; } protected static class BroadcastedElement { @@ -244,7 +283,7 @@ public MatrixIndexes getIndex() { public IndexedMatrixValue getValue() { return value; } - }; + } protected CompletableFuture joinOOC(OOCStream qIn1, OOCStream qIn2, OOCStream qOut, BiFunction mapper, Function on) { return joinOOC(qIn1, qIn2, qOut, mapper, on, on); @@ -257,12 +296,18 @@ protected CompletableFuture joinOOC(OOCStream qIn1, OOCStream final CompletableFuture future = new CompletableFuture<>(); + boolean explicitLeftCaching = !qIn1.hasStreamCache(); + boolean explicitRightCaching = !qIn2.hasStreamCache(); + // We need to construct our own stream to properly manage the cached items in the hash join - CachingStream leftCache = qIn1.hasStreamCache() ? qIn1.getStreamCache() : new CachingStream((SubscribableTaskQueue)qIn1); // We have to assume this generic type for now - CachingStream rightCache = qIn2.hasStreamCache() ? qIn2.getStreamCache() : new CachingStream((SubscribableTaskQueue)qIn2); // We have to assume this generic type for now + CachingStream leftCache = explicitLeftCaching ? new CachingStream((OOCStream) qIn1) : qIn1.getStreamCache(); + CachingStream rightCache = explicitRightCaching ? new CachingStream((OOCStream) qIn2) : qIn2.getStreamCache(); leftCache.activateIndexing(); rightCache.activateIndexing(); + leftCache.incrSubscriberCount(1); + rightCache.incrSubscriberCount(1); + final OOCJoin join = new OOCJoin<>((idx, left, right) -> { T leftObj = (T) leftCache.findCached(left); T rightObj = (T) rightCache.findCached(right); @@ -280,36 +325,40 @@ protected CompletableFuture joinOOC(OOCStream qIn1, OOCStream future.complete(null); }); + if (explicitLeftCaching) + leftCache.scheduleDeletion(); + if (explicitRightCaching) + rightCache.scheduleDeletion(); + return future; } protected CompletableFuture submitOOCTasks(final List> queues, BiConsumer consumer, Runnable finalizer) { + return submitOOCTasks(queues, consumer, finalizer, null); + } + + protected CompletableFuture submitOOCTasks(final List> queues, BiConsumer consumer, Runnable finalizer, BiConsumer onNotProcessed) { List> futures = new ArrayList<>(queues.size()); for (int i = 0; i < queues.size(); i++) futures.add(new CompletableFuture<>()); - return submitOOCTasks(queues, consumer, finalizer, futures, null); + return submitOOCTasks(queues, consumer, finalizer, futures, null, onNotProcessed); } - protected CompletableFuture submitOOCTasks(final List> queues, BiConsumer consumer, Runnable finalizer, List> futures, BiFunction predicate) { + protected CompletableFuture submitOOCTasks(final List> queues, BiConsumer consumer, Runnable finalizer, List> futures, BiFunction predicate, BiConsumer onNotProcessed) { addInStream(queues.toArray(OOCStream[]::new)); ExecutorService pool = CommonThreadPool.get(); final List activeTaskCtrs = new ArrayList<>(queues.size()); - final List streamsClosed = new ArrayList<>(queues.size()); - for (int i = 0; i < queues.size(); i++) { - activeTaskCtrs.add(new AtomicInteger(0)); - streamsClosed.add(new AtomicBoolean(false)); - } + for (int i = 0; i < queues.size(); i++) + activeTaskCtrs.add(new AtomicInteger(1)); - final AtomicInteger globalTaskCtr = new AtomicInteger(0); final CompletableFuture globalFuture = CompletableFuture.allOf(futures.toArray(CompletableFuture[]::new)); if (_outQueues == null) _outQueues = Collections.emptySet(); final Runnable oocFinalizer = oocTask(finalizer, null, Stream.concat(_outQueues.stream(), _inQueues.stream()).toArray(OOCStream[]::new)); - final Object globalLock = new Object(); int i = 0; @SuppressWarnings("unused") @@ -319,84 +368,67 @@ protected CompletableFuture submitOOCTasks(final List> qu for (OOCStream queue : queues) { final int k = i; final AtomicInteger localTaskCtr = activeTaskCtrs.get(k); - final AtomicBoolean localStreamClosed = streamsClosed.get(k); final CompletableFuture localFuture = futures.get(k); + final AtomicBoolean closeRaceWatchdog = new AtomicBoolean(false); //System.out.println("Substream (k " + k + ", id " + streamId + ", type '" + queue.getClass().getSimpleName() + "', stream_id " + queue.hashCode() + ")"); - queue.setSubscriber(oocTask(() -> { - final T item = queue.dequeue(); + queue.setSubscriber(oocTask(callback -> { + final T item = callback.get(); - if (predicate != null && item != null && !predicate.apply(k, item)) // Can get closed due to cancellation - return; + if(item == null) { + if(!closeRaceWatchdog.compareAndSet(false, true)) + throw new DMLRuntimeException("Race condition observed: NO_MORE_TASKS callback has been triggered more than once"); - synchronized (globalLock) { - if (localFuture.isDone()) - return; + if(localTaskCtr.decrementAndGet() == 0) { + // Then we can run the finalization procedure already + localFuture.complete(null); + } + return; + } - globalTaskCtr.incrementAndGet(); + if(predicate != null && !predicate.apply(k, item)) { // Can get closed due to cancellation + if(onNotProcessed != null) + onNotProcessed.accept(k, item); + return; } - localTaskCtr.incrementAndGet(); + if(localFuture.isDone()) { + if(onNotProcessed != null) + onNotProcessed.accept(k, item); + return; + } + else { + localTaskCtr.incrementAndGet(); + } pool.submit(oocTask(() -> { - if(item != null) { - //System.out.println("Accept" + ((IndexedMatrixValue)item).getIndexes() + " (k " + k + ", id " + streamId + ")"); - consumer.accept(k, item); - } - else { - //System.out.println("Close substream (k " + k + ", id " + streamId + ")"); - localStreamClosed.set(true); - } - - boolean runFinalizer = false; - - synchronized (globalLock) { - int localTasks = localTaskCtr.decrementAndGet(); - boolean finalizeStream = localTasks == 0 && localStreamClosed.get(); - - int globalTasks = globalTaskCtr.get() - 1; - - if (finalizeStream || (globalFuture.isDone() && localTasks == 0)) { - localFuture.complete(null); + // TODO For caching streams, we have no guarantee that item is still in memory -> NullPointer possible + consumer.accept(k, item); - if (globalFuture.isDone() && globalTasks == 0) - runFinalizer = true; - } - - globalTaskCtr.decrementAndGet(); - } - - if (runFinalizer) - oocFinalizer.run(); + if(localTaskCtr.decrementAndGet() == 0) + localFuture.complete(null); }, localFuture, Stream.concat(_outQueues.stream(), _inQueues.stream()).toArray(OOCStream[]::new))); + + if(closeRaceWatchdog.get()) // Sanity check + throw new DMLRuntimeException("Race condition observed"); }, null, Stream.concat(_outQueues.stream(), _inQueues.stream()).toArray(OOCStream[]::new))); i++; } - pool.shutdown(); - globalFuture.whenComplete((res, e) -> { - if (globalFuture.isCancelled() || globalFuture.isCompletedExceptionally()) + if (globalFuture.isCancelled() || globalFuture.isCompletedExceptionally()) { futures.forEach(f -> { - if (!f.isDone()) { - if (globalFuture.isCancelled() || globalFuture.isCompletedExceptionally()) + if(!f.isDone()) { + if(globalFuture.isCancelled() || globalFuture.isCompletedExceptionally()) f.cancel(true); else f.complete(null); } }); - - boolean runFinalizer; - - synchronized (globalLock) { - runFinalizer = globalTaskCtr.get() == 0; } - if (runFinalizer) - oocFinalizer.run(); - - //System.out.println("Shutdown (id " + streamId + ")"); + oocFinalizer.run(); }); return globalFuture; } @@ -405,8 +437,8 @@ protected CompletableFuture submitOOCTasks(OOCStream queue, Consume return submitOOCTasks(List.of(queue), (i, tmp) -> consumer.accept(tmp), finalizer); } - protected CompletableFuture submitOOCTasks(OOCStream queue, Consumer consumer, Runnable finalizer, Function predicate) { - return submitOOCTasks(List.of(queue), (i, tmp) -> consumer.accept(tmp), finalizer, List.of(new CompletableFuture()), (i, tmp) -> predicate.apply(tmp)); + protected CompletableFuture submitOOCTasks(OOCStream queue, Consumer consumer, Runnable finalizer, Function predicate, BiConsumer onNotProcessed) { + return submitOOCTasks(List.of(queue), (i, tmp) -> consumer.accept(tmp), finalizer, List.of(new CompletableFuture()), (i, tmp) -> predicate.apply(tmp), onNotProcessed); } protected CompletableFuture submitOOCTask(Runnable r, OOCStream... queues) { @@ -450,6 +482,31 @@ private Runnable oocTask(Runnable r, CompletableFuture future, OOCStream< }; } + private Consumer> oocTask(Consumer> c, CompletableFuture future, OOCStream... queues) { + return callback -> { + try { + c.accept(callback); + } + catch (Exception ex) { + DMLRuntimeException re = ex instanceof DMLRuntimeException ? (DMLRuntimeException) ex : new DMLRuntimeException(ex); + + if (_failed) // Do avoid infinite cycles + throw re; + + _failed = true; + + for (OOCStream q : queues) + q.propagateFailure(re); + + if (future != null) + future.completeExceptionally(re); + + // Rethrow to ensure proper future handling + throw re; + } + }; + } + /** * Tracks blocks and their counts to enable early emission * once all blocks for a given index are processed. diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java index 1a12cb138b7..f02c847e055 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java @@ -22,6 +22,8 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; +import java.util.function.Consumer; + public interface OOCStream extends OOCStreamable { void enqueue(T t); @@ -36,4 +38,28 @@ public interface OOCStream extends OOCStreamable { boolean hasStreamCache(); CachingStream getStreamCache(); + + /** + * Registers a new subscriber that consumes the stream. + * While there is no guarantee for any specific order, the closing item LocalTaskQueue.NO_MORE_TASKS + * is guaranteed to be invoked after every other item has finished processing. Thus, the NO_MORE_TASKS + * callback can be used to free dependent resources and close output streams. + */ + void setSubscriber(Consumer> subscriber); + + class QueueCallback { + private final T _result; + private final DMLRuntimeException _failure; + + public QueueCallback(T result, DMLRuntimeException failure) { + _result = result; + _failure = failure; + } + + public T get() { + if (_failure != null) + throw _failure; + return _result; + } + } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStreamable.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStreamable.java index bdc4086bdcd..af2c0afa660 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStreamable.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStreamable.java @@ -25,6 +25,4 @@ public interface OOCStreamable { OOCStream getWriteStream(); boolean isProcessed(); - - void setSubscriber(Runnable subscriber); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCWatchdog.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCWatchdog.java new file mode 100644 index 00000000000..b7a16778ab7 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCWatchdog.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.ooc; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; + +/** + * Watchdog to help debug OOC streams/tasks that never close. + */ +public final class OOCWatchdog { + public static final boolean WATCH = false; + private static final ConcurrentHashMap OPEN = new ConcurrentHashMap<>(); + private static final ScheduledExecutorService EXEC = + Executors.newSingleThreadScheduledExecutor(r -> { + Thread t = new Thread(r, "TemporaryWatchdog"); + t.setDaemon(true); + return t; + }); + + private static final long STALE_MS = TimeUnit.SECONDS.toMillis(10); + private static final long SCAN_INTERVAL_MS = TimeUnit.SECONDS.toMillis(10); + + static { + EXEC.scheduleAtFixedRate(OOCWatchdog::scan, SCAN_INTERVAL_MS, SCAN_INTERVAL_MS, TimeUnit.MILLISECONDS); + } + + private OOCWatchdog() { + // no-op + } + + public static void registerOpen(String id, String desc, String context, OOCStream stream) { + OPEN.put(id, new Entry(desc, context, System.currentTimeMillis(), stream)); + } + + public static void addEvent(String id, String eventMsg) { + Entry e = OPEN.get(id); + if (e != null) + e.events.add(eventMsg); + } + + public static void registerClose(String id) { + OPEN.remove(id); + } + + private static void scan() { + long now = System.currentTimeMillis(); + for (Map.Entry e : OPEN.entrySet()) { + if (now - e.getValue().openedAt >= STALE_MS) { + if (e.getValue().events.isEmpty()) + continue; // Probably just a stream that has no consumer (remains to be checked why this can happen) + System.err.println("[TemporaryWatchdog] Still open after " + (now - e.getValue().openedAt) + "ms: " + + e.getKey() + " (" + e.getValue().desc + ")" + + (e.getValue().context != null ? " ctx=" + e.getValue().context : "")); + } + } + } + + private static class Entry { + final String desc; + final String context; + final long openedAt; + final OOCStream stream; + ConcurrentLinkedQueue events; + + Entry(String desc, String context, long openedAt, OOCStream stream) { + this.desc = desc; + this.context = context; + this.openedAt = openedAt; + this.stream = stream; + this.events = new ConcurrentLinkedQueue<>(); + } + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ParameterizedBuiltinOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ParameterizedBuiltinOOCInstruction.java index e56d32e4401..d70fc3ccb94 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ParameterizedBuiltinOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ParameterizedBuiltinOOCInstruction.java @@ -20,7 +20,6 @@ package org.apache.sysds.runtime.instructions.ooc; import org.apache.commons.lang3.NotImplementedException; -import org.apache.commons.lang3.mutable.MutableObject; import org.apache.sysds.common.Opcodes; import org.apache.sysds.common.Types; import org.apache.sysds.runtime.DMLRuntimeException; @@ -43,7 +42,6 @@ import java.util.LinkedHashMap; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; -import java.util.concurrent.atomic.AtomicBoolean; public class ParameterizedBuiltinOOCInstruction extends ComputationOOCInstruction { @@ -110,29 +108,26 @@ else if(instOpcode.equalsIgnoreCase(Opcodes.CONTAINS.toString())) { Data finalPattern = pattern; - AtomicBoolean found = new AtomicBoolean(false); + addInStream(qIn); + addOutStream(); // This instruction has no output stream - MutableObject> futureRef = new MutableObject<>(); - CompletableFuture future = submitOOCTasks(qIn, tmp -> { - boolean contains = ((MatrixBlock)tmp.getValue()).containsValue(((ScalarObject)finalPattern).getDoubleValue()); + CompletableFuture future = new CompletableFuture<>(); - if (contains) { - found.set(true); + filterOOC(qIn, tmp -> { + boolean contains = ((MatrixBlock)tmp.getValue()).containsValue(((ScalarObject)finalPattern).getDoubleValue()); - // Now we may complete the future - if (futureRef.getValue() != null) - futureRef.getValue().complete(null); - } - }, () -> {}); - futureRef.setValue(future); + if (contains) + future.complete(true); + }, tmp -> !future.isDone(), // Don't start a separate worker if result already known + () -> future.complete(false)); // Then the pattern was not found + boolean ret; try { - futureRef.getValue().get(); + ret = future.get(); } catch (InterruptedException | ExecutionException e) { throw new DMLRuntimeException(e); } - boolean ret = found.get(); ec.setScalarOutput(output.getName(), new BooleanObject(ret)); } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/PlaybackStream.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/PlaybackStream.java index 6edc4ecf270..5b996da0dbe 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/PlaybackStream.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/PlaybackStream.java @@ -23,13 +23,22 @@ import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; + public class PlaybackStream implements OOCStream, OOCStreamable { private final CachingStream _streamCache; - private int _streamIdx; + private final AtomicInteger _streamIdx; + private final AtomicInteger _taskCtr; + private final AtomicBoolean _subscriberSet; public PlaybackStream(CachingStream streamCache) { this._streamCache = streamCache; - this._streamIdx = 0; + this._streamIdx = new AtomicInteger(0); + this._taskCtr = new AtomicInteger(1); + this._subscriberSet = new AtomicBoolean(false); + streamCache.incrSubscriberCount(1); } @Override @@ -44,15 +53,29 @@ public void closeInput() { @Override public LocalTaskQueue toLocalTaskQueue() { - final SubscribableTaskQueue q = new SubscribableTaskQueue<>(); - setSubscriber(() -> q.enqueue(dequeue())); + final LocalTaskQueue q = new LocalTaskQueue<>(); + setSubscriber(val -> { + if (val.get() == null) { + q.closeInput(); + return; + } + try { + q.enqueueTask(val.get()); + } + catch(InterruptedException e) { + throw new RuntimeException(e); + } + }); return q; } @Override - public synchronized IndexedMatrixValue dequeue() { + public IndexedMatrixValue dequeue() { + if (_subscriberSet.get()) + throw new IllegalStateException("Cannot dequeue from a playback stream if a subscriber has been set"); + try { - return _streamCache.get(_streamIdx++); + return _streamCache.get(_streamIdx.getAndIncrement()); } catch (InterruptedException e) { throw new DMLRuntimeException(e); } @@ -74,8 +97,35 @@ public boolean isProcessed() { } @Override - public void setSubscriber(Runnable subscriber) { - _streamCache.setSubscriber(subscriber); + public void setSubscriber(Consumer> subscriber) { + if (!_subscriberSet.compareAndSet(false, true)) + throw new IllegalArgumentException("Subscriber cannot be set multiple times"); + + /** + * To guarantee that NO_MORE_TASKS is invoked after all subscriber calls + * finished, we keep track of running tasks using a task counter. + */ + _streamCache.setSubscriber(() -> { + try { + _taskCtr.incrementAndGet(); + + IndexedMatrixValue val; + + try { + val = _streamCache.get(_streamIdx.getAndIncrement()); + } catch (InterruptedException e) { + throw new DMLRuntimeException(e); + } + + if (val != null) + subscriber.accept(new QueueCallback<>(val, null)); + + if (_taskCtr.addAndGet(val == null ? -2 : -1) == 0) + subscriber.accept(new QueueCallback<>(null, null)); + } catch (DMLRuntimeException e) { + subscriber.accept(new QueueCallback<>(null, e)); + } + }, false); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java index f136ffc2bb6..7563d8471b6 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java @@ -22,80 +22,172 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; +import java.util.LinkedList; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; + public class SubscribableTaskQueue extends LocalTaskQueue implements OOCStream { - private Runnable _subscriber; - @Override - public synchronized void enqueue(T t) { - try { - super.enqueueTask(t); - } - catch (InterruptedException e) { - throw new DMLRuntimeException(e); + private final AtomicInteger _availableCtr = new AtomicInteger(1); + private final AtomicBoolean _closed = new AtomicBoolean(false); + private volatile Consumer> _subscriber = null; + private String _watchdogId; + + public SubscribableTaskQueue() { + if (OOCWatchdog.WATCH) { + _watchdogId = "STQ-" + hashCode(); + // Capture a short context to help identify origin + OOCWatchdog.registerOpen(_watchdogId, "SubscribableTaskQueue@" + hashCode(), getCtxMsg(), this); } + } - if(_subscriber != null) - _subscriber.run(); + private String getCtxMsg() { + StackTraceElement[] st = new Exception().getStackTrace(); + // Skip the first few frames (constructor, createWritableStream, etc.) + StringBuilder sb = new StringBuilder(); + int limit = Math.min(st.length, 7); + for(int i = 2; i < limit; i++) { + sb.append(st[i].getClassName()).append(".").append(st[i].getMethodName()).append(":") + .append(st[i].getLineNumber()); + if(i < limit - 1) + sb.append(" <- "); + } + return sb.toString(); } @Override - public T dequeue() { - try { - return super.dequeueTask(); + public void enqueue(T t) { + if (t == NO_MORE_TASKS) + throw new DMLRuntimeException("Cannot enqueue NO_MORE_TASKS item"); + + int cnt = _availableCtr.incrementAndGet(); + + if (cnt <= 1) { // Then the queue was already closed and we disallow further enqueues + _availableCtr.decrementAndGet(); // Undo increment + throw new DMLRuntimeException("Cannot enqueue into closed SubscribableTaskQueue"); } - catch (InterruptedException e) { - throw new DMLRuntimeException(e); + + Consumer> s = _subscriber; + + if (s != null) { + s.accept(new QueueCallback<>(t, _failure)); + onDeliveryFinished(); + return; + } + + synchronized (this) { + // Re-check that subscriber is really null to avoid race conditions + if (_subscriber == null) { + try { + super.enqueueTask(t); + } + catch(InterruptedException e) { + throw new DMLRuntimeException(e); + } + return; + } + // Otherwise do not insert and re-schedule subscriber invocation + s = _subscriber; } + + // Last case if due to race a subscriber has been set + s.accept(new QueueCallback<>(t, _failure)); + onDeliveryFinished(); } @Override - public synchronized void closeInput() { - super.closeInput(); - - if(_subscriber != null) { - _subscriber.run(); - _subscriber = null; - } + public void enqueueTask(T t) { + enqueue(t); } @Override - public LocalTaskQueue toLocalTaskQueue() { - return this; + public T dequeue() { + try { + if (OOCWatchdog.WATCH) + OOCWatchdog.addEvent(_watchdogId, "dequeue -- " + getCtxMsg()); + T deq = super.dequeueTask(); + if (deq != NO_MORE_TASKS) + onDeliveryFinished(); + return deq; + } + catch(InterruptedException e) { + throw new DMLRuntimeException(e); + } } @Override - public OOCStream getReadStream() { - return this; + public T dequeueTask() { + return dequeue(); } @Override - public OOCStream getWriteStream() { - return this; + public void closeInput() { + if (_closed.compareAndSet(false, true)) { + super.closeInput(); + onDeliveryFinished(); + } else { + throw new IllegalStateException("Multiple close input calls"); + } } @Override - public void setSubscriber(Runnable subscriber) { - int queueSize; + public void setSubscriber(Consumer> subscriber) { + if(subscriber == null) + throw new IllegalArgumentException("Cannot set subscriber to null"); - synchronized (this) { + LinkedList data; + + synchronized(this) { if(_subscriber != null) throw new DMLRuntimeException("Cannot set multiple subscribers"); - _subscriber = subscriber; - queueSize = _data.size(); - queueSize += _closedInput ? 1 : 0; // To trigger the NO_MORE_TASK element + if(_failure != null) + throw _failure; + data = _data; + _data = new LinkedList<>(); + } + + for (T t : data) { + subscriber.accept(new QueueCallback<>(t, _failure)); + onDeliveryFinished(); } + } - for (int i = 0; i < queueSize; i++) - subscriber.run(); + private void onDeliveryFinished() { + int ctr = _availableCtr.decrementAndGet(); + + if (ctr == 0) { + Consumer> s = _subscriber; + if (s != null) + s.accept(new QueueCallback<>((T) LocalTaskQueue.NO_MORE_TASKS, _failure)); + + if (OOCWatchdog.WATCH) + OOCWatchdog.registerClose(_watchdogId); + } } @Override public synchronized void propagateFailure(DMLRuntimeException re) { super.propagateFailure(re); + Consumer> s = _subscriber; + if(s != null) + s.accept(new QueueCallback<>(null, re)); + } + + @Override + public LocalTaskQueue toLocalTaskQueue() { + return this; + } + + @Override + public OOCStream getReadStream() { + return this; + } - if(_subscriber != null) - _subscriber.run(); + @Override + public OOCStream getWriteStream() { + return this; } @Override diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java index fd80b4e6e90..aba36297e7f 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java @@ -25,8 +25,37 @@ import org.apache.sysds.runtime.instructions.cp.CPOperand; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import java.util.concurrent.ConcurrentHashMap; + public class TeeOOCInstruction extends ComputationOOCInstruction { + private static final ConcurrentHashMap refCtr = new ConcurrentHashMap<>(); + + public static void reset() { + if (!refCtr.isEmpty()) { + System.err.println("There are some dangling streams still in the cache: " + refCtr); + refCtr.clear(); + } + } + + /** + * Increments the reference counter of a stream by the set amount. + */ + public static void incrRef(OOCStreamable stream, int incr) { + if (!(stream instanceof CachingStream)) + return; + + Integer ref = refCtr.compute((CachingStream)stream, (k, v) -> { + if (v == null) + v = 0; + v += incr; + return v <= 0 ? null : v; + }); + + if (ref == null) + ((CachingStream)stream).scheduleDeletion(); + } + protected TeeOOCInstruction(OOCType type, CPOperand in1, CPOperand out, String opcode, String istr) { super(type, null, in1, out, opcode, istr); } @@ -45,9 +74,20 @@ public void processInstruction( ExecutionContext ec ) { MatrixObject min = ec.getMatrixObject(input1); OOCStream qIn = min.getStreamHandle(); + CachingStream handle = qIn.hasStreamCache() ? qIn.getStreamCache() : new CachingStream(qIn); + + if (!qIn.hasStreamCache()) { + // We also set the input stream handle + min.setStreamHandle(handle); + incrRef(handle, 2); + } + else { + incrRef(handle, 1); + } + //get output and create new resettable stream MatrixObject mo = ec.getMatrixObject(output); - mo.setStreamHandle(new CachingStream(qIn)); + mo.setStreamHandle(handle); mo.setMetaData(min.getMetaData()); } } diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/PCATest.java b/src/test/java/org/apache/sysds/test/functions/ooc/PCATest.java new file mode 100644 index 00000000000..e20b7ec4269 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/PCATest.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.ooc; + +import org.apache.sysds.common.Types; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.MatrixWriterFactory; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +import java.io.IOException; + +public class PCATest extends AutomatedTestBase { + private final static String TEST_NAME1 = "PCA"; + private final static String TEST_DIR = "functions/ooc/"; + private final static String TEST_CLASS_DIR = TEST_DIR + PCATest.class.getSimpleName() + "/"; + //private final static double eps = 1e-8; + private static final String INPUT_NAME_1 = "X"; + private static final String OUTPUT_NAME_1 = "PC"; + private static final String OUTPUT_NAME_2 = "V"; + + private final static int rows = 50000; + private final static int cols = 1000; + private final static int maxVal = 2; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1); + addTestConfiguration(TEST_NAME1, config); + } + + @Test + public void testPCA() { + boolean allow_opfusion = OptimizerUtils.ALLOW_OPERATOR_FUSION; + OptimizerUtils.ALLOW_OPERATOR_FUSION = false; // some fused ops are not implemented yet + runPCATest(16); + OptimizerUtils.ALLOW_OPERATOR_FUSION = allow_opfusion; + } + + private void runPCATest(int k) { + Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE); + + try { + getAndLoadTestConfiguration(TEST_NAME1); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME1 + ".dml"; + programArgs = new String[] {"-explain", "hops", "-stats", "-ooc", "-args", input(INPUT_NAME_1), Integer.toString(k), output(OUTPUT_NAME_1), output(OUTPUT_NAME_2)}; + + // 1. Generate the data in-memory as MatrixBlock objects + double[][] X_data = getRandomMatrix(rows, cols, 0, maxVal, 1, 7); + + // 2. Convert the double arrays to MatrixBlock objects + MatrixBlock X_mb = DataConverter.convertToMatrixBlock(X_data); + + // 3. Create a binary matrix writer + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY); + + // 4. Write matrix A to a binary SequenceFile + writer.writeMatrixToHDFS(X_mb, input(INPUT_NAME_1), rows, cols, 1000, X_mb.getNonZeros()); + HDFSTool.writeMetaDataFile(input(INPUT_NAME_1 + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(rows, cols, 1000, X_mb.getNonZeros()), Types.FileFormat.BINARY); + X_data = null; + X_mb = null; + + runTest(true, false, null, -1); + + //check replace OOC op + //Assert.assertTrue("OOC wasn't used for replacement", + // heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.REPLACE)); + + //compare results + + // rerun without ooc flag + programArgs = new String[] {"-explain", "hops", "-stats", "-args", input(INPUT_NAME_1), Integer.toString(k), output(OUTPUT_NAME_1 + "_target"), output(OUTPUT_NAME_2 + "_target")}; + runTest(true, false, null, -1); + + // compare matrices + /*MatrixBlock ret1 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME_1), + Types.FileFormat.BINARY, rows, cols, 1000); + MatrixBlock ret2 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME_1 + "_target"), + Types.FileFormat.BINARY, rows, cols, 1000); + TestUtils.compareMatrices(ret1, ret2, eps); + + MatrixBlock ret2_1 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME_2), + Types.FileFormat.BINARY, rows, cols, 1000); + MatrixBlock ret2_2 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME_2 + "_target"), + Types.FileFormat.BINARY, rows, cols, 1000); + TestUtils.compareMatrices(ret2_1, ret2_2, eps);*/ + } + catch(IOException e) { + throw new RuntimeException(e); + } + finally { + resetExecMode(platformOld); + } + } +} diff --git a/src/test/scripts/functions/ooc/PCA.dml b/src/test/scripts/functions/ooc/PCA.dml new file mode 100644 index 00000000000..567d701ec06 --- /dev/null +++ b/src/test/scripts/functions/ooc/PCA.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = read($1); +k = $2; + +[PC, V] = pca(X=X, K=k) + +write(PC, $3, format="binary"); +write(V, $4, format="binary");