From bad7b8fbd58846d895d5cde92e130e54e0c9cf84 Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Tue, 16 Dec 2025 15:59:15 +0100 Subject: [PATCH 01/58] Commiting leftover before changing tactic --- examples/oracle/counters.gm | 23 ----------------------- src/graphviz.cc | 6 ++++++ src/graphviz.hh | 1 + src/interpreter.cc | 2 +- 4 files changed, 8 insertions(+), 24 deletions(-) delete mode 100644 examples/oracle/counters.gm diff --git a/examples/oracle/counters.gm b/examples/oracle/counters.gm deleted file mode 100644 index 90126bd..0000000 --- a/examples/oracle/counters.gm +++ /dev/null @@ -1,23 +0,0 @@ -// Two threads update the same 'variable' to the same value but there -// should still be a race - -x = 0; -$t = spawn { - // The thread inherits a copy of the snapshot of the heap state known by the - // spawning thread at the point when the thread was spawned. - // So we always know that the assert(x == 0) will succeed. - // This thread will then proceed to mutate its versioned copy of the heap. - assert(x == 0); - x = 1; -}; - -// The spawning thread continues but will mutate and read only from its version -// of the heap which is isolated from the thread it just spawned. So, we know -// that assert(x == 0) will always succeed. -assert(x == 0); -x = 1; - -// When we join the threads, the joining thread will attempt to pull in the -// changes made to the thread. Even thought the two values are the same, this -// represents a data race and so we have to crash. -join $t; // Data race \ No newline at end of file diff --git a/src/graphviz.cc b/src/graphviz.cc index 1198ba3..cd9bd1e 100644 --- a/src/graphviz.cc +++ b/src/graphviz.cc @@ -45,8 +45,13 @@ namespace graph { file << "\t" << (size_t)n << "[fillcolor = " << color << "];" << std::endl; } + void GraphvizPrinter::emitShape(const Node* n, const std::string& shape) { + file << "\t" << (size_t)n << "[shape = " << shape << "];" << std::endl; + } + void GraphvizPrinter::emitConflict(const Node* n, const Conflict& conflict) { emitFillColor(n, "red"); + // emitShape(n, "doubleoctagon"); auto [s1, s2] = conflict.sources; emitConflictEdge(n, s1.get()); emitConflictEdge(n, s2.get()); @@ -138,6 +143,7 @@ namespace graph { void GraphvizPrinter::visitAssertionFailure(const AssertionFailure* n) { emitNode(n, "Assert " + n->cond); emitFillColor(n, "red"); + // emitShape(n, "doubleoctagon"); emitProgramOrderEdge(n, n->next.get()); visitProgramOrder(n->next.get()); } diff --git a/src/graphviz.hh b/src/graphviz.hh index 263bad0..a75ca2f 100644 --- a/src/graphviz.hh +++ b/src/graphviz.hh @@ -24,6 +24,7 @@ namespace gitmem { void emitProgramOrderEdge(const Node* from, const Node* to); void emitReadFromEdge(const Node* from, const Node* to); void emitFillColor(const Node* n, const std::string& color); + void emitShape(const Node* n, const std::string& shape); void emitConflictEdge(const Node* from, const Node* to); void emitSyncEdge(const Node* from, const Node* to); void emitConflict(const Node* n, const Conflict& conflict); diff --git a/src/interpreter.cc b/src/interpreter.cc index 37c63ad..bfe83ff 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -443,7 +443,7 @@ namespace gitmem if (auto term = std::get_if(&delta_or_term)) { thread->terminated = *term; - thread_append_node(ctx); + // thread_append_node(ctx); return *term; } From a4c85111a08fffcc80cb24f80fb12f3af55cfb9d Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Wed, 17 Dec 2025 16:48:37 +0000 Subject: [PATCH 02/58] namespacing a few things --- CMakeLists.txt | 2 ++ src/debugger.cc | 4 +-- src/gitmem.cc | 4 +-- src/gitmem_trieste.cc | 2 +- src/internal.hh | 11 ++++-- src/interpreter.cc | 73 ++++++++++++++++++++------------------- src/interpreter.hh | 25 +++++++------- src/lang.hh | 7 +++- src/parser.cc | 11 ++++-- src/passes/branching.cc | 9 +++-- src/passes/check_refs.cc | 10 ++++-- src/passes/expressions.cc | 7 +++- src/passes/statements.cc | 10 ++++-- src/reader.cc | 8 +++-- 14 files changed, 111 insertions(+), 72 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8b7e65e..04a2d21 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,6 +14,8 @@ FetchContent_MakeAvailable(trieste) set(CMAKE_CXX_STANDARD 23) set(CMAKE_CXX_STANDARD_REQUIRED True) +include_directories(src) + add_executable(gitmem src/gitmem.cc src/reader.cc diff --git a/src/debugger.cc b/src/debugger.cc index d0d3eef..01492fe 100644 --- a/src/debugger.cc +++ b/src/debugger.cc @@ -248,7 +248,7 @@ namespace gitmem return false; case TerminationStatus::assertion_failure_exception: { - auto expr = thread->block->at(thread->pc) / Stmt / Expr; + auto expr = thread->block->at(thread->pc) / lang::Stmt / lang::Expr; msg = "Thread " + std::to_string(tid) + " failed assertion '" + std::string(expr->location().view()) + "' and was terminated"; return false; } @@ -265,7 +265,7 @@ namespace gitmem /** Interpret the AST in an interactive way, letting the user choose which * thread to schedule next. */ - int interpret_interactive(const Node ast, const std::filesystem::path &output_file) + int interpret_interactive(const trieste::Node ast, const std::filesystem::path &output_file) { GlobalContext gctx(ast); diff --git a/src/gitmem.cc b/src/gitmem.cc index 5ea111f..37fe4de 100644 --- a/src/gitmem.cc +++ b/src/gitmem.cc @@ -58,7 +58,7 @@ int main(int argc, char **argv) return 1; } - auto reader = gitmem::reader().file(input_path); + auto reader = gitmem::lang::reader().file(input_path); auto result = reader.read(); if (!result.ok) @@ -75,7 +75,7 @@ int main(int argc, char **argv) gitmem::verbose << "Output will be written to " << output_path << std::endl; int exit_status; - wf::push_back(gitmem::wf); + wf::push_back(gitmem::lang::wf); if (model_check) { exit_status = gitmem::model_check(result.ast, output_path); diff --git a/src/gitmem_trieste.cc b/src/gitmem_trieste.cc index 19cd4ef..0ebc28b 100644 --- a/src/gitmem_trieste.cc +++ b/src/gitmem_trieste.cc @@ -3,5 +3,5 @@ int main(int argc, char** argv) { - return trieste::Driver(gitmem::reader()).run(argc, argv); + return trieste::Driver(gitmem::lang::reader()).run(argc, argv); } diff --git a/src/internal.hh b/src/internal.hh index 4687161..ec01c4d 100644 --- a/src/internal.hh +++ b/src/internal.hh @@ -1,8 +1,10 @@ #pragma once #include "lang.hh" -namespace gitmem -{ +namespace gitmem { + +namespace lang { + using namespace trieste; Parse parser(); @@ -83,4 +85,7 @@ namespace gitmem | (Cond <<= Expr * Const) ; // clang-format on -} + +} // namespace lang + +} // namespace gitmem \ No newline at end of file diff --git a/src/interpreter.cc b/src/interpreter.cc index bfe83ff..3fa86b2 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -5,8 +5,8 @@ #include "interpreter.hh" #include "graphviz.hh" -namespace gitmem -{ +namespace gitmem { + using namespace trieste; /* Interpreter for a gitmem program. Threads can read and write local @@ -25,8 +25,8 @@ namespace gitmem bool is_syncing(Node stmt) { - auto s = stmt / Stmt; - return s == Join || s == Lock || s == Unlock; + auto s = stmt / lang::Stmt; + return s == lang::Join || s == lang::Lock || s == lang::Unlock; } bool is_syncing(Thread &thread) @@ -132,8 +132,8 @@ namespace gitmem */ std::variant evaluate_expression(Node expr, GlobalContext &gctx, ThreadContext &ctx) { - auto e = expr / Expr; - if (e == Reg) + auto e = expr / lang::Expr; + if (e == lang::Reg) { // It is invalid to read a previously unwritten value auto var = std::string(expr->location().view()); @@ -146,7 +146,7 @@ namespace gitmem return TerminationStatus::unassigned_variable_read_exception; } } - else if (e == Var) + else if (e == lang::Var) { // It is invalid to read a previously unwritten value auto var = std::string(expr->location().view()); @@ -163,11 +163,11 @@ namespace gitmem return TerminationStatus::unassigned_variable_read_exception; } } - else if (e == Const) + else if (e == lang::Const) { return size_t(std::stoi(std::string(e->location().view()))); } - else if (e == Add) + else if (e == lang::Add) { size_t sum = 0; for (auto &child : *e) @@ -178,7 +178,7 @@ namespace gitmem } return sum; } - else if (e == Spawn) + else if (e == lang::Spawn) { // Spawning is a sync point, commit local pending commits, and // copy the global state to the spawned thread @@ -187,16 +187,16 @@ namespace gitmem auto node = std::make_shared(tid); ThreadContext new_ctx = { Locals(), ctx.globals, node }; - gctx.threads.push_back(std::make_shared(new_ctx, e / Block)); + gctx.threads.push_back(std::make_shared(new_ctx, e / lang::Block)); thread_append_node(ctx, tid, node); return tid; } - else if (e == Eq || e == Neq) + else if (e == lang::Eq || e == lang::Neq) { - auto lhs = e / Lhs; - auto rhs = e / Rhs; + auto lhs = e / lang::Lhs; + auto rhs = e / lang::Rhs; auto lhsEval = evaluate_expression(lhs, gctx, ctx); if (std::holds_alternative(lhsEval)) return lhsEval; @@ -204,7 +204,7 @@ namespace gitmem auto rhsEval = evaluate_expression(rhs, gctx, ctx); if (std::holds_alternative(rhsEval)) return rhsEval; - return e == Eq? (std::get(lhsEval)) == (std::get(rhsEval)) + return e == lang::Eq? (std::get(lhsEval)) == (std::get(rhsEval)) : (std::get(lhsEval)) != (std::get(rhsEval)); } else @@ -219,22 +219,22 @@ namespace gitmem */ std::variant run_statement(Node stmt, GlobalContext &gctx, ThreadContext &ctx, const ThreadID& tid) { - auto s = stmt / Stmt; - if (s == Nop) + auto s = stmt / lang::Stmt; + if (s == lang::Nop) { verbose << "Nop" << std::endl; } - else if (s == Jump) + else if (s == lang::Jump) { - auto cnst = s / Const; + auto cnst = s / lang::Const; auto delta = std::stoi(std::string(cnst->location().view())); assert(delta > 0); return delta; } - else if (s == Cond) + else if (s == lang::Cond) { - auto expr = s / Expr; - auto cnst = s / Const; + auto expr = s / lang::Expr; + auto cnst = s / lang::Const; auto result = evaluate_expression(expr, gctx, ctx); if (auto b = std::get_if(&result)) @@ -248,21 +248,21 @@ namespace gitmem return std::get(result); } } - else if (s == Assign) + else if (s == lang::Assign) { - auto lhs = s / LVal; + auto lhs = s / lang::LVal; auto var = std::string(lhs->location().view()); - auto rhs = s / Expr; + auto rhs = s / lang::Expr; auto val_or_term = evaluate_expression(rhs, gctx, ctx); if(size_t* val = std::get_if(&val_or_term)) { - if (lhs == Reg) + if (lhs == lang::Reg) { // Local variables can be re-assigned whenever verbose << "Set register '" << lhs->location().view() << "' to " << *val << std::endl; ctx.locals[var] = *val; } - else if (lhs == Var) + else if (lhs == lang::Var) { // Global variable writes need to create a new commit id // to track the history of updates @@ -284,12 +284,12 @@ namespace gitmem return std::get(val_or_term); } } - else if (s == Join) + else if (s == lang::Join) { // A join must waiting for the terminating thread to continue, // we don't want to re-evaluate the expression repeatedly as this // may be effecting so store the result in the cache. - auto expr = s / Expr; + auto expr = s / lang::Expr; if (!gctx.cache.contains(expr)) { @@ -332,12 +332,12 @@ namespace gitmem return 0; } } - else if (s == Lock) + else if (s == lang::Lock) { // We can only lock unlocked locks, if a lock hasn't been used // before it is implicitly created, we then commit the pending // updates of this thread and pull the updates from the lock. - auto v = s / Var; + auto v = s / lang::Var; auto var = std::string(v->location().view()); auto& lock = gctx.locks[var]; @@ -363,14 +363,14 @@ namespace gitmem verbose << "Locked " << var << std::endl; } - else if (s == Unlock) + else if (s == lang::Unlock) { // We can only unlock locks we previously locked. We commit any // pending updates and then copy the threads versioned globals // to the locks versioned globals (nobody could have changed // them since we locked the lock). commit(ctx.globals); - auto v = s / Var; + auto v = s / lang::Var; auto var = std::string(v->location().view()); auto& lock = gctx.locks[var]; @@ -387,9 +387,9 @@ namespace gitmem verbose << "Unlocked " << var << std::endl; } - else if (s == Assert) + else if (s == lang::Assert) { - auto expr = s / Expr; + auto expr = s / lang::Expr; auto result_or_term = evaluate_expression(expr, gctx, ctx); if (size_t* result = std::get_if(&result_or_term)) { @@ -609,4 +609,5 @@ namespace gitmem return result; } -} + +} // gitmem \ No newline at end of file diff --git a/src/interpreter.hh b/src/interpreter.hh index df2c7fc..ae59de3 100644 --- a/src/interpreter.hh +++ b/src/interpreter.hh @@ -5,8 +5,8 @@ #include "graph.hh" #include "graphviz.hh" -namespace gitmem -{ +namespace gitmem { + /* For debug printing */ inline struct Verbose { @@ -72,7 +72,7 @@ namespace gitmem struct Thread { ThreadContext ctx; - Node block; + trieste::Node block; size_t pc = 0; ThreadStatus terminated = std::nullopt; @@ -120,14 +120,14 @@ namespace gitmem { Threads threads; Locks locks; - NodeMap cache; + lang::NodeMap cache; std::shared_ptr entry_node; std::unordered_map> commit_map; Commit uuid = 0; - GlobalContext(const Node &ast) + GlobalContext(const trieste::Node &ast) { - Node starting_block = ast / File / Block; + trieste::Node starting_block = ast / lang::File / lang::Block; entry_node = std::make_shared(0); ThreadContext starting_ctx = {{}, {}, entry_node}; auto main_thread = std::make_shared(starting_ctx, starting_block); @@ -174,9 +174,9 @@ namespace gitmem if (t->terminated || dynamic_pointer_cast(t->ctx.tail->next)) continue; - Node block = t->block; + trieste::Node block = t->block; size_t &pc = t->pc; - Node stmt = block->at(pc); + trieste::Node stmt = block->at(pc); thread_append_node(t->ctx, std::string(stmt->location().view())); } @@ -201,13 +201,14 @@ namespace gitmem inline void operator|=(ProgressStatus &p1, const ProgressStatus &p2) { p1 = (p1 || p2); } // Entry functions - int interpret(const Node, const std::filesystem::path &output_file); - int interpret_interactive(const Node, const std::filesystem::path &output_file); - int model_check(const Node, const std::filesystem::path &output_file); + int interpret(const trieste::Node, const std::filesystem::path &output_file); + int interpret_interactive(const trieste::Node, const std::filesystem::path &output_file); + int model_check(const trieste::Node, const std::filesystem::path &output_file); // Internal functions int run_threads(GlobalContext &); std::variant progress_thread(GlobalContext &, const ThreadID, std::shared_ptr); -} + +} // namespace gitmem \ No newline at end of file diff --git a/src/lang.hh b/src/lang.hh index 4aa3ef0..3c387e2 100644 --- a/src/lang.hh +++ b/src/lang.hh @@ -3,6 +3,9 @@ namespace gitmem { + +namespace lang { + using namespace trieste; Reader reader(); @@ -74,4 +77,6 @@ namespace gitmem ; // clang-format on -} +} // namespace lang + +} // namespace gitmem diff --git a/src/parser.cc b/src/parser.cc index 4eb95ff..0c6faa3 100644 --- a/src/parser.cc +++ b/src/parser.cc @@ -1,8 +1,10 @@ #include "lang.hh" #include "internal.hh" -namespace gitmem -{ +namespace gitmem { + +namespace lang { + using namespace trieste; using namespace trieste::detail; @@ -135,4 +137,7 @@ namespace gitmem return p; } -} + +} // namespace lang + +} // namespace gitmem \ No newline at end of file diff --git a/src/passes/branching.cc b/src/passes/branching.cc index 152cd10..e7e018c 100644 --- a/src/passes/branching.cc +++ b/src/passes/branching.cc @@ -1,7 +1,8 @@ #include "../internal.hh" -namespace gitmem -{ +namespace gitmem { + +namespace lang { using namespace trieste; PassDef branching() @@ -28,4 +29,6 @@ namespace gitmem }}; } -} +} // namespace lang + +} // namespace gitmem \ No newline at end of file diff --git a/src/passes/check_refs.cc b/src/passes/check_refs.cc index ca712b0..664a6b4 100644 --- a/src/passes/check_refs.cc +++ b/src/passes/check_refs.cc @@ -1,7 +1,9 @@ #include "../internal.hh" -namespace gitmem -{ +namespace gitmem { + +namespace lang { + using namespace trieste; PassDef check_refs() @@ -27,4 +29,6 @@ namespace gitmem }}; } -} +} // namespace lang + +} // namespace gitmem \ No newline at end of file diff --git a/src/passes/expressions.cc b/src/passes/expressions.cc index 2c9d957..e945315 100644 --- a/src/passes/expressions.cc +++ b/src/passes/expressions.cc @@ -2,6 +2,9 @@ namespace gitmem { + +namespace lang { + using namespace trieste; PassDef expressions() @@ -136,4 +139,6 @@ namespace gitmem }}; } -} +} // namespace lang + +} // namespace gitmem \ No newline at end of file diff --git a/src/passes/statements.cc b/src/passes/statements.cc index 8b52cdb..22aa86c 100644 --- a/src/passes/statements.cc +++ b/src/passes/statements.cc @@ -1,7 +1,9 @@ #include "../internal.hh" -namespace gitmem -{ +namespace gitmem { + +namespace lang { + using namespace trieste; PassDef statements() @@ -200,4 +202,6 @@ namespace gitmem }}; } -} +} // namespace lang + +} // namespace gitmem \ No newline at end of file diff --git a/src/reader.cc b/src/reader.cc index e3648be..3942a53 100644 --- a/src/reader.cc +++ b/src/reader.cc @@ -2,6 +2,8 @@ namespace gitmem { +namespace lang { + using namespace trieste; Reader reader() @@ -14,8 +16,10 @@ Reader reader() check_refs(), branching(), }, - gitmem::parser(), + gitmem::lang::parser(), }; } -} +} // namespace lang + +} // namespace gitmem \ No newline at end of file From 02ac59c8816e813e1a5d8f1610257a9de44babfc Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Wed, 17 Dec 2025 16:49:08 +0000 Subject: [PATCH 03/58] adding linear and branching --- src/branching/versionstore.hh | 42 +++++++++ src/linear/versionstore.cc | 105 ++++++++++++++++++++++ src/linear/versionstore.hh | 161 ++++++++++++++++++++++++++++++++++ 3 files changed, 308 insertions(+) create mode 100644 src/branching/versionstore.hh create mode 100644 src/linear/versionstore.cc create mode 100644 src/linear/versionstore.hh diff --git a/src/branching/versionstore.hh b/src/branching/versionstore.hh new file mode 100644 index 0000000..688955c --- /dev/null +++ b/src/branching/versionstore.hh @@ -0,0 +1,42 @@ +#pragma once + + +#include +#include +#include +#include +#include + +namespace gitmem { + +namespace branching { + + /* A 'Global' is a structure to capture the current synchronising objects + * representation of a global variable. The structure is the current value, + * the current commit id for the variable, and the history of commited ids. + */ + + using Commit = size_t; + using CommitHistory = std::vector; + + struct Global + { + size_t val; + std::optional commit; + CommitHistory history; + }; + + using Globals = std::unordered_map; + + using Locals = std::unordered_map; + + + struct Conflict + { + std::string var; + std::pair commits; + }; + +} // namespace branching + +} // namespace gitmem \ No newline at end of file diff --git a/src/linear/versionstore.cc b/src/linear/versionstore.cc new file mode 100644 index 0000000..9ae55f6 --- /dev/null +++ b/src/linear/versionstore.cc @@ -0,0 +1,105 @@ +#include "versionstore.hh" +#include + +namespace gitmem { + +namespace linear { + +// ----------------------------- +// LocalVersionStore +// ----------------------------- + +void LocalVersionStore::stage(ObjectNumber obj, Value value) { + _staging[obj] = value; +} + +void LocalVersionStore::clear_staging() { + _staging.clear(); +} + +void LocalVersionStore::advance_base(Timestamp ts) { + _base_timestamp = ts; +} + +// ----------------------------- +// GlobalVersionStore +// ----------------------------- + +ObjectNumber GlobalVersionStore::allocate_object() { + return _next_object++; +} + +std::optional GlobalVersionStore::check_conflicts( + Timestamp base, + const std::unordered_map& changes +) const { + for (const auto& [obj, _] : changes) { + auto it = _history.find(obj); + if (it == _history.end()) { + continue; + } + + const Version& head = it->second.back(); + if (head.timestamp() > base) { + return Conflict{ + .object = obj, + .local_base = base, + .global_head = head.timestamp() + }; + } + } + return std::nullopt; +} + +Timestamp GlobalVersionStore::apply_changes( + Timestamp base, + const std::unordered_map& changes +) { + if (auto conflict = check_conflicts(base, changes)) { + throw std::logic_error("apply_changes called with conflicts"); + } + + Timestamp new_ts = _timestamp++; + for (const auto& [obj, value] : changes) { + _history[obj].emplace_back(new_ts, value); + } + + _timestamp = new_ts; + return new_ts; +} + +// ----------------------------- +// GlobalVersionHistory (protocol) +// ----------------------------- + +std::optional GlobalVersionHistory::push(LocalVersionStore& local) { + if (auto conflict = _global.check_conflicts( + local.base_timestamp(), + local.staged_changes())) { + return conflict; + } + + Timestamp new_base = _global.apply_changes( + local.base_timestamp(), + local.staged_changes() + ); + + local.clear_staging(); + local.advance_base(new_base); + return std::nullopt; +} + +std::optional GlobalVersionHistory::pull(LocalVersionStore& local) { + if (auto conflict = _global.check_conflicts( + local.base_timestamp(), + local.staged_changes())) { + return conflict; + } + + local.advance_base(_global.current_timestamp()); + return std::nullopt; +} + +} // namespace linear + +} // namespace gitmem \ No newline at end of file diff --git a/src/linear/versionstore.hh b/src/linear/versionstore.hh new file mode 100644 index 0000000..c0d1f7f --- /dev/null +++ b/src/linear/versionstore.hh @@ -0,0 +1,161 @@ +#pragma once + + +#include +#include +#include +#include +#include + +namespace gitmem { + +namespace linear { + +// ----------------------------- +// Timestamp +// ----------------------------- + +class LargeCounter { + uint64_t _epoch{0}; + uint64_t _counter{0}; + +public: + auto operator<=>(const LargeCounter&) const = default; + + LargeCounter& operator++() { + if (_counter == UINT64_MAX) { + _counter = 0; + assert(_epoch != UINT64_MAX && "timestamp overflow"); + ++_epoch; + } else { + ++_counter; + } + return *this; + } + + LargeCounter operator++(int) { + LargeCounter old = *this; + ++(*this); + return old; + } +}; + +using Timestamp = LargeCounter; +using Value = size_t; +using ObjectNumber = uint64_t; + +// ----------------------------- +// Version +// ----------------------------- + +class Version { + Timestamp _timestamp; + Value _value; + +public: + Version(Timestamp ts, Value value) + : _timestamp(ts), _value(value) {} + + Timestamp timestamp() const { return _timestamp; } + Value value() const { return _value; } +}; + +using VersionHistory = std::vector; + +// ----------------------------- +// Conflict +// ----------------------------- + +struct Conflict { + ObjectNumber object; + Timestamp local_base; + Timestamp global_head; +}; + +// ----------------------------- +// LocalVersionStore +// ----------------------------- + +class LocalVersionStore { + Timestamp _base_timestamp{}; + std::unordered_map _staging; + +public: + Timestamp base_timestamp() const { return _base_timestamp; } + const auto& staged_changes() const { return _staging; } + + void stage(ObjectNumber obj, Value value); + void clear_staging(); + void advance_base(Timestamp ts); +}; + +// ----------------------------- +// GlobalVersionStore +// ----------------------------- + +class GlobalVersionStore { + Timestamp _timestamp{}; + ObjectNumber _next_object{0}; + std::unordered_map _history; + +public: + Timestamp current_timestamp() const { return _timestamp; } + + ObjectNumber allocate_object(); + + std::optional check_conflicts( + Timestamp base, + const std::unordered_map& changes + ) const; + + Timestamp apply_changes( + Timestamp base, + const std::unordered_map& changes + ); +}; + +// ----------------------------- +// Synchronisation Protocol +// ----------------------------- + +class GlobalVersionHistory { + GlobalVersionStore _global; + +public: + std::optional push(LocalVersionStore& local); + std::optional pull(LocalVersionStore& local); +}; + +} // namespace linear + +namespace branching { + + /* A 'Global' is a structure to capture the current synchronising objects + * representation of a global variable. The structure is the current value, + * the current commit id for the variable, and the history of commited ids. + */ + + using Commit = size_t; + using CommitHistory = std::vector; + + struct Global + { + size_t val; + std::optional commit; + CommitHistory history; + }; + + using Globals = std::unordered_map; + + using Locals = std::unordered_map; + + + struct Conflict + { + std::string var; + std::pair commits; + }; + +} // namespace branching + +} // namespace gitmem \ No newline at end of file From dfbcb5f9ec539387591a589aeb2181d09c183f80 Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Thu, 18 Dec 2025 11:49:48 +0000 Subject: [PATCH 04/58] breaking a bunch of mixed up things into better abstraction to enable two sync protocols --- CMakeLists.txt | 5 +- src/execution_state.cc | 19 ++ src/execution_state.hh | 176 +++++++++++++++ src/gitmem.cc | 6 +- src/interpreter.cc | 212 +++++++----------- src/interpreter.hh | 176 +-------------- .../{versionstore.cc => version_store.cc} | 2 +- .../{versionstore.hh => version_store.hh} | 1 - src/sync_protocol.cc | 170 ++++++++++++++ src/sync_protocol.hh | 117 ++++++++++ 10 files changed, 570 insertions(+), 314 deletions(-) create mode 100644 src/execution_state.cc create mode 100644 src/execution_state.hh rename src/linear/{versionstore.cc => version_store.cc} (98%) rename src/linear/{versionstore.hh => version_store.hh} (99%) create mode 100644 src/sync_protocol.cc create mode 100644 src/sync_protocol.hh diff --git a/CMakeLists.txt b/CMakeLists.txt index 04a2d21..a3d0bbd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,13 +20,14 @@ add_executable(gitmem src/gitmem.cc src/reader.cc src/parser.cc + src/execution_state.cc src/passes/expressions.cc src/passes/statements.cc src/passes/check_refs.cc src/passes/branching.cc src/interpreter.cc - src/debugger.cc - src/model_checker.cc + # src/debugger.cc + # src/model_checker.cc src/graphviz.cc ) diff --git a/src/execution_state.cc b/src/execution_state.cc new file mode 100644 index 0000000..b3c1274 --- /dev/null +++ b/src/execution_state.cc @@ -0,0 +1,19 @@ +#include "execution_state.hh" +#include "sync_protocol.hh" + +namespace gitmem { + + GlobalContext::GlobalContext(const trieste::Node &ast) { + trieste::Node starting_block = ast / lang::File / lang::Block; + entry_node = std::make_shared(0); + // ThreadContext starting_ctx = {{}, {}, entry_node}; + // auto main_thread = std::make_shared(starting_ctx, starting_block); + + // this->threads = {main_thread}; + // this->locks = {}; + // this->cache = {}; + } + + + GlobalContext::~GlobalContext() = default; +} \ No newline at end of file diff --git a/src/execution_state.hh b/src/execution_state.hh new file mode 100644 index 0000000..2f4d09e --- /dev/null +++ b/src/execution_state.hh @@ -0,0 +1,176 @@ +#pragma once + +#include +#include +#include +#include + +#include "graphviz.hh" +#include "lang.hh" + +namespace gitmem { + + /* A 'Global' is a structure to capture the current synchronising objects + * representation of a global variable. The structure is the current value, + * the current commit id for the variable, and the history of commited ids. + */ + + class SyncProtocol; + + struct Commit { size_t value; }; + + struct Global + { + size_t val; + std::optional commit; + std::vector history; + }; + + using Globals = std::unordered_map; + + enum class TerminationStatus + { + completed, + datarace_exception, + unlock_exception, + assertion_failure_exception, + unassigned_variable_read_exception, + }; + + using Locals = std::unordered_map; + + struct ThreadContext + { + Locals locals; + std::shared_ptr tail; + }; + + using ThreadStatus = std::optional; + + struct Thread + { + ThreadContext ctx; + trieste::Node block; + size_t pc = 0; + ThreadStatus terminated = std::nullopt; + + bool operator==(const Thread &other) const + { + return false; + // Globals have a history that we don't care about, so we only + // compare values + // if (ctx.globals.size() != other.ctx.globals.size()) + // return false; + // for (const auto &[var, global] : ctx.globals) + // { + // if (!other.ctx.globals.contains(var) || + // ctx.globals.at(var).val != other.ctx.globals.at(var).val) + // { + // return false; + // } + // } + // return ctx.locals == other.ctx.locals && + // block == other.block && + // pc == other.pc && + // terminated == other.terminated; + } + }; + + using ThreadID = size_t; + + struct Lock + { + Globals globals; + std::optional owner = std::nullopt; + std::shared_ptr last; + }; + + template + std::shared_ptr thread_append_node(ThreadContext& ctx, Args&&...args); + + template<> + std::shared_ptr thread_append_node(ThreadContext& ctx, std::string&& stmt); + + struct GlobalContext + { + // Execution state + std::vector> threads; + std::unordered_map locks; + + // AST evaluation cache + lang::NodeMap cache; + + // Graph root + std::shared_ptr entry_node; + + // Synchronisation semantics (policy) + std::unique_ptr protocol; + + GlobalContext(const trieste::Node &ast); + ~GlobalContext(); + + bool operator==(const GlobalContext &other) const + { + return false; + // if (threads.size() != other.threads.size() || locks.size() != other.locks.size()) + // return false; + + // // Threads may have been spawned in a different order, so we + // // find the thread with the same block in the other context + // for (auto &thread : threads) + // { + // auto it = std::find_if(other.threads.begin(), other.threads.end(), + // [&thread](auto &t) + // { return t->block == thread->block; }); + // if (it == other.threads.end() || !(*thread == **it)) + // return false; + // } + + // for (auto &[name, lock] : locks) + // { + // if (!other.locks.contains(name)) + // return false; + // auto &other_lock = other.locks.at(name); + // if (lock.owner != other_lock.owner) + // return false; + // } + // return true; + } + + void print_execution_graph(const std::filesystem::path &output_path) const + { + // Loop over the threads and add pending nodes to running threads + // to indicate a threads next step + for (const auto& t: threads) + { + assert(t->ctx.tail); + if (t->terminated || dynamic_pointer_cast(t->ctx.tail->next)) + continue; + + trieste::Node block = t->block; + size_t &pc = t->pc; + trieste::Node stmt = block->at(pc); + thread_append_node(t->ctx, std::string(stmt->location().view())); + } + + graph::GraphvizPrinter gv(output_path); + gv.visit(entry_node.get()); + } + }; + + enum class ProgressStatus + { + progress, + no_progress + }; + + inline bool operator!(ProgressStatus p) { return p == ProgressStatus::no_progress; } + + inline ProgressStatus operator||(const ProgressStatus &p1, const ProgressStatus &p2) + { + return (p1 == ProgressStatus::progress || p2 == ProgressStatus::progress) ? ProgressStatus::progress : ProgressStatus::no_progress; + } + + inline void operator|=(ProgressStatus &p1, const ProgressStatus &p2) { p1 = (p1 || p2); } + +} // namespace gitmem \ No newline at end of file diff --git a/src/gitmem.cc b/src/gitmem.cc index 37fe4de..ec1efd1 100644 --- a/src/gitmem.cc +++ b/src/gitmem.cc @@ -78,11 +78,13 @@ int main(int argc, char **argv) wf::push_back(gitmem::lang::wf); if (model_check) { - exit_status = gitmem::model_check(result.ast, output_path); + assert(false && "currently broken"); + // exit_status = gitmem::model_check(result.ast, output_path); } else if (interactive) { - exit_status = gitmem::interpret_interactive(result.ast, output_path); + assert(false && "currently broken"); + // exit_status = gitmem::interpret_interactive(result.ast, output_path); } else { diff --git a/src/interpreter.cc b/src/interpreter.cc index 3fa86b2..5152a34 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -4,6 +4,7 @@ #include "interpreter.hh" #include "graphviz.hh" +#include "sync_protocol.hh" namespace gitmem { @@ -34,78 +35,12 @@ namespace gitmem { return !thread.terminated && is_syncing(thread.block->at(thread.pc)); } - /* At a commit point, walk through all the versioned variables and see if - * they have a pending commit, if so commit the value by appending to - * the variables history. - */ - void commit(Globals &globals) { - for (auto& [var, global] : globals) { - if (global.commit) - { - global.history.push_back(*global.commit); - verbose << "Committed global '" << var << "' with id " << *global.commit << std::endl; - global.commit.reset(); - } - } - } - - - /* A versioned value can be fastforwarded to another version, if one - * version's history is a prefix of another version's history. - * A conflict between two commit histories exists if neither history is a - * prefix of the other. - */ - std::optional> has_conflict(CommitHistory& h1, CommitHistory& h2) - { - size_t length = std::min(h1.size(), h2.size()); - - for (size_t i = 0; i < length; i++) - { - if (h1[i] != h2[i]) return std::pair{h1[i], h2[i]}; - } - - return std::nullopt; - } - struct Conflict { std::string var; std::pair commits; }; - /* Walk through all the global versions from source and update the versions - * in destination to be the most up-to-date version (this could come from - * either source or destination). This means destination will now also - * include variables it previously did not know about. - */ - std::optional pull(Globals &dst, Globals &src) { - for (auto& [var, global] : src) { - if (dst.contains(var)) - { - auto& src_var = src[var]; - auto& dst_var = dst[var]; - if (auto conflict = has_conflict(src_var.history, dst_var.history)) - { - auto [s1, s2] = *conflict; - verbose << "A data race on '" << var << "' was detected from commits " << s1 << " and " << s2 << std::endl; - return Conflict(var, *conflict); - } - else if (src_var.history.size() > dst_var.history.size()) - { - verbose << "Fast-forward '" << var << "' to id " << src_var.val << std::endl; - dst_var.val = src_var.val; - dst_var.history = src_var.history; - } - } - else - { - dst[var].val = src[var].val; - dst[var].history = src[var].history; - } - } - return std::nullopt; - } - template std::shared_ptr thread_append_node(ThreadContext& ctx, Args&&...args) { @@ -130,8 +65,7 @@ namespace gitmem { /* Evaluating an expression either returns the result of the expression or * a the exceptional termination status of the thread. */ - std::variant evaluate_expression(Node expr, GlobalContext &gctx, ThreadContext &ctx) - { + std::variant evaluate_expression(Node expr, GlobalContext &gctx, ThreadContext &ctx) { auto e = expr / lang::Expr; if (e == lang::Reg) { @@ -148,18 +82,10 @@ namespace gitmem { } else if (e == lang::Var) { - // It is invalid to read a previously unwritten value auto var = std::string(expr->location().view()); - if (ctx.globals.contains(var)) - { - auto& global = ctx.globals[var]; - auto commit = global.commit.value_or(global.history.back()); - auto source_node = gctx.commit_map[commit]; - thread_append_node(ctx, var, global.val, commit, source_node); - return global.val; - } - else - { + if (std::optional result = gctx.protocol->read(ctx, var) ) { + return *result; + } else { // It is invalid to read a previously unwritten value return TerminationStatus::unassigned_variable_read_exception; } } @@ -180,16 +106,19 @@ namespace gitmem { } else if (e == lang::Spawn) { - // Spawning is a sync point, commit local pending commits, and - // copy the global state to the spawned thread - commit(ctx.globals); ThreadID tid = gctx.threads.size(); auto node = std::make_shared(tid); + ThreadContext child_ctx = { Locals(), node }; + gctx.threads.push_back(std::make_shared(child_ctx, e / lang::Block)); + thread_append_node(ctx, tid, node); - ThreadContext new_ctx = { Locals(), ctx.globals, node }; - gctx.threads.push_back(std::make_shared(new_ctx, e / lang::Block)); + if (std::optional conflict = gctx.protocol->on_spawn(ctx, child_ctx, gctx)) { + assert(false); // handle this + } - thread_append_node(ctx, tid, node); + // Spawning is a sync point, commit local pending commits, and + // copy the global state to the spawned thread + // commit(ctx.globals); return tid; } @@ -264,15 +193,17 @@ namespace gitmem { } else if (lhs == lang::Var) { - // Global variable writes need to create a new commit id - // to track the history of updates - auto &global = ctx.globals[var]; - global.val = *val; - global.commit = gctx.uuid++; - verbose << "Set global '" << lhs->location().view() << "' to " << *val << " with id " << *(global.commit) << std::endl; - - auto node = thread_append_node(ctx, var, global.val, *global.commit); - gctx.commit_map[*(global.commit)] = node; + gctx.protocol->write(ctx, var, *val); + + // // Global variable writes need to create a new commit id + // // to track the history of updates + // auto &global = ctx.globals[var]; + // global.val = *val; + // global.commit = gctx.uuid++; + // verbose << "Set global '" << lhs->location().view() << "' to " << *val << " with id " << *(global.commit) << std::endl; + + // auto node = thread_append_node(ctx, var, global.val, *global.commit); + // gctx.commit_map[*(global.commit)] = node; } else { @@ -308,23 +239,27 @@ namespace gitmem { // thread will not necessarily have commited them), we then // pull the updates into the joining thread. auto result = gctx.cache[expr]; - auto& thread = gctx.threads[result]; - if (thread->terminated && (*thread->terminated == TerminationStatus::completed)) + auto& joinee = gctx.threads[result]; + if (joinee->terminated && (*joinee->terminated == TerminationStatus::completed)) { - commit(ctx.globals); - commit(thread->ctx.globals); - verbose << "Pulling from thread " << result << std::endl; - if(auto conflict = pull(ctx.globals, thread->ctx.globals)) - { - using graph::Node; - auto [s1, s2] = conflict->commits; - auto sources = std::pair, std::shared_ptr>{gctx.commit_map[s1], gctx.commit_map[s2]}; - auto graph_conflict = graph::Conflict(conflict->var, sources); - thread_append_node(ctx, result, thread->ctx.tail, graph_conflict); + if(auto conflict = gctx.protocol->on_join(ctx, joinee->ctx, gctx)) { + assert(false && "todo"); return TerminationStatus::datarace_exception; + } else { + thread_append_node(ctx, result, joinee->ctx.tail); } - - thread_append_node(ctx, result, thread->ctx.tail); + // commit(ctx.globals); + // commit(thread->ctx.globals); + // verbose << "Pulling from thread " << result << std::endl; + // if(auto conflict = pull(ctx.globals, thread->ctx.globals)) + // { + // using graph::Node; + // auto [s1, s2] = conflict->commits; + // auto sources = std::pair, std::shared_ptr>{gctx.commit_map[s1], gctx.commit_map[s2]}; + // auto graph_conflict = graph::Conflict(conflict->var, sources); + // thread_append_node(ctx, result, thread->ctx.tail, graph_conflict); + // return TerminationStatus::datarace_exception; + // } } else { @@ -347,45 +282,48 @@ namespace gitmem { } lock.owner = tid; - commit(ctx.globals); - if(auto conflict = pull(ctx.globals, lock.globals)) - { - using graph::Node; - auto [s1, s2] = conflict->commits; - auto sources = std::pair, std::shared_ptr>{gctx.commit_map[s1], gctx.commit_map[s2]}; - auto graph_conflict = graph::Conflict(conflict->var, sources); - thread_append_node(ctx, var, lock.last, graph_conflict); - return TerminationStatus::datarace_exception; - } - - thread_append_node(ctx, var, lock.last); + assert(false && "todo"); + // commit(ctx.globals); + // if(auto conflict = pull(ctx.globals, lock.globals)) + // { + // using graph::Node; + // auto [s1, s2] = conflict->commits; + // auto sources = std::pair, std::shared_ptr>{gctx.commit_map[s1], gctx.commit_map[s2]}; + // auto graph_conflict = graph::Conflict(conflict->var, sources); + // thread_append_node(ctx, var, lock.last, graph_conflict); + // return TerminationStatus::datarace_exception; + // } + + // thread_append_node(ctx, var, lock.last); verbose << "Locked " << var << std::endl; } else if (s == lang::Unlock) { - // We can only unlock locks we previously locked. We commit any - // pending updates and then copy the threads versioned globals - // to the locks versioned globals (nobody could have changed - // them since we locked the lock). - commit(ctx.globals); - auto v = s / lang::Var; - auto var = std::string(v->location().view()); + assert(false && "todo"); - auto& lock = gctx.locks[var]; - if (!lock.owner || (lock.owner && *lock.owner != tid)) - { - return TerminationStatus::unlock_exception; - } + // // We can only unlock locks we previously locked. We commit any + // // pending updates and then copy the threads versioned globals + // // to the locks versioned globals (nobody could have changed + // // them since we locked the lock). + // commit(ctx.globals); + // auto v = s / lang::Var; + // auto var = std::string(v->location().view()); + + // auto& lock = gctx.locks[var]; + // if (!lock.owner || (lock.owner && *lock.owner != tid)) + // { + // return TerminationStatus::unlock_exception; + // } - lock.globals = ctx.globals; - lock.owner.reset(); + // lock.globals = ctx.globals; + // lock.owner.reset(); - thread_append_node(ctx, var); - lock.last = ctx.tail; + // thread_append_node(ctx, var); + // lock.last = ctx.tail; - verbose << "Unlocked " << var << std::endl; + // verbose << "Unlocked " << var << std::endl; } else if (s == lang::Assert) { diff --git a/src/interpreter.hh b/src/interpreter.hh index ae59de3..ac8fb23 100644 --- a/src/interpreter.hh +++ b/src/interpreter.hh @@ -4,10 +4,11 @@ #include "lang.hh" #include "graph.hh" #include "graphviz.hh" +#include "execution_state.hh" namespace gitmem { - /* For debug printing */ + /* For debug printing */ inline struct Verbose { bool enabled = false; @@ -32,178 +33,11 @@ namespace gitmem { } } verbose; - /* A 'Global' is a structure to capture the current synchronising objects - * representation of a global variable. The structure is the current value, - * the current commit id for the variable, and the history of commited ids. - */ - - using Commit = size_t; - using CommitHistory = std::vector; - - struct Global - { - size_t val; - std::optional commit; - CommitHistory history; - }; - - using Globals = std::unordered_map; - - enum class TerminationStatus - { - completed, - datarace_exception, - unlock_exception, - assertion_failure_exception, - unassigned_variable_read_exception, - }; - - using Locals = std::unordered_map; - - struct ThreadContext - { - Locals locals; - Globals globals; - std::shared_ptr tail; - }; - - using ThreadStatus = std::optional; - - struct Thread - { - ThreadContext ctx; - trieste::Node block; - size_t pc = 0; - ThreadStatus terminated = std::nullopt; - - bool operator==(const Thread &other) const - { - // Globals have a history that we don't care about, so we only - // compare values - if (ctx.globals.size() != other.ctx.globals.size()) - return false; - for (const auto &[var, global] : ctx.globals) - { - if (!other.ctx.globals.contains(var) || - ctx.globals.at(var).val != other.ctx.globals.at(var).val) - { - return false; - } - } - return ctx.locals == other.ctx.locals && - block == other.block && - pc == other.pc && - terminated == other.terminated; - } - }; - - using ThreadID = size_t; - - struct Lock - { - Globals globals; - std::optional owner = std::nullopt; - std::shared_ptr last; - }; - - using Threads = std::vector>; - - using Locks = std::unordered_map; - - template - std::shared_ptr thread_append_node(ThreadContext& ctx, Args&&...args); - - template<> - std::shared_ptr thread_append_node(ThreadContext& ctx, std::string&& stmt); - - struct GlobalContext - { - Threads threads; - Locks locks; - lang::NodeMap cache; - std::shared_ptr entry_node; - std::unordered_map> commit_map; - Commit uuid = 0; - - GlobalContext(const trieste::Node &ast) - { - trieste::Node starting_block = ast / lang::File / lang::Block; - entry_node = std::make_shared(0); - ThreadContext starting_ctx = {{}, {}, entry_node}; - auto main_thread = std::make_shared(starting_ctx, starting_block); - - this->threads = {main_thread}; - this->locks = {}; - this->cache = {}; - } - - bool operator==(const GlobalContext &other) const - { - if (threads.size() != other.threads.size() || locks.size() != other.locks.size()) - return false; - - // Threads may have been spawned in a different order, so we - // find the thread with the same block in the other context - for (auto &thread : threads) - { - auto it = std::find_if(other.threads.begin(), other.threads.end(), - [&thread](auto &t) - { return t->block == thread->block; }); - if (it == other.threads.end() || !(*thread == **it)) - return false; - } - - for (auto &[name, lock] : locks) - { - if (!other.locks.contains(name)) - return false; - auto &other_lock = other.locks.at(name); - if (lock.owner != other_lock.owner) - return false; - } - return true; - } - - void print_execution_graph(const std::filesystem::path &output_path) const - { - // Loop over the threads and add pending nodes to running threads - // to indicate a threads next step - for (const auto& t: threads) - { - assert(t->ctx.tail); - if (t->terminated || dynamic_pointer_cast(t->ctx.tail->next)) - continue; - - trieste::Node block = t->block; - size_t &pc = t->pc; - trieste::Node stmt = block->at(pc); - thread_append_node(t->ctx, std::string(stmt->location().view())); - } - - graph::GraphvizPrinter gv(output_path); - gv.visit(entry_node.get()); - } - }; - - enum class ProgressStatus - { - progress, - no_progress - }; - - inline bool operator!(ProgressStatus p) { return p == ProgressStatus::no_progress; } - - inline ProgressStatus operator||(const ProgressStatus &p1, const ProgressStatus &p2) - { - return (p1 == ProgressStatus::progress || p2 == ProgressStatus::progress) ? ProgressStatus::progress : ProgressStatus::no_progress; - } - - inline void operator|=(ProgressStatus &p1, const ProgressStatus &p2) { p1 = (p1 || p2); } - // Entry functions int interpret(const trieste::Node, const std::filesystem::path &output_file); - int interpret_interactive(const trieste::Node, const std::filesystem::path &output_file); - int model_check(const trieste::Node, const std::filesystem::path &output_file); + + // int interpret_interactive(const trieste::Node, const std::filesystem::path &output_file); + // int model_check(const trieste::Node, const std::filesystem::path &output_file); // Internal functions int run_threads(GlobalContext &); diff --git a/src/linear/versionstore.cc b/src/linear/version_store.cc similarity index 98% rename from src/linear/versionstore.cc rename to src/linear/version_store.cc index 9ae55f6..90c6941 100644 --- a/src/linear/versionstore.cc +++ b/src/linear/version_store.cc @@ -1,4 +1,4 @@ -#include "versionstore.hh" +#include "version_store.hh" #include namespace gitmem { diff --git a/src/linear/versionstore.hh b/src/linear/version_store.hh similarity index 99% rename from src/linear/versionstore.hh rename to src/linear/version_store.hh index c0d1f7f..896c805 100644 --- a/src/linear/versionstore.hh +++ b/src/linear/version_store.hh @@ -1,6 +1,5 @@ #pragma once - #include #include #include diff --git a/src/sync_protocol.cc b/src/sync_protocol.cc new file mode 100644 index 0000000..d758e2a --- /dev/null +++ b/src/sync_protocol.cc @@ -0,0 +1,170 @@ +#include "sync_protocol.hh" + +namespace gitmem { + +// -------------------- +// LinearSyncProtocol +// -------------------- + +std::optional LinearSyncProtocol::read(ThreadContext& ctx, const std::string& var) { + return std::nullopt; +} + +void LinearSyncProtocol::write(ThreadContext& ctx, const std::string& var, size_t value) { + +} + +std::optional LinearSyncProtocol::on_spawn( + ThreadContext& parent, + ThreadContext& child, + GlobalContext& gctx +) { + // push parent to global history + // child inherits parent view + // (implementation uses your linear history abstraction) +} + +std::optional LinearSyncProtocol::on_join( + ThreadContext& joiner, + ThreadContext& joinee, + GlobalContext& gctx +) { + // push both, pull into joiner + return std::nullopt; +} + +std::optional LinearSyncProtocol::on_lock( + ThreadContext& thread, + Lock& lock, + GlobalContext& gctx +) { + // push thread, pull from global + return std::nullopt; +} + +std::optional LinearSyncProtocol::on_unlock( + ThreadContext& thread, + Lock&, + GlobalContext& gctx +) { + // push thread +} + +// -------------------- +// BranchingSyncProtocol +// -------------------- + +// /* At a commit point, walk through all the versioned variables and see if +// * they have a pending commit, if so commit the value by appending to +// * the variables history. +// */ +// void commit(Globals &globals) { +// for (auto& [var, global] : globals) { +// if (global.commit) +// { +// global.history.push_back(*global.commit); +// verbose << "Committed global '" << var << "' with id " << *global.commit << std::endl; +// global.commit.reset(); +// } +// } +// } + + +// /* A versioned value can be fastforwarded to another version, if one +// * version's history is a prefix of another version's history. +// * A conflict between two commit histories exists if neither history is a +// * prefix of the other. +// */ +// std::optional> has_conflict(CommitHistory& h1, CommitHistory& h2) +// { +// size_t length = std::min(h1.size(), h2.size()); + +// for (size_t i = 0; i < length; i++) +// { +// if (h1[i] != h2[i]) return std::pair{h1[i], h2[i]}; +// } + +// return std::nullopt; +// } + +// /* Walk through all the global versions from source and update the versions +// * in destination to be the most up-to-date version (this could come from +// * either source or destination). This means destination will now also +// * include variables it previously did not know about. +// */ +// std::optional pull(Globals &dst, Globals &src) { +// for (auto& [var, global] : src) { +// if (dst.contains(var)) +// { +// auto& src_var = src[var]; +// auto& dst_var = dst[var]; +// if (auto conflict = has_conflict(src_var.history, dst_var.history)) +// { +// auto [s1, s2] = *conflict; +// verbose << "A data race on '" << var << "' was detected from commits " << s1 << " and " << s2 << std::endl; +// return Conflict(var, *conflict); +// } +// else if (src_var.history.size() > dst_var.history.size()) +// { +// verbose << "Fast-forward '" << var << "' to id " << src_var.val << std::endl; +// dst_var.val = src_var.val; +// dst_var.history = src_var.history; +// } +// } +// else +// { +// dst[var].val = src[var].val; +// dst[var].history = src[var].history; +// } +// } +// return std::nullopt; +// } + +std::optional BranchingSyncProtocol::read(ThreadContext& ctx, const std::string& var) { + return std::nullopt; +} + +void BranchingSyncProtocol::write(ThreadContext& ctx, const std::string& var, size_t value) { + +} + +std::optional BranchingSyncProtocol::on_spawn( + ThreadContext& parent, + ThreadContext& child, + GlobalContext& +) { + // commit(parent.globals); + // child.globals = parent.globals; +} + +std::optional BranchingSyncProtocol::on_join( + ThreadContext& joiner, + ThreadContext& joinee, + GlobalContext& +) { + // commit(joiner.globals); + // commit(joinee.globals); + // return pull(joiner.globals, joinee.globals); + return std::nullopt; +} + +std::optional BranchingSyncProtocol::on_lock( + ThreadContext& thread, + Lock& lock, + GlobalContext& +) { + // commit(thread.globals); + // return pull(thread.globals, lock.globals); + return std::nullopt; +} + +std::optional BranchingSyncProtocol::on_unlock( + ThreadContext& thread, + Lock& lock, + GlobalContext& +) { + // commit(thread.globals); + // lock.globals = thread.globals; +} + +} // namespace gitmem diff --git a/src/sync_protocol.hh b/src/sync_protocol.hh new file mode 100644 index 0000000..1c75310 --- /dev/null +++ b/src/sync_protocol.hh @@ -0,0 +1,117 @@ +#pragma once + +#include +#include +#include "execution_state.hh" + +namespace gitmem { + +struct Conflict; + +class SyncProtocol { +public: + virtual ~SyncProtocol() = default; + + // Read a shared variable into the thread context + virtual std::optional read(ThreadContext& ctx, const std::string& var) = 0; + + // Write a shared variable (staged, not committed) + virtual void write(ThreadContext& ctx, const std::string& var, size_t value) = 0; + + virtual std::optional on_spawn( + ThreadContext& parent, + ThreadContext& child, + GlobalContext& gctx + ) = 0; + + virtual std::optional on_join( + ThreadContext& joiner, + ThreadContext& joinee, + GlobalContext& gctx + ) = 0; + + virtual std::optional on_lock( + ThreadContext& thread, + Lock& lock, + GlobalContext& gctx + ) = 0; + + virtual std::optional on_unlock( + ThreadContext& thread, + Lock& lock, + GlobalContext& gctx + ) = 0; +}; + +// --------------------------------- +// Concrete protocols +// --------------------------------- + +class LinearSyncProtocol final : public SyncProtocol { + // std::unordered_map> ts_nodes; + +public: + std::optional read(ThreadContext& ctx, const std::string& var) override; + + void write(ThreadContext& ctx, const std::string& var, size_t value) override; + + std::optional on_spawn( + ThreadContext& parent, + ThreadContext& child, + GlobalContext& gctx + ) override; + + std::optional on_join( + ThreadContext& joiner, + ThreadContext& joinee, + GlobalContext& gctx + ) override; + + std::optional on_lock( + ThreadContext& thread, + Lock& lock, + GlobalContext& gctx + ) override; + + std::optional on_unlock( + ThreadContext& thread, + Lock& lock, + GlobalContext& gctx + ) override; +}; + +class BranchingSyncProtocol final : public SyncProtocol { + + std::unordered_map> commit_nodes; + +public: + std::optional read(ThreadContext& ctx, const std::string& var) override; + + void write(ThreadContext& ctx, const std::string& var, size_t value) override; + + std::optional on_spawn( + ThreadContext& parent, + ThreadContext& child, + GlobalContext& gctx + ) override; + + std::optional on_join( + ThreadContext& joiner, + ThreadContext& joinee, + GlobalContext& gctx + ) override; + + std::optional on_lock( + ThreadContext& thread, + Lock& lock, + GlobalContext& gctx + ) override; + + std::optional on_unlock( + ThreadContext& thread, + Lock& lock, + GlobalContext& gctx + ) override; +}; + +} // namespace gitmem From 97e686220cc28afdc152bb5915b9d90e21b1e97a Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Fri, 19 Dec 2025 09:34:09 +0000 Subject: [PATCH 05/58] building infra for two sync protocols --- CMakeLists.txt | 2 + .../{versionstore.hh => version_store.hh} | 19 +- src/execution_state.cc | 86 ++++- src/execution_state.hh | 181 +++------ src/interpreter.cc | 352 ++++++++---------- src/interpreter.hh | 68 ++-- src/linear/version_store.cc | 86 +++-- src/linear/version_store.hh | 49 +-- src/sync_protocol.cc | 151 +++++++- src/sync_protocol.hh | 94 ++++- 10 files changed, 596 insertions(+), 492 deletions(-) rename src/branching/{versionstore.hh => version_store.hh} (50%) diff --git a/CMakeLists.txt b/CMakeLists.txt index a3d0bbd..7469df2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,9 +25,11 @@ add_executable(gitmem src/passes/statements.cc src/passes/check_refs.cc src/passes/branching.cc + src/linear/version_store.cc src/interpreter.cc # src/debugger.cc # src/model_checker.cc + src/sync_protocol.cc src/graphviz.cc ) diff --git a/src/branching/versionstore.hh b/src/branching/version_store.hh similarity index 50% rename from src/branching/versionstore.hh rename to src/branching/version_store.hh index 688955c..f59b886 100644 --- a/src/branching/versionstore.hh +++ b/src/branching/version_store.hh @@ -28,15 +28,28 @@ namespace branching { using Globals = std::unordered_map; - using Locals = std::unordered_map; - - struct Conflict { std::string var; std::pair commits; }; + struct LocalVersionStore {}; + + // Join logic + // commit(ctx.globals); + // commit(thread->ctx.globals); + // verbose << "Pulling from thread " << result << std::endl; + // if(auto conflict = pull(ctx.globals, thread->ctx.globals)) + // { + // using graph::Node; + // auto [s1, s2] = conflict->commits; + // auto sources = std::pair, std::shared_ptr>{gctx.commit_map[s1], gctx.commit_map[s2]}; + // auto graph_conflict = graph::Conflict(conflict->var, sources); + // thread_append_node(ctx, result, thread->ctx.tail, graph_conflict); + // return TerminationStatus::datarace_exception; + // } + } // namespace branching } // namespace gitmem \ No newline at end of file diff --git a/src/execution_state.cc b/src/execution_state.cc index b3c1274..1ee2f5d 100644 --- a/src/execution_state.cc +++ b/src/execution_state.cc @@ -3,17 +3,89 @@ namespace gitmem { - GlobalContext::GlobalContext(const trieste::Node &ast) { + bool Thread::operator==(const Thread &other) const + { + return false; + // Globals have a history that we don't care about, so we only + // compare values + // if (ctx.globals.size() != other.ctx.globals.size()) + // return false; + // for (const auto &[var, global] : ctx.globals) + // { + // if (!other.ctx.globals.contains(var) || + // ctx.globals.at(var).val != other.ctx.globals.at(var).val) + // { + // return false; + // } + // } + // return ctx.locals == other.ctx.locals && + // block == other.block && + // pc == other.pc && + // terminated == other.terminated; + } + + GlobalContext::GlobalContext(const trieste::Node &ast, std::unique_ptr protocol): + protocol(std::move(protocol)) + { trieste::Node starting_block = ast / lang::File / lang::Block; - entry_node = std::make_shared(0); - // ThreadContext starting_ctx = {{}, {}, entry_node}; - // auto main_thread = std::make_shared(starting_ctx, starting_block); + ThreadContext starting_ctx = { + .locals = {}, + .tail = std::make_shared(0) + }; + auto main_thread = std::make_shared(starting_ctx, starting_block); - // this->threads = {main_thread}; - // this->locks = {}; - // this->cache = {}; + this->threads = {main_thread}; + this->locks = {}; + this->cache = {}; } GlobalContext::~GlobalContext() = default; + + void GlobalContext::print_execution_graph(const std::filesystem::path &output_path) const { + // Loop over the threads and add pending nodes to running threads + // to indicate a threads next step + for (const auto& t: threads) + { + assert(t->ctx.tail); + if (t->terminated || dynamic_pointer_cast(t->ctx.tail->next)) + continue; + + trieste::Node block = t->block; + size_t &pc = t->pc; + trieste::Node stmt = block->at(pc); + thread_append_node(t->ctx, std::string(stmt->location().view())); + } + + graph::GraphvizPrinter gv(output_path); + gv.visit(entry_node.get()); + } + + bool GlobalContext::operator==(const GlobalContext &other) const { + return false; + // if (threads.size() != other.threads.size() || locks.size() != other.locks.size()) + // return false; + + // // Threads may have been spawned in a different order, so we + // // find the thread with the same block in the other context + // for (auto &thread : threads) + // { + // auto it = std::find_if(other.threads.begin(), other.threads.end(), + // [&thread](auto &t) + // { return t->block == thread->block; }); + // if (it == other.threads.end() || !(*thread == **it)) + // return false; + // } + + // for (auto &[name, lock] : locks) + // { + // if (!other.locks.contains(name)) + // return false; + // auto &other_lock = other.locks.at(name); + // if (lock.owner != other_lock.owner) + // return false; + // } + // return true; + } + } \ No newline at end of file diff --git a/src/execution_state.hh b/src/execution_state.hh index 2f4d09e..2254f5a 100644 --- a/src/execution_state.hh +++ b/src/execution_state.hh @@ -7,82 +7,47 @@ #include "graphviz.hh" #include "lang.hh" +#include "linear/version_store.hh" +#include "branching/version_store.hh" namespace gitmem { - /* A 'Global' is a structure to capture the current synchronising objects - * representation of a global variable. The structure is the current value, - * the current commit id for the variable, and the history of commited ids. - */ - class SyncProtocol; - struct Commit { size_t value; }; - - struct Global - { - size_t val; - std::optional commit; - std::vector history; + enum class TerminationStatus { + completed, + datarace_exception, + unlock_exception, + assertion_failure_exception, + unassigned_variable_read_exception, }; - using Globals = std::unordered_map; + struct ThreadContext { + std::unordered_map locals; + std::shared_ptr tail; - enum class TerminationStatus - { - completed, - datarace_exception, - unlock_exception, - assertion_failure_exception, - unassigned_variable_read_exception, - }; + struct LinearData { linear::LocalVersionStore store; }; + struct BranchingData { branching::LocalVersionStore store; }; - using Locals = std::unordered_map; - - struct ThreadContext - { - Locals locals; - std::shared_ptr tail; + std::optional linear; + std::optional branching; }; - using ThreadStatus = std::optional; - - struct Thread - { - ThreadContext ctx; - trieste::Node block; - size_t pc = 0; - ThreadStatus terminated = std::nullopt; - - bool operator==(const Thread &other) const - { - return false; - // Globals have a history that we don't care about, so we only - // compare values - // if (ctx.globals.size() != other.ctx.globals.size()) - // return false; - // for (const auto &[var, global] : ctx.globals) - // { - // if (!other.ctx.globals.contains(var) || - // ctx.globals.at(var).val != other.ctx.globals.at(var).val) - // { - // return false; - // } - // } - // return ctx.locals == other.ctx.locals && - // block == other.block && - // pc == other.pc && - // terminated == other.terminated; - } + struct Thread { + ThreadContext ctx; + trieste::Node block; + size_t pc = 0; + std::optional terminated = std::nullopt; + + bool operator==(const Thread &other) const; }; using ThreadID = size_t; - struct Lock - { - Globals globals; - std::optional owner = std::nullopt; - std::shared_ptr last; + struct Lock { + // Globals globals; + // std::optional owner = std::nullopt; + // std::shared_ptr last; }; template @@ -91,86 +56,26 @@ namespace gitmem { template<> std::shared_ptr thread_append_node(ThreadContext& ctx, std::string&& stmt); - struct GlobalContext - { - // Execution state - std::vector> threads; - std::unordered_map locks; - - // AST evaluation cache - lang::NodeMap cache; - - // Graph root - std::shared_ptr entry_node; - - // Synchronisation semantics (policy) - std::unique_ptr protocol; - - GlobalContext(const trieste::Node &ast); - ~GlobalContext(); - - bool operator==(const GlobalContext &other) const - { - return false; - // if (threads.size() != other.threads.size() || locks.size() != other.locks.size()) - // return false; - - // // Threads may have been spawned in a different order, so we - // // find the thread with the same block in the other context - // for (auto &thread : threads) - // { - // auto it = std::find_if(other.threads.begin(), other.threads.end(), - // [&thread](auto &t) - // { return t->block == thread->block; }); - // if (it == other.threads.end() || !(*thread == **it)) - // return false; - // } - - // for (auto &[name, lock] : locks) - // { - // if (!other.locks.contains(name)) - // return false; - // auto &other_lock = other.locks.at(name); - // if (lock.owner != other_lock.owner) - // return false; - // } - // return true; - } - - void print_execution_graph(const std::filesystem::path &output_path) const - { - // Loop over the threads and add pending nodes to running threads - // to indicate a threads next step - for (const auto& t: threads) - { - assert(t->ctx.tail); - if (t->terminated || dynamic_pointer_cast(t->ctx.tail->next)) - continue; - - trieste::Node block = t->block; - size_t &pc = t->pc; - trieste::Node stmt = block->at(pc); - thread_append_node(t->ctx, std::string(stmt->location().view())); - } - - graph::GraphvizPrinter gv(output_path); - gv.visit(entry_node.get()); - } - }; + struct GlobalContext { + // Execution state + std::vector> threads; + std::unordered_map locks; - enum class ProgressStatus - { - progress, - no_progress - }; + // AST evaluation cache + lang::NodeMap cache; + + // Graph root + std::shared_ptr entry_node; - inline bool operator!(ProgressStatus p) { return p == ProgressStatus::no_progress; } + // Synchronisation semantics (policy) + std::unique_ptr protocol; - inline ProgressStatus operator||(const ProgressStatus &p1, const ProgressStatus &p2) - { - return (p1 == ProgressStatus::progress || p2 == ProgressStatus::progress) ? ProgressStatus::progress : ProgressStatus::no_progress; - } + GlobalContext(const trieste::Node &ast, std::unique_ptr protocol); + ~GlobalContext(); - inline void operator|=(ProgressStatus &p1, const ProgressStatus &p2) { p1 = (p1 || p2); } + bool operator==(const GlobalContext &other) const; + + void print_execution_graph(const std::filesystem::path &output_path) const; + }; } // namespace gitmem \ No newline at end of file diff --git a/src/interpreter.cc b/src/interpreter.cc index 5152a34..56587e5 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -35,11 +35,6 @@ namespace gitmem { return !thread.terminated && is_syncing(thread.block->at(thread.pc)); } - struct Conflict - { - std::string var; - std::pair commits; - }; template std::shared_ptr thread_append_node(ThreadContext& ctx, Args&&...args) @@ -108,11 +103,11 @@ namespace gitmem { { ThreadID tid = gctx.threads.size(); auto node = std::make_shared(tid); - ThreadContext child_ctx = { Locals(), node }; + ThreadContext child_ctx = { std::unordered_map(), node }; gctx.threads.push_back(std::make_shared(child_ctx, e / lang::Block)); thread_append_node(ctx, tid, node); - if (std::optional conflict = gctx.protocol->on_spawn(ctx, child_ctx, gctx)) { + if (std::optional> conflict = gctx.protocol->on_spawn(ctx, child_ctx, gctx)) { assert(false); // handle this } @@ -146,212 +141,166 @@ namespace gitmem { * counter (0 if waiting for some other thread) or the exceptional * termination status of the thread. */ - std::variant run_statement(Node stmt, GlobalContext &gctx, ThreadContext &ctx, const ThreadID& tid) - { - auto s = stmt / lang::Stmt; - if (s == lang::Nop) - { - verbose << "Nop" << std::endl; - } - else if (s == lang::Jump) - { - auto cnst = s / lang::Const; - auto delta = std::stoi(std::string(cnst->location().view())); - assert(delta > 0); - return delta; - } - else if (s == lang::Cond) - { - auto expr = s / lang::Expr; - auto cnst = s / lang::Const; - auto result = evaluate_expression(expr, gctx, ctx); + std::variant run_statement(Node stmt, GlobalContext &gctx, ThreadContext &ctx, const ThreadID& tid) { + auto s = stmt / lang::Stmt; + if (s == lang::Nop) { - if (auto b = std::get_if(&result)) - { - auto delta = std::stoi(std::string(cnst->location().view())); - assert(delta > 0); - return *b? 1 : delta; - } - else - { - return std::get(result); - } - } - else if (s == lang::Assign) - { - auto lhs = s / lang::LVal; - auto var = std::string(lhs->location().view()); - auto rhs = s / lang::Expr; - auto val_or_term = evaluate_expression(rhs, gctx, ctx); - if(size_t* val = std::get_if(&val_or_term)) - { - if (lhs == lang::Reg) - { - // Local variables can be re-assigned whenever - verbose << "Set register '" << lhs->location().view() << "' to " << *val << std::endl; - ctx.locals[var] = *val; - } - else if (lhs == lang::Var) - { - gctx.protocol->write(ctx, var, *val); + verbose << "Nop" << std::endl; - // // Global variable writes need to create a new commit id - // // to track the history of updates - // auto &global = ctx.globals[var]; - // global.val = *val; - // global.commit = gctx.uuid++; - // verbose << "Set global '" << lhs->location().view() << "' to " << *val << " with id " << *(global.commit) << std::endl; + } else if (s == lang::Jump) { - // auto node = thread_append_node(ctx, var, global.val, *global.commit); - // gctx.commit_map[*(global.commit)] = node; - } - else - { - throw std::runtime_error("Bad left-hand side: " + std::string(lhs->type().str())); - } - } - else - { - return std::get(val_or_term); - } - } - else if (s == lang::Join) - { - // A join must waiting for the terminating thread to continue, - // we don't want to re-evaluate the expression repeatedly as this - // may be effecting so store the result in the cache. - auto expr = s / lang::Expr; + auto cnst = s / lang::Const; + auto delta = std::stoi(std::string(cnst->location().view())); + assert(delta > 0); + return delta; - if (!gctx.cache.contains(expr)) - { - auto val_or_term = evaluate_expression(expr, gctx, ctx); - if (size_t* val = std::get_if(&val_or_term)) - { - gctx.cache[expr] = *val; - } - else - { - return std::get(val_or_term); - } - } + } else if (s == lang::Cond) { - // when joining, we commit the updates of both threads (the joined - // thread will not necessarily have commited them), we then - // pull the updates into the joining thread. - auto result = gctx.cache[expr]; - auto& joinee = gctx.threads[result]; - if (joinee->terminated && (*joinee->terminated == TerminationStatus::completed)) - { - if(auto conflict = gctx.protocol->on_join(ctx, joinee->ctx, gctx)) { - assert(false && "todo"); - return TerminationStatus::datarace_exception; - } else { - thread_append_node(ctx, result, joinee->ctx.tail); - } - // commit(ctx.globals); - // commit(thread->ctx.globals); - // verbose << "Pulling from thread " << result << std::endl; - // if(auto conflict = pull(ctx.globals, thread->ctx.globals)) - // { - // using graph::Node; - // auto [s1, s2] = conflict->commits; - // auto sources = std::pair, std::shared_ptr>{gctx.commit_map[s1], gctx.commit_map[s2]}; - // auto graph_conflict = graph::Conflict(conflict->var, sources); - // thread_append_node(ctx, result, thread->ctx.tail, graph_conflict); - // return TerminationStatus::datarace_exception; - // } - } - else - { - verbose << "Waiting on thread " << result << std::endl; - return 0; - } - } - else if (s == lang::Lock) - { - // We can only lock unlocked locks, if a lock hasn't been used - // before it is implicitly created, we then commit the pending - // updates of this thread and pull the updates from the lock. - auto v = s / lang::Var; - auto var = std::string(v->location().view()); - - auto& lock = gctx.locks[var]; - if (lock.owner) { - verbose << "Waiting for lock " << var << " owned by " << lock.owner.value() << std::endl; - return 0; - } + auto expr = s / lang::Expr; + auto cnst = s / lang::Const; + auto result = evaluate_expression(expr, gctx, ctx); - lock.owner = tid; - assert(false && "todo"); - // commit(ctx.globals); - // if(auto conflict = pull(ctx.globals, lock.globals)) - // { - // using graph::Node; - // auto [s1, s2] = conflict->commits; - // auto sources = std::pair, std::shared_ptr>{gctx.commit_map[s1], gctx.commit_map[s2]}; - // auto graph_conflict = graph::Conflict(conflict->var, sources); - // thread_append_node(ctx, var, lock.last, graph_conflict); - // return TerminationStatus::datarace_exception; - // } + if (auto b = std::get_if(&result)) { + auto delta = std::stoi(std::string(cnst->location().view())); + assert(delta > 0); + return *b? 1 : delta; + } else { + return std::get(result); + } - // thread_append_node(ctx, var, lock.last); + } else if (s == lang::Assign) { - verbose << "Locked " << var << std::endl; + auto lhs = s / lang::LVal; + auto var = std::string(lhs->location().view()); + auto rhs = s / lang::Expr; + auto val_or_term = evaluate_expression(rhs, gctx, ctx); - } - else if (s == lang::Unlock) - { - assert(false && "todo"); + if(size_t* val = std::get_if(&val_or_term)) { + if (lhs == lang::Reg) { - // // We can only unlock locks we previously locked. We commit any - // // pending updates and then copy the threads versioned globals - // // to the locks versioned globals (nobody could have changed - // // them since we locked the lock). - // commit(ctx.globals); - // auto v = s / lang::Var; - // auto var = std::string(v->location().view()); + // Local variables can be re-assigned whenever + verbose << "Set register '" << lhs->location().view() << "' to " << *val << std::endl; + ctx.locals[var] = *val; - // auto& lock = gctx.locks[var]; - // if (!lock.owner || (lock.owner && *lock.owner != tid)) - // { - // return TerminationStatus::unlock_exception; - // } + } else if (lhs == lang::Var) { - // lock.globals = ctx.globals; - // lock.owner.reset(); + gctx.protocol->write(ctx, var, *val); - // thread_append_node(ctx, var); - // lock.last = ctx.tail; + // // Global variable writes need to create a new commit id + // // to track the history of updates + // auto &global = ctx.globals[var]; + // global.val = *val; + // global.commit = gctx.uuid++; + // verbose << "Set global '" << lhs->location().view() << "' to " << *val << " with id " << *(global.commit) << std::endl; - // verbose << "Unlocked " << var << std::endl; + // auto node = thread_append_node(ctx, var, global.val, *global.commit); + // gctx.commit_map[*(global.commit)] = node; + } else { + throw std::runtime_error("Bad left-hand side: " + std::string(lhs->type().str())); + } + } else { + return std::get(val_or_term); } - else if (s == lang::Assert) - { - auto expr = s / lang::Expr; - auto result_or_term = evaluate_expression(expr, gctx, ctx); - if (size_t* result = std::get_if(&result_or_term)) - { - if (*result) - { - verbose << "Assertion passed: " << expr->location().view() << std::endl; - } - else - { - verbose << "Assertion failed: " << expr->location().view() << std::endl; - thread_append_node(ctx, std::string(expr->location().view())); - return TerminationStatus::assertion_failure_exception; - } - } - else - { - return std::get(result_or_term); - } + } else if (s == lang::Join) { + // A join must waiting for the terminating thread to continue, + // we don't want to re-evaluate the expression repeatedly as this + // may be effecting so store the result in the cache. + auto expr = s / lang::Expr; + + if (!gctx.cache.contains(expr)) { + auto val_or_term = evaluate_expression(expr, gctx, ctx); + if (size_t* val = std::get_if(&val_or_term)) { + gctx.cache[expr] = *val; + } else { + return std::get(val_or_term); + } } - else - { - throw std::runtime_error("Unknown statement: " + std::string(stmt->type().str())); + + auto result = gctx.cache[expr]; + auto& joinee = gctx.threads[result]; + if (joinee->terminated && (*joinee->terminated == TerminationStatus::completed)) { + if(auto conflict = gctx.protocol->on_join(ctx, joinee->ctx, gctx)) { + return TerminationStatus::datarace_exception; + } else { + thread_append_node(ctx, result, joinee->ctx.tail); + } + + } else { + verbose << "Waiting on thread " << result << std::endl; + return 0; } - return 1; + } else if (s == lang::Lock) { + assert(false && "todo"); + // We can only lock unlocked locks, if a lock hasn't been used + // before it is implicitly created, we then commit the pending + // updates of this thread and pull the updates from the lock. + // auto v = s / lang::Var; + // auto var = std::string(v->location().view()); + + // auto& lock = gctx.locks[var]; + // if (lock.owner) { + // verbose << "Waiting for lock " << var << " owned by " << lock.owner.value() << std::endl; + // return 0; + // } + + // lock.owner = tid; + // commit(ctx.globals); + // if(auto conflict = pull(ctx.globals, lock.globals)) + // { + // using graph::Node; + // auto [s1, s2] = conflict->commits; + // auto sources = std::pair, std::shared_ptr>{gctx.commit_map[s1], gctx.commit_map[s2]}; + // auto graph_conflict = graph::Conflict(conflict->var, sources); + // thread_append_node(ctx, var, lock.last, graph_conflict); + // return TerminationStatus::datarace_exception; + // } + + // thread_append_node(ctx, var, lock.last); + + // verbose << "Locked " << var << std::endl; + } else if (s == lang::Unlock) { + assert(false && "todo"); + + // // We can only unlock locks we previously locked. We commit any + // // pending updates and then copy the threads versioned globals + // // to the locks versioned globals (nobody could have changed + // // them since we locked the lock). + // commit(ctx.globals); + // auto v = s / lang::Var; + // auto var = std::string(v->location().view()); + + // auto& lock = gctx.locks[var]; + // if (!lock.owner || (lock.owner && *lock.owner != tid)) + // { + // return TerminationStatus::unlock_exception; + // } + + // lock.globals = ctx.globals; + // lock.owner.reset(); + + // thread_append_node(ctx, var); + // lock.last = ctx.tail; + + // verbose << "Unlocked " << var << std::endl; + + } else if (s == lang::Assert) { + auto expr = s / lang::Expr; + auto result_or_term = evaluate_expression(expr, gctx, ctx); + if (size_t* result = std::get_if(&result_or_term)) { + if (*result) { + verbose << "Assertion passed: " << expr->location().view() << std::endl; + } else { + verbose << "Assertion failed: " << expr->location().view() << std::endl; + thread_append_node(ctx, std::string(expr->location().view())); + return TerminationStatus::assertion_failure_exception; + } + } else { + return std::get(result_or_term); + } + } else { + throw std::runtime_error("Unknown statement: " + std::string(stmt->type().str())); + } + return 1; } /* Run a particular thread until it reaches a synchronisation point or until @@ -367,6 +316,10 @@ namespace gitmem { size_t &pc = thread->pc; ThreadContext &ctx = thread->ctx; + if (pc == 0) { + gctx.protocol->on_start(thread->ctx, gctx); + } + bool first_statement = true; while(pc < block->size()) { @@ -397,6 +350,8 @@ namespace gitmem { } thread->terminated = TerminationStatus::completed; + gctx.protocol->on_end(thread->ctx, gctx); + thread_append_node(ctx); return TerminationStatus::completed; } @@ -541,9 +496,10 @@ namespace gitmem { int interpret(const Node ast, const std::filesystem::path &output_path) { - GlobalContext gctx(ast); + // TODO: allow both protocols + GlobalContext gctx(ast, std::make_unique()); auto result = run_threads(gctx); - gctx.print_execution_graph(output_path); + // gctx.print_execution_graph(output_path); FIXME return result; } diff --git a/src/interpreter.hh b/src/interpreter.hh index ac8fb23..5665d06 100644 --- a/src/interpreter.hh +++ b/src/interpreter.hh @@ -9,40 +9,38 @@ namespace gitmem { /* For debug printing */ - inline struct Verbose - { - bool enabled = false; - - template - const Verbose &operator<<(const T &msg) const - { - if (enabled) - { - std::cout << msg; - } - return *this; - } - - const Verbose &operator<<(std::ostream &(*manip)(std::ostream &)) const - { - if (enabled) - { - std::cout << manip; - } - return *this; - } - } verbose; - - // Entry functions - int interpret(const trieste::Node, const std::filesystem::path &output_file); - - // int interpret_interactive(const trieste::Node, const std::filesystem::path &output_file); - // int model_check(const trieste::Node, const std::filesystem::path &output_file); - - // Internal functions - int run_threads(GlobalContext &); - - std::variant - progress_thread(GlobalContext &, const ThreadID, std::shared_ptr); + inline struct Verbose { + bool enabled = false; + + template + const Verbose &operator<<(const T &msg) const { + if (enabled) std::cout << msg; + return *this; + } + + const Verbose &operator<<(std::ostream &(*manip)(std::ostream &)) const { + if (enabled) std::cout << manip; + return *this; + } + } verbose; + + // Entry functions + int interpret(const trieste::Node, const std::filesystem::path &output_file); + + // int interpret_interactive(const trieste::Node, const std::filesystem::path &output_file); + // int model_check(const trieste::Node, const std::filesystem::path &output_file); + + // Internal functions + int run_threads(GlobalContext &); + + enum class ProgressStatus { progress, no_progress }; + inline bool operator!(ProgressStatus p) { return p == ProgressStatus::no_progress; } + inline ProgressStatus operator||(const ProgressStatus &p1, const ProgressStatus &p2) { + return (p1 == ProgressStatus::progress || p2 == ProgressStatus::progress) ? ProgressStatus::progress : ProgressStatus::no_progress; + } + inline void operator|=(ProgressStatus &p1, const ProgressStatus &p2) { p1 = (p1 || p2); } + + std::variant + progress_thread(GlobalContext &, const ThreadID, std::shared_ptr); } // namespace gitmem \ No newline at end of file diff --git a/src/linear/version_store.cc b/src/linear/version_store.cc index 90c6941..6d3e110 100644 --- a/src/linear/version_store.cc +++ b/src/linear/version_store.cc @@ -1,5 +1,8 @@ -#include "version_store.hh" #include +#include + +#include "version_store.hh" +#include "sync_protocol.hh" namespace gitmem { @@ -21,12 +24,48 @@ void LocalVersionStore::advance_base(Timestamp ts) { _base_timestamp = ts; } +std::optional LocalVersionStore::get_staged(ObjectNumber obj) { + auto it = _staging.find(obj); + return it != _staging.end() ? std::make_optional(it->second) : std::nullopt; +} + // ----------------------------- // GlobalVersionStore // ----------------------------- -ObjectNumber GlobalVersionStore::allocate_object() { - return _next_object++; +ObjectNumber GlobalVersionStore::get_object_number(std::string var) { + auto it = _object_numbers.find(var); + if (it != _object_numbers.end()) { + return it->second; + } else { + ObjectNumber number = _next_object++; + _object_numbers[var] = number; + return number; + } +} + +std::string GlobalVersionStore::get_object_name(ObjectNumber find) { + for (const auto& [name, number] : _object_numbers) { + if (number == find) + return name; + } + assert(false && "failed to find object name for object number"); + return ""; +} + +std::optional GlobalVersionStore::get_version_for_timestamp(ObjectNumber obj, Timestamp ts) const { + const auto it = _history.find(obj); + + if (it == _history.end()) + return std::nullopt; + + const VersionHistory& history = it->second; + for (VersionHistory::const_reverse_iterator riter = history.rbegin(); riter != history.rend(); ++riter) { + if (riter->timestamp() <= ts) + return riter->value(); + } + + return std::nullopt; } std::optional GlobalVersionStore::check_conflicts( @@ -39,12 +78,13 @@ std::optional GlobalVersionStore::check_conflicts( continue; } - const Version& head = it->second.back(); - if (head.timestamp() > base) { + const Version& latest = it->second.back(); + if (latest.timestamp() > base) { + std::cout << "conflict" << std::endl; return Conflict{ .object = obj, .local_base = base, - .global_head = head.timestamp() + .global_head = latest.timestamp() }; } } @@ -59,7 +99,7 @@ Timestamp GlobalVersionStore::apply_changes( throw std::logic_error("apply_changes called with conflicts"); } - Timestamp new_ts = _timestamp++; + Timestamp new_ts = ++_timestamp; for (const auto& [obj, value] : changes) { _history[obj].emplace_back(new_ts, value); } @@ -68,38 +108,6 @@ Timestamp GlobalVersionStore::apply_changes( return new_ts; } -// ----------------------------- -// GlobalVersionHistory (protocol) -// ----------------------------- - -std::optional GlobalVersionHistory::push(LocalVersionStore& local) { - if (auto conflict = _global.check_conflicts( - local.base_timestamp(), - local.staged_changes())) { - return conflict; - } - - Timestamp new_base = _global.apply_changes( - local.base_timestamp(), - local.staged_changes() - ); - - local.clear_staging(); - local.advance_base(new_base); - return std::nullopt; -} - -std::optional GlobalVersionHistory::pull(LocalVersionStore& local) { - if (auto conflict = _global.check_conflicts( - local.base_timestamp(), - local.staged_changes())) { - return conflict; - } - - local.advance_base(_global.current_timestamp()); - return std::nullopt; -} - } // namespace linear } // namespace gitmem \ No newline at end of file diff --git a/src/linear/version_store.hh b/src/linear/version_store.hh index 896c805..9adcc2a 100644 --- a/src/linear/version_store.hh +++ b/src/linear/version_store.hh @@ -86,6 +86,7 @@ public: void stage(ObjectNumber obj, Value value); void clear_staging(); void advance_base(Timestamp ts); + std::optional get_staged(ObjectNumber obj); }; // ----------------------------- @@ -96,11 +97,15 @@ class GlobalVersionStore { Timestamp _timestamp{}; ObjectNumber _next_object{0}; std::unordered_map _history; + std::unordered_map _object_numbers; public: Timestamp current_timestamp() const { return _timestamp; } - ObjectNumber allocate_object(); + ObjectNumber get_object_number(std::string); + std::string get_object_name(ObjectNumber); + + std::optional get_version_for_timestamp(ObjectNumber, Timestamp) const; std::optional check_conflicts( Timestamp base, @@ -113,48 +118,6 @@ public: ); }; -// ----------------------------- -// Synchronisation Protocol -// ----------------------------- - -class GlobalVersionHistory { - GlobalVersionStore _global; - -public: - std::optional push(LocalVersionStore& local); - std::optional pull(LocalVersionStore& local); -}; - } // namespace linear -namespace branching { - - /* A 'Global' is a structure to capture the current synchronising objects - * representation of a global variable. The structure is the current value, - * the current commit id for the variable, and the history of commited ids. - */ - - using Commit = size_t; - using CommitHistory = std::vector; - - struct Global - { - size_t val; - std::optional commit; - CommitHistory history; - }; - - using Globals = std::unordered_map; - - using Locals = std::unordered_map; - - - struct Conflict - { - std::string var; - std::pair commits; - }; - -} // namespace branching - } // namespace gitmem \ No newline at end of file diff --git a/src/sync_protocol.cc b/src/sync_protocol.cc index d758e2a..a1cf41c 100644 --- a/src/sync_protocol.cc +++ b/src/sync_protocol.cc @@ -6,54 +6,154 @@ namespace gitmem { // LinearSyncProtocol // -------------------- -std::optional LinearSyncProtocol::read(ThreadContext& ctx, const std::string& var) { +std::optional LinearSyncProtocol::push(linear::LocalVersionStore& local) { + if (auto conflict = _global_store.check_conflicts( + local.base_timestamp(), + local.staged_changes())) { + + // reshape the conflict + return std::make_optional( + _global_store.get_object_name(conflict->object), + std::make_pair(conflict->local_base, conflict->global_head) + ); + + } + + linear::Timestamp new_base = _global_store.apply_changes( + local.base_timestamp(), + local.staged_changes() + ); + + local.clear_staging(); + local.advance_base(new_base); return std::nullopt; } -void LinearSyncProtocol::write(ThreadContext& ctx, const std::string& var, size_t value) { +std::optional LinearSyncProtocol::pull(linear::LocalVersionStore& local) { + if (auto conflict = _global_store.check_conflicts( + local.base_timestamp(), + local.staged_changes())) { + return std::make_optional( + _global_store.get_object_name(conflict->object), + std::make_pair(conflict->local_base, conflict->global_head) + ); + + } + + local.advance_base(_global_store.current_timestamp()); + return std::nullopt; } -std::optional LinearSyncProtocol::on_spawn( +LinearSyncProtocol::~LinearSyncProtocol() = default; + +std::optional LinearSyncProtocol::read(ThreadContext& ctx, const std::string& var) { + linear::ObjectNumber number = _global_store.get_object_number(var); + + if (auto result = store(ctx).get_staged(number)) + return result; + + std::optional value = _global_store.get_version_for_timestamp(number, store(ctx).base_timestamp()); + if (!value) + return std::nullopt; + + // we do not need to record the staged value for correctness + // TODO: there is something about working out if a value has changed vs been written + + return *value; +} + +void LinearSyncProtocol::write(ThreadContext& ctx, const std::string& var, size_t value) { + // write into the staging area of the thread + store(ctx).stage(_global_store.get_object_number(var), value); +} + +std::optional> LinearSyncProtocol::on_spawn( ThreadContext& parent, ThreadContext& child, GlobalContext& gctx ) { + // TODO: i think we can drop the globalcontext but check after branching is added + // push parent to global history + if (auto conflict = push(store(parent))) + return std::make_unique(std::move(*conflict)); + + // TODO: we should probably separate start and spawn // child inherits parent view - // (implementation uses your linear history abstraction) + if (auto conflict = pull(store(child))) + return std::make_unique(std::move(*conflict)); + + return std::nullopt; } -std::optional LinearSyncProtocol::on_join( +std::optional> LinearSyncProtocol::on_join( ThreadContext& joiner, ThreadContext& joinee, GlobalContext& gctx ) { - // push both, pull into joiner + std::cout << "on_join" << std::endl; + // we assume the joinee has already terminated and pushed + + // pull changes into parent + if (auto conflict = pull(store(joiner))) + return std::make_unique(std::move(*conflict)); + return std::nullopt; } -std::optional LinearSyncProtocol::on_lock( +std::optional> LinearSyncProtocol::on_start( + ThreadContext& thread, + GlobalContext& gctx +) { + std::cout << "on_start" << std::endl; + + // pull state from global history + auto conflict = pull(store(thread)); + assert(!conflict && "cannot conflict from starting state"); + + return std::nullopt; +}; + +std::optional> LinearSyncProtocol::on_end( + ThreadContext& thread, + GlobalContext& gctx + ) { + std::cout << "on_end" << std::endl; + + // push changes to global history + if (auto conflict = push(store(thread))) + return std::make_unique(std::move(*conflict)); + + return std::nullopt; +}; + +std::optional> LinearSyncProtocol::on_lock( ThreadContext& thread, Lock& lock, GlobalContext& gctx ) { + assert(false && "todo lock"); // push thread, pull from global return std::nullopt; } -std::optional LinearSyncProtocol::on_unlock( +std::optional> LinearSyncProtocol::on_unlock( ThreadContext& thread, Lock&, GlobalContext& gctx ) { + assert(false && "todo unlock"); // push thread + return std::nullopt; } // -------------------- // BranchingSyncProtocol // -------------------- +BranchingSyncProtocol::~BranchingSyncProtocol() = default; + // /* At a commit point, walk through all the versioned variables and see if // * they have a pending commit, if so commit the value by appending to // * the variables history. @@ -92,7 +192,7 @@ std::optional LinearSyncProtocol::on_unlock( // * either source or destination). This means destination will now also // * include variables it previously did not know about. // */ -// std::optional pull(Globals &dst, Globals &src) { +// std::optional> pull(Globals &dst, Globals &src) { // for (auto& [var, global] : src) { // if (dst.contains(var)) // { @@ -121,50 +221,73 @@ std::optional LinearSyncProtocol::on_unlock( // } std::optional BranchingSyncProtocol::read(ThreadContext& ctx, const std::string& var) { + assert(false && "Todo read"); return std::nullopt; } void BranchingSyncProtocol::write(ThreadContext& ctx, const std::string& var, size_t value) { - + assert(false && "Todo write"); } -std::optional BranchingSyncProtocol::on_spawn( +std::optional> BranchingSyncProtocol::on_spawn( ThreadContext& parent, ThreadContext& child, GlobalContext& ) { + assert(false && "Todo on_spawn"); // commit(parent.globals); // child.globals = parent.globals; + return std::nullopt; } -std::optional BranchingSyncProtocol::on_join( +std::optional> BranchingSyncProtocol::on_join( ThreadContext& joiner, ThreadContext& joinee, GlobalContext& ) { + assert(false && "Todo on_join"); // commit(joiner.globals); // commit(joinee.globals); // return pull(joiner.globals, joinee.globals); return std::nullopt; } -std::optional BranchingSyncProtocol::on_lock( +std::optional> BranchingSyncProtocol::on_start( + ThreadContext& thread, + GlobalContext& gctx +) { + assert(false && "Todo on_start"); + return std::nullopt; +}; + +std::optional> BranchingSyncProtocol::on_end( + ThreadContext& thread, + GlobalContext& gctx + ) { + assert(false && "Todo on_end"); + return std::nullopt; +}; + +std::optional> BranchingSyncProtocol::on_lock( ThreadContext& thread, Lock& lock, GlobalContext& ) { + assert(false && "Todo on_lock"); // commit(thread.globals); // return pull(thread.globals, lock.globals); return std::nullopt; } -std::optional BranchingSyncProtocol::on_unlock( +std::optional> BranchingSyncProtocol::on_unlock( ThreadContext& thread, Lock& lock, GlobalContext& ) { + assert(false && "Todo on_unlock"); // commit(thread.globals); // lock.globals = thread.globals; + return std::nullopt; } } // namespace gitmem diff --git a/src/sync_protocol.hh b/src/sync_protocol.hh index 1c75310..a3449c9 100644 --- a/src/sync_protocol.hh +++ b/src/sync_protocol.hh @@ -3,10 +3,31 @@ #include #include #include "execution_state.hh" +#include "linear/version_store.hh" +#include "branching/version_store.hh" + +/* i want an on_start and on_end event i think too */ namespace gitmem { -struct Conflict; +struct ConflictBase { + virtual ~ConflictBase() = default; + virtual std::string name() const = 0; +}; + +template +struct Conflict : ConflictBase { + std::string var; + std::pair versions; + + std::string name() const override { return var; } + + Conflict(std::string var, std::pair versions): + var(std::move(var)), versions(std::move(versions)) {} +}; + +using LinearConflict = Conflict; +using BranchingConflict = Conflict; class SyncProtocol { public: @@ -18,25 +39,35 @@ public: // Write a shared variable (staged, not committed) virtual void write(ThreadContext& ctx, const std::string& var, size_t value) = 0; - virtual std::optional on_spawn( + virtual std::optional> on_spawn( ThreadContext& parent, ThreadContext& child, GlobalContext& gctx ) = 0; - virtual std::optional on_join( + virtual std::optional> on_join( ThreadContext& joiner, ThreadContext& joinee, GlobalContext& gctx ) = 0; - virtual std::optional on_lock( + virtual std::optional> on_start( + ThreadContext& thread, + GlobalContext& gctx + ) = 0; + + virtual std::optional> on_end( + ThreadContext& thread, + GlobalContext& gctx + ) = 0; + + virtual std::optional> on_lock( ThreadContext& thread, Lock& lock, GlobalContext& gctx ) = 0; - virtual std::optional on_unlock( + virtual std::optional> on_unlock( ThreadContext& thread, Lock& lock, GlobalContext& gctx @@ -48,32 +79,53 @@ public: // --------------------------------- class LinearSyncProtocol final : public SyncProtocol { - // std::unordered_map> ts_nodes; + linear::GlobalVersionStore _global_store; + + static linear::LocalVersionStore& store(ThreadContext& ctx) { + if (!ctx.linear) ctx.linear.emplace(); + return ctx.linear->store; + } + + std::optional push(linear::LocalVersionStore& local); + std::optional pull(linear::LocalVersionStore& local); public: + ~LinearSyncProtocol() override; + + std::optional read(ThreadContext& ctx, const std::string& var) override; void write(ThreadContext& ctx, const std::string& var, size_t value) override; - std::optional on_spawn( + std::optional> on_spawn( ThreadContext& parent, ThreadContext& child, GlobalContext& gctx ) override; - std::optional on_join( + std::optional> on_join( ThreadContext& joiner, ThreadContext& joinee, GlobalContext& gctx ) override; - std::optional on_lock( + std::optional> on_start( + ThreadContext& thread, + GlobalContext& gctx + ) override; + + std::optional> on_end( + ThreadContext& thread, + GlobalContext& gctx + ) override; + + std::optional> on_lock( ThreadContext& thread, Lock& lock, GlobalContext& gctx ) override; - std::optional on_unlock( + std::optional> on_unlock( ThreadContext& thread, Lock& lock, GlobalContext& gctx @@ -82,32 +134,44 @@ public: class BranchingSyncProtocol final : public SyncProtocol { - std::unordered_map> commit_nodes; + // std::unordered_map> commit_nodes; public: + ~BranchingSyncProtocol() override; + std::optional read(ThreadContext& ctx, const std::string& var) override; void write(ThreadContext& ctx, const std::string& var, size_t value) override; - std::optional on_spawn( + std::optional> on_spawn( ThreadContext& parent, ThreadContext& child, GlobalContext& gctx ) override; - std::optional on_join( + std::optional> on_join( ThreadContext& joiner, ThreadContext& joinee, GlobalContext& gctx ) override; - std::optional on_lock( + std::optional> on_start( + ThreadContext& thread, + GlobalContext& gctx + ) override; + + std::optional> on_end( + ThreadContext& thread, + GlobalContext& gctx + ) override; + + std::optional> on_lock( ThreadContext& thread, Lock& lock, GlobalContext& gctx ) override; - std::optional on_unlock( + std::optional> on_unlock( ThreadContext& thread, Lock& lock, GlobalContext& gctx From 36fc7faec7d50195d6a7a00f1d5a28c57891b28e Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Fri, 19 Dec 2025 10:18:12 +0000 Subject: [PATCH 06/58] simple spawn join works with linear --- src/gitmem.cc | 1 + src/interpreter.cc | 2 ++ src/interpreter.hh | 16 ---------------- src/linear/version_store.cc | 1 - src/linear/version_store.hh | 5 +++++ src/sync_protocol.cc | 20 ++++++++++++-------- src/sync_protocol.hh | 9 ++++++--- 7 files changed, 26 insertions(+), 28 deletions(-) diff --git a/src/gitmem.cc b/src/gitmem.cc index ec1efd1..012f125 100644 --- a/src/gitmem.cc +++ b/src/gitmem.cc @@ -2,6 +2,7 @@ #include "lang.hh" #include "interpreter.hh" +#include "debug.hh" int main(int argc, char **argv) { diff --git a/src/interpreter.cc b/src/interpreter.cc index 56587e5..bb66718 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -5,6 +5,7 @@ #include "interpreter.hh" #include "graphviz.hh" #include "sync_protocol.hh" +#include "debug.hh" namespace gitmem { @@ -220,6 +221,7 @@ namespace gitmem { auto& joinee = gctx.threads[result]; if (joinee->terminated && (*joinee->terminated == TerminationStatus::completed)) { if(auto conflict = gctx.protocol->on_join(ctx, joinee->ctx, gctx)) { + verbose << (**conflict) << std::endl; return TerminationStatus::datarace_exception; } else { thread_append_node(ctx, result, joinee->ctx.tail); diff --git a/src/interpreter.hh b/src/interpreter.hh index 5665d06..be7a063 100644 --- a/src/interpreter.hh +++ b/src/interpreter.hh @@ -8,22 +8,6 @@ namespace gitmem { - /* For debug printing */ - inline struct Verbose { - bool enabled = false; - - template - const Verbose &operator<<(const T &msg) const { - if (enabled) std::cout << msg; - return *this; - } - - const Verbose &operator<<(std::ostream &(*manip)(std::ostream &)) const { - if (enabled) std::cout << manip; - return *this; - } - } verbose; - // Entry functions int interpret(const trieste::Node, const std::filesystem::path &output_file); diff --git a/src/linear/version_store.cc b/src/linear/version_store.cc index 6d3e110..227fe14 100644 --- a/src/linear/version_store.cc +++ b/src/linear/version_store.cc @@ -80,7 +80,6 @@ std::optional GlobalVersionStore::check_conflicts( const Version& latest = it->second.back(); if (latest.timestamp() > base) { - std::cout << "conflict" << std::endl; return Conflict{ .object = obj, .local_base = base, diff --git a/src/linear/version_store.hh b/src/linear/version_store.hh index 9adcc2a..a0018bf 100644 --- a/src/linear/version_store.hh +++ b/src/linear/version_store.hh @@ -37,6 +37,11 @@ public: ++(*this); return old; } + + friend std::ostream& operator<<(std::ostream& os, const LargeCounter& counter) { + os << counter._epoch << ":" << counter._counter; + return os; + } }; using Timestamp = LargeCounter; diff --git a/src/sync_protocol.cc b/src/sync_protocol.cc index a1cf41c..f06aa59 100644 --- a/src/sync_protocol.cc +++ b/src/sync_protocol.cc @@ -1,7 +1,15 @@ +#include #include "sync_protocol.hh" +#include "debug.hh" namespace gitmem { +template +std::ostream& Conflict::print(std::ostream& os) const { + os << "conflict on " << var << " { " << versions.first << ", " << versions.second << " }"; + return os; +} + // -------------------- // LinearSyncProtocol // -------------------- @@ -74,16 +82,12 @@ std::optional> LinearSyncProtocol::on_spawn( GlobalContext& gctx ) { // TODO: i think we can drop the globalcontext but check after branching is added + verbose << "on_spawn" << std::endl; // push parent to global history if (auto conflict = push(store(parent))) return std::make_unique(std::move(*conflict)); - // TODO: we should probably separate start and spawn - // child inherits parent view - if (auto conflict = pull(store(child))) - return std::make_unique(std::move(*conflict)); - return std::nullopt; } @@ -92,7 +96,7 @@ std::optional> LinearSyncProtocol::on_join( ThreadContext& joinee, GlobalContext& gctx ) { - std::cout << "on_join" << std::endl; + verbose << "on_join" << std::endl; // we assume the joinee has already terminated and pushed // pull changes into parent @@ -106,7 +110,7 @@ std::optional> LinearSyncProtocol::on_start( ThreadContext& thread, GlobalContext& gctx ) { - std::cout << "on_start" << std::endl; + verbose << "on_start" << std::endl; // pull state from global history auto conflict = pull(store(thread)); @@ -119,7 +123,7 @@ std::optional> LinearSyncProtocol::on_end( ThreadContext& thread, GlobalContext& gctx ) { - std::cout << "on_end" << std::endl; + verbose << "on_end" << std::endl; // push changes to global history if (auto conflict = push(store(thread))) diff --git a/src/sync_protocol.hh b/src/sync_protocol.hh index a3449c9..7e077aa 100644 --- a/src/sync_protocol.hh +++ b/src/sync_protocol.hh @@ -12,7 +12,10 @@ namespace gitmem { struct ConflictBase { virtual ~ConflictBase() = default; - virtual std::string name() const = 0; + virtual std::ostream& print(std::ostream& os) const = 0; + friend std::ostream& operator<<(std::ostream& os, const ConflictBase& conflict) { + return conflict.print(os); + } }; template @@ -20,10 +23,10 @@ struct Conflict : ConflictBase { std::string var; std::pair versions; - std::string name() const override { return var; } - Conflict(std::string var, std::pair versions): var(std::move(var)), versions(std::move(versions)) {} + + std::ostream& print(std::ostream& os) const override; }; using LinearConflict = Conflict; From ec6ebe654051a80062d2543639bc1e0cdf68fbb3 Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Fri, 19 Dec 2025 10:31:35 +0000 Subject: [PATCH 07/58] pre clang format --- src/debugger.cc | 1 + src/execution_state.hh | 4 +-- src/interpreter.cc | 79 ++++++++++++++++++++++-------------------- src/interpreter.hh | 5 +-- src/model_checker.cc | 1 + 5 files changed, 46 insertions(+), 44 deletions(-) diff --git a/src/debugger.cc b/src/debugger.cc index 01492fe..133d2ba 100644 --- a/src/debugger.cc +++ b/src/debugger.cc @@ -1,5 +1,6 @@ #include +#include "debugger.hh" #include "interpreter.hh" namespace gitmem diff --git a/src/execution_state.hh b/src/execution_state.hh index 2254f5a..7b3a33f 100644 --- a/src/execution_state.hh +++ b/src/execution_state.hh @@ -46,8 +46,8 @@ namespace gitmem { struct Lock { // Globals globals; - // std::optional owner = std::nullopt; - // std::shared_ptr last; + std::optional owner = std::nullopt; + std::shared_ptr last; }; template diff --git a/src/interpreter.cc b/src/interpreter.cc index bb66718..81b889a 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -232,58 +232,61 @@ namespace gitmem { return 0; } } else if (s == lang::Lock) { - assert(false && "todo"); // We can only lock unlocked locks, if a lock hasn't been used // before it is implicitly created, we then commit the pending // updates of this thread and pull the updates from the lock. - // auto v = s / lang::Var; - // auto var = std::string(v->location().view()); + auto v = s / lang::Var; + auto var = std::string(v->location().view()); + + Lock& lock = gctx.locks[var]; + if (lock.owner) { + verbose << "Waiting for lock " << var << " owned by " << lock.owner.value() << std::endl; + return 0; + } - // auto& lock = gctx.locks[var]; - // if (lock.owner) { - // verbose << "Waiting for lock " << var << " owned by " << lock.owner.value() << std::endl; - // return 0; - // } + lock.owner = tid; + if(auto conflict = gctx.protocol->on_lock(ctx, lock, gctx)) { + verbose << (**conflict) << std::endl; + // using graph::Node; + // auto [s1, s2] = conflict->commits; + // auto sources = std::pair, std::shared_ptr>{gctx.commit_map[s1], gctx.commit_map[s2]}; + // auto graph_conflict = graph::Conflict(conflict->var, sources); + // thread_append_node(ctx, var, lock.last, graph_conflict); + return TerminationStatus::datarace_exception; + } - // lock.owner = tid; - // commit(ctx.globals); - // if(auto conflict = pull(ctx.globals, lock.globals)) - // { - // using graph::Node; - // auto [s1, s2] = conflict->commits; - // auto sources = std::pair, std::shared_ptr>{gctx.commit_map[s1], gctx.commit_map[s2]}; - // auto graph_conflict = graph::Conflict(conflict->var, sources); - // thread_append_node(ctx, var, lock.last, graph_conflict); - // return TerminationStatus::datarace_exception; - // } - - // thread_append_node(ctx, var, lock.last); - - // verbose << "Locked " << var << std::endl; + thread_append_node(ctx, var, lock.last); + + verbose << "Locked " << var << std::endl; } else if (s == lang::Unlock) { assert(false && "todo"); - // // We can only unlock locks we previously locked. We commit any - // // pending updates and then copy the threads versioned globals - // // to the locks versioned globals (nobody could have changed - // // them since we locked the lock). + // We can only unlock locks we previously locked. We commit any + // pending updates and then copy the threads versioned globals + // to the locks versioned globals (nobody could have changed + // them since we locked the lock). + // commit(ctx.globals); - // auto v = s / lang::Var; - // auto var = std::string(v->location().view()); + auto v = s / lang::Var; + auto var = std::string(v->location().view()); - // auto& lock = gctx.locks[var]; - // if (!lock.owner || (lock.owner && *lock.owner != tid)) - // { - // return TerminationStatus::unlock_exception; - // } + auto& lock = gctx.locks[var]; + if (!lock.owner || (lock.owner && *lock.owner != tid)) { + return TerminationStatus::unlock_exception; + } + + if(auto conflict = gctx.protocol->on_unlock(ctx, lock, gctx)) { + verbose << (**conflict) << std::endl; + return TerminationStatus::datarace_exception; + } // lock.globals = ctx.globals; - // lock.owner.reset(); + lock.owner.reset(); - // thread_append_node(ctx, var); - // lock.last = ctx.tail; + thread_append_node(ctx, var); + lock.last = ctx.tail; - // verbose << "Unlocked " << var << std::endl; + verbose << "Unlocked " << var << std::endl; } else if (s == lang::Assert) { auto expr = s / lang::Expr; diff --git a/src/interpreter.hh b/src/interpreter.hh index be7a063..43334c0 100644 --- a/src/interpreter.hh +++ b/src/interpreter.hh @@ -8,12 +8,9 @@ namespace gitmem { - // Entry functions + // Entry function int interpret(const trieste::Node, const std::filesystem::path &output_file); - // int interpret_interactive(const trieste::Node, const std::filesystem::path &output_file); - // int model_check(const trieste::Node, const std::filesystem::path &output_file); - // Internal functions int run_threads(GlobalContext &); diff --git a/src/model_checker.cc b/src/model_checker.cc index c7fff60..bdd3af5 100644 --- a/src/model_checker.cc +++ b/src/model_checker.cc @@ -1,3 +1,4 @@ +#include "model_checker.hh" #include "interpreter.hh" namespace gitmem From ac0099f4ee952777ef550724c685129d4df39f4c Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Fri, 19 Dec 2025 10:37:01 +0000 Subject: [PATCH 08/58] clang-format --- src/branching/version_store.hh | 82 ++- src/debugger.cc | 627 ++++++++++------------ src/execution_state.cc | 151 +++--- src/execution_state.hh | 108 ++-- src/gitmem.cc | 168 +++--- src/gitmem_trieste.cc | 5 +- src/graph.hh | 311 +++++------ src/graphviz.cc | 301 ++++++----- src/graphviz.hh | 61 ++- src/internal.hh | 24 +- src/interpreter.cc | 943 ++++++++++++++++----------------- src/interpreter.hh | 37 +- src/lang.hh | 107 ++-- src/linear/version_store.cc | 43 +- src/linear/version_store.hh | 32 +- src/main.cc | 6 +- src/model_checker.cc | 380 +++++++------ src/parser.cc | 212 ++++---- src/passes/branching.cc | 48 +- src/passes/check_refs.cc | 41 +- src/passes/expressions.cc | 228 ++++---- src/passes/statements.cc | 345 ++++++------ src/reader.cc | 17 +- src/sync_protocol.cc | 172 +++--- src/sync_protocol.hh | 198 +++---- 25 files changed, 2171 insertions(+), 2476 deletions(-) diff --git a/src/branching/version_store.hh b/src/branching/version_store.hh index f59b886..fed99f7 100644 --- a/src/branching/version_store.hh +++ b/src/branching/version_store.hh @@ -1,54 +1,52 @@ #pragma once - -#include -#include -#include #include #include +#include +#include +#include namespace gitmem { namespace branching { - /* A 'Global' is a structure to capture the current synchronising objects - * representation of a global variable. The structure is the current value, - * the current commit id for the variable, and the history of commited ids. - */ - - using Commit = size_t; - using CommitHistory = std::vector; - - struct Global - { - size_t val; - std::optional commit; - CommitHistory history; - }; - - using Globals = std::unordered_map; - - struct Conflict - { - std::string var; - std::pair commits; - }; - - struct LocalVersionStore {}; - - // Join logic - // commit(ctx.globals); - // commit(thread->ctx.globals); - // verbose << "Pulling from thread " << result << std::endl; - // if(auto conflict = pull(ctx.globals, thread->ctx.globals)) - // { - // using graph::Node; - // auto [s1, s2] = conflict->commits; - // auto sources = std::pair, std::shared_ptr>{gctx.commit_map[s1], gctx.commit_map[s2]}; - // auto graph_conflict = graph::Conflict(conflict->var, sources); - // thread_append_node(ctx, result, thread->ctx.tail, graph_conflict); - // return TerminationStatus::datarace_exception; - // } +/* A 'Global' is a structure to capture the current synchronising objects + * representation of a global variable. The structure is the current value, + * the current commit id for the variable, and the history of commited ids. + */ + +using Commit = size_t; +using CommitHistory = std::vector; + +struct Global { + size_t val; + std::optional commit; + CommitHistory history; +}; + +using Globals = std::unordered_map; + +struct Conflict { + std::string var; + std::pair commits; +}; + +struct LocalVersionStore {}; + +// Join logic +// commit(ctx.globals); +// commit(thread->ctx.globals); +// verbose << "Pulling from thread " << result << std::endl; +// if(auto conflict = pull(ctx.globals, thread->ctx.globals)) +// { +// using graph::Node; +// auto [s1, s2] = conflict->commits; +// auto sources = std::pair, +// std::shared_ptr>{gctx.commit_map[s1], gctx.commit_map[s2]}; auto +// graph_conflict = graph::Conflict(conflict->var, sources); +// thread_append_node(ctx, result, thread->ctx.tail, +// graph_conflict); return TerminationStatus::datarace_exception; +// } } // namespace branching diff --git a/src/debugger.cc b/src/debugger.cc index 133d2ba..2c7d597 100644 --- a/src/debugger.cc +++ b/src/debugger.cc @@ -3,379 +3,310 @@ #include "debugger.hh" #include "interpreter.hh" -namespace gitmem -{ - /** A command that can be parsed by the debugger. Some commands store a - * ThreadID argument. */ - struct Command - { - enum - { - Step, // Run a specified thread to next sync point - Finish, // Finish the rest of the program - Restart, // Start the program from the beginning - List, // List all threads - Print, // Print the execution graph - Graph, // Toggle automatically printing the execution graph - Quit, // Quit the interpreter - Info, // Show commands - Skip, // Do nothing, used for invalid commands - } cmd; - ThreadID argument = 0; - }; +namespace gitmem { +/** A command that can be parsed by the debugger. Some commands store a + * ThreadID argument. */ +struct Command { + enum { + Step, // Run a specified thread to next sync point + Finish, // Finish the rest of the program + Restart, // Start the program from the beginning + List, // List all threads + Print, // Print the execution graph + Graph, // Toggle automatically printing the execution graph + Quit, // Quit the interpreter + Info, // Show commands + Skip, // Do nothing, used for invalid commands + } cmd; + ThreadID argument = 0; +}; - void show_global(const std::string &var, const Global &global) - { - std::cout << var << " = " << global.val - << " [" << (global.commit ? std::to_string(*global.commit) : "_") << "; "; - for (size_t i = 0; i < global.history.size(); ++i) - { - std::cout << global.history[i]; - if (i < global.history.size() - 1) - { - std::cout << ", "; - } - } - std::cout << "]" << std::endl; +void show_global(const std::string &var, const Global &global) { + std::cout << var << " = " << global.val << " [" + << (global.commit ? std::to_string(*global.commit) : "_") << "; "; + for (size_t i = 0; i < global.history.size(); ++i) { + std::cout << global.history[i]; + if (i < global.history.size() - 1) { + std::cout << ", "; } + } + std::cout << "]" << std::endl; +} - /** Print the state of a thread, including its local and global variables, - * and the current position in the program. */ - void show_thread(const Thread &thread, size_t tid) - { - std::cout << "---- Thread " << tid << std::endl; - if (thread.ctx.locals.size() > 0) - { - for (auto &[reg, val] : thread.ctx.locals) - { - std::cout << reg << " = " << val << std::endl; - } - std::cout << "--" << std::endl; - } - - if (thread.ctx.globals.size() > 0) - { - for (auto &[var, val] : thread.ctx.globals) - { - show_global(var, val); - } - std::cout << "--" << std::endl; - } - - size_t idx = 0; - for (const auto &stmt : *thread.block) - { - if (idx == thread.pc) - { - std::cout << "-> "; - } - else - { - std::cout << " "; - } - // Fix indentation of nested blocks - auto s = std::string(stmt->location().view()); - s = std::regex_replace(s, std::regex("\n"), "\n "); - std::cout << s << ";" << std::endl; - - idx++; - } - if (thread.pc == thread.block->size()) - { - std::cout << "-> " << std::endl; - } +/** Print the state of a thread, including its local and global variables, + * and the current position in the program. */ +void show_thread(const Thread &thread, size_t tid) { + std::cout << "---- Thread " << tid << std::endl; + if (thread.ctx.locals.size() > 0) { + for (auto &[reg, val] : thread.ctx.locals) { + std::cout << reg << " = " << val << std::endl; } + std::cout << "--" << std::endl; + } - void show_lock(const std::string &lock_name, const struct Lock &lock) - { - std::cout << lock_name << ": "; - if (lock.owner) - { - std::cout << "held by thread " << *lock.owner; - } - else - { - std::cout << ""; - } - std::cout << std::endl; - for (auto &[var, global] : lock.globals) - { - show_global(var, global); - } + if (thread.ctx.globals.size() > 0) { + for (auto &[var, val] : thread.ctx.globals) { + show_global(var, val); } + std::cout << "--" << std::endl; + } - /** Show the global context, including locks and non-completed threads. If - * show_all is true, show all threads, even those that have terminated - * normally. */ - void show_global_context(const GlobalContext &gctx, bool show_all = false) - { - auto &threads = gctx.threads; - bool showed_any = false; - for (size_t i = 0; i < threads.size(); i++) - { - auto thread = threads[i]; - if (show_all || !thread->terminated || *threads[i]->terminated != TerminationStatus::completed) - { - show_thread(*threads[i], i); - std::cout << std::endl; - showed_any = true; - } - } + size_t idx = 0; + for (const auto &stmt : *thread.block) { + if (idx == thread.pc) { + std::cout << "-> "; + } else { + std::cout << " "; + } + // Fix indentation of nested blocks + auto s = std::string(stmt->location().view()); + s = std::regex_replace(s, std::regex("\n"), "\n "); + std::cout << s << ";" << std::endl; - if (showed_any && gctx.locks.size() > 0) - { - std::cout << "---- Locks" << std::endl; + idx++; + } + if (thread.pc == thread.block->size()) { + std::cout << "-> " << std::endl; + } +} - for (const auto &[lock_name, lock] : gctx.locks) - { - show_lock(lock_name, lock); - } +void show_lock(const std::string &lock_name, const struct Lock &lock) { + std::cout << lock_name << ": "; + if (lock.owner) { + std::cout << "held by thread " << *lock.owner; + } else { + std::cout << ""; + } + std::cout << std::endl; + for (auto &[var, global] : lock.globals) { + show_global(var, global); + } +} - if (gctx.locks.size() > 0) - std::cout << "--" << std::endl; - } +/** Show the global context, including locks and non-completed threads. If + * show_all is true, show all threads, even those that have terminated + * normally. */ +void show_global_context(const GlobalContext &gctx, bool show_all = false) { + auto &threads = gctx.threads; + bool showed_any = false; + for (size_t i = 0; i < threads.size(); i++) { + auto thread = threads[i]; + if (show_all || !thread->terminated || + *threads[i]->terminated != TerminationStatus::completed) { + show_thread(*threads[i], i); + std::cout << std::endl; + showed_any = true; } + } - /** Parse a command. See the help string for the 'Info' command for details. - */ - Command parse_command(std::string &input) - { - auto command = std::string(input); - command.erase(0, command.find_first_not_of(" \t\n\r")); - command.erase(command.find_last_not_of(" \t\n\r") + 1); + if (showed_any && gctx.locks.size() > 0) { + std::cout << "---- Locks" << std::endl; - if (command.find_first_not_of("0123456789") == std::string::npos) - { - // Interpret numbers as stepping - return {Command::Step, std::stoul(command)}; - } - else if (command == "s" || (command.at(0) == 's' && !std::isalpha(command.at(1)))) - { - auto arg = command.substr(1); - arg.erase(0, arg.find_first_not_of(" \t\n\r")); - if (arg.size() > 0 && arg.find_first_not_of("0123456789") == std::string::npos) - { - return {Command::Step, std::stoul(arg)}; - } - else - { - std::cout << "Expected thread id" << std::endl; - return {Command::Skip}; - } - } - else if (command == "q") - { - return {Command::Quit}; - } - else if (command == "r") - { - return {Command::Restart}; - } - else if (command == "f") - { - return {Command::Finish}; - } - else if (command == "l") - { - return {Command::List}; - } - else if (command == "g") - { - return {Command::Graph}; - } - else if (command == "p") - { - return {Command::Print}; - } - else if (command == "?") - { - return {Command::Info}; - } - else - { - std::cout << "Unknown command: " << input << std::endl; - return {Command::Skip}; - } + for (const auto &[lock_name, lock] : gctx.locks) { + show_lock(lock_name, lock); } - /** Perform the Step command on a given thread. Error messages are assigned - * to `msg`. The return value signals whether threads should be printed - * after stepping or not. */ - bool step_thread(ThreadID tid, GlobalContext &gctx, std::string &msg) - { - if (tid >= gctx.threads.size()) - { - msg = "Invalid thread id: " + std::to_string(tid); - return false; - } + if (gctx.locks.size() > 0) + std::cout << "--" << std::endl; + } +} - auto thread = gctx.threads[tid]; - if (auto term = thread->terminated) - { - if (*term == TerminationStatus::completed) - { - msg = "Thread " + std::to_string(tid) + " has terminated normally"; - } - else - { - msg = "Thread " + std::to_string(tid) + " has terminated with an error"; - } - return false; - } +/** Parse a command. See the help string for the 'Info' command for details. + */ +Command parse_command(std::string &input) { + auto command = std::string(input); + command.erase(0, command.find_first_not_of(" \t\n\r")); + command.erase(command.find_last_not_of(" \t\n\r") + 1); - auto prog_or_term = progress_thread(gctx, tid, thread); - if (ProgressStatus *prog = std::get_if(&prog_or_term)) - { - if (!*prog) - { - auto stmt = thread->block->at(thread->pc); - msg = "Thread " + std::to_string(tid) + " is blocking on '" + std::string(stmt->location().view()) + "'"; - return false; - } - } - else if (TerminationStatus *term = std::get_if(&prog_or_term)) - { - switch (*term) - { - case TerminationStatus::completed: - msg = "Thread " + std::to_string(tid) + " terminated normally"; - return true; - case TerminationStatus::datarace_exception: - // TODO: Say on which variable the datarace occurred. To - // do this, have pull return an optional variable that - // is in a race and have the data race exception - // remember that variable. - msg = "Thread " + std::to_string(tid) + " encountered a data race and was terminated"; - return false; - case TerminationStatus::assertion_failure_exception: - { - auto expr = thread->block->at(thread->pc) / lang::Stmt / lang::Expr; - msg = "Thread " + std::to_string(tid) + " failed assertion '" + std::string(expr->location().view()) + "' and was terminated"; - return false; - } - case TerminationStatus::unassigned_variable_read_exception: - throw std::runtime_error("Thread " + std::to_string(tid) + " read an uninitialised variable"); - case TerminationStatus::unlock_exception: - throw std::runtime_error("Thread " + std::to_string(tid) + " unlocked an unlocked lock"); - default: - throw std::runtime_error("Thread " + std::to_string(tid) + " has an unhandled termination state"); - } - } - return true; + if (command.find_first_not_of("0123456789") == std::string::npos) { + // Interpret numbers as stepping + return {Command::Step, std::stoul(command)}; + } else if (command == "s" || + (command.at(0) == 's' && !std::isalpha(command.at(1)))) { + auto arg = command.substr(1); + arg.erase(0, arg.find_first_not_of(" \t\n\r")); + if (arg.size() > 0 && + arg.find_first_not_of("0123456789") == std::string::npos) { + return {Command::Step, std::stoul(arg)}; + } else { + std::cout << "Expected thread id" << std::endl; + return {Command::Skip}; } + } else if (command == "q") { + return {Command::Quit}; + } else if (command == "r") { + return {Command::Restart}; + } else if (command == "f") { + return {Command::Finish}; + } else if (command == "l") { + return {Command::List}; + } else if (command == "g") { + return {Command::Graph}; + } else if (command == "p") { + return {Command::Print}; + } else if (command == "?") { + return {Command::Info}; + } else { + std::cout << "Unknown command: " << input << std::endl; + return {Command::Skip}; + } +} - /** Interpret the AST in an interactive way, letting the user choose which - * thread to schedule next. */ - int interpret_interactive(const trieste::Node ast, const std::filesystem::path &output_file) - { - GlobalContext gctx(ast); +/** Perform the Step command on a given thread. Error messages are assigned + * to `msg`. The return value signals whether threads should be printed + * after stepping or not. */ +bool step_thread(ThreadID tid, GlobalContext &gctx, std::string &msg) { + if (tid >= gctx.threads.size()) { + msg = "Invalid thread id: " + std::to_string(tid); + return false; + } - size_t prev_no_threads = 1; - Command command = {Command::List}; - std::string msg = ""; - bool print_graphs = true; - gctx.print_execution_graph(output_file); - while (command.cmd != Command::Quit) - { - if (command.cmd != Command::Skip || prev_no_threads != gctx.threads.size()) - { - bool show_all = command.cmd == Command::List; - show_global_context(gctx, show_all); - } - prev_no_threads = gctx.threads.size(); + auto thread = gctx.threads[tid]; + if (auto term = thread->terminated) { + if (*term == TerminationStatus::completed) { + msg = "Thread " + std::to_string(tid) + " has terminated normally"; + } else { + msg = "Thread " + std::to_string(tid) + " has terminated with an error"; + } + return false; + } + + auto prog_or_term = progress_thread(gctx, tid, thread); + if (ProgressStatus *prog = std::get_if(&prog_or_term)) { + if (!*prog) { + auto stmt = thread->block->at(thread->pc); + msg = "Thread " + std::to_string(tid) + " is blocking on '" + + std::string(stmt->location().view()) + "'"; + return false; + } + } else if (TerminationStatus *term = + std::get_if(&prog_or_term)) { + switch (*term) { + case TerminationStatus::completed: + msg = "Thread " + std::to_string(tid) + " terminated normally"; + return true; + case TerminationStatus::datarace_exception: + // TODO: Say on which variable the datarace occurred. To + // do this, have pull return an optional variable that + // is in a race and have the data race exception + // remember that variable. + msg = "Thread " + std::to_string(tid) + + " encountered a data race and was terminated"; + return false; + case TerminationStatus::assertion_failure_exception: { + auto expr = thread->block->at(thread->pc) / lang::Stmt / lang::Expr; + msg = "Thread " + std::to_string(tid) + " failed assertion '" + + std::string(expr->location().view()) + "' and was terminated"; + return false; + } + case TerminationStatus::unassigned_variable_read_exception: + throw std::runtime_error("Thread " + std::to_string(tid) + + " read an uninitialised variable"); + case TerminationStatus::unlock_exception: + throw std::runtime_error("Thread " + std::to_string(tid) + + " unlocked an unlocked lock"); + default: + throw std::runtime_error("Thread " + std::to_string(tid) + + " has an unhandled termination state"); + } + } + return true; +} + +/** Interpret the AST in an interactive way, letting the user choose which + * thread to schedule next. */ +int interpret_interactive(const trieste::Node ast, + const std::filesystem::path &output_file) { + GlobalContext gctx(ast); - if (!msg.empty()) - { - std::cout << msg << std::endl; - msg.clear(); - } + size_t prev_no_threads = 1; + Command command = {Command::List}; + std::string msg = ""; + bool print_graphs = true; + gctx.print_execution_graph(output_file); + while (command.cmd != Command::Quit) { + if (command.cmd != Command::Skip || + prev_no_threads != gctx.threads.size()) { + bool show_all = command.cmd == Command::List; + show_global_context(gctx, show_all); + } + prev_no_threads = gctx.threads.size(); - std::cout << "> "; - std::string input; - std::getline(std::cin, input); - if (!input.empty() && input.find_first_not_of(" \t\n\r") != std::string::npos) - { - command = parse_command(input); - } + if (!msg.empty()) { + std::cout << msg << std::endl; + msg.clear(); + } - if (command.cmd == Command::Step) - { - auto tid = command.argument; - if (!step_thread(tid, gctx, msg)) command = {Command::Skip}; + std::cout << "> "; + std::string input; + std::getline(std::cin, input); + if (!input.empty() && + input.find_first_not_of(" \t\n\r") != std::string::npos) { + command = parse_command(input); + } - if (print_graphs) - { - gctx.print_execution_graph(output_file); - verbose << "Execution graph written to " << output_file << std::endl; - } - } - else if (command.cmd == Command::Finish) - { - // Finish the program - if (!run_threads(gctx)) - msg = "Program finished successfully"; - else - msg = "Program terminated with an error"; + if (command.cmd == Command::Step) { + auto tid = command.argument; + if (!step_thread(tid, gctx, msg)) + command = {Command::Skip}; - if (print_graphs) - { - gctx.print_execution_graph(output_file); - verbose << "Execution graph written to " << output_file << std::endl; - } - } - else if (command.cmd == Command::Restart) - { - // Start the program from the beginning - gctx = GlobalContext(ast); - command = {Command::List}; - if (print_graphs) - { - gctx.print_execution_graph(output_file); - verbose << "Execution graph written to " << output_file << std::endl; - } - } - else if (command.cmd == Command::List) - { - // Listing is a no-op - } - else if (command.cmd == Command::Graph) - { - // Toggle printing execution graph automatically - print_graphs = !print_graphs; - std::cout << "graphs " << (print_graphs ? "will" : "won't") << " print automatically" << std::endl; - command = {Command::Skip}; - } - else if (command.cmd == Command::Print) - { - // Print the execution graph - gctx.print_execution_graph(output_file); - verbose << "Execution graph written to " << output_file << std::endl; - command = {Command::Skip}; - } - else if (command.cmd == Command::Skip) - { - // Skip is a no-op - } - else if (command.cmd == Command::Info) - { - std::cout << "Commands:" << std::endl; - std::cout << "s [tid] - Step to next sync point in thread" << std::endl; - std::cout << "[tid] - Step to next sync point in thread" << std::endl; - std::cout << "f - Finish the program" << std::endl; - std::cout << "r - Restart the program" << std::endl; - std::cout << "l - List all threads" << std::endl; - std::cout << "g - Toggle printing the execution graph at sync points" << std::endl; - std::cout << "p - Printing the execution graph at current sync point" << std::endl; - std::cout << "q - Quit the interpreter" << std::endl; - std::cout << "? - Display this help message" << std::endl; - command = {Command::Skip}; - } - else if (command.cmd == Command::Quit) - { - // Quit is a no-op - } - } + if (print_graphs) { + gctx.print_execution_graph(output_file); + verbose << "Execution graph written to " << output_file << std::endl; + } + } else if (command.cmd == Command::Finish) { + // Finish the program + if (!run_threads(gctx)) + msg = "Program finished successfully"; + else + msg = "Program terminated with an error"; - return 0; + if (print_graphs) { + gctx.print_execution_graph(output_file); + verbose << "Execution graph written to " << output_file << std::endl; + } + } else if (command.cmd == Command::Restart) { + // Start the program from the beginning + gctx = GlobalContext(ast); + command = {Command::List}; + if (print_graphs) { + gctx.print_execution_graph(output_file); + verbose << "Execution graph written to " << output_file << std::endl; + } + } else if (command.cmd == Command::List) { + // Listing is a no-op + } else if (command.cmd == Command::Graph) { + // Toggle printing execution graph automatically + print_graphs = !print_graphs; + std::cout << "graphs " << (print_graphs ? "will" : "won't") + << " print automatically" << std::endl; + command = {Command::Skip}; + } else if (command.cmd == Command::Print) { + // Print the execution graph + gctx.print_execution_graph(output_file); + verbose << "Execution graph written to " << output_file << std::endl; + command = {Command::Skip}; + } else if (command.cmd == Command::Skip) { + // Skip is a no-op + } else if (command.cmd == Command::Info) { + std::cout << "Commands:" << std::endl; + std::cout << "s [tid] - Step to next sync point in thread" << std::endl; + std::cout << "[tid] - Step to next sync point in thread" << std::endl; + std::cout << "f - Finish the program" << std::endl; + std::cout << "r - Restart the program" << std::endl; + std::cout << "l - List all threads" << std::endl; + std::cout << "g - Toggle printing the execution graph at sync points" + << std::endl; + std::cout << "p - Printing the execution graph at current sync point" + << std::endl; + std::cout << "q - Quit the interpreter" << std::endl; + std::cout << "? - Display this help message" << std::endl; + command = {Command::Skip}; + } else if (command.cmd == Command::Quit) { + // Quit is a no-op } + } + + return 0; } +} // namespace gitmem diff --git a/src/execution_state.cc b/src/execution_state.cc index 1ee2f5d..d769e13 100644 --- a/src/execution_state.cc +++ b/src/execution_state.cc @@ -3,89 +3,88 @@ namespace gitmem { - bool Thread::operator==(const Thread &other) const - { - return false; - // Globals have a history that we don't care about, so we only - // compare values - // if (ctx.globals.size() != other.ctx.globals.size()) - // return false; - // for (const auto &[var, global] : ctx.globals) - // { - // if (!other.ctx.globals.contains(var) || - // ctx.globals.at(var).val != other.ctx.globals.at(var).val) - // { - // return false; - // } - // } - // return ctx.locals == other.ctx.locals && - // block == other.block && - // pc == other.pc && - // terminated == other.terminated; - } +bool Thread::operator==(const Thread &other) const { + return false; + // Globals have a history that we don't care about, so we only + // compare values + // if (ctx.globals.size() != other.ctx.globals.size()) + // return false; + // for (const auto &[var, global] : ctx.globals) + // { + // if (!other.ctx.globals.contains(var) || + // ctx.globals.at(var).val != other.ctx.globals.at(var).val) + // { + // return false; + // } + // } + // return ctx.locals == other.ctx.locals && + // block == other.block && + // pc == other.pc && + // terminated == other.terminated; +} - GlobalContext::GlobalContext(const trieste::Node &ast, std::unique_ptr protocol): - protocol(std::move(protocol)) - { - trieste::Node starting_block = ast / lang::File / lang::Block; - ThreadContext starting_ctx = { - .locals = {}, - .tail = std::make_shared(0) - }; - auto main_thread = std::make_shared(starting_ctx, starting_block); +GlobalContext::GlobalContext(const trieste::Node &ast, + std::unique_ptr protocol) + : protocol(std::move(protocol)) { + trieste::Node starting_block = ast / lang::File / lang::Block; + ThreadContext starting_ctx = {.locals = {}, + .tail = std::make_shared(0)}; + auto main_thread = std::make_shared(starting_ctx, starting_block); - this->threads = {main_thread}; - this->locks = {}; - this->cache = {}; - } - - - GlobalContext::~GlobalContext() = default; + this->threads = {main_thread}; + this->locks = {}; + this->cache = {}; +} - void GlobalContext::print_execution_graph(const std::filesystem::path &output_path) const { - // Loop over the threads and add pending nodes to running threads - // to indicate a threads next step - for (const auto& t: threads) - { - assert(t->ctx.tail); - if (t->terminated || dynamic_pointer_cast(t->ctx.tail->next)) - continue; +GlobalContext::~GlobalContext() = default; - trieste::Node block = t->block; - size_t &pc = t->pc; - trieste::Node stmt = block->at(pc); - thread_append_node(t->ctx, std::string(stmt->location().view())); - } +void GlobalContext::print_execution_graph( + const std::filesystem::path &output_path) const { + // Loop over the threads and add pending nodes to running threads + // to indicate a threads next step + for (const auto &t : threads) { + assert(t->ctx.tail); + if (t->terminated || + dynamic_pointer_cast(t->ctx.tail->next)) + continue; - graph::GraphvizPrinter gv(output_path); - gv.visit(entry_node.get()); + trieste::Node block = t->block; + size_t &pc = t->pc; + trieste::Node stmt = block->at(pc); + thread_append_node(t->ctx, + std::string(stmt->location().view())); } - bool GlobalContext::operator==(const GlobalContext &other) const { - return false; - // if (threads.size() != other.threads.size() || locks.size() != other.locks.size()) - // return false; + graph::GraphvizPrinter gv(output_path); + gv.visit(entry_node.get()); +} - // // Threads may have been spawned in a different order, so we - // // find the thread with the same block in the other context - // for (auto &thread : threads) - // { - // auto it = std::find_if(other.threads.begin(), other.threads.end(), - // [&thread](auto &t) - // { return t->block == thread->block; }); - // if (it == other.threads.end() || !(*thread == **it)) - // return false; - // } +bool GlobalContext::operator==(const GlobalContext &other) const { + return false; + // if (threads.size() != other.threads.size() || locks.size() != + // other.locks.size()) + // return false; - // for (auto &[name, lock] : locks) - // { - // if (!other.locks.contains(name)) - // return false; - // auto &other_lock = other.locks.at(name); - // if (lock.owner != other_lock.owner) - // return false; - // } - // return true; - } + // // Threads may have been spawned in a different order, so we + // // find the thread with the same block in the other context + // for (auto &thread : threads) + // { + // auto it = std::find_if(other.threads.begin(), other.threads.end(), + // [&thread](auto &t) + // { return t->block == thread->block; }); + // if (it == other.threads.end() || !(*thread == **it)) + // return false; + // } + + // for (auto &[name, lock] : locks) + // { + // if (!other.locks.contains(name)) + // return false; + // auto &other_lock = other.locks.at(name); + // if (lock.owner != other_lock.owner) + // return false; + // } + // return true; +} -} \ No newline at end of file +} // namespace gitmem \ No newline at end of file diff --git a/src/execution_state.hh b/src/execution_state.hh index 7b3a33f..768a83e 100644 --- a/src/execution_state.hh +++ b/src/execution_state.hh @@ -1,81 +1,87 @@ #pragma once -#include #include -#include #include +#include +#include +#include "branching/version_store.hh" #include "graphviz.hh" #include "lang.hh" #include "linear/version_store.hh" -#include "branching/version_store.hh" namespace gitmem { - class SyncProtocol; +class SyncProtocol; - enum class TerminationStatus { - completed, - datarace_exception, - unlock_exception, - assertion_failure_exception, - unassigned_variable_read_exception, - }; +enum class TerminationStatus { + completed, + datarace_exception, + unlock_exception, + assertion_failure_exception, + unassigned_variable_read_exception, +}; - struct ThreadContext { - std::unordered_map locals; - std::shared_ptr tail; +struct ThreadContext { + std::unordered_map locals; + std::shared_ptr tail; - struct LinearData { linear::LocalVersionStore store; }; - struct BranchingData { branching::LocalVersionStore store; }; - - std::optional linear; - std::optional branching; + struct LinearData { + linear::LocalVersionStore store; + }; + struct BranchingData { + branching::LocalVersionStore store; }; - struct Thread { - ThreadContext ctx; - trieste::Node block; - size_t pc = 0; - std::optional terminated = std::nullopt; + std::optional linear; + std::optional branching; +}; - bool operator==(const Thread &other) const; - }; +struct Thread { + ThreadContext ctx; + trieste::Node block; + size_t pc = 0; + std::optional terminated = std::nullopt; - using ThreadID = size_t; + bool operator==(const Thread &other) const; +}; - struct Lock { - // Globals globals; - std::optional owner = std::nullopt; - std::shared_ptr last; - }; +using ThreadID = size_t; - template - std::shared_ptr thread_append_node(ThreadContext& ctx, Args&&...args); +struct Lock { + // Globals globals; + std::optional owner = std::nullopt; + std::shared_ptr last; +}; - template<> - std::shared_ptr thread_append_node(ThreadContext& ctx, std::string&& stmt); +template +std::shared_ptr thread_append_node(ThreadContext &ctx, Args &&...args); - struct GlobalContext { - // Execution state - std::vector> threads; - std::unordered_map locks; +template <> +std::shared_ptr +thread_append_node(ThreadContext &ctx, std::string &&stmt); - // AST evaluation cache - lang::NodeMap cache; +struct GlobalContext { + // Execution state + std::vector> threads; + std::unordered_map locks; - // Graph root - std::shared_ptr entry_node; + // AST evaluation cache + lang::NodeMap cache; - // Synchronisation semantics (policy) - std::unique_ptr protocol; + // Graph root + std::shared_ptr entry_node; - GlobalContext(const trieste::Node &ast, std::unique_ptr protocol); - ~GlobalContext(); + // Synchronisation semantics (policy) + std::unique_ptr protocol; - bool operator==(const GlobalContext &other) const; + GlobalContext(const trieste::Node &ast, + std::unique_ptr protocol); + ~GlobalContext(); - void print_execution_graph(const std::filesystem::path &output_path) const; - }; + bool operator==(const GlobalContext &other) const; + + void print_execution_graph(const std::filesystem::path &output_path) const; +}; } // namespace gitmem \ No newline at end of file diff --git a/src/gitmem.cc b/src/gitmem.cc index 012f125..3d88804 100644 --- a/src/gitmem.cc +++ b/src/gitmem.cc @@ -1,104 +1,82 @@ #include -#include "lang.hh" -#include "interpreter.hh" #include "debug.hh" +#include "interpreter.hh" +#include "lang.hh" -int main(int argc, char **argv) -{ - using namespace trieste; - CLI::App app; - - std::filesystem::path input_path; - app.add_option("input", input_path, "Path to the input file ")->required()->check(CLI::ExistingFile); - - std::filesystem::path output_path = ""; - app.add_option( - "-o,--output", - output_path, - "Path to the output file." - ); - - bool verbose = false; - app.add_flag( - "-v,--verbose", - verbose, - "Enable verbose output from the interpreter." - ); - - // TODO: These should probably be subcommands - bool interactive = false; - app.add_flag( - "-i,--interactive", - interactive, - "Enable interactive scheduling mode (use command ? for help)."); - - bool model_check = false; - app.add_flag( - "-e,--explore", - model_check, - "Explore all possible execution paths."); - - try - { - app.parse(argc, argv); - } - catch (const CLI::ParseError &e) - { - return app.exit(e); +int main(int argc, char **argv) { + using namespace trieste; + CLI::App app; + + std::filesystem::path input_path; + app.add_option("input", input_path, "Path to the input file ") + ->required() + ->check(CLI::ExistingFile); + + std::filesystem::path output_path = ""; + app.add_option("-o,--output", output_path, "Path to the output file."); + + bool verbose = false; + app.add_flag("-v,--verbose", verbose, + "Enable verbose output from the interpreter."); + + // TODO: These should probably be subcommands + bool interactive = false; + app.add_flag("-i,--interactive", interactive, + "Enable interactive scheduling mode (use command ? for help)."); + + bool model_check = false; + app.add_flag("-e,--explore", model_check, + "Explore all possible execution paths."); + + try { + app.parse(argc, argv); + } catch (const CLI::ParseError &e) { + return app.exit(e); + } + + try { + gitmem::verbose.enabled = verbose; + + gitmem::verbose << "Reading file " << input_path << std::endl; + if (!std::filesystem::exists(input_path)) { + std::cerr << "Input file does not exist: " << input_path << std::endl; + return 1; } - try - { - gitmem::verbose.enabled = verbose; - - gitmem::verbose << "Reading file " << input_path << std::endl; - if (!std::filesystem::exists(input_path)) - { - std::cerr << "Input file does not exist: " << input_path << std::endl; - return 1; - } - - auto reader = gitmem::lang::reader().file(input_path); - auto result = reader.read(); - - if (!result.ok) - { - trieste::logging::Error err; - result.print_errors(err); - trieste::logging::Debug() << result.ast; - return 1; - } - - if (output_path.empty()) - output_path = input_path.stem().replace_extension(".dot"); - - gitmem::verbose << "Output will be written to " << output_path << std::endl; - - int exit_status; - wf::push_back(gitmem::lang::wf); - if (model_check) - { - assert(false && "currently broken"); - // exit_status = gitmem::model_check(result.ast, output_path); - } - else if (interactive) - { - assert(false && "currently broken"); - // exit_status = gitmem::interpret_interactive(result.ast, output_path); - } - else - { - exit_status = gitmem::interpret(result.ast, output_path); - } - wf::pop_front(); - - gitmem::verbose << "Execution finished with exit status " << exit_status << std::endl; - return exit_status; + auto reader = gitmem::lang::reader().file(input_path); + auto result = reader.read(); + + if (!result.ok) { + trieste::logging::Error err; + result.print_errors(err); + trieste::logging::Debug() << result.ast; + return 1; } - catch (const std::exception &e) - { - std::cerr << "Exception caught: " << e.what() << std::endl; - return 1; + + if (output_path.empty()) + output_path = input_path.stem().replace_extension(".dot"); + + gitmem::verbose << "Output will be written to " << output_path << std::endl; + + int exit_status; + wf::push_back(gitmem::lang::wf); + if (model_check) { + assert(false && "currently broken"); + // exit_status = gitmem::model_check(result.ast, output_path); + } else if (interactive) { + assert(false && "currently broken"); + // exit_status = gitmem::interpret_interactive(result.ast, output_path); + } else { + exit_status = gitmem::interpret(result.ast, output_path); } + wf::pop_front(); + + gitmem::verbose << "Execution finished with exit status " << exit_status + << std::endl; + return exit_status; + } catch (const std::exception &e) { + std::cerr << "Exception caught: " << e.what() << std::endl; + return 1; + } } diff --git a/src/gitmem_trieste.cc b/src/gitmem_trieste.cc index 0ebc28b..a715915 100644 --- a/src/gitmem_trieste.cc +++ b/src/gitmem_trieste.cc @@ -1,7 +1,6 @@ -#include #include "lang.hh" +#include -int main(int argc, char** argv) -{ +int main(int argc, char **argv) { return trieste::Driver(gitmem::lang::reader()).run(argc, argv); } diff --git a/src/graph.hh b/src/graph.hh index 3a3c071..a201439 100644 --- a/src/graph.hh +++ b/src/graph.hh @@ -1,181 +1,144 @@ #pragma once +#include #include #include -#include -#include namespace gitmem { - namespace graph { - - struct Visitor; - - struct Node - { - std::shared_ptr next = nullptr; - - virtual void accept(Visitor*) const = 0; - }; - - struct Start; - struct End; - struct Write; - struct Read; - struct Spawn; - struct Join; - struct Lock; - struct Unlock; - struct AssertionFailure; - struct Pending; - - struct Conflict - { - std::string var; - std::pair, std::shared_ptr> sources; - }; - - struct Visitor - { - virtual void visitStart(const Start*) = 0; - virtual void visitEnd(const End*) = 0; - virtual void visitWrite(const Write*) = 0; - virtual void visitRead(const Read*) = 0; - virtual void visitSpawn(const Spawn*) = 0; - virtual void visitJoin(const Join*) = 0; - virtual void visitLock(const Lock*) = 0; - virtual void visitUnlock(const Unlock*) = 0; - virtual void visitAssertionFailure(const AssertionFailure*) = 0; - virtual void visitPending(const Pending*) = 0; - virtual void visit(const Node* n) { n->accept(this); } - }; - - struct Start : Node - { - size_t id; - - Start(size_t id): id(id) {} - - void accept(Visitor* v) const override - { - v->visitStart(this); - } - }; - - struct End : Node - { - End() {} - - void accept(Visitor* v) const override - { - v->visitEnd(this); - } - }; - - struct Write : Node - { - const std::string var; - const size_t value; - const size_t id; - - Write(const std::string var, const size_t value, const size_t id): var(var), value(value), id(id) {} - - void accept(Visitor* v) const override - { - v->visitWrite(this); - } - }; - - struct Read : Node - { - const std::string var; - const size_t value; - const size_t id; - const std::shared_ptr sauce; - - - Read(const std::string var, const size_t value, const size_t id, const std::shared_ptr sauce): var(var), value(value), id(id), sauce(sauce) {} - - void accept(Visitor* v) const override - { - v->visitRead(this); - } - }; - - struct Spawn : Node - { - const size_t tid; - const std::shared_ptr spawned; - - Spawn(const size_t tid, const std::shared_ptr spawned): tid(tid), spawned(spawned) {} - - void accept(Visitor* v) const override - { - v->visitSpawn(this); - } - }; - - struct Join : Node - { - const size_t tid; - const std::shared_ptr joinee; - const std::optional conflict; - - Join(const size_t tid, const std::shared_ptr joinee, std::optional conflict = std::nullopt): tid(tid), joinee(joinee), conflict(conflict) {} - - void accept(Visitor* v) const override - { - v->visitJoin(this); - } - }; - - struct Lock : Node - { - const std::string var; - const std::shared_ptr ordered_after; - const std::optional conflict; - - Lock(const std::string var, const std::shared_ptr ordered_after, std::optional conflict = std::nullopt): var(var), ordered_after(ordered_after), conflict(conflict) {} - - void accept(Visitor* v) const override - { - v->visitLock(this); - } - }; - - struct Unlock : Node - { - const std::string var; - - Unlock(const std::string var): var(var) {} - - void accept(Visitor* v) const override - { - v->visitUnlock(this); - } - }; - - struct AssertionFailure : Node - { - const std::string cond; - - AssertionFailure(const std::string &cond): cond(cond) {} - - void accept(Visitor* v) const override - { - v->visitAssertionFailure(this); - } - }; - - struct Pending : Node - { - const std::string statement; - - Pending(const std::string statement): statement(statement) {} - void accept(Visitor* v) const override - { - v->visitPending(this); - } - }; - } -} +namespace graph { + +struct Visitor; + +struct Node { + std::shared_ptr next = nullptr; + + virtual void accept(Visitor *) const = 0; +}; + +struct Start; +struct End; +struct Write; +struct Read; +struct Spawn; +struct Join; +struct Lock; +struct Unlock; +struct AssertionFailure; +struct Pending; + +struct Conflict { + std::string var; + std::pair, std::shared_ptr> sources; +}; + +struct Visitor { + virtual void visitStart(const Start *) = 0; + virtual void visitEnd(const End *) = 0; + virtual void visitWrite(const Write *) = 0; + virtual void visitRead(const Read *) = 0; + virtual void visitSpawn(const Spawn *) = 0; + virtual void visitJoin(const Join *) = 0; + virtual void visitLock(const Lock *) = 0; + virtual void visitUnlock(const Unlock *) = 0; + virtual void visitAssertionFailure(const AssertionFailure *) = 0; + virtual void visitPending(const Pending *) = 0; + virtual void visit(const Node *n) { n->accept(this); } +}; + +struct Start : Node { + size_t id; + + Start(size_t id) : id(id) {} + + void accept(Visitor *v) const override { v->visitStart(this); } +}; + +struct End : Node { + End() {} + + void accept(Visitor *v) const override { v->visitEnd(this); } +}; + +struct Write : Node { + const std::string var; + const size_t value; + const size_t id; + + Write(const std::string var, const size_t value, const size_t id) + : var(var), value(value), id(id) {} + + void accept(Visitor *v) const override { v->visitWrite(this); } +}; + +struct Read : Node { + const std::string var; + const size_t value; + const size_t id; + const std::shared_ptr sauce; + + Read(const std::string var, const size_t value, const size_t id, + const std::shared_ptr sauce) + : var(var), value(value), id(id), sauce(sauce) {} + + void accept(Visitor *v) const override { v->visitRead(this); } +}; + +struct Spawn : Node { + const size_t tid; + const std::shared_ptr spawned; + + Spawn(const size_t tid, const std::shared_ptr spawned) + : tid(tid), spawned(spawned) {} + + void accept(Visitor *v) const override { v->visitSpawn(this); } +}; + +struct Join : Node { + const size_t tid; + const std::shared_ptr joinee; + const std::optional conflict; + + Join(const size_t tid, const std::shared_ptr joinee, + std::optional conflict = std::nullopt) + : tid(tid), joinee(joinee), conflict(conflict) {} + + void accept(Visitor *v) const override { v->visitJoin(this); } +}; + +struct Lock : Node { + const std::string var; + const std::shared_ptr ordered_after; + const std::optional conflict; + + Lock(const std::string var, const std::shared_ptr ordered_after, + std::optional conflict = std::nullopt) + : var(var), ordered_after(ordered_after), conflict(conflict) {} + + void accept(Visitor *v) const override { v->visitLock(this); } +}; + +struct Unlock : Node { + const std::string var; + + Unlock(const std::string var) : var(var) {} + + void accept(Visitor *v) const override { v->visitUnlock(this); } +}; + +struct AssertionFailure : Node { + const std::string cond; + + AssertionFailure(const std::string &cond) : cond(cond) {} + + void accept(Visitor *v) const override { v->visitAssertionFailure(this); } +}; + +struct Pending : Node { + const std::string statement; + + Pending(const std::string statement) : statement(statement) {} + void accept(Visitor *v) const override { v->visitPending(this); } +}; +} // namespace graph +} // namespace gitmem diff --git a/src/graphviz.cc b/src/graphviz.cc index cd9bd1e..ee6a109 100644 --- a/src/graphviz.cc +++ b/src/graphviz.cc @@ -4,155 +4,164 @@ namespace gitmem { namespace graph { - using std::to_string; - - void GraphvizPrinter::emitNode(const Node* n, const std::string& label, const std::string& style) { - file << "\t" << (size_t)n << "[label=\"" << label << "\", shape=rectangle, style=\"rounded,filled\", "; - if (!style.empty()) file << style; - file << "]" << ";" << std::endl; - } - - void GraphvizPrinter::emitEdge(const Node* from, const Node* to, const std::string& label, const std::string& style) { - if (!from || !to) return; - - file << "\t" << (size_t)from << " -> " << (size_t)to; - if (!style.empty() || !label.empty()) { - file << "["; - if (!style.empty()) file << style; - if (!label.empty()) file << " label=\"" << label << "\""; - file << "]"; - } - file << ";" << std::endl; - } - - void GraphvizPrinter::emitProgramOrderEdge(const Node* from, const Node* to) { - emitEdge(from, to, ""); - } - - void GraphvizPrinter::emitReadFromEdge(const Node* from, const Node* to) { - emitEdge(from, to, "rf", "style=dashed, constraint=false"); - } - - void GraphvizPrinter::emitConflictEdge(const Node* from, const Node* to) { - emitEdge(from, to, "race", "style=dashed, color=red, constraint=false"); - } - - void GraphvizPrinter::emitSyncEdge(const Node* from, const Node* to) { - emitEdge(from, to, "sync", "style=bold, constraint=false"); - } - - void GraphvizPrinter::emitFillColor(const Node* n, const std::string& color) { - file << "\t" << (size_t)n << "[fillcolor = " << color << "];" << std::endl; - } - - void GraphvizPrinter::emitShape(const Node* n, const std::string& shape) { - file << "\t" << (size_t)n << "[shape = " << shape << "];" << std::endl; - } - - void GraphvizPrinter::emitConflict(const Node* n, const Conflict& conflict) { - emitFillColor(n, "red"); - // emitShape(n, "doubleoctagon"); - auto [s1, s2] = conflict.sources; - emitConflictEdge(n, s1.get()); - emitConflictEdge(n, s2.get()); - } - - GraphvizPrinter::GraphvizPrinter(std::string filename) noexcept { - file.open(filename); - } - - void GraphvizPrinter::visit(const Node* n) { - file << "digraph G {" << std::endl; +using std::to_string; + +void GraphvizPrinter::emitNode(const Node *n, const std::string &label, + const std::string &style) { + file << "\t" << (size_t)n << "[label=\"" << label + << "\", shape=rectangle, style=\"rounded,filled\", "; + if (!style.empty()) + file << style; + file << "]" << ";" << std::endl; +} + +void GraphvizPrinter::emitEdge(const Node *from, const Node *to, + const std::string &label, + const std::string &style) { + if (!from || !to) + return; + + file << "\t" << (size_t)from << " -> " << (size_t)to; + if (!style.empty() || !label.empty()) { + file << "["; + if (!style.empty()) + file << style; + if (!label.empty()) + file << " label=\"" << label << "\""; + file << "]"; + } + file << ";" << std::endl; +} + +void GraphvizPrinter::emitProgramOrderEdge(const Node *from, const Node *to) { + emitEdge(from, to, ""); +} + +void GraphvizPrinter::emitReadFromEdge(const Node *from, const Node *to) { + emitEdge(from, to, "rf", "style=dashed, constraint=false"); +} + +void GraphvizPrinter::emitConflictEdge(const Node *from, const Node *to) { + emitEdge(from, to, "race", "style=dashed, color=red, constraint=false"); +} + +void GraphvizPrinter::emitSyncEdge(const Node *from, const Node *to) { + emitEdge(from, to, "sync", "style=bold, constraint=false"); +} + +void GraphvizPrinter::emitFillColor(const Node *n, const std::string &color) { + file << "\t" << (size_t)n << "[fillcolor = " << color << "];" << std::endl; +} + +void GraphvizPrinter::emitShape(const Node *n, const std::string &shape) { + file << "\t" << (size_t)n << "[shape = " << shape << "];" << std::endl; +} + +void GraphvizPrinter::emitConflict(const Node *n, const Conflict &conflict) { + emitFillColor(n, "red"); + // emitShape(n, "doubleoctagon"); + auto [s1, s2] = conflict.sources; + emitConflictEdge(n, s1.get()); + emitConflictEdge(n, s2.get()); +} + +GraphvizPrinter::GraphvizPrinter(std::string filename) noexcept { + file.open(filename); +} + +void GraphvizPrinter::visit(const Node *n) { + file << "digraph G {" << std::endl; + n->accept(this); + file << "}" << std::endl; +} + +void GraphvizPrinter::visitProgramOrder(const Node *n) { + if (n) { n->accept(this); + } else { file << "}" << std::endl; } - - void GraphvizPrinter::visitProgramOrder(const Node* n) { - if(n) - { - n->accept(this); - } - else - { - file << "}" << std::endl; - } - } - - void GraphvizPrinter::visitStart(const Start* n) { - file << "subgraph cluster_Thread_" << n->id << "{" << std::endl; - file << "\tlabel = \"Thread #" << n->id << "\";" << std::endl; - file << "\tcolor=black;" << std::endl; - emitNode(n, "", "shape=circle width=.3 style=filled color=black"); - emitProgramOrderEdge(n, n->next.get()); - visitProgramOrder(n->next.get()); - } - - void GraphvizPrinter::visitEnd(const End* n) { - assert(!n->next); - emitNode(n, "", "shape=doublecircle width=.2 style=empty"); - file << "}" << std::endl; - } - - void GraphvizPrinter::visitWrite(const Write* n) { - emitNode(n, "W" + n->var + " = " + to_string(n->value)); - emitProgramOrderEdge(n, n->next.get()); - visitProgramOrder(n->next.get()); - } - - void GraphvizPrinter::visitRead(const Read* n) { - emitNode(n, "R" + n->var + " = " + to_string(n->value)); - emitProgramOrderEdge(n, n->next.get()); - visitProgramOrder(n->next.get()); - - assert(n->sauce); - emitReadFromEdge(n, n->sauce.get()); - } - - void GraphvizPrinter::visitSpawn(const Spawn* n) { - emitNode(n, "Spawn " + std::to_string(n->tid)); - emitProgramOrderEdge(n, n->next.get()); - visitProgramOrder(n->next.get()); - if (n->spawned) { - emitSyncEdge(n, n->spawned.get()); - visitProgramOrder(n->spawned.get()); - } - } - - void GraphvizPrinter::visitJoin(const Join* n) { - emitNode(n, "Join " + std::to_string(n->tid)); - emitProgramOrderEdge(n, n->next.get()); - visitProgramOrder(n->next.get()); - if (n->joinee) emitSyncEdge(n->joinee.get(), n); - if (n->conflict) emitConflict(n, n->conflict.value()); - } - - void GraphvizPrinter::visitLock(const Lock* n) { - emitNode(n, "Lock " + n->var); - emitProgramOrderEdge(n, n->next.get()); - visitProgramOrder(n->next.get()); - if (n->ordered_after) emitSyncEdge(n->ordered_after.get(), n); - if (n->conflict) emitConflict(n, n->conflict.value()); - } - - void GraphvizPrinter::visitUnlock(const Unlock* n) { - emitNode(n, "Unlock " + n->var); - emitProgramOrderEdge(n, n->next.get()); - visitProgramOrder(n->next.get()); - } - - void GraphvizPrinter::visitAssertionFailure(const AssertionFailure* n) { - emitNode(n, "Assert " + n->cond); - emitFillColor(n, "red"); - // emitShape(n, "doubleoctagon"); - emitProgramOrderEdge(n, n->next.get()); - visitProgramOrder(n->next.get()); - } - - void GraphvizPrinter::visitPending(const Pending* n) { - assert(!n->next); - emitNode(n, "" + n->statement + "", "style=dashed"); - file << "}" << std::endl; - } +} + +void GraphvizPrinter::visitStart(const Start *n) { + file << "subgraph cluster_Thread_" << n->id << "{" << std::endl; + file << "\tlabel = \"Thread #" << n->id << "\";" << std::endl; + file << "\tcolor=black;" << std::endl; + emitNode(n, "", "shape=circle width=.3 style=filled color=black"); + emitProgramOrderEdge(n, n->next.get()); + visitProgramOrder(n->next.get()); +} + +void GraphvizPrinter::visitEnd(const End *n) { + assert(!n->next); + emitNode(n, "", "shape=doublecircle width=.2 style=empty"); + file << "}" << std::endl; +} + +void GraphvizPrinter::visitWrite(const Write *n) { + emitNode(n, "W" + n->var + " = " + to_string(n->value)); + emitProgramOrderEdge(n, n->next.get()); + visitProgramOrder(n->next.get()); +} + +void GraphvizPrinter::visitRead(const Read *n) { + emitNode(n, "R" + n->var + " = " + to_string(n->value)); + emitProgramOrderEdge(n, n->next.get()); + visitProgramOrder(n->next.get()); + + assert(n->sauce); + emitReadFromEdge(n, n->sauce.get()); +} + +void GraphvizPrinter::visitSpawn(const Spawn *n) { + emitNode(n, "Spawn " + std::to_string(n->tid)); + emitProgramOrderEdge(n, n->next.get()); + visitProgramOrder(n->next.get()); + if (n->spawned) { + emitSyncEdge(n, n->spawned.get()); + visitProgramOrder(n->spawned.get()); + } +} + +void GraphvizPrinter::visitJoin(const Join *n) { + emitNode(n, "Join " + std::to_string(n->tid)); + emitProgramOrderEdge(n, n->next.get()); + visitProgramOrder(n->next.get()); + if (n->joinee) + emitSyncEdge(n->joinee.get(), n); + if (n->conflict) + emitConflict(n, n->conflict.value()); +} + +void GraphvizPrinter::visitLock(const Lock *n) { + emitNode(n, "Lock " + n->var); + emitProgramOrderEdge(n, n->next.get()); + visitProgramOrder(n->next.get()); + if (n->ordered_after) + emitSyncEdge(n->ordered_after.get(), n); + if (n->conflict) + emitConflict(n, n->conflict.value()); +} + +void GraphvizPrinter::visitUnlock(const Unlock *n) { + emitNode(n, "Unlock " + n->var); + emitProgramOrderEdge(n, n->next.get()); + visitProgramOrder(n->next.get()); +} + +void GraphvizPrinter::visitAssertionFailure(const AssertionFailure *n) { + emitNode(n, "Assert " + n->cond); + emitFillColor(n, "red"); + // emitShape(n, "doubleoctagon"); + emitProgramOrderEdge(n, n->next.get()); + visitProgramOrder(n->next.get()); +} + +void GraphvizPrinter::visitPending(const Pending *n) { + assert(!n->next); + emitNode(n, "" + n->statement + "", "style=dashed"); + file << "}" << std::endl; +} } // namespace graph } // namespace gitmem diff --git a/src/graphviz.hh b/src/graphviz.hh index a75ca2f..adfef24 100644 --- a/src/graphviz.hh +++ b/src/graphviz.hh @@ -2,33 +2,36 @@ #include "graph.hh" namespace gitmem { - namespace graph { - struct GraphvizPrinter : Visitor { - void visitStart(const Start*) override; - void visitEnd(const End*) override; - void visitWrite(const Write*) override; - void visitRead(const Read*) override; - void visitSpawn(const Spawn*) override; - void visitJoin(const Join*) override; - void visitLock(const Lock*) override; - void visitUnlock(const Unlock*) override; - void visitAssertionFailure(const AssertionFailure*) override; - void visitPending(const Pending*) override; - void visit(const Node* n) override; +namespace graph { +struct GraphvizPrinter : Visitor { + void visitStart(const Start *) override; + void visitEnd(const End *) override; + void visitWrite(const Write *) override; + void visitRead(const Read *) override; + void visitSpawn(const Spawn *) override; + void visitJoin(const Join *) override; + void visitLock(const Lock *) override; + void visitUnlock(const Unlock *) override; + void visitAssertionFailure(const AssertionFailure *) override; + void visitPending(const Pending *) override; + void visit(const Node *n) override; - GraphvizPrinter(std::string filename) noexcept; - private: - std::ofstream file; - void emitNode(const Node* n, const std::string& label, const std::string& style = ""); - void emitEdge(const Node* from, const Node* to, const std::string& label, const std::string& style = ""); - void emitProgramOrderEdge(const Node* from, const Node* to); - void emitReadFromEdge(const Node* from, const Node* to); - void emitFillColor(const Node* n, const std::string& color); - void emitShape(const Node* n, const std::string& shape); - void emitConflictEdge(const Node* from, const Node* to); - void emitSyncEdge(const Node* from, const Node* to); - void emitConflict(const Node* n, const Conflict& conflict); - void visitProgramOrder(const Node* n); - }; - } -} + GraphvizPrinter(std::string filename) noexcept; + +private: + std::ofstream file; + void emitNode(const Node *n, const std::string &label, + const std::string &style = ""); + void emitEdge(const Node *from, const Node *to, const std::string &label, + const std::string &style = ""); + void emitProgramOrderEdge(const Node *from, const Node *to); + void emitReadFromEdge(const Node *from, const Node *to); + void emitFillColor(const Node *n, const std::string &color); + void emitShape(const Node *n, const std::string &shape); + void emitConflictEdge(const Node *from, const Node *to); + void emitSyncEdge(const Node *from, const Node *to); + void emitConflict(const Node *n, const Conflict &conflict); + void visitProgramOrder(const Node *n); +}; +} // namespace graph +} // namespace gitmem diff --git a/src/internal.hh b/src/internal.hh index ec01c4d..c258ab7 100644 --- a/src/internal.hh +++ b/src/internal.hh @@ -5,21 +5,21 @@ namespace gitmem { namespace lang { - using namespace trieste; +using namespace trieste; - Parse parser(); - PassDef expressions(); - PassDef statements(); - PassDef check_refs(); - PassDef branching(); +Parse parser(); +PassDef expressions(); +PassDef statements(); +PassDef check_refs(); +PassDef branching(); - inline const auto parse_token = - Reg | Var | Const | Nop | Brace | Paren | - Spawn | Join | Lock | Unlock | Assert | If | Else; +inline const auto parse_token = Reg | Var | Const | Nop | Brace | Paren | + Spawn | Join | Lock | Unlock | Assert | If | + Else; - inline const auto parse_op = Group | Assign | Eq | Neq | Add | Semi; +inline const auto parse_op = Group | Assign | Eq | Neq | Add | Semi; - // clang-format off +// clang-format off inline const wf::Wellformed parser_wf = (Top <<= File) | (File <<= ~parse_op) @@ -84,7 +84,7 @@ namespace lang { | (Jump <<= Const) | (Cond <<= Expr * Const) ; - // clang-format on +// clang-format on } // namespace lang diff --git a/src/interpreter.cc b/src/interpreter.cc index 81b889a..b892bc5 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -1,512 +1,501 @@ +#include #include #include -#include -#include "interpreter.hh" +#include "debug.hh" #include "graphviz.hh" +#include "interpreter.hh" #include "sync_protocol.hh" -#include "debug.hh" namespace gitmem { - using namespace trieste; - - /* Interpreter for a gitmem program. Threads can read and write local - * variables as well as versioned global variables. Globals are not stored - * in a single memory location but instead in the state of 'synchronising - * objects' which include threads and locks. Synchronising actions between - * threads, and between threads and locks, synchronise the versioned memory - * and if both objects see updates to the same versioned global variable - * then a data race is detected. These synchronising actions include: - * - thread t1 joining a thread t2, which waits for t2 to complete before - * trying to 'pull' the new data into t1 - * - t locking a lock l, which waits for the lock l to be available before - * trying to 'pull' the new data into t - * - t unlocking a lock l, which updates l to have t's versioned memory - */ - - bool is_syncing(Node stmt) - { - auto s = stmt / lang::Stmt; - return s == lang::Join || s == lang::Lock || s == lang::Unlock; +using namespace trieste; + +/* Interpreter for a gitmem program. Threads can read and write local + * variables as well as versioned global variables. Globals are not stored + * in a single memory location but instead in the state of 'synchronising + * objects' which include threads and locks. Synchronising actions between + * threads, and between threads and locks, synchronise the versioned memory + * and if both objects see updates to the same versioned global variable + * then a data race is detected. These synchronising actions include: + * - thread t1 joining a thread t2, which waits for t2 to complete before + * trying to 'pull' the new data into t1 + * - t locking a lock l, which waits for the lock l to be available before + * trying to 'pull' the new data into t + * - t unlocking a lock l, which updates l to have t's versioned memory + */ + +bool is_syncing(Node stmt) { + auto s = stmt / lang::Stmt; + return s == lang::Join || s == lang::Lock || s == lang::Unlock; +} + +bool is_syncing(Thread &thread) { + return !thread.terminated && is_syncing(thread.block->at(thread.pc)); +} + +template +std::shared_ptr thread_append_node(ThreadContext &ctx, Args &&...args) { + assert(ctx.tail); + auto node = std::make_shared(std::forward(args)...); + ctx.tail->next = node; + ctx.tail = node; + return node; +} + +template <> +std::shared_ptr +thread_append_node(ThreadContext &ctx, std::string &&stmt) { + // pending nodes don't update the tail position as we will destroy them + // once we execute the node + auto s = std::regex_replace(stmt, std::regex("\n"), "\\l "); + auto node = make_shared(std::move(s)); + ctx.tail->next = node; + return node; +} + +/* Evaluating an expression either returns the result of the expression or + * a the exceptional termination status of the thread. + */ +std::variant +evaluate_expression(Node expr, GlobalContext &gctx, ThreadContext &ctx) { + auto e = expr / lang::Expr; + if (e == lang::Reg) { + // It is invalid to read a previously unwritten value + auto var = std::string(expr->location().view()); + if (ctx.locals.contains(var)) { + return ctx.locals[var]; + } else { + return TerminationStatus::unassigned_variable_read_exception; + } + } else if (e == lang::Var) { + auto var = std::string(expr->location().view()); + if (std::optional result = gctx.protocol->read(ctx, var)) { + return *result; + } else { // It is invalid to read a previously unwritten value + return TerminationStatus::unassigned_variable_read_exception; + } + } else if (e == lang::Const) { + return size_t(std::stoi(std::string(e->location().view()))); + } else if (e == lang::Add) { + size_t sum = 0; + for (auto &child : *e) { + auto result = evaluate_expression(child, gctx, ctx); + if (std::holds_alternative(result)) + return result; + sum += std::get(result); + } + return sum; + } else if (e == lang::Spawn) { + ThreadID tid = gctx.threads.size(); + auto node = std::make_shared(tid); + ThreadContext child_ctx = {std::unordered_map(), node}; + gctx.threads.push_back( + std::make_shared(child_ctx, e / lang::Block)); + thread_append_node(ctx, tid, node); + + if (std::optional> conflict = + gctx.protocol->on_spawn(ctx, child_ctx, gctx)) { + assert(false); // handle this } - bool is_syncing(Thread &thread) - { - return !thread.terminated && is_syncing(thread.block->at(thread.pc)); + // Spawning is a sync point, commit local pending commits, and + // copy the global state to the spawned thread + // commit(ctx.globals); + + return tid; + } else if (e == lang::Eq || e == lang::Neq) { + auto lhs = e / lang::Lhs; + auto rhs = e / lang::Rhs; + + auto lhsEval = evaluate_expression(lhs, gctx, ctx); + if (std::holds_alternative(lhsEval)) + return lhsEval; + + auto rhsEval = evaluate_expression(rhs, gctx, ctx); + if (std::holds_alternative(rhsEval)) + return rhsEval; + + return e == lang::Eq + ? (std::get(lhsEval)) == (std::get(rhsEval)) + : (std::get(lhsEval)) != (std::get(rhsEval)); + } else { + throw std::runtime_error("Unknown expression: " + + std::string(expr->type().str())); + } +} + +/* Evaluating a statement either returns the resulting change of the program + * counter (0 if waiting for some other thread) or the exceptional + * termination status of the thread. + */ +std::variant run_statement(Node stmt, + GlobalContext &gctx, + ThreadContext &ctx, + const ThreadID &tid) { + auto s = stmt / lang::Stmt; + if (s == lang::Nop) { + + verbose << "Nop" << std::endl; + + } else if (s == lang::Jump) { + + auto cnst = s / lang::Const; + auto delta = std::stoi(std::string(cnst->location().view())); + assert(delta > 0); + return delta; + + } else if (s == lang::Cond) { + + auto expr = s / lang::Expr; + auto cnst = s / lang::Const; + auto result = evaluate_expression(expr, gctx, ctx); + + if (auto b = std::get_if(&result)) { + auto delta = std::stoi(std::string(cnst->location().view())); + assert(delta > 0); + return *b ? 1 : delta; + } else { + return std::get(result); } + } else if (s == lang::Assign) { - template - std::shared_ptr thread_append_node(ThreadContext& ctx, Args&&...args) - { - assert(ctx.tail); - auto node = std::make_shared(std::forward(args)...); - ctx.tail->next = node; - ctx.tail = node; - return node; - } + auto lhs = s / lang::LVal; + auto var = std::string(lhs->location().view()); + auto rhs = s / lang::Expr; + auto val_or_term = evaluate_expression(rhs, gctx, ctx); - template<> - std::shared_ptr thread_append_node(ThreadContext& ctx, std::string&& stmt) - { - // pending nodes don't update the tail position as we will destroy them - // once we execute the node - auto s = std::regex_replace(stmt, std::regex("\n"), "\\l "); - auto node = make_shared(std::move(s)); - ctx.tail->next = node; - return node; - } + if (size_t *val = std::get_if(&val_or_term)) { + if (lhs == lang::Reg) { + + // Local variables can be re-assigned whenever + verbose << "Set register '" << lhs->location().view() << "' to " << *val + << std::endl; + ctx.locals[var] = *val; + + } else if (lhs == lang::Var) { - /* Evaluating an expression either returns the result of the expression or - * a the exceptional termination status of the thread. - */ - std::variant evaluate_expression(Node expr, GlobalContext &gctx, ThreadContext &ctx) { - auto e = expr / lang::Expr; - if (e == lang::Reg) - { - // It is invalid to read a previously unwritten value - auto var = std::string(expr->location().view()); - if (ctx.locals.contains(var)) - { - return ctx.locals[var]; - } - else - { - return TerminationStatus::unassigned_variable_read_exception; - } - } - else if (e == lang::Var) - { - auto var = std::string(expr->location().view()); - if (std::optional result = gctx.protocol->read(ctx, var) ) { - return *result; - } else { // It is invalid to read a previously unwritten value - return TerminationStatus::unassigned_variable_read_exception; - } - } - else if (e == lang::Const) - { - return size_t(std::stoi(std::string(e->location().view()))); - } - else if (e == lang::Add) - { - size_t sum = 0; - for (auto &child : *e) - { - auto result = evaluate_expression(child, gctx, ctx); - if (std::holds_alternative(result)) return result; - sum += std::get(result); - } - return sum; - } - else if (e == lang::Spawn) - { - ThreadID tid = gctx.threads.size(); - auto node = std::make_shared(tid); - ThreadContext child_ctx = { std::unordered_map(), node }; - gctx.threads.push_back(std::make_shared(child_ctx, e / lang::Block)); - thread_append_node(ctx, tid, node); - - if (std::optional> conflict = gctx.protocol->on_spawn(ctx, child_ctx, gctx)) { - assert(false); // handle this - } - - // Spawning is a sync point, commit local pending commits, and - // copy the global state to the spawned thread - // commit(ctx.globals); - - return tid; - } - else if (e == lang::Eq || e == lang::Neq) - { - auto lhs = e / lang::Lhs; - auto rhs = e / lang::Rhs; - - auto lhsEval = evaluate_expression(lhs, gctx, ctx); - if (std::holds_alternative(lhsEval)) return lhsEval; - - auto rhsEval = evaluate_expression(rhs, gctx, ctx); - if (std::holds_alternative(rhsEval)) return rhsEval; - - return e == lang::Eq? (std::get(lhsEval)) == (std::get(rhsEval)) - : (std::get(lhsEval)) != (std::get(rhsEval)); - } - else - { - throw std::runtime_error("Unknown expression: " + std::string(expr->type().str())); - } + gctx.protocol->write(ctx, var, *val); + + // // Global variable writes need to create a new commit id + // // to track the history of updates + // auto &global = ctx.globals[var]; + // global.val = *val; + // global.commit = gctx.uuid++; + // verbose << "Set global '" << lhs->location().view() << "' to " << + // *val << " with id " << *(global.commit) << std::endl; + + // auto node = thread_append_node(ctx, var, global.val, + // *global.commit); gctx.commit_map[*(global.commit)] = node; + } else { + throw std::runtime_error("Bad left-hand side: " + + std::string(lhs->type().str())); + } + } else { + return std::get(val_or_term); + } + } else if (s == lang::Join) { + // A join must waiting for the terminating thread to continue, + // we don't want to re-evaluate the expression repeatedly as this + // may be effecting so store the result in the cache. + auto expr = s / lang::Expr; + + if (!gctx.cache.contains(expr)) { + auto val_or_term = evaluate_expression(expr, gctx, ctx); + if (size_t *val = std::get_if(&val_or_term)) { + gctx.cache[expr] = *val; + } else { + return std::get(val_or_term); + } } - /* Evaluating a statement either returns the resulting change of the program - * counter (0 if waiting for some other thread) or the exceptional - * termination status of the thread. - */ - std::variant run_statement(Node stmt, GlobalContext &gctx, ThreadContext &ctx, const ThreadID& tid) { - auto s = stmt / lang::Stmt; - if (s == lang::Nop) { - - verbose << "Nop" << std::endl; - - } else if (s == lang::Jump) { - - auto cnst = s / lang::Const; - auto delta = std::stoi(std::string(cnst->location().view())); - assert(delta > 0); - return delta; - - } else if (s == lang::Cond) { - - auto expr = s / lang::Expr; - auto cnst = s / lang::Const; - auto result = evaluate_expression(expr, gctx, ctx); - - if (auto b = std::get_if(&result)) { - auto delta = std::stoi(std::string(cnst->location().view())); - assert(delta > 0); - return *b? 1 : delta; - } else { - return std::get(result); - } - - } else if (s == lang::Assign) { - - auto lhs = s / lang::LVal; - auto var = std::string(lhs->location().view()); - auto rhs = s / lang::Expr; - auto val_or_term = evaluate_expression(rhs, gctx, ctx); - - if(size_t* val = std::get_if(&val_or_term)) { - if (lhs == lang::Reg) { - - // Local variables can be re-assigned whenever - verbose << "Set register '" << lhs->location().view() << "' to " << *val << std::endl; - ctx.locals[var] = *val; - - } else if (lhs == lang::Var) { - - gctx.protocol->write(ctx, var, *val); - - // // Global variable writes need to create a new commit id - // // to track the history of updates - // auto &global = ctx.globals[var]; - // global.val = *val; - // global.commit = gctx.uuid++; - // verbose << "Set global '" << lhs->location().view() << "' to " << *val << " with id " << *(global.commit) << std::endl; - - // auto node = thread_append_node(ctx, var, global.val, *global.commit); - // gctx.commit_map[*(global.commit)] = node; - } else { - throw std::runtime_error("Bad left-hand side: " + std::string(lhs->type().str())); - } - } else { - return std::get(val_or_term); - } - } else if (s == lang::Join) { - // A join must waiting for the terminating thread to continue, - // we don't want to re-evaluate the expression repeatedly as this - // may be effecting so store the result in the cache. - auto expr = s / lang::Expr; - - if (!gctx.cache.contains(expr)) { - auto val_or_term = evaluate_expression(expr, gctx, ctx); - if (size_t* val = std::get_if(&val_or_term)) { - gctx.cache[expr] = *val; - } else { - return std::get(val_or_term); - } - } - - auto result = gctx.cache[expr]; - auto& joinee = gctx.threads[result]; - if (joinee->terminated && (*joinee->terminated == TerminationStatus::completed)) { - if(auto conflict = gctx.protocol->on_join(ctx, joinee->ctx, gctx)) { - verbose << (**conflict) << std::endl; - return TerminationStatus::datarace_exception; - } else { - thread_append_node(ctx, result, joinee->ctx.tail); - } - - } else { - verbose << "Waiting on thread " << result << std::endl; - return 0; - } - } else if (s == lang::Lock) { - // We can only lock unlocked locks, if a lock hasn't been used - // before it is implicitly created, we then commit the pending - // updates of this thread and pull the updates from the lock. - auto v = s / lang::Var; - auto var = std::string(v->location().view()); - - Lock& lock = gctx.locks[var]; - if (lock.owner) { - verbose << "Waiting for lock " << var << " owned by " << lock.owner.value() << std::endl; - return 0; - } - - lock.owner = tid; - if(auto conflict = gctx.protocol->on_lock(ctx, lock, gctx)) { - verbose << (**conflict) << std::endl; - // using graph::Node; - // auto [s1, s2] = conflict->commits; - // auto sources = std::pair, std::shared_ptr>{gctx.commit_map[s1], gctx.commit_map[s2]}; - // auto graph_conflict = graph::Conflict(conflict->var, sources); - // thread_append_node(ctx, var, lock.last, graph_conflict); - return TerminationStatus::datarace_exception; - } - - thread_append_node(ctx, var, lock.last); - - verbose << "Locked " << var << std::endl; - } else if (s == lang::Unlock) { - assert(false && "todo"); - - // We can only unlock locks we previously locked. We commit any - // pending updates and then copy the threads versioned globals - // to the locks versioned globals (nobody could have changed - // them since we locked the lock). - - // commit(ctx.globals); - auto v = s / lang::Var; - auto var = std::string(v->location().view()); - - auto& lock = gctx.locks[var]; - if (!lock.owner || (lock.owner && *lock.owner != tid)) { - return TerminationStatus::unlock_exception; - } - - if(auto conflict = gctx.protocol->on_unlock(ctx, lock, gctx)) { - verbose << (**conflict) << std::endl; - return TerminationStatus::datarace_exception; - } - - // lock.globals = ctx.globals; - lock.owner.reset(); - - thread_append_node(ctx, var); - lock.last = ctx.tail; - - verbose << "Unlocked " << var << std::endl; - - } else if (s == lang::Assert) { - auto expr = s / lang::Expr; - auto result_or_term = evaluate_expression(expr, gctx, ctx); - if (size_t* result = std::get_if(&result_or_term)) { - if (*result) { - verbose << "Assertion passed: " << expr->location().view() << std::endl; - } else { - verbose << "Assertion failed: " << expr->location().view() << std::endl; - thread_append_node(ctx, std::string(expr->location().view())); - return TerminationStatus::assertion_failure_exception; - } - } else { - return std::get(result_or_term); - } + auto result = gctx.cache[expr]; + auto &joinee = gctx.threads[result]; + if (joinee->terminated && + (*joinee->terminated == TerminationStatus::completed)) { + if (auto conflict = gctx.protocol->on_join(ctx, joinee->ctx, gctx)) { + verbose << (**conflict) << std::endl; + return TerminationStatus::datarace_exception; } else { - throw std::runtime_error("Unknown statement: " + std::string(stmt->type().str())); + thread_append_node(ctx, result, joinee->ctx.tail); } - return 1; + + } else { + verbose << "Waiting on thread " << result << std::endl; + return 0; + } + } else if (s == lang::Lock) { + // We can only lock unlocked locks, if a lock hasn't been used + // before it is implicitly created, we then commit the pending + // updates of this thread and pull the updates from the lock. + auto v = s / lang::Var; + auto var = std::string(v->location().view()); + + Lock &lock = gctx.locks[var]; + if (lock.owner) { + verbose << "Waiting for lock " << var << " owned by " + << lock.owner.value() << std::endl; + return 0; } - /* Run a particular thread until it reaches a synchronisation point or until - * it terminates. Report whether the thread was able to progress or not, or - * whether it terminated. - */ - std::variant run_single_thread_to_sync(GlobalContext& gctx, const ThreadID tid, std::shared_ptr thread) - { - if (thread->terminated) { - return *(thread->terminated); - } - Node block = thread->block; - size_t &pc = thread->pc; - ThreadContext &ctx = thread->ctx; - - if (pc == 0) { - gctx.protocol->on_start(thread->ctx, gctx); - } - - bool first_statement = true; - while(pc < block->size()) - { - Node stmt = block->at(pc); - - if (!first_statement && is_syncing(stmt)) - { - return ProgressStatus::progress; - } - - auto delta_or_term = run_statement(stmt, gctx, ctx, tid); - if (auto term = std::get_if(&delta_or_term)) - { - thread->terminated = *term; - // thread_append_node(ctx); - return *term; - } - - auto delta = std::get(delta_or_term); - - if(delta == 0) - { - return first_statement ? ProgressStatus::no_progress : ProgressStatus::progress; - } - - pc += delta; - first_statement = false; - } - - thread->terminated = TerminationStatus::completed; - gctx.protocol->on_end(thread->ctx, gctx); - - thread_append_node(ctx); - return TerminationStatus::completed; + lock.owner = tid; + if (auto conflict = gctx.protocol->on_lock(ctx, lock, gctx)) { + verbose << (**conflict) << std::endl; + // using graph::Node; + // auto [s1, s2] = conflict->commits; + // auto sources = std::pair, + // std::shared_ptr>{gctx.commit_map[s1], gctx.commit_map[s2]}; + // auto graph_conflict = graph::Conflict(conflict->var, sources); + // thread_append_node(ctx, var, lock.last, + // graph_conflict); + return TerminationStatus::datarace_exception; } - /** - * Run a thread to the next sync point, including any threads spawned by that thread - */ - std::variant - progress_thread(GlobalContext &gctx, const ThreadID tid, std::shared_ptr thread) - { - auto no_threads = gctx.threads.size(); - auto prog_or_term = run_single_thread_to_sync(gctx, tid, thread); - - bool any_progress = std::holds_alternative(prog_or_term) && - std::get(prog_or_term) == ProgressStatus::progress; - for (size_t i = no_threads; i < gctx.threads.size(); ++i) - { - // If there are new threads, we can run them to sync as well - any_progress = true; - auto new_thread = gctx.threads[i]; - if (!is_syncing(*new_thread)) - { - verbose << "==== Thread " << i << " (spawn) ====" << std::endl; - progress_thread(gctx, i, new_thread); - } - } - - if (std::holds_alternative(prog_or_term)) - return prog_or_term; - - return any_progress ? ProgressStatus::progress : ProgressStatus::no_progress; + thread_append_node(ctx, var, lock.last); + + verbose << "Locked " << var << std::endl; + } else if (s == lang::Unlock) { + assert(false && "todo"); + + // We can only unlock locks we previously locked. We commit any + // pending updates and then copy the threads versioned globals + // to the locks versioned globals (nobody could have changed + // them since we locked the lock). + + // commit(ctx.globals); + auto v = s / lang::Var; + auto var = std::string(v->location().view()); + + auto &lock = gctx.locks[var]; + if (!lock.owner || (lock.owner && *lock.owner != tid)) { + return TerminationStatus::unlock_exception; } - /* Try to evaluate all threads until a sync point or termination point - */ - std::variant run_threads_to_sync(GlobalContext& gctx) - { - verbose << "-----------------------" << std::endl; - bool all_completed = true; - ProgressStatus any_progress = ProgressStatus::no_progress; - for (size_t i = 0; i < gctx.threads.size(); ++i) - { - verbose << "==== t" << i << " ====" << std::endl; - auto thread = gctx.threads[i]; - if (!thread->terminated) - { - auto prog_or_term = run_single_thread_to_sync(gctx, i, thread); - if (ProgressStatus* prog = std::get_if(&prog_or_term)) - { - any_progress |= *prog; - } - else - { - // We could return termination status of any error here and stop - // at the first error - thread->terminated = std::get(prog_or_term); - any_progress |= ProgressStatus::progress; - } - - all_completed &= thread->terminated.has_value(); - // if a thread spawns a new thread, it will end up at the end so - // we will always include the new threads in the termination - // criteria - } - } - - if (all_completed) return TerminationStatus::completed; - - return any_progress; + if (auto conflict = gctx.protocol->on_unlock(ctx, lock, gctx)) { + verbose << (**conflict) << std::endl; + return TerminationStatus::datarace_exception; } - bool is_finished(std::variant& prog_or_term) - { - // Either, the system is stuck and made no progress in which case there - // is a deadlock (or a thread is stuck waiting for a crashed thread?) - if (ProgressStatus* prog = std::get_if(&prog_or_term)) - return (*prog) == ProgressStatus::no_progress; + // lock.globals = ctx.globals; + lock.owner.reset(); + + thread_append_node(ctx, var); + lock.last = ctx.tail; - // Or, there was some termination criteria in which case we stop - return true; + verbose << "Unlocked " << var << std::endl; + + } else if (s == lang::Assert) { + auto expr = s / lang::Expr; + auto result_or_term = evaluate_expression(expr, gctx, ctx); + if (size_t *result = std::get_if(&result_or_term)) { + if (*result) { + verbose << "Assertion passed: " << expr->location().view() << std::endl; + } else { + verbose << "Assertion failed: " << expr->location().view() << std::endl; + thread_append_node( + ctx, std::string(expr->location().view())); + return TerminationStatus::assertion_failure_exception; + } + } else { + return std::get(result_or_term); + } + } else { + throw std::runtime_error("Unknown statement: " + + std::string(stmt->type().str())); + } + return 1; +} + +/* Run a particular thread until it reaches a synchronisation point or until + * it terminates. Report whether the thread was able to progress or not, or + * whether it terminated. + */ +std::variant +run_single_thread_to_sync(GlobalContext &gctx, const ThreadID tid, + std::shared_ptr thread) { + if (thread->terminated) { + return *(thread->terminated); + } + Node block = thread->block; + size_t &pc = thread->pc; + ThreadContext &ctx = thread->ctx; + + if (pc == 0) { + gctx.protocol->on_start(thread->ctx, gctx); + } + + bool first_statement = true; + while (pc < block->size()) { + Node stmt = block->at(pc); + + if (!first_statement && is_syncing(stmt)) { + return ProgressStatus::progress; } - /* Try to evaluate all threads until they have all terminated in some way - * or we have reached a stuck configuration. - */ - int run_threads(GlobalContext &gctx) - { - std::variant prog_or_term; - do { - prog_or_term = run_threads_to_sync(gctx); - } while (!is_finished(prog_or_term)); - - verbose << "----------- execution complete -----------" << std::endl; - - bool exception_detected = false; - for (size_t i = 0; i < gctx.threads.size(); ++i) - { - const auto& thread = gctx.threads[i]; - if (thread->terminated) - { - switch (thread->terminated.value()) - { - case TerminationStatus::completed: - verbose << "Thread " << i << " terminated normally" << std::endl; - break; - - case TerminationStatus::unlock_exception: - verbose << "Thread " << i << " unlocked a lock it does not own" << std::endl; - exception_detected = true; - break; - - case TerminationStatus::datarace_exception: - verbose << "Thread " << i << " encountered a data-race" << std::endl; - exception_detected = true; - break; - - case TerminationStatus::assertion_failure_exception: - verbose << "Thread " << i << " failed an assertion" << std::endl; - exception_detected = true; - break; - - case TerminationStatus::unassigned_variable_read_exception: - verbose << "Thread " << i << " read an uninitialised value" << std::endl; - exception_detected = true; - break; - - default: - verbose << "Thread " << i << " has an unhandled termination state" << std::endl; - break; - } - } - else - { - exception_detected = true; - thread_append_node(thread->ctx); - verbose << "Thread " << i << " is stuck" << std::endl; - } - } - - return exception_detected ? 1 : 0; + auto delta_or_term = run_statement(stmt, gctx, ctx, tid); + if (auto term = std::get_if(&delta_or_term)) { + thread->terminated = *term; + // thread_append_node(ctx); + return *term; } - int interpret(const Node ast, const std::filesystem::path &output_path) - { - // TODO: allow both protocols - GlobalContext gctx(ast, std::make_unique()); - auto result = run_threads(gctx); - // gctx.print_execution_graph(output_path); FIXME + auto delta = std::get(delta_or_term); - return result; + if (delta == 0) { + return first_statement ? ProgressStatus::no_progress + : ProgressStatus::progress; } -} // gitmem \ No newline at end of file + pc += delta; + first_statement = false; + } + + thread->terminated = TerminationStatus::completed; + gctx.protocol->on_end(thread->ctx, gctx); + + thread_append_node(ctx); + return TerminationStatus::completed; +} + +/** + * Run a thread to the next sync point, including any threads spawned by that + * thread + */ +std::variant +progress_thread(GlobalContext &gctx, const ThreadID tid, + std::shared_ptr thread) { + auto no_threads = gctx.threads.size(); + auto prog_or_term = run_single_thread_to_sync(gctx, tid, thread); + + bool any_progress = + std::holds_alternative(prog_or_term) && + std::get(prog_or_term) == ProgressStatus::progress; + for (size_t i = no_threads; i < gctx.threads.size(); ++i) { + // If there are new threads, we can run them to sync as well + any_progress = true; + auto new_thread = gctx.threads[i]; + if (!is_syncing(*new_thread)) { + verbose << "==== Thread " << i << " (spawn) ====" << std::endl; + progress_thread(gctx, i, new_thread); + } + } + + if (std::holds_alternative(prog_or_term)) + return prog_or_term; + + return any_progress ? ProgressStatus::progress : ProgressStatus::no_progress; +} + +/* Try to evaluate all threads until a sync point or termination point + */ +std::variant +run_threads_to_sync(GlobalContext &gctx) { + verbose << "-----------------------" << std::endl; + bool all_completed = true; + ProgressStatus any_progress = ProgressStatus::no_progress; + for (size_t i = 0; i < gctx.threads.size(); ++i) { + verbose << "==== t" << i << " ====" << std::endl; + auto thread = gctx.threads[i]; + if (!thread->terminated) { + auto prog_or_term = run_single_thread_to_sync(gctx, i, thread); + if (ProgressStatus *prog = std::get_if(&prog_or_term)) { + any_progress |= *prog; + } else { + // We could return termination status of any error here and stop + // at the first error + thread->terminated = std::get(prog_or_term); + any_progress |= ProgressStatus::progress; + } + + all_completed &= thread->terminated.has_value(); + // if a thread spawns a new thread, it will end up at the end so + // we will always include the new threads in the termination + // criteria + } + } + + if (all_completed) + return TerminationStatus::completed; + + return any_progress; +} + +bool is_finished( + std::variant &prog_or_term) { + // Either, the system is stuck and made no progress in which case there + // is a deadlock (or a thread is stuck waiting for a crashed thread?) + if (ProgressStatus *prog = std::get_if(&prog_or_term)) + return (*prog) == ProgressStatus::no_progress; + + // Or, there was some termination criteria in which case we stop + return true; +} + +/* Try to evaluate all threads until they have all terminated in some way + * or we have reached a stuck configuration. + */ +int run_threads(GlobalContext &gctx) { + std::variant prog_or_term; + do { + prog_or_term = run_threads_to_sync(gctx); + } while (!is_finished(prog_or_term)); + + verbose << "----------- execution complete -----------" << std::endl; + + bool exception_detected = false; + for (size_t i = 0; i < gctx.threads.size(); ++i) { + const auto &thread = gctx.threads[i]; + if (thread->terminated) { + switch (thread->terminated.value()) { + case TerminationStatus::completed: + verbose << "Thread " << i << " terminated normally" << std::endl; + break; + + case TerminationStatus::unlock_exception: + verbose << "Thread " << i << " unlocked a lock it does not own" + << std::endl; + exception_detected = true; + break; + + case TerminationStatus::datarace_exception: + verbose << "Thread " << i << " encountered a data-race" << std::endl; + exception_detected = true; + break; + + case TerminationStatus::assertion_failure_exception: + verbose << "Thread " << i << " failed an assertion" << std::endl; + exception_detected = true; + break; + + case TerminationStatus::unassigned_variable_read_exception: + verbose << "Thread " << i << " read an uninitialised value" + << std::endl; + exception_detected = true; + break; + + default: + verbose << "Thread " << i << " has an unhandled termination state" + << std::endl; + break; + } + } else { + exception_detected = true; + thread_append_node(thread->ctx); + verbose << "Thread " << i << " is stuck" << std::endl; + } + } + + return exception_detected ? 1 : 0; +} + +int interpret(const Node ast, const std::filesystem::path &output_path) { + // TODO: allow both protocols + GlobalContext gctx(ast, std::make_unique()); + auto result = run_threads(gctx); + // gctx.print_execution_graph(output_path); FIXME + + return result; +} + +} // namespace gitmem \ No newline at end of file diff --git a/src/interpreter.hh b/src/interpreter.hh index 43334c0..b6970ca 100644 --- a/src/interpreter.hh +++ b/src/interpreter.hh @@ -1,27 +1,34 @@ #pragma once -#include -#include "lang.hh" +#include "execution_state.hh" #include "graph.hh" #include "graphviz.hh" -#include "execution_state.hh" +#include "lang.hh" +#include namespace gitmem { - // Entry function - int interpret(const trieste::Node, const std::filesystem::path &output_file); +// Entry function +int interpret(const trieste::Node, const std::filesystem::path &output_file); - // Internal functions - int run_threads(GlobalContext &); +// Internal functions +int run_threads(GlobalContext &); - enum class ProgressStatus { progress, no_progress }; - inline bool operator!(ProgressStatus p) { return p == ProgressStatus::no_progress; } - inline ProgressStatus operator||(const ProgressStatus &p1, const ProgressStatus &p2) { - return (p1 == ProgressStatus::progress || p2 == ProgressStatus::progress) ? ProgressStatus::progress : ProgressStatus::no_progress; - } - inline void operator|=(ProgressStatus &p1, const ProgressStatus &p2) { p1 = (p1 || p2); } +enum class ProgressStatus { progress, no_progress }; +inline bool operator!(ProgressStatus p) { + return p == ProgressStatus::no_progress; +} +inline ProgressStatus operator||(const ProgressStatus &p1, + const ProgressStatus &p2) { + return (p1 == ProgressStatus::progress || p2 == ProgressStatus::progress) + ? ProgressStatus::progress + : ProgressStatus::no_progress; +} +inline void operator|=(ProgressStatus &p1, const ProgressStatus &p2) { + p1 = (p1 || p2); +} - std::variant - progress_thread(GlobalContext &, const ThreadID, std::shared_ptr); +std::variant +progress_thread(GlobalContext &, const ThreadID, std::shared_ptr); } // namespace gitmem \ No newline at end of file diff --git a/src/lang.hh b/src/lang.hh index 3c387e2..e244b99 100644 --- a/src/lang.hh +++ b/src/lang.hh @@ -1,62 +1,61 @@ #pragma once #include -namespace gitmem -{ +namespace gitmem { namespace lang { - using namespace trieste; - - Reader reader(); - - // Variables - inline const auto Reg = TokenDef("reg", flag::print); - inline const auto Var = TokenDef("var", flag::print); - - // Constants - inline const auto Const = TokenDef("const", flag::print); - - // Arithmetic - inline const auto Add = TokenDef("+"); - - // Comparison - inline const auto Eq = TokenDef("=="); - inline const auto Neq = TokenDef("!="); - - // Statements - inline const auto Semi = TokenDef(";"); - inline const auto Assign = TokenDef("=", flag::lookup); - inline const auto Spawn = TokenDef("spawn"); - inline const auto Join = TokenDef("join"); - inline const auto Lock = TokenDef("lock"); - inline const auto Unlock = TokenDef("unlock"); - inline const auto Nop = TokenDef("nop"); - inline const auto Assert = TokenDef("assert"); - inline const auto If = TokenDef("if"); - inline const auto Else = TokenDef("else"); - - // Branching - inline const auto Jump = TokenDef("jump"); - inline const auto Cond = TokenDef("cond"); - - // Grouping tokens - inline const auto Brace = TokenDef("brace"); - inline const auto Paren = TokenDef("paren"); - - inline const auto Stmt = TokenDef("stmt"); - inline const auto Expr = TokenDef("expr"); - inline const auto Block = TokenDef("block", flag::symtab | flag::defbeforeuse); - - // Convenience - inline const auto LVal = TokenDef("lval"); - inline const auto Lhs = TokenDef("lhs"); - inline const auto Rhs = TokenDef("rhs"); - inline const auto Op = TokenDef("op"); - inline const auto Then = TokenDef("then"); - - // Well-formedness - // clang-format off +using namespace trieste; + +Reader reader(); + +// Variables +inline const auto Reg = TokenDef("reg", flag::print); +inline const auto Var = TokenDef("var", flag::print); + +// Constants +inline const auto Const = TokenDef("const", flag::print); + +// Arithmetic +inline const auto Add = TokenDef("+"); + +// Comparison +inline const auto Eq = TokenDef("=="); +inline const auto Neq = TokenDef("!="); + +// Statements +inline const auto Semi = TokenDef(";"); +inline const auto Assign = TokenDef("=", flag::lookup); +inline const auto Spawn = TokenDef("spawn"); +inline const auto Join = TokenDef("join"); +inline const auto Lock = TokenDef("lock"); +inline const auto Unlock = TokenDef("unlock"); +inline const auto Nop = TokenDef("nop"); +inline const auto Assert = TokenDef("assert"); +inline const auto If = TokenDef("if"); +inline const auto Else = TokenDef("else"); + +// Branching +inline const auto Jump = TokenDef("jump"); +inline const auto Cond = TokenDef("cond"); + +// Grouping tokens +inline const auto Brace = TokenDef("brace"); +inline const auto Paren = TokenDef("paren"); + +inline const auto Stmt = TokenDef("stmt"); +inline const auto Expr = TokenDef("expr"); +inline const auto Block = TokenDef("block", flag::symtab | flag::defbeforeuse); + +// Convenience +inline const auto LVal = TokenDef("lval"); +inline const auto Lhs = TokenDef("lhs"); +inline const auto Rhs = TokenDef("rhs"); +inline const auto Op = TokenDef("op"); +inline const auto Then = TokenDef("then"); + +// Well-formedness +// clang-format off inline const wf::Wellformed wf = (Top <<= File) | (File <<= Block) @@ -75,7 +74,7 @@ namespace lang { | (Jump <<= Const) | (Cond <<= Expr * Const) ; - // clang-format on +// clang-format on } // namespace lang diff --git a/src/linear/version_store.cc b/src/linear/version_store.cc index 227fe14..dc2404b 100644 --- a/src/linear/version_store.cc +++ b/src/linear/version_store.cc @@ -1,8 +1,8 @@ -#include #include +#include -#include "version_store.hh" #include "sync_protocol.hh" +#include "version_store.hh" namespace gitmem { @@ -16,13 +16,9 @@ void LocalVersionStore::stage(ObjectNumber obj, Value value) { _staging[obj] = value; } -void LocalVersionStore::clear_staging() { - _staging.clear(); -} +void LocalVersionStore::clear_staging() { _staging.clear(); } -void LocalVersionStore::advance_base(Timestamp ts) { - _base_timestamp = ts; -} +void LocalVersionStore::advance_base(Timestamp ts) { _base_timestamp = ts; } std::optional LocalVersionStore::get_staged(ObjectNumber obj) { auto it = _staging.find(obj); @@ -45,7 +41,7 @@ ObjectNumber GlobalVersionStore::get_object_number(std::string var) { } std::string GlobalVersionStore::get_object_name(ObjectNumber find) { - for (const auto& [name, number] : _object_numbers) { + for (const auto &[name, number] : _object_numbers) { if (number == find) return name; } @@ -53,14 +49,17 @@ std::string GlobalVersionStore::get_object_name(ObjectNumber find) { return ""; } -std::optional GlobalVersionStore::get_version_for_timestamp(ObjectNumber obj, Timestamp ts) const { +std::optional +GlobalVersionStore::get_version_for_timestamp(ObjectNumber obj, + Timestamp ts) const { const auto it = _history.find(obj); if (it == _history.end()) return std::nullopt; - const VersionHistory& history = it->second; - for (VersionHistory::const_reverse_iterator riter = history.rbegin(); riter != history.rend(); ++riter) { + const VersionHistory &history = it->second; + for (VersionHistory::const_reverse_iterator riter = history.rbegin(); + riter != history.rend(); ++riter) { if (riter->timestamp() <= ts) return riter->value(); } @@ -69,37 +68,31 @@ std::optional GlobalVersionStore::get_version_for_timestamp(ObjectNumber } std::optional GlobalVersionStore::check_conflicts( - Timestamp base, - const std::unordered_map& changes -) const { - for (const auto& [obj, _] : changes) { + Timestamp base, + const std::unordered_map &changes) const { + for (const auto &[obj, _] : changes) { auto it = _history.find(obj); if (it == _history.end()) { continue; } - const Version& latest = it->second.back(); + const Version &latest = it->second.back(); if (latest.timestamp() > base) { return Conflict{ - .object = obj, - .local_base = base, - .global_head = latest.timestamp() - }; + .object = obj, .local_base = base, .global_head = latest.timestamp()}; } } return std::nullopt; } Timestamp GlobalVersionStore::apply_changes( - Timestamp base, - const std::unordered_map& changes -) { + Timestamp base, const std::unordered_map &changes) { if (auto conflict = check_conflicts(base, changes)) { throw std::logic_error("apply_changes called with conflicts"); } Timestamp new_ts = ++_timestamp; - for (const auto& [obj, value] : changes) { + for (const auto &[obj, value] : changes) { _history[obj].emplace_back(new_ts, value); } diff --git a/src/linear/version_store.hh b/src/linear/version_store.hh index a0018bf..81210fa 100644 --- a/src/linear/version_store.hh +++ b/src/linear/version_store.hh @@ -1,10 +1,10 @@ #pragma once -#include -#include -#include #include #include +#include +#include +#include namespace gitmem { @@ -19,9 +19,9 @@ class LargeCounter { uint64_t _counter{0}; public: - auto operator<=>(const LargeCounter&) const = default; + auto operator<=>(const LargeCounter &) const = default; - LargeCounter& operator++() { + LargeCounter &operator++() { if (_counter == UINT64_MAX) { _counter = 0; assert(_epoch != UINT64_MAX && "timestamp overflow"); @@ -38,7 +38,8 @@ public: return old; } - friend std::ostream& operator<<(std::ostream& os, const LargeCounter& counter) { + friend std::ostream &operator<<(std::ostream &os, + const LargeCounter &counter) { os << counter._epoch << ":" << counter._counter; return os; } @@ -57,8 +58,7 @@ class Version { Value _value; public: - Version(Timestamp ts, Value value) - : _timestamp(ts), _value(value) {} + Version(Timestamp ts, Value value) : _timestamp(ts), _value(value) {} Timestamp timestamp() const { return _timestamp; } Value value() const { return _value; } @@ -86,7 +86,7 @@ class LocalVersionStore { public: Timestamp base_timestamp() const { return _base_timestamp; } - const auto& staged_changes() const { return _staging; } + const auto &staged_changes() const { return _staging; } void stage(ObjectNumber obj, Value value); void clear_staging(); @@ -112,15 +112,13 @@ public: std::optional get_version_for_timestamp(ObjectNumber, Timestamp) const; - std::optional check_conflicts( - Timestamp base, - const std::unordered_map& changes - ) const; + std::optional + check_conflicts(Timestamp base, + const std::unordered_map &changes) const; - Timestamp apply_changes( - Timestamp base, - const std::unordered_map& changes - ); + Timestamp + apply_changes(Timestamp base, + const std::unordered_map &changes); }; } // namespace linear diff --git a/src/main.cc b/src/main.cc index bb5a948..db19c05 100644 --- a/src/main.cc +++ b/src/main.cc @@ -1,8 +1,6 @@ -#include #include "reader.cc" +#include - -int main(int argc, char** argv) -{ +int main(int argc, char **argv) { return trieste::Driver(grunq::reader()).run(argc, argv); } diff --git a/src/model_checker.cc b/src/model_checker.cc index bdd3af5..450d97c 100644 --- a/src/model_checker.cc +++ b/src/model_checker.cc @@ -1,217 +1,199 @@ #include "model_checker.hh" #include "interpreter.hh" -namespace gitmem -{ - using namespace trieste; - - /** - * A TraceNode represents a point in the space of possible schedulings. A - * path in a tree of TraceNodes represents a scheduling, with the thread ID - * of each node being the thread that was scheduled at that point. When - * there are no more children to explore, or when one thread has crashed, - * the TraceNode is marked as complete so that the next run will not explore - * it again. - * - */ - struct TraceNode - { - size_t tid_; - bool complete; - std::vector> children; - - TraceNode(const size_t tid) : tid_(tid), complete(false) {} - - std::shared_ptr extend(ThreadID tid) - { - children.push_back(std::make_shared(tid)); - return children.back(); - } +namespace gitmem { +using namespace trieste; + +/** + * A TraceNode represents a point in the space of possible schedulings. A + * path in a tree of TraceNodes represents a scheduling, with the thread ID + * of each node being the thread that was scheduled at that point. When + * there are no more children to explore, or when one thread has crashed, + * the TraceNode is marked as complete so that the next run will not explore + * it again. + * + */ +struct TraceNode { + size_t tid_; + bool complete; + std::vector> children; + + TraceNode(const size_t tid) : tid_(tid), complete(false) {} + + std::shared_ptr extend(ThreadID tid) { + children.push_back(std::make_shared(tid)); + return children.back(); + } + + bool is_leaf() const { return children.empty(); } +}; + +/** + * Print the traces of the program, one trace per line. Each trace is a + * sequence of thread IDs that were scheduled in that order. + */ +template +void print_traces(S &stream, const std::vector> &traces) { + for (const auto &trace : traces) { + for (const auto &tid : trace) { + stream << tid << " "; + } + stream << std::endl; + } +} - bool is_leaf() const - { - return children.empty(); - } - }; - - /** - * Print the traces of the program, one trace per line. Each trace is a - * sequence of thread IDs that were scheduled in that order. - */ - template - void print_traces(S &stream, const std::vector> &traces) - { - for (const auto &trace : traces) - { - for (const auto &tid : trace) - { - stream << tid << " "; - } - stream << std::endl; +/** Build an output path for the execution graph, appending an index to the + * filename to avoid overwriting previous graphs. */ +std::filesystem::path +build_output_path(const std::filesystem::path &output_path, const size_t idx) { + auto parent = output_path.parent_path(); + auto name = output_path.stem().string(); + auto ext = output_path.extension().string(); + return parent / (name + "_" + std::to_string(idx) + ext); +} + +/** + * Explore all possible execution paths of the program, printing one trace + * for each distinct final state that led to an error. + */ +int model_check(const Node ast, const std::filesystem::path &output_path) { + GlobalContext gctx(ast); + + auto final_contexts = std::vector{}; + auto failing_contexts = std::vector{}; + auto deadlocked_contexts = std::vector{}; + + auto final_traces = std::vector>{}; + auto failing_traces = std::vector>{}; + auto deadlocked_traces = std::vector>{}; + + const auto root = std::make_shared(0); + auto cursor = root; + auto current_trace = std::vector{0}; // Start with the main thread + verbose << "==== Thread " << cursor->tid_ << " ====" << std::endl; + progress_thread(gctx, cursor->tid_, gctx.threads[cursor->tid_]); + + while (!root->complete) { + while (!cursor->children.empty() && !cursor->children.back()->complete) { + // We have a child that is not complete, we can extend that trace + cursor = cursor->children.back(); + current_trace.push_back(cursor->tid_); + verbose << "==== Thread " << cursor->tid_ + << " (replay) ====" << std::endl; + progress_thread(gctx, cursor->tid_, gctx.threads[cursor->tid_]); + } + + // Try to find a thread to schedule next + size_t start_idx = + cursor->children.empty() ? 0 : cursor->children.back()->tid_ + 1; + size_t no_threads = gctx.threads.size(); + bool made_progress = false; + for (size_t i = start_idx; i < no_threads && !made_progress; ++i) { + auto thread = gctx.threads[i]; + if (!thread->terminated) { + // Run the thread to the next sync point + verbose << "==== Thread " << i << " ====" << std::endl; + auto prog_or_term = progress_thread(gctx, i, thread); + if (std::holds_alternative(prog_or_term)) { + // Thread terminated, we can extend the trace + made_progress = true; + cursor = cursor->extend(i); + current_trace.push_back(i); + if (std::get(prog_or_term) != + TerminationStatus::completed) { + // Thread terminated with an error, we can stop here + verbose << "Thread " << i << " terminated with an error" + << std::endl; + cursor->complete = true; + } + } else if (std::get(prog_or_term) == + ProgressStatus::progress) { + // Thread made progress, we can continue + made_progress = true; + cursor = cursor->extend(i); + current_trace.push_back(i); } + } } - /** Build an output path for the execution graph, appending an index to the - * filename to avoid overwriting previous graphs. */ - std::filesystem::path build_output_path(const std::filesystem::path &output_path, const size_t idx) - { - auto parent = output_path.parent_path(); - auto name = output_path.stem().string(); - auto ext = output_path.extension().string(); - return parent / (name + "_" + std::to_string(idx) + ext); + if (!made_progress) { + // No threads made progress, we can stop here + cursor->complete = true; } - /** - * Explore all possible execution paths of the program, printing one trace - * for each distinct final state that led to an error. - */ - int model_check(const Node ast, const std::filesystem::path &output_path) - { - GlobalContext gctx(ast); - - auto final_contexts = std::vector{}; - auto failing_contexts = std::vector{}; - auto deadlocked_contexts = std::vector{}; - - auto final_traces = std::vector>{}; - auto failing_traces = std::vector>{}; - auto deadlocked_traces = std::vector>{}; - - const auto root = std::make_shared(0); - auto cursor = root; - auto current_trace = std::vector{0}; // Start with the main thread - verbose << "==== Thread " << cursor->tid_ << " ====" << std::endl; - progress_thread(gctx, cursor->tid_, gctx.threads[cursor->tid_]); - - while (!root->complete) - { - while (!cursor->children.empty() && !cursor->children.back()->complete) - { - // We have a child that is not complete, we can extend that trace - cursor = cursor->children.back(); - current_trace.push_back(cursor->tid_); - verbose << "==== Thread " << cursor->tid_ << " (replay) ====" << std::endl; - progress_thread(gctx, cursor->tid_, gctx.threads[cursor->tid_]); - } - - // Try to find a thread to schedule next - size_t start_idx = cursor->children.empty() ? 0 : cursor->children.back()->tid_ + 1; - size_t no_threads = gctx.threads.size(); - bool made_progress = false; - for (size_t i = start_idx; i < no_threads && !made_progress; ++i) - { - auto thread = gctx.threads[i]; - if (!thread->terminated) - { - // Run the thread to the next sync point - verbose << "==== Thread " << i << " ====" << std::endl; - auto prog_or_term = progress_thread(gctx, i, thread); - if (std::holds_alternative(prog_or_term)) - { - // Thread terminated, we can extend the trace - made_progress = true; - cursor = cursor->extend(i); - current_trace.push_back(i); - if (std::get(prog_or_term) != TerminationStatus::completed) - { - // Thread terminated with an error, we can stop here - verbose << "Thread " << i << " terminated with an error" << std::endl; - cursor->complete = true; - } - } - else if (std::get(prog_or_term) == ProgressStatus::progress) - { - // Thread made progress, we can continue - made_progress = true; - cursor = cursor->extend(i); - current_trace.push_back(i); - } - } - } - - if (!made_progress) - { - // No threads made progress, we can stop here - cursor->complete = true; - } - - bool all_completed = std::all_of(gctx.threads.begin(), gctx.threads.end(), - [](const auto &thread) - { return thread->terminated && *thread->terminated == TerminationStatus::completed; }); - bool any_crashed = - std::any_of(gctx.threads.begin(), gctx.threads.end(), - [](const auto &thread) - { return thread->terminated && *thread->terminated != TerminationStatus::completed; }); - - bool is_deadlock = !all_completed && !made_progress && cursor->is_leaf(); - - if (all_completed || any_crashed || is_deadlock) - { - // Remember final state if it is new - if (!std::any_of(final_contexts.begin(), final_contexts.end(), - [&gctx](const GlobalContext &state) - { return state == gctx; })) - { - final_contexts.push_back(gctx); - final_traces.push_back(current_trace); - if (any_crashed) - { - failing_traces.push_back(current_trace); - failing_contexts.push_back(gctx); - } - else if (is_deadlock) - { - deadlocked_traces.push_back(current_trace); - deadlocked_contexts.push_back(gctx); - } - } - - cursor->complete = true; - } - - if (cursor->complete && !root->complete) - { - // Reset the cursor to the root and start a new trace - verbose << std::endl - << "Restarting trace..." << std::endl; - gctx = GlobalContext(ast); - - cursor = root; - current_trace.clear(); - current_trace.push_back(0); // Start with the main thread again - verbose << "==== Thread " << cursor->tid_ << " (replay) ====" << std::endl; - progress_thread(gctx, cursor->tid_, gctx.threads[cursor->tid_]); - } + bool all_completed = std::all_of( + gctx.threads.begin(), gctx.threads.end(), [](const auto &thread) { + return thread->terminated && + *thread->terminated == TerminationStatus::completed; + }); + bool any_crashed = std::any_of( + gctx.threads.begin(), gctx.threads.end(), [](const auto &thread) { + return thread->terminated && + *thread->terminated != TerminationStatus::completed; + }); + + bool is_deadlock = !all_completed && !made_progress && cursor->is_leaf(); + + if (all_completed || any_crashed || is_deadlock) { + // Remember final state if it is new + if (!std::any_of( + final_contexts.begin(), final_contexts.end(), + [&gctx](const GlobalContext &state) { return state == gctx; })) { + final_contexts.push_back(gctx); + final_traces.push_back(current_trace); + if (any_crashed) { + failing_traces.push_back(current_trace); + failing_contexts.push_back(gctx); + } else if (is_deadlock) { + deadlocked_traces.push_back(current_trace); + deadlocked_contexts.push_back(gctx); } + } - verbose << "Found a total of " << final_traces.size() << " trace(s) with distinct final states:" << std::endl; - print_traces(verbose, final_traces); + cursor->complete = true; + } - size_t idx = 0; - if (!failing_traces.empty()) - { - std::cout << "Found " << failing_traces.size() << " trace(s) with errors:" << std::endl; - print_traces(std::cout, failing_traces); + if (cursor->complete && !root->complete) { + // Reset the cursor to the root and start a new trace + verbose << std::endl << "Restarting trace..." << std::endl; + gctx = GlobalContext(ast); + + cursor = root; + current_trace.clear(); + current_trace.push_back(0); // Start with the main thread again + verbose << "==== Thread " << cursor->tid_ + << " (replay) ====" << std::endl; + progress_thread(gctx, cursor->tid_, gctx.threads[cursor->tid_]); + } + } - for (const auto &ctx : failing_contexts) - { - auto path = build_output_path(output_path, idx++); - ctx.print_execution_graph(path); - } - } + verbose << "Found a total of " << final_traces.size() + << " trace(s) with distinct final states:" << std::endl; + print_traces(verbose, final_traces); - if (!deadlocked_traces.empty()) - { - std::cout << "Found " << deadlocked_traces.size() << " trace(s) leading to deadlock:" << std::endl; - print_traces(std::cout, deadlocked_traces); + size_t idx = 0; + if (!failing_traces.empty()) { + std::cout << "Found " << failing_traces.size() + << " trace(s) with errors:" << std::endl; + print_traces(std::cout, failing_traces); - for (const auto &ctx : deadlocked_contexts) - { - auto path = build_output_path(output_path, idx++); - ctx.print_execution_graph(path); - } - } + for (const auto &ctx : failing_contexts) { + auto path = build_output_path(output_path, idx++); + ctx.print_execution_graph(path); + } + } + + if (!deadlocked_traces.empty()) { + std::cout << "Found " << deadlocked_traces.size() + << " trace(s) leading to deadlock:" << std::endl; + print_traces(std::cout, deadlocked_traces); - return deadlocked_traces.empty() && failing_traces.empty() ? 0 : 1; + for (const auto &ctx : deadlocked_contexts) { + auto path = build_output_path(output_path, idx++); + ctx.print_execution_graph(path); } + } + + return deadlocked_traces.empty() && failing_traces.empty() ? 0 : 1; } +} // namespace gitmem diff --git a/src/parser.cc b/src/parser.cc index 0c6faa3..212e574 100644 --- a/src/parser.cc +++ b/src/parser.cc @@ -1,142 +1,140 @@ -#include "lang.hh" #include "internal.hh" +#include "lang.hh" namespace gitmem { namespace lang { - using namespace trieste; - using namespace trieste::detail; - - Parse parser() - { - Parse p(depth::file, parser_wf); - auto infix = [](Make& m, Token t) { - // This precedence table maps infix operators to the operators that have - // *higher* precedence, and which should therefore be terminated when that - // operator is encountered. Note that operators with the same precedence - // terminate each other. (for reasons, it has to be defined inside the lambda) - const auto precedence_table = std::map> { - {Add, {}}, - {Eq, {Add}}, - {Neq, {Add}}, +using namespace trieste; +using namespace trieste::detail; + +Parse parser() { + Parse p(depth::file, parser_wf); + auto infix = [](Make &m, Token t) { + // This precedence table maps infix operators to the operators that have + // *higher* precedence, and which should therefore be terminated when that + // operator is encountered. Note that operators with the same precedence + // terminate each other. (for reasons, it has to be defined inside the + // lambda) + const auto precedence_table = std::map>{ + {Add, {}}, + {Eq, {Add}}, + {Neq, {Add}}, {Assign, {Add, Eq, Neq}}, - }; - - auto skip = precedence_table.at(t); - m.seq(t, skip); - // Push group to be able to check whether an operand follows - m.push(Group); }; -/* - auto pair_with = [pop_until](Make &m, Token preceding, Token following) { - pop_until(m, preceding, {Paren, Brace, File}); - m.term(); + auto skip = precedence_table.at(t); + m.seq(t, skip); + // Push group to be able to check whether an operand follows + m.push(Group); + }; - if (!m.in(preceding)) { - const std::string msg = (std::string) "Unexpected '" + following.str() + "'"; - m.error(msg); - return; - } + /* + auto pair_with = [pop_until](Make &m, Token preceding, Token following) { + pop_until(m, preceding, {Paren, Brace, File}); + m.term(); - m.pop(preceding); - m.push(following); - }; -*/ + if (!m.in(preceding)) { + const std::string msg = (std::string) "Unexpected '" + following.str() + + "'"; m.error(msg); return; + } - auto pop_until = [](Make &m, Token t, std::initializer_list stop = {File}) { - while (!m.in(t) && !m.group_in(t) - && !m.in(stop) && !m.group_in(stop)) { - m.term(); - m.pop(); - } + m.pop(preceding); + m.push(following); + }; + */ - return (m.in(t) || m.group_in(t)); - }; + auto pop_until = [](Make &m, Token t, + std::initializer_list stop = {File}) { + while (!m.in(t) && !m.group_in(t) && !m.in(stop) && !m.group_in(stop)) { + m.term(); + m.pop(); + } - p("start", + return (m.in(t) || m.group_in(t)); + }; + + p("start", { // Whitespace - "[[:space:]]+" >> [](auto&) { }, // no-op + "[[:space:]]+" >> [](auto &) {}, // no-op // Line comment - "//[^\n]*" >> [](auto&) { }, // no-op + "//[^\n]*" >> [](auto &) {}, // no-op // Constant - "[[:digit:]]+" >> [](auto& m) { m.add(Const); }, + "[[:digit:]]+" >> [](auto &m) { m.add(Const); }, // Addition - R"(\+)" >> [infix](auto& m) { infix(m, Add); }, + R"(\+)" >> [infix](auto &m) { infix(m, Add); }, // Comparison - "==" >> [infix](auto& m) { infix(m, Eq); }, - "!=" >> [infix](auto& m) { infix(m, Neq); }, + "==" >> [infix](auto &m) { infix(m, Eq); }, + "!=" >> [infix](auto &m) { infix(m, Neq); }, // Statements - ";" >> [](auto& m) { m.seq(Semi, {Assign, Spawn, Join, Lock, Unlock, Assert, If, Else, Eq, Neq, Add, Group}); }, - "=" >> [infix](auto& m) { infix(m, Assign); }, - "spawn" >> [](auto& m) { m.push(Spawn); }, - "join" >> [](auto& m) { m.push(Join); }, - "lock" >> [](auto& m) { m.push(Lock); }, - "unlock" >> [](auto& m) { m.push(Unlock); }, - "assert" >> [](auto& m) { m.push(Assert); }, - "nop" >> [](auto& m) { m.add(Nop); }, - - "if" >> [](auto& m) { m.push(If); }, - "else" >> [pop_until](auto &m) - { - pop_until(m, Semi, {Brace, Paren, File}); - m.push(Else); - }, + ";" >> + [](auto &m) { + m.seq(Semi, {Assign, Spawn, Join, Lock, Unlock, Assert, If, Else, + Eq, Neq, Add, Group}); + }, + "=" >> [infix](auto &m) { infix(m, Assign); }, + "spawn" >> [](auto &m) { m.push(Spawn); }, + "join" >> [](auto &m) { m.push(Join); }, + "lock" >> [](auto &m) { m.push(Lock); }, + "unlock" >> [](auto &m) { m.push(Unlock); }, + "assert" >> [](auto &m) { m.push(Assert); }, + "nop" >> [](auto &m) { m.add(Nop); }, + + "if" >> [](auto &m) { m.push(If); }, + "else" >> + [pop_until](auto &m) { + pop_until(m, Semi, {Brace, Paren, File}); + m.push(Else); + }, // Variables - R"(\$[_[:alpha:]][_[:alnum:]]*)" >> [](auto& m) { m.add(Reg); }, - R"([_[:alpha:]][_[:alnum:]]*)" >> [](auto& m) { m.add(Var); }, + R"(\$[_[:alpha:]][_[:alnum:]]*)" >> [](auto &m) { m.add(Reg); }, + R"([_[:alpha:]][_[:alnum:]]*)" >> [](auto &m) { m.add(Var); }, // Grouping - "\\{" >> [](auto& m) { m.push(Brace); }, - "\\}" >> [pop_until](auto& m) - { - pop_until(m, Brace, {Paren}); - m.term(); - m.pop(Brace); - m.extend(Brace); - if (m.group_in(If)) - { - m.term(); - m.pop(If); - } - else if (m.group_in(Else)) - { - m.term(); - m.pop(Else); - } - if (m.group_in({Semi, Brace, File})) - { - m.seq(Semi); - } - }, - - "\\(" >> [](auto& m) { m.push(Paren); }, - "\\)" >> [pop_until](auto& m) - { - pop_until(m, Paren, {Brace}); - m.term(); - m.pop(Paren); - m.extend(Paren); - }, - } - ); - - p.done([pop_until](auto& m) { - if (!m.in(Semi)) - m.error("Expected ';' at end of file"); - pop_until(m, File, {Brace, Paren}); + "\\{" >> [](auto &m) { m.push(Brace); }, + "\\}" >> + [pop_until](auto &m) { + pop_until(m, Brace, {Paren}); + m.term(); + m.pop(Brace); + m.extend(Brace); + if (m.group_in(If)) { + m.term(); + m.pop(If); + } else if (m.group_in(Else)) { + m.term(); + m.pop(Else); + } + if (m.group_in({Semi, Brace, File})) { + m.seq(Semi); + } + }, + + "\\(" >> [](auto &m) { m.push(Paren); }, + "\\)" >> + [pop_until](auto &m) { + pop_until(m, Paren, {Brace}); + m.term(); + m.pop(Paren); + m.extend(Paren); + }, }); - return p; - } + p.done([pop_until](auto &m) { + if (!m.in(Semi)) + m.error("Expected ';' at end of file"); + pop_until(m, File, {Brace, Paren}); + }); + + return p; +} } // namespace lang diff --git a/src/passes/branching.cc b/src/passes/branching.cc index e7e018c..cf589bd 100644 --- a/src/passes/branching.cc +++ b/src/passes/branching.cc @@ -3,31 +3,31 @@ namespace gitmem { namespace lang { - using namespace trieste; +using namespace trieste; - PassDef branching() - { - return { - "branching", - branching_wf, - dir::bottomup | dir::once, - { - T(Stmt) << (T(If)[If] << (T(Expr)[Expr] * T(Block)[Then] * T(Block)[Else])) >> - [](Match &_) -> Node - { - auto then_length = std::to_string(_(Then)->size() + 1 + 1); // +1 for the jump - auto else_length = std::to_string(_(Else)->size() + 1); - auto cond_loc = Location("if (" + std::string(_(Expr)->location().view()) + ") jump " + then_length); - auto jump_loc = Location("jump " + else_length); - auto cond = (Stmt ^ cond_loc) << (Cond << _(Expr) << (Const ^ then_length)); - auto jump = (Stmt ^ jump_loc) << (Jump << (Const ^ else_length)); - return Seq << cond - << *_(Then) - << jump - << *_(Else); - }, - }}; - } +PassDef branching() { + return {"branching", + branching_wf, + dir::bottomup | dir::once, + { + T(Stmt) << (T(If)[If] << (T(Expr)[Expr] * T(Block)[Then] * + T(Block)[Else])) >> + [](Match &_) -> Node { + auto then_length = + std::to_string(_(Then)->size() + 1 + 1); // +1 for the jump + auto else_length = std::to_string(_(Else)->size() + 1); + auto cond_loc = + Location("if (" + std::string(_(Expr)->location().view()) + + ") jump " + then_length); + auto jump_loc = Location("jump " + else_length); + auto cond = (Stmt ^ cond_loc) + << (Cond << _(Expr) << (Const ^ then_length)); + auto jump = (Stmt ^ jump_loc) + << (Jump << (Const ^ else_length)); + return Seq << cond << *_(Then) << jump << *_(Else); + }, + }}; +} } // namespace lang diff --git a/src/passes/check_refs.cc b/src/passes/check_refs.cc index 664a6b4..1cdd448 100644 --- a/src/passes/check_refs.cc +++ b/src/passes/check_refs.cc @@ -4,30 +4,25 @@ namespace gitmem { namespace lang { - using namespace trieste; +using namespace trieste; - PassDef check_refs() - { - return { - "check_refs", - statements_wf, - dir::bottomup | dir::once, - { - In(Expr) * T(Reg)[Reg] >> - [](Match &_) -> Node - { - auto reg = _(Reg); - auto enclosing_block = reg->scope(); - auto bindings = reg->lookup(enclosing_block); - if (bindings.empty()) - { - return Error << (ErrorAst << _(Reg)) - << (ErrorMsg ^ "Register has not been assigned"); - } - return NoChange; - }, - }}; - } +PassDef check_refs() { + return {"check_refs", + statements_wf, + dir::bottomup | dir::once, + { + In(Expr) * T(Reg)[Reg] >> [](Match &_) -> Node { + auto reg = _(Reg); + auto enclosing_block = reg->scope(); + auto bindings = reg->lookup(enclosing_block); + if (bindings.empty()) { + return Error << (ErrorAst << _(Reg)) + << (ErrorMsg ^ "Register has not been assigned"); + } + return NoChange; + }, + }}; +} } // namespace lang diff --git a/src/passes/expressions.cc b/src/passes/expressions.cc index e945315..152a413 100644 --- a/src/passes/expressions.cc +++ b/src/passes/expressions.cc @@ -1,143 +1,101 @@ #include "../internal.hh" -namespace gitmem -{ +namespace gitmem { namespace lang { - using namespace trieste; - - PassDef expressions() - { - auto Operand = T(Expr) << (T(Reg, Var, Const, Add)); - return { - "expressions", - expressions_wf, - dir::bottomup, - { - --In(Expr) * T(Const, Reg, Var)[Expr] >> - [](Match &_) -> Node - { - return Expr << _(Expr); - }, - - --In(Expr) * T(Spawn)[Spawn] << (T(Brace) * End) >> - [](Match &_) -> Node - { - return Expr << _(Spawn); - }, - - // Additions must have *at least* two operands - --In(Expr) * T(Add)[Add] << (Operand * Operand) >> - [](Match &_) -> Node - { - return Expr << _(Add); - }, - - --In(Expr) * T(Eq, Neq)[Eq] << (Operand * Operand * End) >> - [](Match &_) -> Node - { - return Expr << _(Eq); - }, - - T(Group) << (T(Brace)[Brace] * End) >> - [](Match &_) -> Node - { - return _(Brace); - }, - - T(Group) << (T(Paren)[Paren] * End) >> - [](Match &_) -> Node - { - return _(Paren); - }, - - T(Group) << (T(Expr)[Expr] * End) >> - [](Match &_) -> Node - { - return _(Expr); - }, - - T(Paren) << (T(Expr)[Expr] * End) >> - [](Match &_) -> Node - { - return _(Expr); - }, - - // Error rules - In(Group) * T(Expr) * (!T(Brace))[Expr] >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Expr)) - << (ErrorMsg ^ "Unexpected term (did you forget a brace or a semicolon?)"); - }, - - In(Group) * Any * T(Expr)[Expr] >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Expr)) - << (ErrorMsg ^ "Unexpected expression"); - }, - - T(Spawn)[Spawn] << End >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Spawn)) - << (ErrorMsg ^ "Expected body of spawn"); - }, - - --In(Expr) * T(Spawn) << Any[Expr] >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Expr)) - << (ErrorMsg ^ "Invalid body of spawn"); - }, - - --In(Expr) * T(Add)[Add] << ((T(Group) << End) / (Any * (T(Group) << End))) >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Add)) - << (ErrorMsg ^ "Expected operand"); - }, - - --In(Expr) * T(Add)[Add] << (Any) >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Add)) - << (ErrorMsg ^ "Invalid operands for addition"); - }, - - - --In(Expr) * T(Eq, Neq)[Eq] << (Any * (T(Group) << End)) >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Eq)) - << (ErrorMsg ^ "Expected right-hand side of equality"); - }, - - --In(Expr) * T(Eq, Neq)[Eq] << Any >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Eq)) - << (ErrorMsg ^ "Bad equality"); - }, - - Any * T(Paren)[Paren] >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Paren)) - << (ErrorMsg ^ "Unexpected parenthesis"); - }, - - T(Paren) * Any[Expr] >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Expr)) - << (ErrorMsg ^ "Unexpected term (did you forget a brace or semicolon?)"); - }, - - }}; - } +using namespace trieste; + +PassDef expressions() { + auto Operand = T(Expr) << (T(Reg, Var, Const, Add)); + return {"expressions", + expressions_wf, + dir::bottomup, + { + --In(Expr) * T(Const, Reg, Var)[Expr] >> + [](Match &_) -> Node { return Expr << _(Expr); }, + + --In(Expr) * T(Spawn)[Spawn] << (T(Brace) * End) >> + [](Match &_) -> Node { return Expr << _(Spawn); }, + + // Additions must have *at least* two operands + --In(Expr) * T(Add)[Add] << (Operand * Operand) >> + [](Match &_) -> Node { return Expr << _(Add); }, + + --In(Expr) * T(Eq, Neq)[Eq] << (Operand * Operand * End) >> + [](Match &_) -> Node { return Expr << _(Eq); }, + + T(Group) << (T(Brace)[Brace] * End) >> + [](Match &_) -> Node { return _(Brace); }, + + T(Group) << (T(Paren)[Paren] * End) >> + [](Match &_) -> Node { return _(Paren); }, + + T(Group) << (T(Expr)[Expr] * End) >> + [](Match &_) -> Node { return _(Expr); }, + + T(Paren) << (T(Expr)[Expr] * End) >> + [](Match &_) -> Node { return _(Expr); }, + + // Error rules + In(Group) * T(Expr) * (!T(Brace))[Expr] >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Expr)) + << (ErrorMsg ^ "Unexpected term (did you forget a " + "brace or a semicolon?)"); + }, + + In(Group) * Any *T(Expr)[Expr] >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Expr)) + << (ErrorMsg ^ "Unexpected expression"); + }, + + T(Spawn)[Spawn] << End >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Spawn)) + << (ErrorMsg ^ "Expected body of spawn"); + }, + + --In(Expr) * T(Spawn) << Any[Expr] >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Expr)) + << (ErrorMsg ^ "Invalid body of spawn"); + }, + + --In(Expr) * T(Add)[Add] + << ((T(Group) << End) / (Any * (T(Group) << End))) >> + [](Match &_) -> Node { + return Error << (ErrorAst << _(Add)) + << (ErrorMsg ^ "Expected operand"); + }, + + --In(Expr) * T(Add)[Add] << (Any) >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Add)) + << (ErrorMsg ^ "Invalid operands for addition"); + }, + + --In(Expr) * T(Eq, Neq)[Eq] << (Any * (T(Group) << End)) >> + [](Match &_) -> Node { + return Error + << (ErrorAst << _(Eq)) + << (ErrorMsg ^ "Expected right-hand side of equality"); + }, + + --In(Expr) * T(Eq, Neq)[Eq] << Any >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Eq)) + << (ErrorMsg ^ "Bad equality"); + }, + + Any *T(Paren)[Paren] >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Paren)) + << (ErrorMsg ^ "Unexpected parenthesis"); + }, + + T(Paren) * Any[Expr] >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Expr)) + << (ErrorMsg ^ "Unexpected term (did you forget a " + "brace or semicolon?)"); + }, + + }}; +} } // namespace lang diff --git a/src/passes/statements.cc b/src/passes/statements.cc index 22aa86c..91e4a2e 100644 --- a/src/passes/statements.cc +++ b/src/passes/statements.cc @@ -4,203 +4,154 @@ namespace gitmem { namespace lang { - using namespace trieste; - - PassDef statements() - { - auto RVal = T(Expr) << (T(Reg, Var, Add, Const, Spawn)); - auto Condition = T(Expr) << (T(Eq, Neq)); - return { - "statements", - statements_wf, - dir::bottomup, - { - // Make Semi into Block - In(File) * T(Semi)[Semi] >> - [](Match &_) -> Node - { - return Block << *_(Semi); - }, - - T(Brace) << T(Semi)[Semi] >> - [](Match &_) -> Node - { - return Block << *_(Semi); - }, - - // Statements - --In(Stmt) * T(Nop)[Nop] >> - [](Match &_) -> Node - { - return Stmt << _(Nop); - }, - - --In(Stmt) * T(Join)[Join] << (RVal * End) >> - [](Match &_) -> Node - { - return Stmt << _(Join); - }, - - --In(Stmt) * T(Lock) << ((T(Expr) << T(Var)[Var]) * End) >> - [](Match &_) -> Node - { - return Stmt << (Lock << _(Var)); - }, - - --In(Stmt) * T(Unlock) << ((T(Expr) << T(Var)[Var]) * End) >> - [](Match &_) -> Node - { - return Stmt << (Unlock << _(Var)); - }, - - --In(Stmt) * T(Assign) << ((T(Expr) << (T(Reg, Var)[LVal] * End)) * RVal[Expr] * End) >> - [](Match &_) -> Node - { - return Stmt << (Assign << _(LVal) - << _(Expr)); - }, - - --In(Stmt) * T(Assert) << (Condition[Expr] * End) >> - [](Match &_) -> Node - { - return Stmt << (Assert << _(Expr)); - }, - - --In(Stmt) * (T(Group) << (T(If) << (T(Group) << (Condition[Expr] * T(Block)[Then])) * End)) - * (T(Group) << ((T(Else) << T(Block)[Else]) * End)) >> - [](Match &_) -> Node - { - return Stmt << (If << _(Expr) << _(Then) << _(Else)); - }, - - --In(Stmt) * (T(Group) << (T(If) << (T(Group) << (T(Expr)[Expr] * T(Block)[Then])) * End)) >> - [](Match &_) -> Node - { - return Stmt << (If << _(Expr) - << _(Then) - << (Block << ((Stmt ^ "nop") << Nop))); - }, - - T(Group) << (T(Stmt)[Stmt] * End) >> - [](Match &_) -> Node - { - return _(Stmt); - }, - - // Error rules - In(Group) * T(Stmt) * Any[Expr] >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Expr)) - << (ErrorMsg ^ "Unexpected term"); - }, - - T(Brace, File)[Brace] << End >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Brace)) - << (ErrorMsg ^ "Expected statement"); - }, - - T(Paren)[Paren] << End >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Paren)) - << (ErrorMsg ^ "Expected expression"); - }, - - --In(Spawn) * T(Brace)[Brace] >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Brace)) - << (ErrorMsg ^ "Unexpected block"); - }, - - --In(Stmt) * T(Join) << End >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Join)) - << (ErrorMsg ^ "Expected thread identifier"); - }, - - --In(Stmt) * T(Join) << Any[Expr] >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Expr)) - << (ErrorMsg ^ "Invalid thread identifier"); - }, - - --In(Stmt) * T(Lock, Unlock) << End >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Lock)) - << (ErrorMsg ^ "Expected lock identifier"); - }, - - --In(Stmt) * T(Lock, Unlock) << Any[Expr] >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Expr)) - << (ErrorMsg ^ "Invalid lock identifier"); - }, - - --In(Stmt) * T(Assign) << (Any * End) >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Assign)) - << (ErrorMsg ^ "Expected right-hand side to assignment"); - }, - - --In(Stmt) * T(Assign) << ((T(Expr) << T(Reg, Var)) * Any[Expr]) >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Expr)) - << (ErrorMsg ^ "Invalid right-hand side to assignment"); - }, - - --In(Stmt) * T(Assign) << Any[LVal] >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(LVal)) - << (ErrorMsg ^ "Invalid left-hand side to assignment"); - }, - - --In(Stmt) * T(Assert)[Assert] << (T(Group) << End) >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Assert)) - << (ErrorMsg ^ "Expected condition"); - }, - - --In(Stmt) * T(Assert) << (Any[Expr] * End) >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Expr)) - << (ErrorMsg ^ "Invalid assertion"); - }, - - In(If) * (Start * T(Block)[Expr]) / (T(Group) << (!Condition)[Expr]) >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Expr)) - << (ErrorMsg ^ "Invalid condition"); - }, - - In(File, Brace) * T(Stmt)[Stmt] >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Stmt)) - << (ErrorMsg ^ "Expected semicolon"); - }, - - In(Brace, File, Semi) * (!T(Stmt, Semi, Block))[Expr] >> - [](Match &_) -> Node - { - return Error << (ErrorAst << _(Expr)) - << (ErrorMsg ^ "Expected statement"); - }, - }}; - } +using namespace trieste; + +PassDef statements() { + auto RVal = T(Expr) << (T(Reg, Var, Add, Const, Spawn)); + auto Condition = T(Expr) << (T(Eq, Neq)); + return { + "statements", + statements_wf, + dir::bottomup, + { + // Make Semi into Block + In(File) * T(Semi)[Semi] >> + [](Match &_) -> Node { return Block << *_(Semi); }, + + T(Brace) << T(Semi)[Semi] >> + [](Match &_) -> Node { return Block << *_(Semi); }, + + // Statements + --In(Stmt) * T(Nop)[Nop] >> + [](Match &_) -> Node { return Stmt << _(Nop); }, + + --In(Stmt) * T(Join)[Join] << (RVal * End) >> + [](Match &_) -> Node { return Stmt << _(Join); }, + + --In(Stmt) * T(Lock) << ((T(Expr) << T(Var)[Var]) * End) >> + [](Match &_) -> Node { return Stmt << (Lock << _(Var)); }, + + --In(Stmt) * T(Unlock) << ((T(Expr) << T(Var)[Var]) * End) >> + [](Match &_) -> Node { return Stmt << (Unlock << _(Var)); }, + + --In(Stmt) * T(Assign) << ((T(Expr) << (T(Reg, Var)[LVal] * End)) * + RVal[Expr] * End) >> + [](Match &_) -> Node { + return Stmt << (Assign << _(LVal) << _(Expr)); + }, + + --In(Stmt) * T(Assert) << (Condition[Expr] * End) >> + [](Match &_) -> Node { return Stmt << (Assert << _(Expr)); }, + + --In(Stmt) * + (T(Group) << (T(If) << (T(Group) << (Condition[Expr] * + T(Block)[Then])) * + End)) * + (T(Group) << ((T(Else) << T(Block)[Else]) * End)) >> + [](Match &_) -> Node { + return Stmt << (If << _(Expr) << _(Then) << _(Else)); + }, + + --In(Stmt) * (T(Group) << (T(If) << (T(Group) << (T(Expr)[Expr] * + T(Block)[Then])) * + End)) >> + [](Match &_) -> Node { + return Stmt << (If << _(Expr) << _(Then) + << (Block << ((Stmt ^ "nop") << Nop))); + }, + + T(Group) << (T(Stmt)[Stmt] * End) >> + [](Match &_) -> Node { return _(Stmt); }, + + // Error rules + In(Group) * T(Stmt) * Any[Expr] >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Expr)) + << (ErrorMsg ^ "Unexpected term"); + }, + + T(Brace, File)[Brace] << End >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Brace)) + << (ErrorMsg ^ "Expected statement"); + }, + + T(Paren)[Paren] << End >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Paren)) + << (ErrorMsg ^ "Expected expression"); + }, + + --In(Spawn) * T(Brace)[Brace] >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Brace)) + << (ErrorMsg ^ "Unexpected block"); + }, + + --In(Stmt) * T(Join) << End >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Join)) + << (ErrorMsg ^ "Expected thread identifier"); + }, + + --In(Stmt) * T(Join) << Any[Expr] >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Expr)) + << (ErrorMsg ^ "Invalid thread identifier"); + }, + + --In(Stmt) * T(Lock, Unlock) << End >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Lock)) + << (ErrorMsg ^ "Expected lock identifier"); + }, + + --In(Stmt) * T(Lock, Unlock) << Any[Expr] >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Expr)) + << (ErrorMsg ^ "Invalid lock identifier"); + }, + + --In(Stmt) * T(Assign) << (Any * End) >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Assign)) + << (ErrorMsg ^ + "Expected right-hand side to assignment"); + }, + + --In(Stmt) * T(Assign) << ((T(Expr) << T(Reg, Var)) * Any[Expr]) >> + [](Match &_) -> Node { + return Error << (ErrorAst << _(Expr)) + << (ErrorMsg ^ + "Invalid right-hand side to assignment"); + }, + + --In(Stmt) * T(Assign) << Any[LVal] >> [](Match &_) -> Node { + return Error << (ErrorAst << _(LVal)) + << (ErrorMsg ^ "Invalid left-hand side to assignment"); + }, + + --In(Stmt) * T(Assert)[Assert] << (T(Group) << End) >> + [](Match &_) -> Node { + return Error << (ErrorAst << _(Assert)) + << (ErrorMsg ^ "Expected condition"); + }, + + --In(Stmt) * T(Assert) << (Any[Expr] * End) >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Expr)) + << (ErrorMsg ^ "Invalid assertion"); + }, + + In(If) * (Start * T(Block)[Expr]) / + (T(Group) << (!Condition)[Expr]) >> + [](Match &_) -> Node { + return Error << (ErrorAst << _(Expr)) + << (ErrorMsg ^ "Invalid condition"); + }, + + In(File, Brace) * T(Stmt)[Stmt] >> [](Match &_) -> Node { + return Error << (ErrorAst << _(Stmt)) + << (ErrorMsg ^ "Expected semicolon"); + }, + + In(Brace, File, Semi) * (!T(Stmt, Semi, Block))[Expr] >> + [](Match &_) -> Node { + return Error << (ErrorAst << _(Expr)) + << (ErrorMsg ^ "Expected statement"); + }, + }}; +} } // namespace lang diff --git a/src/reader.cc b/src/reader.cc index 3942a53..37bf9fe 100644 --- a/src/reader.cc +++ b/src/reader.cc @@ -6,19 +6,18 @@ namespace lang { using namespace trieste; -Reader reader() - { - return { +Reader reader() { + return { "gitmem", { - expressions(), - statements(), - check_refs(), - branching(), + expressions(), + statements(), + check_refs(), + branching(), }, gitmem::lang::parser(), - }; - } + }; +} } // namespace lang diff --git a/src/sync_protocol.cc b/src/sync_protocol.cc index f06aa59..8238d1a 100644 --- a/src/sync_protocol.cc +++ b/src/sync_protocol.cc @@ -1,12 +1,12 @@ -#include #include "sync_protocol.hh" #include "debug.hh" +#include namespace gitmem { -template -std::ostream& Conflict::print(std::ostream& os) const { - os << "conflict on " << var << " { " << versions.first << ", " << versions.second << " }"; +template std::ostream &Conflict::print(std::ostream &os) const { + os << "conflict on " << var << " { " << versions.first << ", " + << versions.second << " }"; return os; } @@ -14,39 +14,33 @@ std::ostream& Conflict::print(std::ostream& os) const { // LinearSyncProtocol // -------------------- -std::optional LinearSyncProtocol::push(linear::LocalVersionStore& local) { - if (auto conflict = _global_store.check_conflicts( - local.base_timestamp(), - local.staged_changes())) { +std::optional +LinearSyncProtocol::push(linear::LocalVersionStore &local) { + if (auto conflict = _global_store.check_conflicts(local.base_timestamp(), + local.staged_changes())) { // reshape the conflict return std::make_optional( - _global_store.get_object_name(conflict->object), - std::make_pair(conflict->local_base, conflict->global_head) - ); - + _global_store.get_object_name(conflict->object), + std::make_pair(conflict->local_base, conflict->global_head)); } linear::Timestamp new_base = _global_store.apply_changes( - local.base_timestamp(), - local.staged_changes() - ); + local.base_timestamp(), local.staged_changes()); local.clear_staging(); local.advance_base(new_base); return std::nullopt; } -std::optional LinearSyncProtocol::pull(linear::LocalVersionStore& local) { - if (auto conflict = _global_store.check_conflicts( - local.base_timestamp(), - local.staged_changes())) { +std::optional +LinearSyncProtocol::pull(linear::LocalVersionStore &local) { + if (auto conflict = _global_store.check_conflicts(local.base_timestamp(), + local.staged_changes())) { return std::make_optional( - _global_store.get_object_name(conflict->object), - std::make_pair(conflict->local_base, conflict->global_head) - ); - + _global_store.get_object_name(conflict->object), + std::make_pair(conflict->local_base, conflict->global_head)); } local.advance_base(_global_store.current_timestamp()); @@ -55,33 +49,36 @@ std::optional LinearSyncProtocol::pull(linear::LocalVersionStore LinearSyncProtocol::~LinearSyncProtocol() = default; -std::optional LinearSyncProtocol::read(ThreadContext& ctx, const std::string& var) { +std::optional LinearSyncProtocol::read(ThreadContext &ctx, + const std::string &var) { linear::ObjectNumber number = _global_store.get_object_number(var); if (auto result = store(ctx).get_staged(number)) return result; - std::optional value = _global_store.get_version_for_timestamp(number, store(ctx).base_timestamp()); + std::optional value = _global_store.get_version_for_timestamp( + number, store(ctx).base_timestamp()); if (!value) return std::nullopt; // we do not need to record the staged value for correctness - // TODO: there is something about working out if a value has changed vs been written + // TODO: there is something about working out if a value has changed vs been + // written return *value; } -void LinearSyncProtocol::write(ThreadContext& ctx, const std::string& var, size_t value) { +void LinearSyncProtocol::write(ThreadContext &ctx, const std::string &var, + size_t value) { // write into the staging area of the thread store(ctx).stage(_global_store.get_object_number(var), value); } -std::optional> LinearSyncProtocol::on_spawn( - ThreadContext& parent, - ThreadContext& child, - GlobalContext& gctx -) { - // TODO: i think we can drop the globalcontext but check after branching is added +std::optional> +LinearSyncProtocol::on_spawn(ThreadContext &parent, ThreadContext &child, + GlobalContext &gctx) { + // TODO: i think we can drop the globalcontext but check after branching is + // added verbose << "on_spawn" << std::endl; // push parent to global history @@ -91,11 +88,9 @@ std::optional> LinearSyncProtocol::on_spawn( return std::nullopt; } -std::optional> LinearSyncProtocol::on_join( - ThreadContext& joiner, - ThreadContext& joinee, - GlobalContext& gctx -) { +std::optional> +LinearSyncProtocol::on_join(ThreadContext &joiner, ThreadContext &joinee, + GlobalContext &gctx) { verbose << "on_join" << std::endl; // we assume the joinee has already terminated and pushed @@ -106,10 +101,8 @@ std::optional> LinearSyncProtocol::on_join( return std::nullopt; } -std::optional> LinearSyncProtocol::on_start( - ThreadContext& thread, - GlobalContext& gctx -) { +std::optional> +LinearSyncProtocol::on_start(ThreadContext &thread, GlobalContext &gctx) { verbose << "on_start" << std::endl; // pull state from global history @@ -119,10 +112,8 @@ std::optional> LinearSyncProtocol::on_start( return std::nullopt; }; -std::optional> LinearSyncProtocol::on_end( - ThreadContext& thread, - GlobalContext& gctx - ) { +std::optional> +LinearSyncProtocol::on_end(ThreadContext &thread, GlobalContext &gctx) { verbose << "on_end" << std::endl; // push changes to global history @@ -132,21 +123,17 @@ std::optional> LinearSyncProtocol::on_end( return std::nullopt; }; -std::optional> LinearSyncProtocol::on_lock( - ThreadContext& thread, - Lock& lock, - GlobalContext& gctx -) { +std::optional> +LinearSyncProtocol::on_lock(ThreadContext &thread, Lock &lock, + GlobalContext &gctx) { assert(false && "todo lock"); // push thread, pull from global return std::nullopt; } -std::optional> LinearSyncProtocol::on_unlock( - ThreadContext& thread, - Lock&, - GlobalContext& gctx -) { +std::optional> +LinearSyncProtocol::on_unlock(ThreadContext &thread, Lock &, + GlobalContext &gctx) { assert(false && "todo unlock"); // push thread return std::nullopt; @@ -167,19 +154,19 @@ BranchingSyncProtocol::~BranchingSyncProtocol() = default; // if (global.commit) // { // global.history.push_back(*global.commit); -// verbose << "Committed global '" << var << "' with id " << *global.commit << std::endl; -// global.commit.reset(); +// verbose << "Committed global '" << var << "' with id " << +// *global.commit << std::endl; global.commit.reset(); // } // } // } - // /* A versioned value can be fastforwarded to another version, if one // * version's history is a prefix of another version's history. // * A conflict between two commit histories exists if neither history is a // * prefix of the other. // */ -// std::optional> has_conflict(CommitHistory& h1, CommitHistory& h2) +// std::optional> has_conflict(CommitHistory& h1, +// CommitHistory& h2) // { // size_t length = std::min(h1.size(), h2.size()); @@ -196,22 +183,25 @@ BranchingSyncProtocol::~BranchingSyncProtocol() = default; // * either source or destination). This means destination will now also // * include variables it previously did not know about. // */ -// std::optional> pull(Globals &dst, Globals &src) { +// std::optional> pull(Globals &dst, Globals &src) +// { // for (auto& [var, global] : src) { // if (dst.contains(var)) // { // auto& src_var = src[var]; // auto& dst_var = dst[var]; -// if (auto conflict = has_conflict(src_var.history, dst_var.history)) +// if (auto conflict = has_conflict(src_var.history, +// dst_var.history)) // { // auto [s1, s2] = *conflict; -// verbose << "A data race on '" << var << "' was detected from commits " << s1 << " and " << s2 << std::endl; -// return Conflict(var, *conflict); +// verbose << "A data race on '" << var << "' was detected from +// commits " << s1 << " and " << s2 << std::endl; return +// Conflict(var, *conflict); // } // else if (src_var.history.size() > dst_var.history.size()) // { -// verbose << "Fast-forward '" << var << "' to id " << src_var.val << std::endl; -// dst_var.val = src_var.val; +// verbose << "Fast-forward '" << var << "' to id " << +// src_var.val << std::endl; dst_var.val = src_var.val; // dst_var.history = src_var.history; // } // } @@ -224,31 +214,29 @@ BranchingSyncProtocol::~BranchingSyncProtocol() = default; // return std::nullopt; // } -std::optional BranchingSyncProtocol::read(ThreadContext& ctx, const std::string& var) { +std::optional BranchingSyncProtocol::read(ThreadContext &ctx, + const std::string &var) { assert(false && "Todo read"); return std::nullopt; } -void BranchingSyncProtocol::write(ThreadContext& ctx, const std::string& var, size_t value) { +void BranchingSyncProtocol::write(ThreadContext &ctx, const std::string &var, + size_t value) { assert(false && "Todo write"); } -std::optional> BranchingSyncProtocol::on_spawn( - ThreadContext& parent, - ThreadContext& child, - GlobalContext& -) { +std::optional> +BranchingSyncProtocol::on_spawn(ThreadContext &parent, ThreadContext &child, + GlobalContext &) { assert(false && "Todo on_spawn"); // commit(parent.globals); // child.globals = parent.globals; return std::nullopt; } -std::optional> BranchingSyncProtocol::on_join( - ThreadContext& joiner, - ThreadContext& joinee, - GlobalContext& -) { +std::optional> +BranchingSyncProtocol::on_join(ThreadContext &joiner, ThreadContext &joinee, + GlobalContext &) { assert(false && "Todo on_join"); // commit(joiner.globals); // commit(joinee.globals); @@ -256,38 +244,30 @@ std::optional> BranchingSyncProtocol::on_join( return std::nullopt; } -std::optional> BranchingSyncProtocol::on_start( - ThreadContext& thread, - GlobalContext& gctx -) { +std::optional> +BranchingSyncProtocol::on_start(ThreadContext &thread, GlobalContext &gctx) { assert(false && "Todo on_start"); return std::nullopt; }; -std::optional> BranchingSyncProtocol::on_end( - ThreadContext& thread, - GlobalContext& gctx - ) { +std::optional> +BranchingSyncProtocol::on_end(ThreadContext &thread, GlobalContext &gctx) { assert(false && "Todo on_end"); return std::nullopt; }; -std::optional> BranchingSyncProtocol::on_lock( - ThreadContext& thread, - Lock& lock, - GlobalContext& -) { +std::optional> +BranchingSyncProtocol::on_lock(ThreadContext &thread, Lock &lock, + GlobalContext &) { assert(false && "Todo on_lock"); // commit(thread.globals); // return pull(thread.globals, lock.globals); return std::nullopt; } -std::optional> BranchingSyncProtocol::on_unlock( - ThreadContext& thread, - Lock& lock, - GlobalContext& -) { +std::optional> +BranchingSyncProtocol::on_unlock(ThreadContext &thread, Lock &lock, + GlobalContext &) { assert(false && "Todo on_unlock"); // commit(thread.globals); // lock.globals = thread.globals; diff --git a/src/sync_protocol.hh b/src/sync_protocol.hh index 7e077aa..7b55ae8 100644 --- a/src/sync_protocol.hh +++ b/src/sync_protocol.hh @@ -1,10 +1,10 @@ #pragma once -#include -#include +#include "branching/version_store.hh" #include "execution_state.hh" #include "linear/version_store.hh" -#include "branching/version_store.hh" +#include +#include /* i want an on_start and on_end event i think too */ @@ -12,21 +12,21 @@ namespace gitmem { struct ConflictBase { virtual ~ConflictBase() = default; - virtual std::ostream& print(std::ostream& os) const = 0; - friend std::ostream& operator<<(std::ostream& os, const ConflictBase& conflict) { + virtual std::ostream &print(std::ostream &os) const = 0; + friend std::ostream &operator<<(std::ostream &os, + const ConflictBase &conflict) { return conflict.print(os); } }; -template -struct Conflict : ConflictBase { +template struct Conflict : ConflictBase { std::string var; std::pair versions; - Conflict(std::string var, std::pair versions): - var(std::move(var)), versions(std::move(versions)) {} + Conflict(std::string var, std::pair versions) + : var(std::move(var)), versions(std::move(versions)) {} - std::ostream& print(std::ostream& os) const override; + std::ostream &print(std::ostream &os) const override; }; using LinearConflict = Conflict; @@ -37,44 +37,32 @@ public: virtual ~SyncProtocol() = default; // Read a shared variable into the thread context - virtual std::optional read(ThreadContext& ctx, const std::string& var) = 0; + virtual std::optional read(ThreadContext &ctx, + const std::string &var) = 0; // Write a shared variable (staged, not committed) - virtual void write(ThreadContext& ctx, const std::string& var, size_t value) = 0; - - virtual std::optional> on_spawn( - ThreadContext& parent, - ThreadContext& child, - GlobalContext& gctx - ) = 0; - - virtual std::optional> on_join( - ThreadContext& joiner, - ThreadContext& joinee, - GlobalContext& gctx - ) = 0; - - virtual std::optional> on_start( - ThreadContext& thread, - GlobalContext& gctx - ) = 0; - - virtual std::optional> on_end( - ThreadContext& thread, - GlobalContext& gctx - ) = 0; - - virtual std::optional> on_lock( - ThreadContext& thread, - Lock& lock, - GlobalContext& gctx - ) = 0; - - virtual std::optional> on_unlock( - ThreadContext& thread, - Lock& lock, - GlobalContext& gctx - ) = 0; + virtual void write(ThreadContext &ctx, const std::string &var, + size_t value) = 0; + + virtual std::optional> + on_spawn(ThreadContext &parent, ThreadContext &child, + GlobalContext &gctx) = 0; + + virtual std::optional> + on_join(ThreadContext &joiner, ThreadContext &joinee, + GlobalContext &gctx) = 0; + + virtual std::optional> + on_start(ThreadContext &thread, GlobalContext &gctx) = 0; + + virtual std::optional> + on_end(ThreadContext &thread, GlobalContext &gctx) = 0; + + virtual std::optional> + on_lock(ThreadContext &thread, Lock &lock, GlobalContext &gctx) = 0; + + virtual std::optional> + on_unlock(ThreadContext &thread, Lock &lock, GlobalContext &gctx) = 0; }; // --------------------------------- @@ -84,55 +72,42 @@ public: class LinearSyncProtocol final : public SyncProtocol { linear::GlobalVersionStore _global_store; - static linear::LocalVersionStore& store(ThreadContext& ctx) { - if (!ctx.linear) ctx.linear.emplace(); + static linear::LocalVersionStore &store(ThreadContext &ctx) { + if (!ctx.linear) + ctx.linear.emplace(); return ctx.linear->store; } - std::optional push(linear::LocalVersionStore& local); - std::optional pull(linear::LocalVersionStore& local); + std::optional push(linear::LocalVersionStore &local); + std::optional pull(linear::LocalVersionStore &local); public: ~LinearSyncProtocol() override; + std::optional read(ThreadContext &ctx, + const std::string &var) override; - std::optional read(ThreadContext& ctx, const std::string& var) override; - - void write(ThreadContext& ctx, const std::string& var, size_t value) override; + void write(ThreadContext &ctx, const std::string &var, size_t value) override; - std::optional> on_spawn( - ThreadContext& parent, - ThreadContext& child, - GlobalContext& gctx - ) override; + std::optional> + on_spawn(ThreadContext &parent, ThreadContext &child, + GlobalContext &gctx) override; - std::optional> on_join( - ThreadContext& joiner, - ThreadContext& joinee, - GlobalContext& gctx - ) override; + std::optional> + on_join(ThreadContext &joiner, ThreadContext &joinee, + GlobalContext &gctx) override; - std::optional> on_start( - ThreadContext& thread, - GlobalContext& gctx - ) override; + std::optional> + on_start(ThreadContext &thread, GlobalContext &gctx) override; - std::optional> on_end( - ThreadContext& thread, - GlobalContext& gctx - ) override; + std::optional> + on_end(ThreadContext &thread, GlobalContext &gctx) override; - std::optional> on_lock( - ThreadContext& thread, - Lock& lock, - GlobalContext& gctx - ) override; + std::optional> + on_lock(ThreadContext &thread, Lock &lock, GlobalContext &gctx) override; - std::optional> on_unlock( - ThreadContext& thread, - Lock& lock, - GlobalContext& gctx - ) override; + std::optional> + on_unlock(ThreadContext &thread, Lock &lock, GlobalContext &gctx) override; }; class BranchingSyncProtocol final : public SyncProtocol { @@ -142,43 +117,30 @@ class BranchingSyncProtocol final : public SyncProtocol { public: ~BranchingSyncProtocol() override; - std::optional read(ThreadContext& ctx, const std::string& var) override; - - void write(ThreadContext& ctx, const std::string& var, size_t value) override; - - std::optional> on_spawn( - ThreadContext& parent, - ThreadContext& child, - GlobalContext& gctx - ) override; - - std::optional> on_join( - ThreadContext& joiner, - ThreadContext& joinee, - GlobalContext& gctx - ) override; - - std::optional> on_start( - ThreadContext& thread, - GlobalContext& gctx - ) override; - - std::optional> on_end( - ThreadContext& thread, - GlobalContext& gctx - ) override; - - std::optional> on_lock( - ThreadContext& thread, - Lock& lock, - GlobalContext& gctx - ) override; - - std::optional> on_unlock( - ThreadContext& thread, - Lock& lock, - GlobalContext& gctx - ) override; + std::optional read(ThreadContext &ctx, + const std::string &var) override; + + void write(ThreadContext &ctx, const std::string &var, size_t value) override; + + std::optional> + on_spawn(ThreadContext &parent, ThreadContext &child, + GlobalContext &gctx) override; + + std::optional> + on_join(ThreadContext &joiner, ThreadContext &joinee, + GlobalContext &gctx) override; + + std::optional> + on_start(ThreadContext &thread, GlobalContext &gctx) override; + + std::optional> + on_end(ThreadContext &thread, GlobalContext &gctx) override; + + std::optional> + on_lock(ThreadContext &thread, Lock &lock, GlobalContext &gctx) override; + + std::optional> + on_unlock(ThreadContext &thread, Lock &lock, GlobalContext &gctx) override; }; } // namespace gitmem From 2bd041a2fc8ab3f0ab00763ddbdbb0d315107c00 Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Fri, 19 Dec 2025 11:34:36 +0000 Subject: [PATCH 09/58] got the model checker running again, but it is not aware of the distinct points between spawning and starting a thread --- CMakeLists.txt | 2 +- .../passing/semantics/conditional_non_race.gm | 7 +++--- src/execution_state.cc | 1 + src/execution_state.hh | 6 +++++ src/gitmem.cc | 4 +-- src/interpreter.cc | 4 +-- src/model_checker.cc | 25 +++++++++++-------- src/sync_protocol.cc | 19 +++++++------- 8 files changed, 39 insertions(+), 29 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 7469df2..6e39979 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -28,7 +28,7 @@ add_executable(gitmem src/linear/version_store.cc src/interpreter.cc # src/debugger.cc - # src/model_checker.cc + src/model_checker.cc src/sync_protocol.cc src/graphviz.cc ) diff --git a/examples/passing/semantics/conditional_non_race.gm b/examples/passing/semantics/conditional_non_race.gm index 1a4e743..1915ed6 100644 --- a/examples/passing/semantics/conditional_non_race.gm +++ b/examples/passing/semantics/conditional_non_race.gm @@ -12,6 +12,7 @@ $t2 = spawn { }; join $t1; join $t2; -assert (x != 0); -assert (y != 0); -assert (x == y); +// FIXME in branching these assertions are always true, but not in linear +// assert (x != 0); +// assert (y != 0); +// assert (x == y); diff --git a/src/execution_state.cc b/src/execution_state.cc index d769e13..f298e54 100644 --- a/src/execution_state.cc +++ b/src/execution_state.cc @@ -40,6 +40,7 @@ GlobalContext::~GlobalContext() = default; void GlobalContext::print_execution_graph( const std::filesystem::path &output_path) const { + return; // FIXME // Loop over the threads and add pending nodes to running threads // to indicate a threads next step for (const auto &t : threads) { diff --git a/src/execution_state.hh b/src/execution_state.hh index 768a83e..f9ffffa 100644 --- a/src/execution_state.hh +++ b/src/execution_state.hh @@ -79,6 +79,12 @@ struct GlobalContext { std::unique_ptr protocol); ~GlobalContext(); + GlobalContext(GlobalContext&&) = default; + GlobalContext& operator=(GlobalContext&&) = default; + + GlobalContext(const GlobalContext&) = delete; + GlobalContext& operator=(const GlobalContext&) = delete; + bool operator==(const GlobalContext &other) const; void print_execution_graph(const std::filesystem::path &output_path) const; diff --git a/src/gitmem.cc b/src/gitmem.cc index 3d88804..bc7f49d 100644 --- a/src/gitmem.cc +++ b/src/gitmem.cc @@ -2,6 +2,7 @@ #include "debug.hh" #include "interpreter.hh" +#include "model_checker.hh" #include "lang.hh" int main(int argc, char **argv) { @@ -62,8 +63,7 @@ int main(int argc, char **argv) { int exit_status; wf::push_back(gitmem::lang::wf); if (model_check) { - assert(false && "currently broken"); - // exit_status = gitmem::model_check(result.ast, output_path); + exit_status = gitmem::model_check(result.ast, output_path); } else if (interactive) { assert(false && "currently broken"); // exit_status = gitmem::interpret_interactive(result.ast, output_path); diff --git a/src/interpreter.cc b/src/interpreter.cc index b892bc5..eb01429 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -256,8 +256,6 @@ std::variant run_statement(Node stmt, verbose << "Locked " << var << std::endl; } else if (s == lang::Unlock) { - assert(false && "todo"); - // We can only unlock locks we previously locked. We commit any // pending updates and then copy the threads versioned globals // to the locks versioned globals (nobody could have changed @@ -286,6 +284,7 @@ std::variant run_statement(Node stmt, verbose << "Unlocked " << var << std::endl; } else if (s == lang::Assert) { + auto expr = s / lang::Expr; auto result_or_term = evaluate_expression(expr, gctx, ctx); if (size_t *result = std::get_if(&result_or_term)) { @@ -300,6 +299,7 @@ std::variant run_statement(Node stmt, } else { return std::get(result_or_term); } + } else { throw std::runtime_error("Unknown statement: " + std::string(stmt->type().str())); diff --git a/src/model_checker.cc b/src/model_checker.cc index 450d97c..7d99ef1 100644 --- a/src/model_checker.cc +++ b/src/model_checker.cc @@ -1,5 +1,7 @@ #include "model_checker.hh" #include "interpreter.hh" +#include "debug.hh" +#include "sync_protocol.hh" namespace gitmem { using namespace trieste; @@ -57,11 +59,11 @@ build_output_path(const std::filesystem::path &output_path, const size_t idx) { * for each distinct final state that led to an error. */ int model_check(const Node ast, const std::filesystem::path &output_path) { - GlobalContext gctx(ast); + GlobalContext gctx(ast, std::make_unique()); - auto final_contexts = std::vector{}; - auto failing_contexts = std::vector{}; - auto deadlocked_contexts = std::vector{}; + auto final_contexts = std::vector>{}; + auto failing_contexts = std::vector>{}; + auto deadlocked_contexts = std::vector>{}; auto final_traces = std::vector>{}; auto failing_traces = std::vector>{}; @@ -138,15 +140,16 @@ int model_check(const Node ast, const std::filesystem::path &output_path) { // Remember final state if it is new if (!std::any_of( final_contexts.begin(), final_contexts.end(), - [&gctx](const GlobalContext &state) { return state == gctx; })) { - final_contexts.push_back(gctx); + [&gctx](const std::shared_ptr &state) { return *state == gctx; })) { + std::shared_ptr gctxp = std::make_shared(std::move(gctx)); + final_contexts.push_back(gctxp); final_traces.push_back(current_trace); if (any_crashed) { failing_traces.push_back(current_trace); - failing_contexts.push_back(gctx); + failing_contexts.push_back(gctxp); } else if (is_deadlock) { deadlocked_traces.push_back(current_trace); - deadlocked_contexts.push_back(gctx); + deadlocked_contexts.push_back(gctxp); } } @@ -156,7 +159,7 @@ int model_check(const Node ast, const std::filesystem::path &output_path) { if (cursor->complete && !root->complete) { // Reset the cursor to the root and start a new trace verbose << std::endl << "Restarting trace..." << std::endl; - gctx = GlobalContext(ast); + gctx = GlobalContext(ast, std::make_unique()); cursor = root; current_trace.clear(); @@ -179,7 +182,7 @@ int model_check(const Node ast, const std::filesystem::path &output_path) { for (const auto &ctx : failing_contexts) { auto path = build_output_path(output_path, idx++); - ctx.print_execution_graph(path); + ctx->print_execution_graph(path); } } @@ -190,7 +193,7 @@ int model_check(const Node ast, const std::filesystem::path &output_path) { for (const auto &ctx : deadlocked_contexts) { auto path = build_output_path(output_path, idx++); - ctx.print_execution_graph(path); + ctx->print_execution_graph(path); } } diff --git a/src/sync_protocol.cc b/src/sync_protocol.cc index 8238d1a..2103a13 100644 --- a/src/sync_protocol.cc +++ b/src/sync_protocol.cc @@ -79,7 +79,6 @@ LinearSyncProtocol::on_spawn(ThreadContext &parent, ThreadContext &child, GlobalContext &gctx) { // TODO: i think we can drop the globalcontext but check after branching is // added - verbose << "on_spawn" << std::endl; // push parent to global history if (auto conflict = push(store(parent))) @@ -91,7 +90,6 @@ LinearSyncProtocol::on_spawn(ThreadContext &parent, ThreadContext &child, std::optional> LinearSyncProtocol::on_join(ThreadContext &joiner, ThreadContext &joinee, GlobalContext &gctx) { - verbose << "on_join" << std::endl; // we assume the joinee has already terminated and pushed // pull changes into parent @@ -103,8 +101,6 @@ LinearSyncProtocol::on_join(ThreadContext &joiner, ThreadContext &joinee, std::optional> LinearSyncProtocol::on_start(ThreadContext &thread, GlobalContext &gctx) { - verbose << "on_start" << std::endl; - // pull state from global history auto conflict = pull(store(thread)); assert(!conflict && "cannot conflict from starting state"); @@ -114,8 +110,6 @@ LinearSyncProtocol::on_start(ThreadContext &thread, GlobalContext &gctx) { std::optional> LinearSyncProtocol::on_end(ThreadContext &thread, GlobalContext &gctx) { - verbose << "on_end" << std::endl; - // push changes to global history if (auto conflict = push(store(thread))) return std::make_unique(std::move(*conflict)); @@ -126,16 +120,21 @@ LinearSyncProtocol::on_end(ThreadContext &thread, GlobalContext &gctx) { std::optional> LinearSyncProtocol::on_lock(ThreadContext &thread, Lock &lock, GlobalContext &gctx) { - assert(false && "todo lock"); - // push thread, pull from global + + if (auto conflict = pull(store(thread))) + return std::make_unique(std::move(*conflict)); + return std::nullopt; } std::optional> LinearSyncProtocol::on_unlock(ThreadContext &thread, Lock &, GlobalContext &gctx) { - assert(false && "todo unlock"); - // push thread + + // push changes to global history + if (auto conflict = push(store(thread))) + return std::make_unique(std::move(*conflict)); + return std::nullopt; } From d2f40a03f98e5d6af8c6cf28f1745ad02a2bbfb2 Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Tue, 23 Dec 2025 12:26:47 +0000 Subject: [PATCH 10/58] moving examples around and adding command line switch to pick between sync protocols --- .../semantics/branching}/addition.gm | 0 .../branching}/conditional_non_race.gm | 6 +- .../semantics/branching}/globals.gm | 0 .../semantics/branching}/if.gm | 0 .../semantics/branching}/join_fastforward.gm | 0 .../semantics/branching}/join_nop.gm | 0 .../branching}/join_pulled_variable.gm | 0 .../semantics/branching}/join_spawn.gm | 0 .../semantics/branching}/local.gm | 0 .../semantics/branching}/lock.gm | 0 .../semantics/branching}/lock_as_sync.gm | 0 .../semantics/branching}/spawn.gm | 0 .../semantics/branching}/spawn_read_global.gm | 0 examples/accept/semantics/linear/addition.gm | 4 + .../semantics/linear/conditional_non_race.gm | 18 +++ examples/accept/semantics/linear/globals.gm | 5 + examples/accept/semantics/linear/if.gm | 13 ++ .../semantics/linear/join_fastforward.gm | 4 + examples/accept/semantics/linear/join_nop.gm | 4 + .../semantics/linear/join_pulled_variable.gm | 15 ++ .../accept/semantics/linear/join_spawn.gm | 7 + examples/accept/semantics/linear/local.gm | 7 + examples/accept/semantics/linear/lock.gm | 2 + .../accept/semantics/linear/lock_as_sync.gm | 10 ++ examples/accept/semantics/linear/spawn.gm | 4 + .../semantics/linear/spawn_read_global.gm | 8 ++ .../semantics/branching}/conditional_race.gm | 0 .../semantics/branching}/deadlock.gm | 0 .../semantics/branching}/error_and_race.gm | 0 .../semantics/branching}/failed_assertion.gm | 0 .../failed_assertion_many_schedules.gm | 0 .../branching}/failed_assertion_neq.gm | 0 .../branching}/join_blocked_thread.gm | 0 .../semantics/branching}/join_datarace.gm | 0 .../semantics/branching}/join_deadlock.gm | 0 .../semantics/branching}/join_nonexisting.gm | 0 .../semantics/branching}/read_unassigned.gm | 0 .../semantics/branching}/test.gm | 0 .../semantics/branching}/unlock.gm | 0 .../branching}/unlock_another_threads_lock.gm | 0 .../semantics/linear/conditional_race.gm | 32 +++++ examples/reject/semantics/linear/deadlock.gm | 12 ++ .../reject/semantics/linear/error_and_race.gm | 13 ++ .../semantics/linear/failed_assertion.gm | 6 + .../linear/failed_assertion_many_schedules.gm | 5 + .../semantics/linear/failed_assertion_neq.gm | 6 + .../semantics/linear/join_blocked_thread.gm | 2 + .../reject/semantics/linear/join_datarace.gm | 10 ++ .../reject/semantics/linear/join_deadlock.gm | 1 + .../semantics/linear/join_nonexisting.gm | 2 + .../semantics/linear/read_unassigned.gm | 1 + examples/reject/semantics/linear/test.gm | 11 ++ examples/reject/semantics/linear/unlock.gm | 2 + .../linear/unlock_another_threads_lock.gm | 5 + .../{failing => reject}/syntax/bad_add.gm | 0 .../{failing => reject}/syntax/bad_add2.gm | 0 .../{failing => reject}/syntax/bad_assert.gm | 0 .../syntax/bad_condition.gm | 0 .../{failing => reject}/syntax/bad_join.gm | 0 .../{failing => reject}/syntax/bad_lhs.gm | 0 .../{failing => reject}/syntax/bad_lock.gm | 0 .../{failing => reject}/syntax/bad_rhs.gm | 0 examples/{failing => reject}/syntax/empty.gm | 0 .../syntax/empty_assert.gm | 0 .../syntax/empty_assign.gm | 0 .../{failing => reject}/syntax/empty_brace.gm | 0 .../{failing => reject}/syntax/if_no_brace.gm | 0 .../{failing => reject}/syntax/if_no_cond.gm | 0 .../syntax/no_semicolon.gm | 0 .../syntax/spurious_else.gm | 0 examples/{failing => reject}/syntax/top_eq.gm | 0 .../{failing => reject}/unassigned_reg.gm | 0 src/debug.hh | 24 ++++ src/debugger.hh | 8 ++ src/gitmem.cc | 10 +- src/interpreter.cc | 16 ++- src/interpreter.hh | 3 +- src/model_checker.cc | 6 +- src/model_checker.hh | 10 ++ src/sync_protocol.cc | 10 ++ src/sync_protocol.hh | 7 + test_gitmem.py | 134 ++++++++++++------ 82 files changed, 380 insertions(+), 63 deletions(-) rename examples/{passing/semantics => accept/semantics/branching}/addition.gm (100%) rename examples/{passing/semantics => accept/semantics/branching}/conditional_non_race.gm (78%) rename examples/{passing/semantics => accept/semantics/branching}/globals.gm (100%) rename examples/{passing/semantics => accept/semantics/branching}/if.gm (100%) rename examples/{passing/semantics => accept/semantics/branching}/join_fastforward.gm (100%) rename examples/{passing/semantics => accept/semantics/branching}/join_nop.gm (100%) rename examples/{passing/semantics => accept/semantics/branching}/join_pulled_variable.gm (100%) rename examples/{passing/semantics => accept/semantics/branching}/join_spawn.gm (100%) rename examples/{passing/semantics => accept/semantics/branching}/local.gm (100%) rename examples/{passing/semantics => accept/semantics/branching}/lock.gm (100%) rename examples/{passing/semantics => accept/semantics/branching}/lock_as_sync.gm (100%) rename examples/{passing/semantics => accept/semantics/branching}/spawn.gm (100%) rename examples/{passing/semantics => accept/semantics/branching}/spawn_read_global.gm (100%) create mode 100644 examples/accept/semantics/linear/addition.gm create mode 100644 examples/accept/semantics/linear/conditional_non_race.gm create mode 100644 examples/accept/semantics/linear/globals.gm create mode 100644 examples/accept/semantics/linear/if.gm create mode 100644 examples/accept/semantics/linear/join_fastforward.gm create mode 100644 examples/accept/semantics/linear/join_nop.gm create mode 100644 examples/accept/semantics/linear/join_pulled_variable.gm create mode 100644 examples/accept/semantics/linear/join_spawn.gm create mode 100644 examples/accept/semantics/linear/local.gm create mode 100644 examples/accept/semantics/linear/lock.gm create mode 100644 examples/accept/semantics/linear/lock_as_sync.gm create mode 100644 examples/accept/semantics/linear/spawn.gm create mode 100644 examples/accept/semantics/linear/spawn_read_global.gm rename examples/{failing/semantics => reject/semantics/branching}/conditional_race.gm (100%) rename examples/{failing/semantics => reject/semantics/branching}/deadlock.gm (100%) rename examples/{failing/semantics => reject/semantics/branching}/error_and_race.gm (100%) rename examples/{failing/semantics => reject/semantics/branching}/failed_assertion.gm (100%) rename examples/{failing/semantics => reject/semantics/branching}/failed_assertion_many_schedules.gm (100%) rename examples/{failing/semantics => reject/semantics/branching}/failed_assertion_neq.gm (100%) rename examples/{failing/semantics => reject/semantics/branching}/join_blocked_thread.gm (100%) rename examples/{failing/semantics => reject/semantics/branching}/join_datarace.gm (100%) rename examples/{failing/semantics => reject/semantics/branching}/join_deadlock.gm (100%) rename examples/{failing/semantics => reject/semantics/branching}/join_nonexisting.gm (100%) rename examples/{failing/semantics => reject/semantics/branching}/read_unassigned.gm (100%) rename examples/{failing/semantics => reject/semantics/branching}/test.gm (100%) rename examples/{failing/semantics => reject/semantics/branching}/unlock.gm (100%) rename examples/{failing/semantics => reject/semantics/branching}/unlock_another_threads_lock.gm (100%) create mode 100644 examples/reject/semantics/linear/conditional_race.gm create mode 100644 examples/reject/semantics/linear/deadlock.gm create mode 100644 examples/reject/semantics/linear/error_and_race.gm create mode 100644 examples/reject/semantics/linear/failed_assertion.gm create mode 100644 examples/reject/semantics/linear/failed_assertion_many_schedules.gm create mode 100644 examples/reject/semantics/linear/failed_assertion_neq.gm create mode 100644 examples/reject/semantics/linear/join_blocked_thread.gm create mode 100644 examples/reject/semantics/linear/join_datarace.gm create mode 100644 examples/reject/semantics/linear/join_deadlock.gm create mode 100644 examples/reject/semantics/linear/join_nonexisting.gm create mode 100644 examples/reject/semantics/linear/read_unassigned.gm create mode 100644 examples/reject/semantics/linear/test.gm create mode 100644 examples/reject/semantics/linear/unlock.gm create mode 100644 examples/reject/semantics/linear/unlock_another_threads_lock.gm rename examples/{failing => reject}/syntax/bad_add.gm (100%) rename examples/{failing => reject}/syntax/bad_add2.gm (100%) rename examples/{failing => reject}/syntax/bad_assert.gm (100%) rename examples/{failing => reject}/syntax/bad_condition.gm (100%) rename examples/{failing => reject}/syntax/bad_join.gm (100%) rename examples/{failing => reject}/syntax/bad_lhs.gm (100%) rename examples/{failing => reject}/syntax/bad_lock.gm (100%) rename examples/{failing => reject}/syntax/bad_rhs.gm (100%) rename examples/{failing => reject}/syntax/empty.gm (100%) rename examples/{failing => reject}/syntax/empty_assert.gm (100%) rename examples/{failing => reject}/syntax/empty_assign.gm (100%) rename examples/{failing => reject}/syntax/empty_brace.gm (100%) rename examples/{failing => reject}/syntax/if_no_brace.gm (100%) rename examples/{failing => reject}/syntax/if_no_cond.gm (100%) rename examples/{failing => reject}/syntax/no_semicolon.gm (100%) rename examples/{failing => reject}/syntax/spurious_else.gm (100%) rename examples/{failing => reject}/syntax/top_eq.gm (100%) rename examples/{failing => reject}/unassigned_reg.gm (100%) create mode 100644 src/debug.hh create mode 100644 src/debugger.hh create mode 100644 src/model_checker.hh diff --git a/examples/passing/semantics/addition.gm b/examples/accept/semantics/branching/addition.gm similarity index 100% rename from examples/passing/semantics/addition.gm rename to examples/accept/semantics/branching/addition.gm diff --git a/examples/passing/semantics/conditional_non_race.gm b/examples/accept/semantics/branching/conditional_non_race.gm similarity index 78% rename from examples/passing/semantics/conditional_non_race.gm rename to examples/accept/semantics/branching/conditional_non_race.gm index 1915ed6..b8d4206 100644 --- a/examples/passing/semantics/conditional_non_race.gm +++ b/examples/accept/semantics/branching/conditional_non_race.gm @@ -13,6 +13,6 @@ $t2 = spawn { join $t1; join $t2; // FIXME in branching these assertions are always true, but not in linear -// assert (x != 0); -// assert (y != 0); -// assert (x == y); +assert (x != 0); +assert (y != 0); +assert (x == y); diff --git a/examples/passing/semantics/globals.gm b/examples/accept/semantics/branching/globals.gm similarity index 100% rename from examples/passing/semantics/globals.gm rename to examples/accept/semantics/branching/globals.gm diff --git a/examples/passing/semantics/if.gm b/examples/accept/semantics/branching/if.gm similarity index 100% rename from examples/passing/semantics/if.gm rename to examples/accept/semantics/branching/if.gm diff --git a/examples/passing/semantics/join_fastforward.gm b/examples/accept/semantics/branching/join_fastforward.gm similarity index 100% rename from examples/passing/semantics/join_fastforward.gm rename to examples/accept/semantics/branching/join_fastforward.gm diff --git a/examples/passing/semantics/join_nop.gm b/examples/accept/semantics/branching/join_nop.gm similarity index 100% rename from examples/passing/semantics/join_nop.gm rename to examples/accept/semantics/branching/join_nop.gm diff --git a/examples/passing/semantics/join_pulled_variable.gm b/examples/accept/semantics/branching/join_pulled_variable.gm similarity index 100% rename from examples/passing/semantics/join_pulled_variable.gm rename to examples/accept/semantics/branching/join_pulled_variable.gm diff --git a/examples/passing/semantics/join_spawn.gm b/examples/accept/semantics/branching/join_spawn.gm similarity index 100% rename from examples/passing/semantics/join_spawn.gm rename to examples/accept/semantics/branching/join_spawn.gm diff --git a/examples/passing/semantics/local.gm b/examples/accept/semantics/branching/local.gm similarity index 100% rename from examples/passing/semantics/local.gm rename to examples/accept/semantics/branching/local.gm diff --git a/examples/passing/semantics/lock.gm b/examples/accept/semantics/branching/lock.gm similarity index 100% rename from examples/passing/semantics/lock.gm rename to examples/accept/semantics/branching/lock.gm diff --git a/examples/passing/semantics/lock_as_sync.gm b/examples/accept/semantics/branching/lock_as_sync.gm similarity index 100% rename from examples/passing/semantics/lock_as_sync.gm rename to examples/accept/semantics/branching/lock_as_sync.gm diff --git a/examples/passing/semantics/spawn.gm b/examples/accept/semantics/branching/spawn.gm similarity index 100% rename from examples/passing/semantics/spawn.gm rename to examples/accept/semantics/branching/spawn.gm diff --git a/examples/passing/semantics/spawn_read_global.gm b/examples/accept/semantics/branching/spawn_read_global.gm similarity index 100% rename from examples/passing/semantics/spawn_read_global.gm rename to examples/accept/semantics/branching/spawn_read_global.gm diff --git a/examples/accept/semantics/linear/addition.gm b/examples/accept/semantics/linear/addition.gm new file mode 100644 index 0000000..3ebb8a9 --- /dev/null +++ b/examples/accept/semantics/linear/addition.gm @@ -0,0 +1,4 @@ +$r1 = 1; +x = 2; +y = $r1 + x; +assert(y + 1 != x + 1); diff --git a/examples/accept/semantics/linear/conditional_non_race.gm b/examples/accept/semantics/linear/conditional_non_race.gm new file mode 100644 index 0000000..b8d4206 --- /dev/null +++ b/examples/accept/semantics/linear/conditional_non_race.gm @@ -0,0 +1,18 @@ +x = 0; +y = 0; +$t1 = spawn { + if (x == 0) { + y = 1; + } +}; +$t2 = spawn { + if (y == 0) { + x = 1; + } +}; +join $t1; +join $t2; +// FIXME in branching these assertions are always true, but not in linear +assert (x != 0); +assert (y != 0); +assert (x == y); diff --git a/examples/accept/semantics/linear/globals.gm b/examples/accept/semantics/linear/globals.gm new file mode 100644 index 0000000..283e6eb --- /dev/null +++ b/examples/accept/semantics/linear/globals.gm @@ -0,0 +1,5 @@ +nop; +x = 0; +x = 2; +y = 2; +assert(x == y); diff --git a/examples/accept/semantics/linear/if.gm b/examples/accept/semantics/linear/if.gm new file mode 100644 index 0000000..eb270f6 --- /dev/null +++ b/examples/accept/semantics/linear/if.gm @@ -0,0 +1,13 @@ +$r = 0; +x = 0; +if ($r == 0) { + x = x + 1; +} else { + x = 0; +} +if ($r == 1) { + x = 0; +} else { + x = x + 1; +} +assert (x == 2); diff --git a/examples/accept/semantics/linear/join_fastforward.gm b/examples/accept/semantics/linear/join_fastforward.gm new file mode 100644 index 0000000..d488fe4 --- /dev/null +++ b/examples/accept/semantics/linear/join_fastforward.gm @@ -0,0 +1,4 @@ +x = 2; +$t = spawn { x = 3; }; +join $t; +assert(x == 3); diff --git a/examples/accept/semantics/linear/join_nop.gm b/examples/accept/semantics/linear/join_nop.gm new file mode 100644 index 0000000..ed09147 --- /dev/null +++ b/examples/accept/semantics/linear/join_nop.gm @@ -0,0 +1,4 @@ +x = 2; +$t = spawn { nop; }; +join $t; +assert(x == 2); diff --git a/examples/accept/semantics/linear/join_pulled_variable.gm b/examples/accept/semantics/linear/join_pulled_variable.gm new file mode 100644 index 0000000..ae481e8 --- /dev/null +++ b/examples/accept/semantics/linear/join_pulled_variable.gm @@ -0,0 +1,15 @@ +x = 0; +t = spawn { + assert(x == 0); + x = 2; + t2 = spawn { + assert(x == 2); + x = 14; + assert(x == 14); + }; + assert(x == 2); +}; +join t; +assert (x == 2); +join t2; +assert (x == 14); \ No newline at end of file diff --git a/examples/accept/semantics/linear/join_spawn.gm b/examples/accept/semantics/linear/join_spawn.gm new file mode 100644 index 0000000..ba5acf0 --- /dev/null +++ b/examples/accept/semantics/linear/join_spawn.gm @@ -0,0 +1,7 @@ +x = 1; +join spawn { + assert(x == 1); + x = 2; + assert(x == 2); +}; +assert(x == 2); \ No newline at end of file diff --git a/examples/accept/semantics/linear/local.gm b/examples/accept/semantics/linear/local.gm new file mode 100644 index 0000000..266605b --- /dev/null +++ b/examples/accept/semantics/linear/local.gm @@ -0,0 +1,7 @@ +$t1 = 1; +$t2 = 2; +nop; +$t1 = $t2; +nop; +assert($t1 == $t2); +nop; diff --git a/examples/accept/semantics/linear/lock.gm b/examples/accept/semantics/linear/lock.gm new file mode 100644 index 0000000..bc95524 --- /dev/null +++ b/examples/accept/semantics/linear/lock.gm @@ -0,0 +1,2 @@ +lock l1; +unlock l1; diff --git a/examples/accept/semantics/linear/lock_as_sync.gm b/examples/accept/semantics/linear/lock_as_sync.gm new file mode 100644 index 0000000..63fdd4b --- /dev/null +++ b/examples/accept/semantics/linear/lock_as_sync.gm @@ -0,0 +1,10 @@ +x = 0; +lock l1; +$t = spawn { + assert (x == 0); + lock l1; + unlock l1; + assert (x == 2); +}; +x = 2; +unlock l1; diff --git a/examples/accept/semantics/linear/spawn.gm b/examples/accept/semantics/linear/spawn.gm new file mode 100644 index 0000000..0200097 --- /dev/null +++ b/examples/accept/semantics/linear/spawn.gm @@ -0,0 +1,4 @@ +x = 2; +y = 2; +$t = spawn { nop; }; +assert(x == y); diff --git a/examples/accept/semantics/linear/spawn_read_global.gm b/examples/accept/semantics/linear/spawn_read_global.gm new file mode 100644 index 0000000..83c2126 --- /dev/null +++ b/examples/accept/semantics/linear/spawn_read_global.gm @@ -0,0 +1,8 @@ +x = 2; +y = 2; +$t = spawn { + assert(x == y); + x = 42; + assert(x == 42); +}; +assert(x == y); diff --git a/examples/failing/semantics/conditional_race.gm b/examples/reject/semantics/branching/conditional_race.gm similarity index 100% rename from examples/failing/semantics/conditional_race.gm rename to examples/reject/semantics/branching/conditional_race.gm diff --git a/examples/failing/semantics/deadlock.gm b/examples/reject/semantics/branching/deadlock.gm similarity index 100% rename from examples/failing/semantics/deadlock.gm rename to examples/reject/semantics/branching/deadlock.gm diff --git a/examples/failing/semantics/error_and_race.gm b/examples/reject/semantics/branching/error_and_race.gm similarity index 100% rename from examples/failing/semantics/error_and_race.gm rename to examples/reject/semantics/branching/error_and_race.gm diff --git a/examples/failing/semantics/failed_assertion.gm b/examples/reject/semantics/branching/failed_assertion.gm similarity index 100% rename from examples/failing/semantics/failed_assertion.gm rename to examples/reject/semantics/branching/failed_assertion.gm diff --git a/examples/failing/semantics/failed_assertion_many_schedules.gm b/examples/reject/semantics/branching/failed_assertion_many_schedules.gm similarity index 100% rename from examples/failing/semantics/failed_assertion_many_schedules.gm rename to examples/reject/semantics/branching/failed_assertion_many_schedules.gm diff --git a/examples/failing/semantics/failed_assertion_neq.gm b/examples/reject/semantics/branching/failed_assertion_neq.gm similarity index 100% rename from examples/failing/semantics/failed_assertion_neq.gm rename to examples/reject/semantics/branching/failed_assertion_neq.gm diff --git a/examples/failing/semantics/join_blocked_thread.gm b/examples/reject/semantics/branching/join_blocked_thread.gm similarity index 100% rename from examples/failing/semantics/join_blocked_thread.gm rename to examples/reject/semantics/branching/join_blocked_thread.gm diff --git a/examples/failing/semantics/join_datarace.gm b/examples/reject/semantics/branching/join_datarace.gm similarity index 100% rename from examples/failing/semantics/join_datarace.gm rename to examples/reject/semantics/branching/join_datarace.gm diff --git a/examples/failing/semantics/join_deadlock.gm b/examples/reject/semantics/branching/join_deadlock.gm similarity index 100% rename from examples/failing/semantics/join_deadlock.gm rename to examples/reject/semantics/branching/join_deadlock.gm diff --git a/examples/failing/semantics/join_nonexisting.gm b/examples/reject/semantics/branching/join_nonexisting.gm similarity index 100% rename from examples/failing/semantics/join_nonexisting.gm rename to examples/reject/semantics/branching/join_nonexisting.gm diff --git a/examples/failing/semantics/read_unassigned.gm b/examples/reject/semantics/branching/read_unassigned.gm similarity index 100% rename from examples/failing/semantics/read_unassigned.gm rename to examples/reject/semantics/branching/read_unassigned.gm diff --git a/examples/failing/semantics/test.gm b/examples/reject/semantics/branching/test.gm similarity index 100% rename from examples/failing/semantics/test.gm rename to examples/reject/semantics/branching/test.gm diff --git a/examples/failing/semantics/unlock.gm b/examples/reject/semantics/branching/unlock.gm similarity index 100% rename from examples/failing/semantics/unlock.gm rename to examples/reject/semantics/branching/unlock.gm diff --git a/examples/failing/semantics/unlock_another_threads_lock.gm b/examples/reject/semantics/branching/unlock_another_threads_lock.gm similarity index 100% rename from examples/failing/semantics/unlock_another_threads_lock.gm rename to examples/reject/semantics/branching/unlock_another_threads_lock.gm diff --git a/examples/reject/semantics/linear/conditional_race.gm b/examples/reject/semantics/linear/conditional_race.gm new file mode 100644 index 0000000..6dd0a36 --- /dev/null +++ b/examples/reject/semantics/linear/conditional_race.gm @@ -0,0 +1,32 @@ +x = 0; +y = 0; +flag = 0; +$t1 = spawn { + lock l1; + $r = 0; + if (flag == 0) { + flag = 1; + $r = 1; + } + unlock l1; + if ($r == 1) { + x = 1; + } +}; +$t2 = spawn { + lock l1; + $r = 0; + if (flag == 0) { + flag = 1; + $r = 1; + } + unlock l1; + if ($r == 1) { + y = 1; + } else { + x = 1; + } +}; +join $t1; +join $t2; +assert (x != y); diff --git a/examples/reject/semantics/linear/deadlock.gm b/examples/reject/semantics/linear/deadlock.gm new file mode 100644 index 0000000..ccf9fe9 --- /dev/null +++ b/examples/reject/semantics/linear/deadlock.gm @@ -0,0 +1,12 @@ +$t1 = spawn { + lock l1; + lock l2; + unlock l2; + unlock l1; +}; +$t2 = spawn { + lock l2; + lock l1; + unlock l1; + unlock l2; +}; diff --git a/examples/reject/semantics/linear/error_and_race.gm b/examples/reject/semantics/linear/error_and_race.gm new file mode 100644 index 0000000..636e9e7 --- /dev/null +++ b/examples/reject/semantics/linear/error_and_race.gm @@ -0,0 +1,13 @@ +x = 1; +$t1 = spawn { + lock l1; + x = 1; + unlock l1; +}; +$t2 = spawn { + x = 2; +}; +join $t2; +lock l1; +assert(x == 1); +unlock l2; diff --git a/examples/reject/semantics/linear/failed_assertion.gm b/examples/reject/semantics/linear/failed_assertion.gm new file mode 100644 index 0000000..2c4edd7 --- /dev/null +++ b/examples/reject/semantics/linear/failed_assertion.gm @@ -0,0 +1,6 @@ +x = 0; +$t = spawn { + x = 1; +}; +join $t; +assert(x == 0); diff --git a/examples/reject/semantics/linear/failed_assertion_many_schedules.gm b/examples/reject/semantics/linear/failed_assertion_many_schedules.gm new file mode 100644 index 0000000..9d6d1e9 --- /dev/null +++ b/examples/reject/semantics/linear/failed_assertion_many_schedules.gm @@ -0,0 +1,5 @@ +$t1 = spawn { assert(1 == 2); }; +$t2 = spawn { nop; }; +$t3 = spawn { nop; }; +$t4 = spawn { nop; }; +$t5 = spawn { nop; }; diff --git a/examples/reject/semantics/linear/failed_assertion_neq.gm b/examples/reject/semantics/linear/failed_assertion_neq.gm new file mode 100644 index 0000000..bd2db83 --- /dev/null +++ b/examples/reject/semantics/linear/failed_assertion_neq.gm @@ -0,0 +1,6 @@ +x = 0; +$t = spawn { + x = 1; +}; +join $t; +assert(x != 1); diff --git a/examples/reject/semantics/linear/join_blocked_thread.gm b/examples/reject/semantics/linear/join_blocked_thread.gm new file mode 100644 index 0000000..4aafbda --- /dev/null +++ b/examples/reject/semantics/linear/join_blocked_thread.gm @@ -0,0 +1,2 @@ +$t2 = spawn { join 0; }; +join $t2; \ No newline at end of file diff --git a/examples/reject/semantics/linear/join_datarace.gm b/examples/reject/semantics/linear/join_datarace.gm new file mode 100644 index 0000000..5b14916 --- /dev/null +++ b/examples/reject/semantics/linear/join_datarace.gm @@ -0,0 +1,10 @@ +x = 2; +$t = spawn { + assert (x == 2); + x = 3; + assert (x == 3); +}; +assert(x == 2); +x = 4; +assert(x == 4); +join $t; \ No newline at end of file diff --git a/examples/reject/semantics/linear/join_deadlock.gm b/examples/reject/semantics/linear/join_deadlock.gm new file mode 100644 index 0000000..e4feaee --- /dev/null +++ b/examples/reject/semantics/linear/join_deadlock.gm @@ -0,0 +1 @@ +join spawn { join 0; }; \ No newline at end of file diff --git a/examples/reject/semantics/linear/join_nonexisting.gm b/examples/reject/semantics/linear/join_nonexisting.gm new file mode 100644 index 0000000..3a85a1b --- /dev/null +++ b/examples/reject/semantics/linear/join_nonexisting.gm @@ -0,0 +1,2 @@ +$t = 42; +join $t; diff --git a/examples/reject/semantics/linear/read_unassigned.gm b/examples/reject/semantics/linear/read_unassigned.gm new file mode 100644 index 0000000..94120df --- /dev/null +++ b/examples/reject/semantics/linear/read_unassigned.gm @@ -0,0 +1 @@ +x = y; \ No newline at end of file diff --git a/examples/reject/semantics/linear/test.gm b/examples/reject/semantics/linear/test.gm new file mode 100644 index 0000000..6958578 --- /dev/null +++ b/examples/reject/semantics/linear/test.gm @@ -0,0 +1,11 @@ +nop; +x = 0; +$t = spawn { + lock l1; + $r = 1; + x = $r; + unlock l1; +}; +x = 2; // Data race! +join $t; +assert(x == 2); \ No newline at end of file diff --git a/examples/reject/semantics/linear/unlock.gm b/examples/reject/semantics/linear/unlock.gm new file mode 100644 index 0000000..bc083f6 --- /dev/null +++ b/examples/reject/semantics/linear/unlock.gm @@ -0,0 +1,2 @@ +unlock l1; +x=1; \ No newline at end of file diff --git a/examples/reject/semantics/linear/unlock_another_threads_lock.gm b/examples/reject/semantics/linear/unlock_another_threads_lock.gm new file mode 100644 index 0000000..41ef1ca --- /dev/null +++ b/examples/reject/semantics/linear/unlock_another_threads_lock.gm @@ -0,0 +1,5 @@ +lock l; +t = spawn { + unlock l; +}; +join t; diff --git a/examples/failing/syntax/bad_add.gm b/examples/reject/syntax/bad_add.gm similarity index 100% rename from examples/failing/syntax/bad_add.gm rename to examples/reject/syntax/bad_add.gm diff --git a/examples/failing/syntax/bad_add2.gm b/examples/reject/syntax/bad_add2.gm similarity index 100% rename from examples/failing/syntax/bad_add2.gm rename to examples/reject/syntax/bad_add2.gm diff --git a/examples/failing/syntax/bad_assert.gm b/examples/reject/syntax/bad_assert.gm similarity index 100% rename from examples/failing/syntax/bad_assert.gm rename to examples/reject/syntax/bad_assert.gm diff --git a/examples/failing/syntax/bad_condition.gm b/examples/reject/syntax/bad_condition.gm similarity index 100% rename from examples/failing/syntax/bad_condition.gm rename to examples/reject/syntax/bad_condition.gm diff --git a/examples/failing/syntax/bad_join.gm b/examples/reject/syntax/bad_join.gm similarity index 100% rename from examples/failing/syntax/bad_join.gm rename to examples/reject/syntax/bad_join.gm diff --git a/examples/failing/syntax/bad_lhs.gm b/examples/reject/syntax/bad_lhs.gm similarity index 100% rename from examples/failing/syntax/bad_lhs.gm rename to examples/reject/syntax/bad_lhs.gm diff --git a/examples/failing/syntax/bad_lock.gm b/examples/reject/syntax/bad_lock.gm similarity index 100% rename from examples/failing/syntax/bad_lock.gm rename to examples/reject/syntax/bad_lock.gm diff --git a/examples/failing/syntax/bad_rhs.gm b/examples/reject/syntax/bad_rhs.gm similarity index 100% rename from examples/failing/syntax/bad_rhs.gm rename to examples/reject/syntax/bad_rhs.gm diff --git a/examples/failing/syntax/empty.gm b/examples/reject/syntax/empty.gm similarity index 100% rename from examples/failing/syntax/empty.gm rename to examples/reject/syntax/empty.gm diff --git a/examples/failing/syntax/empty_assert.gm b/examples/reject/syntax/empty_assert.gm similarity index 100% rename from examples/failing/syntax/empty_assert.gm rename to examples/reject/syntax/empty_assert.gm diff --git a/examples/failing/syntax/empty_assign.gm b/examples/reject/syntax/empty_assign.gm similarity index 100% rename from examples/failing/syntax/empty_assign.gm rename to examples/reject/syntax/empty_assign.gm diff --git a/examples/failing/syntax/empty_brace.gm b/examples/reject/syntax/empty_brace.gm similarity index 100% rename from examples/failing/syntax/empty_brace.gm rename to examples/reject/syntax/empty_brace.gm diff --git a/examples/failing/syntax/if_no_brace.gm b/examples/reject/syntax/if_no_brace.gm similarity index 100% rename from examples/failing/syntax/if_no_brace.gm rename to examples/reject/syntax/if_no_brace.gm diff --git a/examples/failing/syntax/if_no_cond.gm b/examples/reject/syntax/if_no_cond.gm similarity index 100% rename from examples/failing/syntax/if_no_cond.gm rename to examples/reject/syntax/if_no_cond.gm diff --git a/examples/failing/syntax/no_semicolon.gm b/examples/reject/syntax/no_semicolon.gm similarity index 100% rename from examples/failing/syntax/no_semicolon.gm rename to examples/reject/syntax/no_semicolon.gm diff --git a/examples/failing/syntax/spurious_else.gm b/examples/reject/syntax/spurious_else.gm similarity index 100% rename from examples/failing/syntax/spurious_else.gm rename to examples/reject/syntax/spurious_else.gm diff --git a/examples/failing/syntax/top_eq.gm b/examples/reject/syntax/top_eq.gm similarity index 100% rename from examples/failing/syntax/top_eq.gm rename to examples/reject/syntax/top_eq.gm diff --git a/examples/failing/unassigned_reg.gm b/examples/reject/unassigned_reg.gm similarity index 100% rename from examples/failing/unassigned_reg.gm rename to examples/reject/unassigned_reg.gm diff --git a/src/debug.hh b/src/debug.hh new file mode 100644 index 0000000..e963461 --- /dev/null +++ b/src/debug.hh @@ -0,0 +1,24 @@ +#pragma once + +#include + +namespace gitmem { + +/* For debug printing */ +inline struct Verbose { + bool enabled = false; + + template const Verbose &operator<<(const T &msg) const { + if (enabled) + std::cout << msg; + return *this; + } + + const Verbose &operator<<(std::ostream &(*manip)(std::ostream &)) const { + if (enabled) + std::cout << manip; + return *this; + } +} verbose; + +} // namespace gitmem \ No newline at end of file diff --git a/src/debugger.hh b/src/debugger.hh new file mode 100644 index 0000000..789cc84 --- /dev/null +++ b/src/debugger.hh @@ -0,0 +1,8 @@ +#pragma once + +#include + +namespace gitmem { + int interpret_interactive(const trieste::Node, + const std::filesystem::path &output_file); +} \ No newline at end of file diff --git a/src/gitmem.cc b/src/gitmem.cc index bc7f49d..55c523a 100644 --- a/src/gitmem.cc +++ b/src/gitmem.cc @@ -4,6 +4,7 @@ #include "interpreter.hh" #include "model_checker.hh" #include "lang.hh" +#include "sync_protocol.hh" int main(int argc, char **argv) { using namespace trieste; @@ -30,6 +31,10 @@ int main(int argc, char **argv) { app.add_flag("-e,--explore", model_check, "Explore all possible execution paths."); + bool branching = false; + app.add_flag("-b,--branching", branching, + "Using branching semantics."); + try { app.parse(argc, argv); } catch (const CLI::ParseError &e) { @@ -61,14 +66,15 @@ int main(int argc, char **argv) { gitmem::verbose << "Output will be written to " << output_path << std::endl; int exit_status; + gitmem::SyncKind sync_kind = branching ? gitmem::SyncKind::Branching : gitmem::SyncKind::Linear; wf::push_back(gitmem::lang::wf); if (model_check) { - exit_status = gitmem::model_check(result.ast, output_path); + exit_status = gitmem::model_check(result.ast, output_path, sync_kind); } else if (interactive) { assert(false && "currently broken"); // exit_status = gitmem::interpret_interactive(result.ast, output_path); } else { - exit_status = gitmem::interpret(result.ast, output_path); + exit_status = gitmem::interpret(result.ast, output_path, sync_kind); } wf::pop_front(); diff --git a/src/interpreter.cc b/src/interpreter.cc index eb01429..92cd4bf 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -99,6 +99,12 @@ evaluate_expression(Node expr, GlobalContext &gctx, ThreadContext &ctx) { assert(false); // handle this } + // making the on_spawn and on_start events happen at thread spawn + if (std::optional> conflict = + gctx.protocol->on_start(child_ctx, gctx)) { + std::unreachable(); + } + // Spawning is a sync point, commit local pending commits, and // copy the global state to the spawned thread // commit(ctx.globals); @@ -321,10 +327,6 @@ run_single_thread_to_sync(GlobalContext &gctx, const ThreadID tid, size_t &pc = thread->pc; ThreadContext &ctx = thread->ctx; - if (pc == 0) { - gctx.protocol->on_start(thread->ctx, gctx); - } - bool first_statement = true; while (pc < block->size()) { Node stmt = block->at(pc); @@ -371,6 +373,7 @@ progress_thread(GlobalContext &gctx, const ThreadID tid, bool any_progress = std::holds_alternative(prog_or_term) && std::get(prog_or_term) == ProgressStatus::progress; + for (size_t i = no_threads; i < gctx.threads.size(); ++i) { // If there are new threads, we can run them to sync as well any_progress = true; @@ -489,9 +492,8 @@ int run_threads(GlobalContext &gctx) { return exception_detected ? 1 : 0; } -int interpret(const Node ast, const std::filesystem::path &output_path) { - // TODO: allow both protocols - GlobalContext gctx(ast, std::make_unique()); +int interpret(const Node ast, const std::filesystem::path &output_path, SyncKind sync_kind) { + GlobalContext gctx(ast, make_protocol(sync_kind)); auto result = run_threads(gctx); // gctx.print_execution_graph(output_path); FIXME diff --git a/src/interpreter.hh b/src/interpreter.hh index b6970ca..4a15326 100644 --- a/src/interpreter.hh +++ b/src/interpreter.hh @@ -5,11 +5,12 @@ #include "graphviz.hh" #include "lang.hh" #include +#include "sync_protocol.hh" namespace gitmem { // Entry function -int interpret(const trieste::Node, const std::filesystem::path &output_file); +int interpret(const trieste::Node, const std::filesystem::path &output_file, SyncKind sync_kind); // Internal functions int run_threads(GlobalContext &); diff --git a/src/model_checker.cc b/src/model_checker.cc index 7d99ef1..8c1fc7e 100644 --- a/src/model_checker.cc +++ b/src/model_checker.cc @@ -58,8 +58,8 @@ build_output_path(const std::filesystem::path &output_path, const size_t idx) { * Explore all possible execution paths of the program, printing one trace * for each distinct final state that led to an error. */ -int model_check(const Node ast, const std::filesystem::path &output_path) { - GlobalContext gctx(ast, std::make_unique()); +int model_check(const Node ast, const std::filesystem::path &output_path, SyncKind sync_kind) { + GlobalContext gctx(ast, make_protocol(sync_kind)); auto final_contexts = std::vector>{}; auto failing_contexts = std::vector>{}; @@ -159,7 +159,7 @@ int model_check(const Node ast, const std::filesystem::path &output_path) { if (cursor->complete && !root->complete) { // Reset the cursor to the root and start a new trace verbose << std::endl << "Restarting trace..." << std::endl; - gctx = GlobalContext(ast, std::make_unique()); + gctx = GlobalContext(ast, make_protocol(sync_kind)); cursor = root; current_trace.clear(); diff --git a/src/model_checker.hh b/src/model_checker.hh new file mode 100644 index 0000000..448f2c0 --- /dev/null +++ b/src/model_checker.hh @@ -0,0 +1,10 @@ +#pragma once + +#include +#include "sync_protocol.hh" + +namespace gitmem { + using namespace trieste; + + int model_check(const Node ast, const std::filesystem::path &output_path, SyncKind sync_kind); +} \ No newline at end of file diff --git a/src/sync_protocol.cc b/src/sync_protocol.cc index 2103a13..946f275 100644 --- a/src/sync_protocol.cc +++ b/src/sync_protocol.cc @@ -273,4 +273,14 @@ BranchingSyncProtocol::on_unlock(ThreadContext &thread, Lock &lock, return std::nullopt; } +std::unique_ptr make_protocol(SyncKind sync_kind) { + switch (sync_kind) { + case SyncKind::Linear: + return std::make_unique(); + case SyncKind::Branching: + return std::make_unique(); + } + std::unreachable(); +} + } // namespace gitmem diff --git a/src/sync_protocol.hh b/src/sync_protocol.hh index 7b55ae8..c4b5122 100644 --- a/src/sync_protocol.hh +++ b/src/sync_protocol.hh @@ -10,6 +10,13 @@ namespace gitmem { +enum class SyncKind { + Linear, + Branching +}; + +std::unique_ptr make_protocol(SyncKind); + struct ConflictBase { virtual ~ConflictBase() = default; virtual std::ostream &print(std::ostream &os) const = 0; diff --git a/test_gitmem.py b/test_gitmem.py index 77fb36f..b0b60c6 100644 --- a/test_gitmem.py +++ b/test_gitmem.py @@ -2,58 +2,98 @@ import subprocess import sys import argparse +from collections import defaultdict EXAMPLES_DIR = "examples" -def run_gitmem_test(gitmem_path, file_path, should_pass): - try: - result = subprocess.run([gitmem_path, file_path, "-e", "-o", "/dev/null"], capture_output=True, text=True) - passed = (result.returncode == 0) - except FileNotFoundError: - print(f"Error: '{gitmem_path}' executable not found.") - sys.exit(1) - - if passed == should_pass: - status = "PASS" - else: - status = "FAIL" +def run_gitmem_test(gitmem_path, file_path, should_accept): + try: + result = subprocess.run( + [gitmem_path, file_path, "-e", "-o", "/dev/null"], + capture_output=True, + text=True + ) + accepted = (result.returncode == 0) + except FileNotFoundError: + print(f"Error: '{gitmem_path}' executable not found.") + sys.exit(1) - print(f"[{status}] {file_path} (exit code: {result.returncode})") - return status == "PASS" + status = "PASS" if accepted == should_accept else "FAIL" + print(f"[{status}] {file_path} (exit code: {result.returncode})") + return status == "PASS" def main(): - parser = argparse.ArgumentParser(description="Test runner for gitmem.") - parser.add_argument( - "--gitmem", "-g", - required=True, - help="Path to the gitmem executable" - ) - args = parser.parse_args() - gitmem_path = args.gitmem - - total_tests = 0 - failed_tests = 0 - - for outcome in ["passing", "failing"]: - should_pass = (outcome == "passing") - for category in ["syntax", "semantics"]: - test_dir = os.path.join(EXAMPLES_DIR, outcome, category) - if not os.path.isdir(test_dir): - continue - for root, _, files in os.walk(test_dir): - for file in files: - file_path = os.path.join(root, file) - total_tests += 1 - if not run_gitmem_test(gitmem_path, file_path, should_pass): - failed_tests += 1 - - print("\nSummary:") - print(f"Total tests run: {total_tests}") - print(f"Tests failed: {failed_tests}") - print(f"Tests passed: {total_tests - failed_tests}") - - if failed_tests > 0: - sys.exit(1) + parser = argparse.ArgumentParser(description="Test runner for gitmem.") + parser.add_argument( + "--gitmem", "-g", + required=True, + help="Path to the gitmem executable" + ) + args = parser.parse_args() + gitmem_path = args.gitmem + + # results[expectation][category][subcategory] = {"total": x, "failed": y} + results = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: { + "total": 0, + "failed": 0 + }))) + + total_tests = 0 + failed_tests = 0 + + for expectation in ["accept", "reject"]: + should_accept = (expectation == "accept") + + for category in ["syntax", "semantics"]: + base_dir = os.path.join(EXAMPLES_DIR, expectation, category) + if not os.path.isdir(base_dir): + continue + + if category == "semantics": + subcategories = ["branching", "linear"] + else: + subcategories = [None] + + for subcategory in subcategories: + if subcategory: + test_dir = os.path.join(base_dir, subcategory) + if not os.path.isdir(test_dir): + continue + else: + test_dir = base_dir + + for root, _, files in os.walk(test_dir): + for file in files: + file_path = os.path.join(root, file) + + total_tests += 1 + results[expectation][category][subcategory]["total"] += 1 + + if not run_gitmem_test(gitmem_path, file_path, should_accept): + failed_tests += 1 + results[expectation][category][subcategory]["failed"] += 1 + + print("\nDetailed Summary:") + for expectation, categories in results.items(): + print(f"\n{expectation.upper()}:") + for category, subcats in categories.items(): + print(f" {category}:") + for subcategory, stats in subcats.items(): + label = subcategory if subcategory else "all" + passed = stats["total"] - stats["failed"] + print( + f" {label}: " + f"{passed}/{stats['total']} passed " + f"({stats['failed']} failed)" + ) + + print("\nOverall Summary:") + print(f"Total tests run: {total_tests}") + print(f"Tests failed: {failed_tests}") + print(f"Tests passed: {total_tests - failed_tests}") + + if failed_tests > 0: + sys.exit(1) if __name__ == "__main__": - main() + main() From 8b1d272a67f75407a68b89dbefb82aaef64a03f7 Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Tue, 23 Dec 2025 15:12:17 +0000 Subject: [PATCH 11/58] tests use the branching switch --- test_gitmem.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/test_gitmem.py b/test_gitmem.py index b0b60c6..121aed0 100644 --- a/test_gitmem.py +++ b/test_gitmem.py @@ -6,10 +6,15 @@ EXAMPLES_DIR = "examples" -def run_gitmem_test(gitmem_path, file_path, should_accept): +def run_gitmem_test(gitmem_path, file_path, should_accept, is_branching): + cmd = [gitmem_path, file_path, "-e", "-o", "/dev/null"] + + if is_branching: + cmd.insert(2, "-b") + try: result = subprocess.run( - [gitmem_path, file_path, "-e", "-o", "/dev/null"], + cmd, capture_output=True, text=True ) @@ -59,8 +64,10 @@ def main(): test_dir = os.path.join(base_dir, subcategory) if not os.path.isdir(test_dir): continue + is_branching = (subcategory == "branching") else: test_dir = base_dir + is_branching = False for root, _, files in os.walk(test_dir): for file in files: @@ -69,7 +76,12 @@ def main(): total_tests += 1 results[expectation][category][subcategory]["total"] += 1 - if not run_gitmem_test(gitmem_path, file_path, should_accept): + if not run_gitmem_test( + gitmem_path, + file_path, + should_accept, + is_branching + ): failed_tests += 1 results[expectation][category][subcategory]["failed"] += 1 From 1cef5c93bff004413557f5025ceba6fefb385981 Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Tue, 23 Dec 2025 15:16:23 +0000 Subject: [PATCH 12/58] removing copy operator from thread context and moving child ctx init around a bit on spawn to ensure the correct context and not a copy is initialised --- src/execution_state.cc | 5 ++--- src/execution_state.hh | 8 ++++++++ src/interpreter.cc | 17 ++++++++++------- src/sync_protocol.cc | 4 ++++ 4 files changed, 24 insertions(+), 10 deletions(-) diff --git a/src/execution_state.cc b/src/execution_state.cc index f298e54..cd3efb6 100644 --- a/src/execution_state.cc +++ b/src/execution_state.cc @@ -27,9 +27,8 @@ GlobalContext::GlobalContext(const trieste::Node &ast, std::unique_ptr protocol) : protocol(std::move(protocol)) { trieste::Node starting_block = ast / lang::File / lang::Block; - ThreadContext starting_ctx = {.locals = {}, - .tail = std::make_shared(0)}; - auto main_thread = std::make_shared(starting_ctx, starting_block); + ThreadContext starting_ctx(std::make_shared(0)); + auto main_thread = std::make_shared(std::move(starting_ctx), starting_block); this->threads = {main_thread}; this->locks = {}; diff --git a/src/execution_state.hh b/src/execution_state.hh index f9ffffa..846fbde 100644 --- a/src/execution_state.hh +++ b/src/execution_state.hh @@ -35,6 +35,14 @@ struct ThreadContext { std::optional linear; std::optional branching; + + ThreadContext(const ThreadContext&) = delete; + ThreadContext& operator=(const ThreadContext&) = delete; + + ThreadContext(ThreadContext&&) = default; + ThreadContext& operator=(ThreadContext&&) = default; + + ThreadContext(std::shared_ptr tail): tail(tail) {} }; struct Thread { diff --git a/src/interpreter.cc b/src/interpreter.cc index 92cd4bf..d8e2ff7 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -89,10 +89,7 @@ evaluate_expression(Node expr, GlobalContext &gctx, ThreadContext &ctx) { } else if (e == lang::Spawn) { ThreadID tid = gctx.threads.size(); auto node = std::make_shared(tid); - ThreadContext child_ctx = {std::unordered_map(), node}; - gctx.threads.push_back( - std::make_shared(child_ctx, e / lang::Block)); - thread_append_node(ctx, tid, node); + ThreadContext child_ctx(node); if (std::optional> conflict = gctx.protocol->on_spawn(ctx, child_ctx, gctx)) { @@ -105,9 +102,9 @@ evaluate_expression(Node expr, GlobalContext &gctx, ThreadContext &ctx) { std::unreachable(); } - // Spawning is a sync point, commit local pending commits, and - // copy the global state to the spawned thread - // commit(ctx.globals); + gctx.threads.push_back( + std::make_shared(std::move(child_ctx), e / lang::Block)); + thread_append_node(ctx, tid, node); return tid; } else if (e == lang::Eq || e == lang::Neq) { @@ -327,6 +324,12 @@ run_single_thread_to_sync(GlobalContext &gctx, const ThreadID tid, size_t &pc = thread->pc; ThreadContext &ctx = thread->ctx; + // TODO: one possible interpretation of on_start is to sync when the thread + // starts execution statements + // if (pc == 0) { + // gctx.protocol->on_start(thread->ctx, gctx); + // } + bool first_statement = true; while (pc < block->size()) { Node stmt = block->at(pc); diff --git a/src/sync_protocol.cc b/src/sync_protocol.cc index 946f275..1f94fb7 100644 --- a/src/sync_protocol.cc +++ b/src/sync_protocol.cc @@ -224,6 +224,10 @@ void BranchingSyncProtocol::write(ThreadContext &ctx, const std::string &var, assert(false && "Todo write"); } +// Spawning is a sync point, commit local pending commits, and +// copy the global state to the spawned thread +// commit(ctx.globals); + std::optional> BranchingSyncProtocol::on_spawn(ThreadContext &parent, ThreadContext &child, GlobalContext &) { From 38469b3571c56212af3dcec145b9a14b56c6d732 Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Tue, 23 Dec 2025 15:29:59 +0000 Subject: [PATCH 13/58] making it easier to identify failing tests --- test_gitmem.py | 33 +++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/test_gitmem.py b/test_gitmem.py index 121aed0..30532c3 100644 --- a/test_gitmem.py +++ b/test_gitmem.py @@ -6,6 +6,20 @@ EXAMPLES_DIR = "examples" +def supports_color(): + return sys.stdout.isatty() and os.getenv("NO_COLOR") is None + +def color(text, code): + if not supports_color(): + return text + return f"\033[{code}m{text}\033[0m" + +def green(text): + return color(text, "32") + +def red(text): + return color(text, "31") + def run_gitmem_test(gitmem_path, file_path, should_accept, is_branching): cmd = [gitmem_path, file_path, "-e", "-o", "/dev/null"] @@ -23,9 +37,11 @@ def run_gitmem_test(gitmem_path, file_path, should_accept, is_branching): print(f"Error: '{gitmem_path}' executable not found.") sys.exit(1) - status = "PASS" if accepted == should_accept else "FAIL" + passed = (accepted == should_accept) + status = green("PASS") if passed else red("FAIL") + print(f"[{status}] {file_path} (exit code: {result.returncode})") - return status == "PASS" + return passed def main(): parser = argparse.ArgumentParser(description="Test runner for gitmem.") @@ -45,6 +61,7 @@ def main(): total_tests = 0 failed_tests = 0 + failing_tests = [] for expectation in ["accept", "reject"]: should_accept = (expectation == "accept") @@ -76,14 +93,17 @@ def main(): total_tests += 1 results[expectation][category][subcategory]["total"] += 1 - if not run_gitmem_test( + passed = run_gitmem_test( gitmem_path, file_path, should_accept, is_branching - ): + ) + + if not passed: failed_tests += 1 results[expectation][category][subcategory]["failed"] += 1 + failing_tests.append(file_path) print("\nDetailed Summary:") for expectation, categories in results.items(): @@ -104,6 +124,11 @@ def main(): print(f"Tests failed: {failed_tests}") print(f"Tests passed: {total_tests - failed_tests}") + if failing_tests: + print("\nFailing tests:") + for path in failing_tests: + print(f" {red(path)}") + if failed_tests > 0: sys.exit(1) From 764418f09d54984ae3eca76f72de6b2c2ba60edb Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Tue, 23 Dec 2025 16:15:44 +0000 Subject: [PATCH 14/58] working on fixing up linear tests --- .../semantics/linear/join_pulled_variable.gm | 2 +- src/interpreter.cc | 20 ++++++++++++++----- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/examples/accept/semantics/linear/join_pulled_variable.gm b/examples/accept/semantics/linear/join_pulled_variable.gm index ae481e8..ef82df7 100644 --- a/examples/accept/semantics/linear/join_pulled_variable.gm +++ b/examples/accept/semantics/linear/join_pulled_variable.gm @@ -10,6 +10,6 @@ t = spawn { assert(x == 2); }; join t; -assert (x == 2); +// assert (x == 2 || x == 42); || not support but, in linear we only know this to be true join t2; assert (x == 14); \ No newline at end of file diff --git a/src/interpreter.cc b/src/interpreter.cc index d8e2ff7..68aae82 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -31,7 +31,12 @@ bool is_syncing(Node stmt) { } bool is_syncing(Thread &thread) { - return !thread.terminated && is_syncing(thread.block->at(thread.pc)); + // Can only be true if a thread hasn't terminated + // Either it has executed all statements but not yet terminated (and my sync) + // Or it is at a synchronisation node + // The lazy eval here is important + return !thread.terminated && + ((thread.pc >= thread.block->size()) || is_syncing(thread.block->at(thread.pc))); } template @@ -356,11 +361,16 @@ run_single_thread_to_sync(GlobalContext &gctx, const ThreadID tid, first_statement = false; } - thread->terminated = TerminationStatus::completed; - gctx.protocol->on_end(thread->ctx, gctx); + // End should be it's own sync step + if (first_statement) { + thread->terminated = TerminationStatus::completed; + gctx.protocol->on_end(thread->ctx, gctx); - thread_append_node(ctx); - return TerminationStatus::completed; + thread_append_node(ctx); + return TerminationStatus::completed; + } + + return ProgressStatus::progress; } /** From ac7fbe33a9893bbac7872006f79640462fbf6070 Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Wed, 24 Dec 2025 20:31:01 +0000 Subject: [PATCH 15/58] adding printing and equality methods --- CMakeLists.txt | 2 +- .../semantics/linear/conditional_race.gm | 5 +- src/branching/version_store.hh | 13 +- src/debugger.cc | 76 +++------ src/debugger.hh | 4 +- src/execution_state.cc | 144 +++++++++++++----- src/execution_state.hh | 15 ++ src/gitmem.cc | 4 +- src/linear/version_store.cc | 5 + src/linear/version_store.hh | 18 +++ src/sync_protocol.cc | 10 ++ src/sync_protocol.hh | 11 ++ 12 files changed, 204 insertions(+), 103 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6e39979..6e8f9d9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -27,7 +27,7 @@ add_executable(gitmem src/passes/branching.cc src/linear/version_store.cc src/interpreter.cc - # src/debugger.cc + src/debugger.cc src/model_checker.cc src/sync_protocol.cc src/graphviz.cc diff --git a/examples/reject/semantics/linear/conditional_race.gm b/examples/reject/semantics/linear/conditional_race.gm index 6dd0a36..6adbc05 100644 --- a/examples/reject/semantics/linear/conditional_race.gm +++ b/examples/reject/semantics/linear/conditional_race.gm @@ -1,6 +1,9 @@ x = 0; y = 0; flag = 0; +// if t1 gets the lock l1 first then there is a race +// t1 will see the flag is not set, set the flag to 1, and set its x to 1 +// t2 will see the flag is set, and will also set its x at 1 $t1 = spawn { lock l1; $r = 0; @@ -29,4 +32,4 @@ $t2 = spawn { }; join $t1; join $t2; -assert (x != y); +assert (x != y); \ No newline at end of file diff --git a/src/branching/version_store.hh b/src/branching/version_store.hh index fed99f7..e20b706 100644 --- a/src/branching/version_store.hh +++ b/src/branching/version_store.hh @@ -31,7 +31,18 @@ struct Conflict { std::pair commits; }; -struct LocalVersionStore {}; +struct LocalVersionStore { + friend std::ostream& operator<<(std::ostream& os, const LocalVersionStore& store) { + assert(false && "TODO"); + return os; + } + + bool operator==(const LocalVersionStore& other) const { + assert(false && "TODO"); + return false; + } + +}; // Join logic // commit(ctx.globals); diff --git a/src/debugger.cc b/src/debugger.cc index 2c7d597..37e2eb9 100644 --- a/src/debugger.cc +++ b/src/debugger.cc @@ -1,5 +1,6 @@ #include +#include "debug.hh" #include "debugger.hh" #include "interpreter.hh" @@ -21,54 +22,17 @@ struct Command { ThreadID argument = 0; }; -void show_global(const std::string &var, const Global &global) { - std::cout << var << " = " << global.val << " [" - << (global.commit ? std::to_string(*global.commit) : "_") << "; "; - for (size_t i = 0; i < global.history.size(); ++i) { - std::cout << global.history[i]; - if (i < global.history.size() - 1) { - std::cout << ", "; - } - } - std::cout << "]" << std::endl; -} - -/** Print the state of a thread, including its local and global variables, - * and the current position in the program. */ -void show_thread(const Thread &thread, size_t tid) { - std::cout << "---- Thread " << tid << std::endl; - if (thread.ctx.locals.size() > 0) { - for (auto &[reg, val] : thread.ctx.locals) { - std::cout << reg << " = " << val << std::endl; - } - std::cout << "--" << std::endl; - } - - if (thread.ctx.globals.size() > 0) { - for (auto &[var, val] : thread.ctx.globals) { - show_global(var, val); - } - std::cout << "--" << std::endl; - } - - size_t idx = 0; - for (const auto &stmt : *thread.block) { - if (idx == thread.pc) { - std::cout << "-> "; - } else { - std::cout << " "; - } - // Fix indentation of nested blocks - auto s = std::string(stmt->location().view()); - s = std::regex_replace(s, std::regex("\n"), "\n "); - std::cout << s << ";" << std::endl; - - idx++; - } - if (thread.pc == thread.block->size()) { - std::cout << "-> " << std::endl; - } -} +// void show_global(const std::string &var, const Global &global) { +// std::cout << var << " = " << global.val << " [" +// << (global.commit ? std::to_string(*global.commit) : "_") << "; "; +// for (size_t i = 0; i < global.history.size(); ++i) { +// std::cout << global.history[i]; +// if (i < global.history.size() - 1) { +// std::cout << ", "; +// } +// } +// std::cout << "]" << std::endl; +// } void show_lock(const std::string &lock_name, const struct Lock &lock) { std::cout << lock_name << ": "; @@ -78,9 +42,9 @@ void show_lock(const std::string &lock_name, const struct Lock &lock) { std::cout << ""; } std::cout << std::endl; - for (auto &[var, global] : lock.globals) { - show_global(var, global); - } + // for (auto &[var, global] : lock.globals) { + // show_global(var, global); + // } } /** Show the global context, including locks and non-completed threads. If @@ -93,7 +57,8 @@ void show_global_context(const GlobalContext &gctx, bool show_all = false) { auto thread = threads[i]; if (show_all || !thread->terminated || *threads[i]->terminated != TerminationStatus::completed) { - show_thread(*threads[i], i); + std::cout << "---- Thread " << i << std::endl; + std::cout << *threads[i] << std::endl; std::cout << std::endl; showed_any = true; } @@ -216,8 +181,9 @@ bool step_thread(ThreadID tid, GlobalContext &gctx, std::string &msg) { /** Interpret the AST in an interactive way, letting the user choose which * thread to schedule next. */ int interpret_interactive(const trieste::Node ast, - const std::filesystem::path &output_file) { - GlobalContext gctx(ast); + const std::filesystem::path &output_file, + SyncKind sync_kind) { + GlobalContext gctx(ast, make_protocol(sync_kind)); size_t prev_no_threads = 1; Command command = {Command::List}; @@ -267,7 +233,7 @@ int interpret_interactive(const trieste::Node ast, } } else if (command.cmd == Command::Restart) { // Start the program from the beginning - gctx = GlobalContext(ast); + gctx = GlobalContext(ast, make_protocol(sync_kind)); command = {Command::List}; if (print_graphs) { gctx.print_execution_graph(output_file); diff --git a/src/debugger.hh b/src/debugger.hh index 789cc84..c891fad 100644 --- a/src/debugger.hh +++ b/src/debugger.hh @@ -1,8 +1,10 @@ #pragma once #include +#include "sync_protocol.hh" namespace gitmem { int interpret_interactive(const trieste::Node, - const std::filesystem::path &output_file); + const std::filesystem::path &output_file, + SyncKind sync_kind); } \ No newline at end of file diff --git a/src/execution_state.cc b/src/execution_state.cc index cd3efb6..4b66b99 100644 --- a/src/execution_state.cc +++ b/src/execution_state.cc @@ -1,26 +1,44 @@ +#include + #include "execution_state.hh" #include "sync_protocol.hh" namespace gitmem { +bool ThreadContext::operator==(const ThreadContext &other) const { + if (locals != other.locals) + return false; + + // ignore the graph node, we're not interested in that + + if (linear) { + return other.linear && (linear->store == other.linear->store); + } else if (branching) { + return other.branching && (branching->store == other.branching->store); + } + + return !other.linear && !other.branching; +} + +// An old comment on equals +// Globals have a history that we don't care about, so we only +// compare values +// if (ctx.globals.size() != other.ctx.globals.size()) +// return false; +// for (const auto &[var, global] : ctx.globals) +// { +// if (!other.ctx.globals.contains(var) || +// ctx.globals.at(var).val != other.ctx.globals.at(var).val) +// { +// return false; +// } +// } + bool Thread::operator==(const Thread &other) const { - return false; - // Globals have a history that we don't care about, so we only - // compare values - // if (ctx.globals.size() != other.ctx.globals.size()) - // return false; - // for (const auto &[var, global] : ctx.globals) - // { - // if (!other.ctx.globals.contains(var) || - // ctx.globals.at(var).val != other.ctx.globals.at(var).val) - // { - // return false; - // } - // } - // return ctx.locals == other.ctx.locals && - // block == other.block && - // pc == other.pc && - // terminated == other.terminated; + return ctx == other.ctx && + block == other.block && + pc == other.pc && + terminated == other.terminated; } GlobalContext::GlobalContext(const trieste::Node &ast, @@ -60,31 +78,73 @@ void GlobalContext::print_execution_graph( } bool GlobalContext::operator==(const GlobalContext &other) const { - return false; - // if (threads.size() != other.threads.size() || locks.size() != - // other.locks.size()) - // return false; - - // // Threads may have been spawned in a different order, so we - // // find the thread with the same block in the other context - // for (auto &thread : threads) - // { - // auto it = std::find_if(other.threads.begin(), other.threads.end(), - // [&thread](auto &t) - // { return t->block == thread->block; }); - // if (it == other.threads.end() || !(*thread == **it)) - // return false; - // } - - // for (auto &[name, lock] : locks) - // { - // if (!other.locks.contains(name)) - // return false; - // auto &other_lock = other.locks.at(name); - // if (lock.owner != other_lock.owner) - // return false; - // } - // return true; + if (threads.size() != other.threads.size() || + locks.size() != other.locks.size()) + return false; + + // Threads may have been spawned in a different order, so we + // find the thread with the same block in the other context + for (auto &thread : threads) { + auto it = std::find_if(other.threads.begin(), other.threads.end(), + [&thread](auto &t) + { return t->block == thread->block; }); + if (it == other.threads.end() || !(*thread == **it)) + return false; + } + + for (auto &[name, lock] : locks) { + if (!other.locks.contains(name)) + return false; + auto &other_lock = other.locks.at(name); + if (lock.owner != other_lock.owner) + return false; + } + return true; +} + +/** Print the state of a thread, including its local and global variables, + * and the current position in the program. */ +std::ostream& operator<<(std::ostream& os, const Thread& thread) { + os << thread.ctx << std::endl; + + size_t idx = 0; + for (const auto &stmt : *(thread.block)) { + if (idx == thread.pc) { + os << "-> "; + } else { + os << " "; + } + + // This should be somewhere else + // Fix indentation of nested blocks + auto s = std::string(stmt->location().view()); + s = std::regex_replace(s, std::regex("\n"), "\n "); + os << s << ";" << std::endl; + + idx++; + } + if (thread.pc == thread.block->size()) { + os << "-> " << std::endl; + } + + return os; +} + +std::ostream& operator<<(std::ostream& os, const ThreadContext& ctx) { + if (ctx.locals.size() > 0) { + for (auto &[reg, val] : ctx.locals) { + os << reg << " = " << val << std::endl; + } + os << "--" << std::endl; + } + + if (ctx.linear) { + os << ctx.linear->store << std::endl; + } else if (ctx.branching) { + os << ctx.branching->store << std::endl; + } + + return os; } } // namespace gitmem \ No newline at end of file diff --git a/src/execution_state.hh b/src/execution_state.hh index 846fbde..9615f22 100644 --- a/src/execution_state.hh +++ b/src/execution_state.hh @@ -43,6 +43,10 @@ struct ThreadContext { ThreadContext& operator=(ThreadContext&&) = default; ThreadContext(std::shared_ptr tail): tail(tail) {} + + bool operator==(const ThreadContext &other) const; + + friend std::ostream& operator<<(std::ostream&, const ThreadContext&); }; struct Thread { @@ -51,7 +55,18 @@ struct Thread { size_t pc = 0; std::optional terminated = std::nullopt; + Thread(ThreadContext&& ctx, trieste::Node block): + ctx(std::move(ctx)), block(block) {}; + + Thread(const Thread&) = delete; + Thread& operator=(const Thread&) = delete; + + Thread(Thread&&) = default; + Thread& operator=(Thread&&) = default; + bool operator==(const Thread &other) const; + + friend std::ostream& operator<<(std::ostream&, const Thread&); }; using ThreadID = size_t; diff --git a/src/gitmem.cc b/src/gitmem.cc index 55c523a..6fda58a 100644 --- a/src/gitmem.cc +++ b/src/gitmem.cc @@ -3,6 +3,7 @@ #include "debug.hh" #include "interpreter.hh" #include "model_checker.hh" +#include "debugger.hh" #include "lang.hh" #include "sync_protocol.hh" @@ -71,8 +72,7 @@ int main(int argc, char **argv) { if (model_check) { exit_status = gitmem::model_check(result.ast, output_path, sync_kind); } else if (interactive) { - assert(false && "currently broken"); - // exit_status = gitmem::interpret_interactive(result.ast, output_path); + exit_status = gitmem::interpret_interactive(result.ast, output_path, sync_kind); } else { exit_status = gitmem::interpret(result.ast, output_path, sync_kind); } diff --git a/src/linear/version_store.cc b/src/linear/version_store.cc index dc2404b..692ded6 100644 --- a/src/linear/version_store.cc +++ b/src/linear/version_store.cc @@ -25,6 +25,11 @@ std::optional LocalVersionStore::get_staged(ObjectNumber obj) { return it != _staging.end() ? std::make_optional(it->second) : std::nullopt; } +bool LocalVersionStore::operator==(const LocalVersionStore& other) const { + return _base_timestamp == other._base_timestamp && + _staging == other._staging; +} + // ----------------------------- // GlobalVersionStore // ----------------------------- diff --git a/src/linear/version_store.hh b/src/linear/version_store.hh index 81210fa..6c851c7 100644 --- a/src/linear/version_store.hh +++ b/src/linear/version_store.hh @@ -92,6 +92,24 @@ public: void clear_staging(); void advance_base(Timestamp ts); std::optional get_staged(ObjectNumber obj); + + bool operator==(const LocalVersionStore& other) const; + + friend std::ostream& operator<<(std::ostream& os, const LocalVersionStore& store) { + os << "LocalVersionStore{" + << "base=" << store._base_timestamp + << ", staged={"; + + bool first = true; + for (const auto& [obj, val] : store._staging) { + if (!first) os << ", "; + first = false; + os << obj << "->" << val; + } + + os << "}}"; + return os; + } }; // ----------------------------- diff --git a/src/sync_protocol.cc b/src/sync_protocol.cc index 1f94fb7..9d644f8 100644 --- a/src/sync_protocol.cc +++ b/src/sync_protocol.cc @@ -14,6 +14,11 @@ template std::ostream &Conflict::print(std::ostream &os) const { // LinearSyncProtocol // -------------------- +std::ostream &LinearSyncProtocol::print(std::ostream &os) const { + assert(false && "todo"); + return os; +} + std::optional LinearSyncProtocol::push(linear::LocalVersionStore &local) { if (auto conflict = _global_store.check_conflicts(local.base_timestamp(), @@ -144,6 +149,11 @@ LinearSyncProtocol::on_unlock(ThreadContext &thread, Lock &, BranchingSyncProtocol::~BranchingSyncProtocol() = default; +std::ostream &BranchingSyncProtocol::print(std::ostream &os) const { + assert(false && "TODO"); + return os; +} + // /* At a commit point, walk through all the versioned variables and see if // * they have a pending commit, if so commit the value by appending to // * the variables history. diff --git a/src/sync_protocol.hh b/src/sync_protocol.hh index c4b5122..3ea32df 100644 --- a/src/sync_protocol.hh +++ b/src/sync_protocol.hh @@ -70,6 +70,13 @@ public: virtual std::optional> on_unlock(ThreadContext &thread, Lock &lock, GlobalContext &gctx) = 0; + + + virtual std::ostream &print(std::ostream &os) const = 0; + friend std::ostream &operator<<(std::ostream &os, + const SyncProtocol &protocol) { + return protocol.print(os); + } }; // --------------------------------- @@ -115,6 +122,8 @@ public: std::optional> on_unlock(ThreadContext &thread, Lock &lock, GlobalContext &gctx) override; + + std::ostream &print(std::ostream &os) const override; }; class BranchingSyncProtocol final : public SyncProtocol { @@ -148,6 +157,8 @@ public: std::optional> on_unlock(ThreadContext &thread, Lock &lock, GlobalContext &gctx) override; + + std::ostream &print(std::ostream &os) const override; }; } // namespace gitmem From 8ea2f42cfbdb446dd9bd259ec8a312b5ba654a0f Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Wed, 24 Dec 2025 23:02:01 +0000 Subject: [PATCH 16/58] some printing algorithms --- src/debugger.cc | 49 +---------------- src/execution_state.cc | 105 ++++++++++++++++++++++++++++++------ src/execution_state.hh | 12 +++-- src/interpreter.cc | 16 ++---- src/linear/version_store.cc | 36 +++++++++++++ src/linear/version_store.hh | 18 ++----- src/sync_protocol.cc | 35 ++++++++---- src/sync_protocol.hh | 17 ++---- 8 files changed, 172 insertions(+), 116 deletions(-) diff --git a/src/debugger.cc b/src/debugger.cc index 37e2eb9..519de08 100644 --- a/src/debugger.cc +++ b/src/debugger.cc @@ -22,58 +22,11 @@ struct Command { ThreadID argument = 0; }; -// void show_global(const std::string &var, const Global &global) { -// std::cout << var << " = " << global.val << " [" -// << (global.commit ? std::to_string(*global.commit) : "_") << "; "; -// for (size_t i = 0; i < global.history.size(); ++i) { -// std::cout << global.history[i]; -// if (i < global.history.size() - 1) { -// std::cout << ", "; -// } -// } -// std::cout << "]" << std::endl; -// } - -void show_lock(const std::string &lock_name, const struct Lock &lock) { - std::cout << lock_name << ": "; - if (lock.owner) { - std::cout << "held by thread " << *lock.owner; - } else { - std::cout << ""; - } - std::cout << std::endl; - // for (auto &[var, global] : lock.globals) { - // show_global(var, global); - // } -} - /** Show the global context, including locks and non-completed threads. If * show_all is true, show all threads, even those that have terminated * normally. */ void show_global_context(const GlobalContext &gctx, bool show_all = false) { - auto &threads = gctx.threads; - bool showed_any = false; - for (size_t i = 0; i < threads.size(); i++) { - auto thread = threads[i]; - if (show_all || !thread->terminated || - *threads[i]->terminated != TerminationStatus::completed) { - std::cout << "---- Thread " << i << std::endl; - std::cout << *threads[i] << std::endl; - std::cout << std::endl; - showed_any = true; - } - } - - if (showed_any && gctx.locks.size() > 0) { - std::cout << "---- Locks" << std::endl; - - for (const auto &[lock_name, lock] : gctx.locks) { - show_lock(lock_name, lock); - } - - if (gctx.locks.size() > 0) - std::cout << "--" << std::endl; - } + std::cout << gctx << std::endl; } /** Parse a command. See the help string for the 'Info' command for details. diff --git a/src/execution_state.cc b/src/execution_state.cc index 4b66b99..54293d0 100644 --- a/src/execution_state.cc +++ b/src/execution_state.cc @@ -5,19 +5,39 @@ namespace gitmem { +ThreadContext::ThreadContext(std::shared_ptr tail, SyncKind sync_kind): tail(tail) { + switch (sync_kind) { + case SyncKind::Linear: + sync.emplace(); + break; + case SyncKind::Branching: + sync.emplace(); + break; + } +} + bool ThreadContext::operator==(const ThreadContext &other) const { if (locals != other.locals) return false; // ignore the graph node, we're not interested in that - if (linear) { - return other.linear && (linear->store == other.linear->store); - } else if (branching) { - return other.branching && (branching->store == other.branching->store); - } + if (sync.index() != other.sync.index()) + return false; + + return std::visit([&](const auto& a, const auto& b) -> bool { + using A = std::decay_t; + using B = std::decay_t; - return !other.linear && !other.branching; + if constexpr (std::is_same_v && + std::is_same_v) { + return true; + } else if constexpr (std::is_same_v) { + return a.store == b.store; + } else { + return false; // unreachable due to index check + } + }, sync, other.sync); } // An old comment on equals @@ -45,7 +65,7 @@ GlobalContext::GlobalContext(const trieste::Node &ast, std::unique_ptr protocol) : protocol(std::move(protocol)) { trieste::Node starting_block = ast / lang::File / lang::Block; - ThreadContext starting_ctx(std::make_shared(0)); + ThreadContext starting_ctx(std::make_shared(0), this->protocol->kind()); auto main_thread = std::make_shared(std::move(starting_ctx), starting_block); this->threads = {main_thread}; @@ -131,20 +151,75 @@ std::ostream& operator<<(std::ostream& os, const Thread& thread) { } std::ostream& operator<<(std::ostream& os, const ThreadContext& ctx) { - if (ctx.locals.size() > 0) { - for (auto &[reg, val] : ctx.locals) { - os << reg << " = " << val << std::endl; + os << "ThreadContext{locals={"; + + bool first = true; + for (const auto& [k, v] : ctx.locals) { + if (!first) os << ", "; + first = false; + os << k << "=" << v; + } + + os << "}"; //, tail=" << ctx.tail; + + std::visit([&](const auto& data) { + using T = std::decay_t; + + if constexpr (std::is_same_v) { + os << ", sync=linear{" << data.store << "}"; + } else if constexpr (std::is_same_v) { + os << ", sync=branching{" << data.store << "}"; + } + }, ctx.sync); + + os << "}"; + return os; +} + +void show_lock(const std::string &lock_name, const struct Lock &lock) { + std::cout << lock_name << ": "; + if (lock.owner) { + std::cout << "held by thread " << *lock.owner; + } else { + std::cout << ""; + } + std::cout << std::endl; + // for (auto &[var, global] : lock.globals) { + // show_global(var, global); + // } +} + +std::ostream& operator<<(std::ostream& os, const GlobalContext& gctx) { + os << *gctx.protocol << std::endl; + + bool show_all = false; + + auto &threads = gctx.threads; + bool showed_any = false; + for (size_t i = 0; i < threads.size(); i++) { + auto thread = threads[i]; + if (show_all || !thread->terminated || + *threads[i]->terminated != TerminationStatus::completed) { + os << "---- Thread " << i << std::endl; + os << *threads[i] << std::endl; + os << std::endl; + showed_any = true; } - os << "--" << std::endl; } - if (ctx.linear) { - os << ctx.linear->store << std::endl; - } else if (ctx.branching) { - os << ctx.branching->store << std::endl; + if (showed_any && gctx.locks.size() > 0) { + os << "---- Locks" << std::endl; + + for (const auto &[lock_name, lock] : gctx.locks) { + show_lock(lock_name, lock); + } + + if (gctx.locks.size() > 0) + os << "--" << std::endl; } return os; } + } // namespace gitmem \ No newline at end of file diff --git a/src/execution_state.hh b/src/execution_state.hh index 9615f22..c8a7814 100644 --- a/src/execution_state.hh +++ b/src/execution_state.hh @@ -5,10 +5,11 @@ #include #include -#include "branching/version_store.hh" -#include "graphviz.hh" #include "lang.hh" +#include "sync_kind.hh" #include "linear/version_store.hh" +#include "branching/version_store.hh" +#include "graphviz.hh" namespace gitmem { @@ -33,8 +34,7 @@ struct ThreadContext { branching::LocalVersionStore store; }; - std::optional linear; - std::optional branching; + std::variant sync; ThreadContext(const ThreadContext&) = delete; ThreadContext& operator=(const ThreadContext&) = delete; @@ -42,7 +42,7 @@ struct ThreadContext { ThreadContext(ThreadContext&&) = default; ThreadContext& operator=(ThreadContext&&) = default; - ThreadContext(std::shared_ptr tail): tail(tail) {} + ThreadContext(std::shared_ptr tail, SyncKind sync_kind); bool operator==(const ThreadContext &other) const; @@ -110,6 +110,8 @@ struct GlobalContext { bool operator==(const GlobalContext &other) const; + friend std::ostream& operator<<(std::ostream&, const GlobalContext&); + void print_execution_graph(const std::filesystem::path &output_path) const; }; diff --git a/src/interpreter.cc b/src/interpreter.cc index 68aae82..1daa282 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -94,19 +94,13 @@ evaluate_expression(Node expr, GlobalContext &gctx, ThreadContext &ctx) { } else if (e == lang::Spawn) { ThreadID tid = gctx.threads.size(); auto node = std::make_shared(tid); - ThreadContext child_ctx(node); + ThreadContext child_ctx(node, gctx.protocol->kind()); if (std::optional> conflict = gctx.protocol->on_spawn(ctx, child_ctx, gctx)) { assert(false); // handle this } - // making the on_spawn and on_start events happen at thread spawn - if (std::optional> conflict = - gctx.protocol->on_start(child_ctx, gctx)) { - std::unreachable(); - } - gctx.threads.push_back( std::make_shared(std::move(child_ctx), e / lang::Block)); thread_append_node(ctx, tid, node); @@ -329,11 +323,11 @@ run_single_thread_to_sync(GlobalContext &gctx, const ThreadID tid, size_t &pc = thread->pc; ThreadContext &ctx = thread->ctx; - // TODO: one possible interpretation of on_start is to sync when the thread + // one possible interpretation of on_start is to sync when the thread // starts execution statements - // if (pc == 0) { - // gctx.protocol->on_start(thread->ctx, gctx); - // } + if (pc == 0) { + gctx.protocol->on_start(thread->ctx, gctx); + } bool first_statement = true; while (pc < block->size()) { diff --git a/src/linear/version_store.cc b/src/linear/version_store.cc index 692ded6..a30e150 100644 --- a/src/linear/version_store.cc +++ b/src/linear/version_store.cc @@ -30,6 +30,22 @@ bool LocalVersionStore::operator==(const LocalVersionStore& other) const { _staging == other._staging; } +std::ostream& operator<<(std::ostream& os, const LocalVersionStore& store) { + os << "LocalVersionStore{" + << "base=" << store._base_timestamp + << ", staged={"; + + bool first = true; + for (const auto& [obj, val] : store._staging) { + if (!first) os << ", "; + first = false; + os << obj << "->" << val; + } + + os << "}}"; + return os; +} + // ----------------------------- // GlobalVersionStore // ----------------------------- @@ -105,6 +121,26 @@ Timestamp GlobalVersionStore::apply_changes( return new_ts; } +std::ostream& operator<<(std::ostream& os, const GlobalVersionStore& store) { + os << "GlobalVersionStore(timestamp=" << store._timestamp + << ", next_object=" << store._next_object << ")\n"; + + for (const auto& [obj_num, history] : store._history) { + os << " Object " << obj_num; + auto it = std::find_if(store._object_numbers.begin(), store._object_numbers.end(), + [&](const auto& pair){ return pair.second == obj_num; }); + if (it != store._object_numbers.end()) + os << " (" << it->first << ")"; + os << ":\n"; + + for (const auto& version : history) { + os << " [" << version.timestamp() << "] = " << version.value() << "\n"; + } + } + + return os; +} + } // namespace linear } // namespace gitmem \ No newline at end of file diff --git a/src/linear/version_store.hh b/src/linear/version_store.hh index 6c851c7..c6d05c6 100644 --- a/src/linear/version_store.hh +++ b/src/linear/version_store.hh @@ -95,21 +95,7 @@ public: bool operator==(const LocalVersionStore& other) const; - friend std::ostream& operator<<(std::ostream& os, const LocalVersionStore& store) { - os << "LocalVersionStore{" - << "base=" << store._base_timestamp - << ", staged={"; - - bool first = true; - for (const auto& [obj, val] : store._staging) { - if (!first) os << ", "; - first = false; - os << obj << "->" << val; - } - - os << "}}"; - return os; - } + friend std::ostream& operator<<(std::ostream&, const LocalVersionStore&); }; // ----------------------------- @@ -137,6 +123,8 @@ public: Timestamp apply_changes(Timestamp base, const std::unordered_map &changes); + + friend std::ostream& operator<<(std::ostream&, const GlobalVersionStore&); }; } // namespace linear diff --git a/src/sync_protocol.cc b/src/sync_protocol.cc index 9d644f8..cb8d0d5 100644 --- a/src/sync_protocol.cc +++ b/src/sync_protocol.cc @@ -15,7 +15,7 @@ template std::ostream &Conflict::print(std::ostream &os) const { // -------------------- std::ostream &LinearSyncProtocol::print(std::ostream &os) const { - assert(false && "todo"); + os << _global_store << std::endl; return os; } @@ -58,11 +58,13 @@ std::optional LinearSyncProtocol::read(ThreadContext &ctx, const std::string &var) { linear::ObjectNumber number = _global_store.get_object_number(var); - if (auto result = store(ctx).get_staged(number)) + auto& store = std::get(ctx.sync).store; + + if (auto result = store.get_staged(number)) return result; std::optional value = _global_store.get_version_for_timestamp( - number, store(ctx).base_timestamp()); + number, store.base_timestamp()); if (!value) return std::nullopt; @@ -76,7 +78,8 @@ std::optional LinearSyncProtocol::read(ThreadContext &ctx, void LinearSyncProtocol::write(ThreadContext &ctx, const std::string &var, size_t value) { // write into the staging area of the thread - store(ctx).stage(_global_store.get_object_number(var), value); + auto& store = std::get(ctx.sync).store; + store.stage(_global_store.get_object_number(var), value); } std::optional> @@ -86,9 +89,16 @@ LinearSyncProtocol::on_spawn(ThreadContext &parent, ThreadContext &child, // added // push parent to global history - if (auto conflict = push(store(parent))) + auto& store = std::get(parent.sync).store; + if (auto conflict = push(store)) return std::make_unique(std::move(*conflict)); + // pull into the child + store = std::get(child.sync).store; + if (auto conflict = pull(store)) { + std::unreachable(); + } + return std::nullopt; } @@ -98,7 +108,8 @@ LinearSyncProtocol::on_join(ThreadContext &joiner, ThreadContext &joinee, // we assume the joinee has already terminated and pushed // pull changes into parent - if (auto conflict = pull(store(joiner))) + auto& store = std::get(joiner.sync).store; + if (auto conflict = pull(store)) return std::make_unique(std::move(*conflict)); return std::nullopt; @@ -107,7 +118,8 @@ LinearSyncProtocol::on_join(ThreadContext &joiner, ThreadContext &joinee, std::optional> LinearSyncProtocol::on_start(ThreadContext &thread, GlobalContext &gctx) { // pull state from global history - auto conflict = pull(store(thread)); + auto& store = std::get(thread.sync).store; + auto conflict = pull(store); assert(!conflict && "cannot conflict from starting state"); return std::nullopt; @@ -116,7 +128,8 @@ LinearSyncProtocol::on_start(ThreadContext &thread, GlobalContext &gctx) { std::optional> LinearSyncProtocol::on_end(ThreadContext &thread, GlobalContext &gctx) { // push changes to global history - if (auto conflict = push(store(thread))) + auto& store = std::get(thread.sync).store; + if (auto conflict = push(store)) return std::make_unique(std::move(*conflict)); return std::nullopt; @@ -126,7 +139,8 @@ std::optional> LinearSyncProtocol::on_lock(ThreadContext &thread, Lock &lock, GlobalContext &gctx) { - if (auto conflict = pull(store(thread))) + auto& store = std::get(thread.sync).store; + if (auto conflict = pull(store)) return std::make_unique(std::move(*conflict)); return std::nullopt; @@ -137,7 +151,8 @@ LinearSyncProtocol::on_unlock(ThreadContext &thread, Lock &, GlobalContext &gctx) { // push changes to global history - if (auto conflict = push(store(thread))) + auto& store = std::get(thread.sync).store; + if (auto conflict = push(store)) return std::make_unique(std::move(*conflict)); return std::nullopt; diff --git a/src/sync_protocol.hh b/src/sync_protocol.hh index 3ea32df..6f008dc 100644 --- a/src/sync_protocol.hh +++ b/src/sync_protocol.hh @@ -1,7 +1,8 @@ #pragma once -#include "branching/version_store.hh" +#include "sync_kind.hh" #include "execution_state.hh" +#include "branching/version_store.hh" #include "linear/version_store.hh" #include #include @@ -10,11 +11,6 @@ namespace gitmem { -enum class SyncKind { - Linear, - Branching -}; - std::unique_ptr make_protocol(SyncKind); struct ConflictBase { @@ -42,6 +38,7 @@ using BranchingConflict = Conflict; class SyncProtocol { public: virtual ~SyncProtocol() = default; + virtual SyncKind kind() const = 0; // Read a shared variable into the thread context virtual std::optional read(ThreadContext &ctx, @@ -86,17 +83,12 @@ public: class LinearSyncProtocol final : public SyncProtocol { linear::GlobalVersionStore _global_store; - static linear::LocalVersionStore &store(ThreadContext &ctx) { - if (!ctx.linear) - ctx.linear.emplace(); - return ctx.linear->store; - } - std::optional push(linear::LocalVersionStore &local); std::optional pull(linear::LocalVersionStore &local); public: ~LinearSyncProtocol() override; + SyncKind kind() const override { return SyncKind::Linear; }; std::optional read(ThreadContext &ctx, const std::string &var) override; @@ -132,6 +124,7 @@ class BranchingSyncProtocol final : public SyncProtocol { public: ~BranchingSyncProtocol() override; + SyncKind kind() const override { return SyncKind::Branching; }; std::optional read(ThreadContext &ctx, const std::string &var) override; From 48043491ea561de597b441a202557cca6158c320 Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Wed, 24 Dec 2025 23:19:15 +0000 Subject: [PATCH 17/58] fix a priting issue, and fix expected error codes for rejects in python test --- src/debugger.cc | 2 +- src/execution_state.cc | 16 ++++++++-------- src/execution_state.hh | 1 + test_gitmem.py | 14 +++++++++----- 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/src/debugger.cc b/src/debugger.cc index 519de08..a47de95 100644 --- a/src/debugger.cc +++ b/src/debugger.cc @@ -147,7 +147,7 @@ int interpret_interactive(const trieste::Node ast, if (command.cmd != Command::Skip || prev_no_threads != gctx.threads.size()) { bool show_all = command.cmd == Command::List; - show_global_context(gctx, show_all); + gctx.print(std::cout, show_all); } prev_no_threads = gctx.threads.size(); diff --git a/src/execution_state.cc b/src/execution_state.cc index 54293d0..59b5695 100644 --- a/src/execution_state.cc +++ b/src/execution_state.cc @@ -189,12 +189,9 @@ void show_lock(const std::string &lock_name, const struct Lock &lock) { // } } -std::ostream& operator<<(std::ostream& os, const GlobalContext& gctx) { - os << *gctx.protocol << std::endl; - - bool show_all = false; +void GlobalContext::print(std::ostream& os, bool show_all) const { + os << *protocol << std::endl; - auto &threads = gctx.threads; bool showed_any = false; for (size_t i = 0; i < threads.size(); i++) { auto thread = threads[i]; @@ -207,17 +204,20 @@ std::ostream& operator<<(std::ostream& os, const GlobalContext& gctx) { } } - if (showed_any && gctx.locks.size() > 0) { + if (showed_any && locks.size() > 0) { os << "---- Locks" << std::endl; - for (const auto &[lock_name, lock] : gctx.locks) { + for (const auto &[lock_name, lock] : locks) { show_lock(lock_name, lock); } - if (gctx.locks.size() > 0) + if (locks.size() > 0) os << "--" << std::endl; } +} +std::ostream& operator<<(std::ostream& os, const GlobalContext& gctx) { + gctx.print(os); return os; } diff --git a/src/execution_state.hh b/src/execution_state.hh index c8a7814..a00f272 100644 --- a/src/execution_state.hh +++ b/src/execution_state.hh @@ -110,6 +110,7 @@ struct GlobalContext { bool operator==(const GlobalContext &other) const; + void print(std::ostream& os, bool show_all = false) const; friend std::ostream& operator<<(std::ostream&, const GlobalContext&); void print_execution_graph(const std::filesystem::path &output_path) const; diff --git a/test_gitmem.py b/test_gitmem.py index 30532c3..5d36153 100644 --- a/test_gitmem.py +++ b/test_gitmem.py @@ -32,16 +32,20 @@ def run_gitmem_test(gitmem_path, file_path, should_accept, is_branching): capture_output=True, text=True ) - accepted = (result.returncode == 0) + if should_accept: + # Accept tests pass if exit code is 0 + accepted = (result.returncode == 0) + else: + # Reject tests pass only if exit code is 1 + accepted = (result.returncode == 1) + except FileNotFoundError: print(f"Error: '{gitmem_path}' executable not found.") sys.exit(1) - passed = (accepted == should_accept) - status = green("PASS") if passed else red("FAIL") - + status = green("PASS") if accepted else red("FAIL") print(f"[{status}] {file_path} (exit code: {result.returncode})") - return passed + return accepted def main(): parser = argparse.ArgumentParser(description="Test runner for gitmem.") From 6dc78e226d28f3a190861130a29ef3cd2a8ace97 Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Wed, 24 Dec 2025 23:39:47 +0000 Subject: [PATCH 18/58] conflict on termination now result in an error --- src/interpreter.cc | 13 ++++++++++--- src/sync_protocol.cc | 2 +- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/interpreter.cc b/src/interpreter.cc index 1daa282..e128877 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -98,7 +98,7 @@ evaluate_expression(Node expr, GlobalContext &gctx, ThreadContext &ctx) { if (std::optional> conflict = gctx.protocol->on_spawn(ctx, child_ctx, gctx)) { - assert(false); // handle this + throw std::logic_error("This code path should never be reached"); } gctx.threads.push_back( @@ -356,10 +356,17 @@ run_single_thread_to_sync(GlobalContext &gctx, const ThreadID tid, } // End should be it's own sync step + + // TODO: tidy this up if (first_statement) { - thread->terminated = TerminationStatus::completed; - gctx.protocol->on_end(thread->ctx, gctx); + if (std::optional> conflict = + gctx.protocol->on_end(ctx, gctx)) { + verbose << (**conflict) << std::endl; + thread->terminated = TerminationStatus::datarace_exception; + return TerminationStatus::datarace_exception; + } + thread->terminated = TerminationStatus::completed; thread_append_node(ctx); return TerminationStatus::completed; } diff --git a/src/sync_protocol.cc b/src/sync_protocol.cc index cb8d0d5..4287371 100644 --- a/src/sync_protocol.cc +++ b/src/sync_protocol.cc @@ -96,7 +96,7 @@ LinearSyncProtocol::on_spawn(ThreadContext &parent, ThreadContext &child, // pull into the child store = std::get(child.sync).store; if (auto conflict = pull(store)) { - std::unreachable(); + throw std::logic_error("This code path should never be reached"); } return std::nullopt; From 482cf0c856fbac55283299912bc91ef7093a01db Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Thu, 25 Dec 2025 08:27:07 +0000 Subject: [PATCH 19/58] tidying up run single thread --- src/interpreter.cc | 68 ++++++++++++++++++++++------------------------ 1 file changed, 32 insertions(+), 36 deletions(-) diff --git a/src/interpreter.cc b/src/interpreter.cc index e128877..d71d189 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -316,62 +316,58 @@ std::variant run_statement(Node stmt, std::variant run_single_thread_to_sync(GlobalContext &gctx, const ThreadID tid, std::shared_ptr thread) { - if (thread->terminated) { - return *(thread->terminated); - } + if (thread->terminated) + return *thread->terminated; + + auto& ctx = thread->ctx; + auto& pc = thread->pc; Node block = thread->block; - size_t &pc = thread->pc; - ThreadContext &ctx = thread->ctx; - // one possible interpretation of on_start is to sync when the thread - // starts execution statements - if (pc == 0) { - gctx.protocol->on_start(thread->ctx, gctx); - } + // Initial sync when thread starts executing + if (pc == 0) + gctx.protocol->on_start(ctx, gctx); + + bool made_progress = false; - bool first_statement = true; while (pc < block->size()) { Node stmt = block->at(pc); - if (!first_statement && is_syncing(stmt)) { + // Stop *before* executing a sync statement (except first) + if (made_progress && is_syncing(stmt)) return ProgressStatus::progress; - } - auto delta_or_term = run_statement(stmt, gctx, ctx, tid); - if (auto term = std::get_if(&delta_or_term)) { + auto result = run_statement(stmt, gctx, ctx, tid); + + if (auto term = std::get_if(&result)) { thread->terminated = *term; - // thread_append_node(ctx); return *term; } - auto delta = std::get(delta_or_term); + int delta = std::get(result); - if (delta == 0) { - return first_statement ? ProgressStatus::no_progress - : ProgressStatus::progress; - } + // Blocked (e.g. waiting on lock/join) + if (delta == 0) + return made_progress ? ProgressStatus::progress + : ProgressStatus::no_progress; pc += delta; - first_statement = false; + made_progress = true; } - // End should be it's own sync step + // If we ran *any* statements, finishing is a sync point for next iteration + if (made_progress) + return ProgressStatus::progress; - // TODO: tidy this up - if (first_statement) { - if (std::optional> conflict = - gctx.protocol->on_end(ctx, gctx)) { - verbose << (**conflict) << std::endl; - thread->terminated = TerminationStatus::datarace_exception; - return TerminationStatus::datarace_exception; - } - - thread->terminated = TerminationStatus::completed; - thread_append_node(ctx); - return TerminationStatus::completed; + // Otherwise, we truly reached the end this iteration + if (auto conflict = gctx.protocol->on_end(ctx, gctx)) { + verbose << (**conflict) << std::endl; + thread->terminated = TerminationStatus::datarace_exception; + return TerminationStatus::datarace_exception; } - return ProgressStatus::progress; + thread->terminated = TerminationStatus::completed; + thread_append_node(ctx); + return TerminationStatus::completed; } /** From 94f10db928b7a028374aed874717fce936b345fc Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Fri, 2 Jan 2026 13:25:00 +0100 Subject: [PATCH 20/58] Building a event trace infrastructure --- src/debugger.cc | 23 +++++++----- src/execution_state.cc | 50 +++++++++++++------------- src/execution_state.hh | 27 +++++--------- src/interpreter.cc | 61 +++++++++---------------------- src/interpreter.hh | 19 ++-------- src/model_checker.cc | 4 +-- src/progress_status.hh | 19 ++++++++++ src/sync_kind.hh | 10 ++++++ src/termination_status.hh | 9 +++++ src/thread_id.hh | 7 ++++ src/thread_trace.hh | 76 +++++++++++++++++++++++++++++++++++++++ 11 files changed, 191 insertions(+), 114 deletions(-) create mode 100644 src/progress_status.hh create mode 100644 src/sync_kind.hh create mode 100644 src/termination_status.hh create mode 100644 src/thread_id.hh create mode 100644 src/thread_trace.hh diff --git a/src/debugger.cc b/src/debugger.cc index a47de95..c0e8295 100644 --- a/src/debugger.cc +++ b/src/debugger.cc @@ -142,7 +142,7 @@ int interpret_interactive(const trieste::Node ast, Command command = {Command::List}; std::string msg = ""; bool print_graphs = true; - gctx.print_execution_graph(output_file); + // gctx.print_execution_graph(output_file); while (command.cmd != Command::Quit) { if (command.cmd != Command::Skip || prev_no_threads != gctx.threads.size()) { @@ -170,18 +170,21 @@ int interpret_interactive(const trieste::Node ast, command = {Command::Skip}; if (print_graphs) { - gctx.print_execution_graph(output_file); + // gctx.print_execution_graph(output_file); + assert(false && "todo"); verbose << "Execution graph written to " << output_file << std::endl; } } else if (command.cmd == Command::Finish) { // Finish the program - if (!run_threads(gctx)) - msg = "Program finished successfully"; - else - msg = "Program terminated with an error"; + assert(false && "fixme"); + // if (!run_threads(gctx)) + // msg = "Program finished successfully"; + // else + // msg = "Program terminated with an error"; if (print_graphs) { - gctx.print_execution_graph(output_file); + // gctx.print_execution_graph(output_file); + assert(false && "todo"); verbose << "Execution graph written to " << output_file << std::endl; } } else if (command.cmd == Command::Restart) { @@ -189,7 +192,8 @@ int interpret_interactive(const trieste::Node ast, gctx = GlobalContext(ast, make_protocol(sync_kind)); command = {Command::List}; if (print_graphs) { - gctx.print_execution_graph(output_file); + // gctx.print_execution_graph(output_file); + assert(false && "todo"); verbose << "Execution graph written to " << output_file << std::endl; } } else if (command.cmd == Command::List) { @@ -202,7 +206,8 @@ int interpret_interactive(const trieste::Node ast, command = {Command::Skip}; } else if (command.cmd == Command::Print) { // Print the execution graph - gctx.print_execution_graph(output_file); + // gctx.print_execution_graph(output_file); + assert(false && "todo"); verbose << "Execution graph written to " << output_file << std::endl; command = {Command::Skip}; } else if (command.cmd == Command::Skip) { diff --git a/src/execution_state.cc b/src/execution_state.cc index 59b5695..1e3224e 100644 --- a/src/execution_state.cc +++ b/src/execution_state.cc @@ -5,7 +5,7 @@ namespace gitmem { -ThreadContext::ThreadContext(std::shared_ptr tail, SyncKind sync_kind): tail(tail) { +ThreadContext::ThreadContext(SyncKind sync_kind) { switch (sync_kind) { case SyncKind::Linear: sync.emplace(); @@ -65,8 +65,10 @@ GlobalContext::GlobalContext(const trieste::Node &ast, std::unique_ptr protocol) : protocol(std::move(protocol)) { trieste::Node starting_block = ast / lang::File / lang::Block; - ThreadContext starting_ctx(std::make_shared(0), this->protocol->kind()); - auto main_thread = std::make_shared(std::move(starting_ctx), starting_block); + ThreadContext starting_ctx(this->protocol->kind()); + + ThreadID main_tid = 0; + auto main_thread = std::make_shared(main_tid, std::move(starting_ctx), starting_block); this->threads = {main_thread}; this->locks = {}; @@ -75,27 +77,27 @@ GlobalContext::GlobalContext(const trieste::Node &ast, GlobalContext::~GlobalContext() = default; -void GlobalContext::print_execution_graph( - const std::filesystem::path &output_path) const { - return; // FIXME - // Loop over the threads and add pending nodes to running threads - // to indicate a threads next step - for (const auto &t : threads) { - assert(t->ctx.tail); - if (t->terminated || - dynamic_pointer_cast(t->ctx.tail->next)) - continue; - - trieste::Node block = t->block; - size_t &pc = t->pc; - trieste::Node stmt = block->at(pc); - thread_append_node(t->ctx, - std::string(stmt->location().view())); - } - - graph::GraphvizPrinter gv(output_path); - gv.visit(entry_node.get()); -} +// void GlobalContext::print_execution_graph( +// const std::filesystem::path &output_path) const { +// return; // FIXME +// // Loop over the threads and add pending nodes to running threads +// // to indicate a threads next step +// for (const auto &t : threads) { +// assert(t->ctx.tail); +// if (t->terminated || +// dynamic_pointer_cast(t->ctx.tail->next)) +// continue; + +// trieste::Node block = t->block; +// size_t &pc = t->pc; +// trieste::Node stmt = block->at(pc); +// thread_append_node(t->ctx, +// std::string(stmt->location().view())); +// } + +// graph::GraphvizPrinter gv(output_path); +// gv.visit(entry_node.get()); +// } bool GlobalContext::operator==(const GlobalContext &other) const { if (threads.size() != other.threads.size() || diff --git a/src/execution_state.hh b/src/execution_state.hh index a00f272..3b26fa5 100644 --- a/src/execution_state.hh +++ b/src/execution_state.hh @@ -10,22 +10,16 @@ #include "linear/version_store.hh" #include "branching/version_store.hh" #include "graphviz.hh" +#include "termination_status.hh" +#include "thread_trace.hh" +#include "thread_id.hh" namespace gitmem { class SyncProtocol; -enum class TerminationStatus { - completed, - datarace_exception, - unlock_exception, - assertion_failure_exception, - unassigned_variable_read_exception, -}; - struct ThreadContext { std::unordered_map locals; - std::shared_ptr tail; struct LinearData { linear::LocalVersionStore store; @@ -42,7 +36,7 @@ struct ThreadContext { ThreadContext(ThreadContext&&) = default; ThreadContext& operator=(ThreadContext&&) = default; - ThreadContext(std::shared_ptr tail, SyncKind sync_kind); + ThreadContext(SyncKind sync_kind); bool operator==(const ThreadContext &other) const; @@ -51,12 +45,13 @@ struct ThreadContext { struct Thread { ThreadContext ctx; + ThreadTrace trace; trieste::Node block; size_t pc = 0; std::optional terminated = std::nullopt; - Thread(ThreadContext&& ctx, trieste::Node block): - ctx(std::move(ctx)), block(block) {}; + Thread(ThreadID tid, ThreadContext&& ctx, trieste::Node block): + ctx(std::move(ctx)), trace(tid), block(block) {}; Thread(const Thread&) = delete; Thread& operator=(const Thread&) = delete; @@ -69,12 +64,8 @@ struct Thread { friend std::ostream& operator<<(std::ostream&, const Thread&); }; -using ThreadID = size_t; - struct Lock { - // Globals globals; std::optional owner = std::nullopt; - std::shared_ptr last; }; template @@ -93,7 +84,7 @@ struct GlobalContext { lang::NodeMap cache; // Graph root - std::shared_ptr entry_node; + // std::shared_ptr entry_node; // Synchronisation semantics (policy) std::unique_ptr protocol; @@ -113,7 +104,7 @@ struct GlobalContext { void print(std::ostream& os, bool show_all = false) const; friend std::ostream& operator<<(std::ostream&, const GlobalContext&); - void print_execution_graph(const std::filesystem::path &output_path) const; + // void print_execution_graph(const std::filesystem::path &output_path) const; }; } // namespace gitmem \ No newline at end of file diff --git a/src/interpreter.cc b/src/interpreter.cc index d71d189..dcf7f35 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -39,31 +39,11 @@ bool is_syncing(Thread &thread) { ((thread.pc >= thread.block->size()) || is_syncing(thread.block->at(thread.pc))); } -template -std::shared_ptr thread_append_node(ThreadContext &ctx, Args &&...args) { - assert(ctx.tail); - auto node = std::make_shared(std::forward(args)...); - ctx.tail->next = node; - ctx.tail = node; - return node; -} - -template <> -std::shared_ptr -thread_append_node(ThreadContext &ctx, std::string &&stmt) { - // pending nodes don't update the tail position as we will destroy them - // once we execute the node - auto s = std::regex_replace(stmt, std::regex("\n"), "\\l "); - auto node = make_shared(std::move(s)); - ctx.tail->next = node; - return node; -} - /* Evaluating an expression either returns the result of the expression or * a the exceptional termination status of the thread. */ std::variant -evaluate_expression(Node expr, GlobalContext &gctx, ThreadContext &ctx) { +evaluate_expression(trieste::Node expr, GlobalContext &gctx, ThreadContext &ctx) { auto e = expr / lang::Expr; if (e == lang::Reg) { // It is invalid to read a previously unwritten value @@ -93,8 +73,7 @@ evaluate_expression(Node expr, GlobalContext &gctx, ThreadContext &ctx) { return sum; } else if (e == lang::Spawn) { ThreadID tid = gctx.threads.size(); - auto node = std::make_shared(tid); - ThreadContext child_ctx(node, gctx.protocol->kind()); + ThreadContext child_ctx(gctx.protocol->kind()); if (std::optional> conflict = gctx.protocol->on_spawn(ctx, child_ctx, gctx)) { @@ -102,8 +81,8 @@ evaluate_expression(Node expr, GlobalContext &gctx, ThreadContext &ctx) { } gctx.threads.push_back( - std::make_shared(std::move(child_ctx), e / lang::Block)); - thread_append_node(ctx, tid, node); + std::make_shared(tid, std::move(child_ctx), e / lang::Block)); + // thread_append_node(ctx, tid, node); return tid; } else if (e == lang::Eq || e == lang::Neq) { @@ -220,7 +199,7 @@ std::variant run_statement(Node stmt, verbose << (**conflict) << std::endl; return TerminationStatus::datarace_exception; } else { - thread_append_node(ctx, result, joinee->ctx.tail); + // thread_append_node(ctx, result, joinee->ctx.tail); } } else { @@ -254,7 +233,7 @@ std::variant run_statement(Node stmt, return TerminationStatus::datarace_exception; } - thread_append_node(ctx, var, lock.last); + // thread_append_node(ctx, var, lock.last); verbose << "Locked " << var << std::endl; } else if (s == lang::Unlock) { @@ -280,8 +259,8 @@ std::variant run_statement(Node stmt, // lock.globals = ctx.globals; lock.owner.reset(); - thread_append_node(ctx, var); - lock.last = ctx.tail; + // thread_append_node(ctx, var); + // lock.last = ctx.tail; verbose << "Unlocked " << var << std::endl; @@ -294,8 +273,8 @@ std::variant run_statement(Node stmt, verbose << "Assertion passed: " << expr->location().view() << std::endl; } else { verbose << "Assertion failed: " << expr->location().view() << std::endl; - thread_append_node( - ctx, std::string(expr->location().view())); + // thread_append_node( + // ctx, std::string(expr->location().view())); return TerminationStatus::assertion_failure_exception; } } else { @@ -366,7 +345,7 @@ run_single_thread_to_sync(GlobalContext &gctx, const ThreadID tid, } thread->terminated = TerminationStatus::completed; - thread_append_node(ctx); + // thread_append_node(ctx); return TerminationStatus::completed; } @@ -434,21 +413,18 @@ run_threads_to_sync(GlobalContext &gctx) { return any_progress; } -bool is_finished( - std::variant &prog_or_term) { +static bool is_finished( std::variant &prog_or_term) { // Either, the system is stuck and made no progress in which case there // is a deadlock (or a thread is stuck waiting for a crashed thread?) - if (ProgressStatus *prog = std::get_if(&prog_or_term)) - return (*prog) == ProgressStatus::no_progress; - // Or, there was some termination criteria in which case we stop - return true; + return std::holds_alternative(prog_or_term) || + std::get(prog_or_term) == ProgressStatus::no_progress; } /* Try to evaluate all threads until they have all terminated in some way * or we have reached a stuck configuration. */ -int run_threads(GlobalContext &gctx) { +int run(GlobalContext gctx) { std::variant prog_or_term; do { prog_or_term = run_threads_to_sync(gctx); @@ -494,7 +470,7 @@ int run_threads(GlobalContext &gctx) { } } else { exception_detected = true; - thread_append_node(thread->ctx); + // thread_append_node(thread->ctx); verbose << "Thread " << i << " is stuck" << std::endl; } } @@ -504,10 +480,7 @@ int run_threads(GlobalContext &gctx) { int interpret(const Node ast, const std::filesystem::path &output_path, SyncKind sync_kind) { GlobalContext gctx(ast, make_protocol(sync_kind)); - auto result = run_threads(gctx); - // gctx.print_execution_graph(output_path); FIXME - - return result; + return run(std::move(gctx)); } } // namespace gitmem \ No newline at end of file diff --git a/src/interpreter.hh b/src/interpreter.hh index 4a15326..fb26045 100644 --- a/src/interpreter.hh +++ b/src/interpreter.hh @@ -4,31 +4,16 @@ #include "graph.hh" #include "graphviz.hh" #include "lang.hh" +#include "progress_status.hh" #include #include "sync_protocol.hh" +#include "termination_status.hh" namespace gitmem { // Entry function int interpret(const trieste::Node, const std::filesystem::path &output_file, SyncKind sync_kind); -// Internal functions -int run_threads(GlobalContext &); - -enum class ProgressStatus { progress, no_progress }; -inline bool operator!(ProgressStatus p) { - return p == ProgressStatus::no_progress; -} -inline ProgressStatus operator||(const ProgressStatus &p1, - const ProgressStatus &p2) { - return (p1 == ProgressStatus::progress || p2 == ProgressStatus::progress) - ? ProgressStatus::progress - : ProgressStatus::no_progress; -} -inline void operator|=(ProgressStatus &p1, const ProgressStatus &p2) { - p1 = (p1 || p2); -} - std::variant progress_thread(GlobalContext &, const ThreadID, std::shared_ptr); diff --git a/src/model_checker.cc b/src/model_checker.cc index 8c1fc7e..c5cb844 100644 --- a/src/model_checker.cc +++ b/src/model_checker.cc @@ -182,7 +182,7 @@ int model_check(const Node ast, const std::filesystem::path &output_path, SyncKi for (const auto &ctx : failing_contexts) { auto path = build_output_path(output_path, idx++); - ctx->print_execution_graph(path); + // ctx->print_execution_graph(path); } } @@ -193,7 +193,7 @@ int model_check(const Node ast, const std::filesystem::path &output_path, SyncKi for (const auto &ctx : deadlocked_contexts) { auto path = build_output_path(output_path, idx++); - ctx->print_execution_graph(path); + // ctx->print_execution_graph(path); } } diff --git a/src/progress_status.hh b/src/progress_status.hh new file mode 100644 index 0000000..4da6cba --- /dev/null +++ b/src/progress_status.hh @@ -0,0 +1,19 @@ +#pragma once + +namespace gitmem { + +enum class ProgressStatus { progress, no_progress }; +inline bool operator!(ProgressStatus p) { + return p == ProgressStatus::no_progress; +} +inline ProgressStatus operator||(const ProgressStatus &p1, + const ProgressStatus &p2) { + return (p1 == ProgressStatus::progress || p2 == ProgressStatus::progress) + ? ProgressStatus::progress + : ProgressStatus::no_progress; +} +inline void operator|=(ProgressStatus &p1, const ProgressStatus &p2) { + p1 = (p1 || p2); +} + +} \ No newline at end of file diff --git a/src/sync_kind.hh b/src/sync_kind.hh new file mode 100644 index 0000000..303df8e --- /dev/null +++ b/src/sync_kind.hh @@ -0,0 +1,10 @@ +#pragma once + +namespace gitmem { + +enum class SyncKind { + Linear, + Branching +}; + +} \ No newline at end of file diff --git a/src/termination_status.hh b/src/termination_status.hh new file mode 100644 index 0000000..5417b74 --- /dev/null +++ b/src/termination_status.hh @@ -0,0 +1,9 @@ +#pragma once + +enum class TerminationStatus { + completed, + datarace_exception, + unlock_exception, + assertion_failure_exception, + unassigned_variable_read_exception, +}; diff --git a/src/thread_id.hh b/src/thread_id.hh new file mode 100644 index 0000000..637ef56 --- /dev/null +++ b/src/thread_id.hh @@ -0,0 +1,7 @@ +#pragma once + +#include + +namespace gitmem { + using ThreadID = std::size_t; +} \ No newline at end of file diff --git a/src/thread_trace.hh b/src/thread_trace.hh new file mode 100644 index 0000000..6866d77 --- /dev/null +++ b/src/thread_trace.hh @@ -0,0 +1,76 @@ +#pragma once + +#include "thread_id.hh" +#include "graph.hh" + +namespace gitmem { + +struct ThreadTrace { + ThreadID tid; + std::shared_ptr head; + std::shared_ptr tail; + +private: + template + void append(Args&&... args) { + assert(tail); + auto node = std::make_shared(std::forward(args)...); + tail->next = node; + tail = node; + } + +public: + explicit ThreadTrace(ThreadID tid): tid(tid), head(nullptr), tail(nullptr) {} + + void on_start(ThreadID tid) { + assert(head == tail && head == nullptr); + head = std::make_shared(tid); + tail = head; + } + + void on_stmt(std::string text) { + append(std::move(text)); + } + + void on_lock(std::string lock, std::shared_ptr last) { + append(std::move(lock), last); + } + + void on_unlock(std::string lock) { + append(std::move(lock)); + } + + void on_join(ThreadID tid, std::shared_ptr target) { + append(tid, target); + } + + void on_assert_fail(std::string expr) { + append(std::move(expr)); + } + + void on_end() { + append(); + } +}; + +} // namespace gitmem + +// template +// std::shared_ptr thread_append_node(ThreadContext &ctx, Args &&...args) { +// assert(ctx.tail); +// auto node = std::make_shared(std::forward(args)...); +// ctx.tail->next = node; +// ctx.tail = node; +// return node; +// } + +// template <> +// std::shared_ptr +// thread_append_node(ThreadContext &ctx, std::string &&stmt) { +// // pending nodes don't update the tail position as we will destroy them +// // once we execute the node +// auto s = std::regex_replace(stmt, std::regex("\n"), "\\l "); +// auto node = make_shared(std::move(s)); +// ctx.tail->next = node; +// return node; +// } \ No newline at end of file From da5cf83c105bfcb2d7046f40fb3d5951d52f022c Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Fri, 2 Jan 2026 14:47:49 +0100 Subject: [PATCH 21/58] sorting out the interpreter struct to be better encapsulated and consitent --- CMakeLists.txt | 2 +- src/execution_state.hh | 5 +++- src/gitmem.cc | 3 ++- src/interpreter.cc | 56 +++++++++++++++++++++--------------------- src/interpreter.hh | 22 +++++++++++++++++ src/model_checker.cc | 18 ++++++++------ test_gitmem.py | 26 ++++++++++++++++---- 7 files changed, 89 insertions(+), 43 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6e8f9d9..6e39979 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -27,7 +27,7 @@ add_executable(gitmem src/passes/branching.cc src/linear/version_store.cc src/interpreter.cc - src/debugger.cc + # src/debugger.cc src/model_checker.cc src/sync_protocol.cc src/graphviz.cc diff --git a/src/execution_state.hh b/src/execution_state.hh index 3b26fa5..655acf0 100644 --- a/src/execution_state.hh +++ b/src/execution_state.hh @@ -44,6 +44,7 @@ struct ThreadContext { }; struct Thread { + ThreadID tid; ThreadContext ctx; ThreadTrace trace; trieste::Node block; @@ -51,7 +52,7 @@ struct Thread { std::optional terminated = std::nullopt; Thread(ThreadID tid, ThreadContext&& ctx, trieste::Node block): - ctx(std::move(ctx)), trace(tid), block(block) {}; + tid(tid), ctx(std::move(ctx)), trace(tid), block(block) {}; Thread(const Thread&) = delete; Thread& operator=(const Thread&) = delete; @@ -93,6 +94,8 @@ struct GlobalContext { std::unique_ptr protocol); ~GlobalContext(); + GlobalContext clone() const; + GlobalContext(GlobalContext&&) = default; GlobalContext& operator=(GlobalContext&&) = default; diff --git a/src/gitmem.cc b/src/gitmem.cc index 6fda58a..ad40077 100644 --- a/src/gitmem.cc +++ b/src/gitmem.cc @@ -72,7 +72,8 @@ int main(int argc, char **argv) { if (model_check) { exit_status = gitmem::model_check(result.ast, output_path, sync_kind); } else if (interactive) { - exit_status = gitmem::interpret_interactive(result.ast, output_path, sync_kind); + assert(false && "fixme"); + // exit_status = gitmem::interpret_interactive(result.ast, output_path, sync_kind); } else { exit_status = gitmem::interpret(result.ast, output_path, sync_kind); } diff --git a/src/interpreter.cc b/src/interpreter.cc index dcf7f35..bc9ecef 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -25,12 +25,12 @@ using namespace trieste; * - t unlocking a lock l, which updates l to have t's versioned memory */ -bool is_syncing(Node stmt) { +static bool is_syncing(Node stmt) { auto s = stmt / lang::Stmt; return s == lang::Join || s == lang::Lock || s == lang::Unlock; } -bool is_syncing(Thread &thread) { +static bool is_syncing(Thread &thread) { // Can only be true if a thread hasn't terminated // Either it has executed all statements but not yet terminated (and my sync) // Or it is at a synchronisation node @@ -43,7 +43,9 @@ bool is_syncing(Thread &thread) { * a the exceptional termination status of the thread. */ std::variant -evaluate_expression(trieste::Node expr, GlobalContext &gctx, ThreadContext &ctx) { +Interpreter::evaluate_expression(trieste::Node expr, std::shared_ptr thread) { + ThreadContext& ctx = thread->ctx; + auto e = expr / lang::Expr; if (e == lang::Reg) { // It is invalid to read a previously unwritten value @@ -65,7 +67,7 @@ evaluate_expression(trieste::Node expr, GlobalContext &gctx, ThreadContext &ctx) } else if (e == lang::Add) { size_t sum = 0; for (auto &child : *e) { - auto result = evaluate_expression(child, gctx, ctx); + auto result = evaluate_expression(child, thread); if (std::holds_alternative(result)) return result; sum += std::get(result); @@ -89,11 +91,11 @@ evaluate_expression(trieste::Node expr, GlobalContext &gctx, ThreadContext &ctx) auto lhs = e / lang::Lhs; auto rhs = e / lang::Rhs; - auto lhsEval = evaluate_expression(lhs, gctx, ctx); + auto lhsEval = evaluate_expression(lhs, thread); if (std::holds_alternative(lhsEval)) return lhsEval; - auto rhsEval = evaluate_expression(rhs, gctx, ctx); + auto rhsEval = evaluate_expression(rhs, thread); if (std::holds_alternative(rhsEval)) return rhsEval; @@ -110,10 +112,9 @@ evaluate_expression(trieste::Node expr, GlobalContext &gctx, ThreadContext &ctx) * counter (0 if waiting for some other thread) or the exceptional * termination status of the thread. */ -std::variant run_statement(Node stmt, - GlobalContext &gctx, - ThreadContext &ctx, - const ThreadID &tid) { +std::variant Interpreter::run_statement(Node stmt, std::shared_ptr thread) { + ThreadContext& ctx = thread->ctx; + auto s = stmt / lang::Stmt; if (s == lang::Nop) { @@ -130,7 +131,7 @@ std::variant run_statement(Node stmt, auto expr = s / lang::Expr; auto cnst = s / lang::Const; - auto result = evaluate_expression(expr, gctx, ctx); + auto result = evaluate_expression(expr, thread); if (auto b = std::get_if(&result)) { auto delta = std::stoi(std::string(cnst->location().view())); @@ -145,7 +146,7 @@ std::variant run_statement(Node stmt, auto lhs = s / lang::LVal; auto var = std::string(lhs->location().view()); auto rhs = s / lang::Expr; - auto val_or_term = evaluate_expression(rhs, gctx, ctx); + auto val_or_term = evaluate_expression(rhs, thread); if (size_t *val = std::get_if(&val_or_term)) { if (lhs == lang::Reg) { @@ -183,7 +184,7 @@ std::variant run_statement(Node stmt, auto expr = s / lang::Expr; if (!gctx.cache.contains(expr)) { - auto val_or_term = evaluate_expression(expr, gctx, ctx); + auto val_or_term = evaluate_expression(expr, thread); if (size_t *val = std::get_if(&val_or_term)) { gctx.cache[expr] = *val; } else { @@ -220,7 +221,7 @@ std::variant run_statement(Node stmt, return 0; } - lock.owner = tid; + lock.owner = thread->tid; if (auto conflict = gctx.protocol->on_lock(ctx, lock, gctx)) { verbose << (**conflict) << std::endl; // using graph::Node; @@ -247,7 +248,7 @@ std::variant run_statement(Node stmt, auto var = std::string(v->location().view()); auto &lock = gctx.locks[var]; - if (!lock.owner || (lock.owner && *lock.owner != tid)) { + if (!lock.owner || (lock.owner && *lock.owner != thread->tid)) { return TerminationStatus::unlock_exception; } @@ -267,7 +268,7 @@ std::variant run_statement(Node stmt, } else if (s == lang::Assert) { auto expr = s / lang::Expr; - auto result_or_term = evaluate_expression(expr, gctx, ctx); + auto result_or_term = evaluate_expression(expr, thread); if (size_t *result = std::get_if(&result_or_term)) { if (*result) { verbose << "Assertion passed: " << expr->location().view() << std::endl; @@ -293,8 +294,7 @@ std::variant run_statement(Node stmt, * whether it terminated. */ std::variant -run_single_thread_to_sync(GlobalContext &gctx, const ThreadID tid, - std::shared_ptr thread) { +Interpreter::run_single_thread_to_sync(std::shared_ptr thread) { if (thread->terminated) return *thread->terminated; @@ -315,7 +315,7 @@ run_single_thread_to_sync(GlobalContext &gctx, const ThreadID tid, if (made_progress && is_syncing(stmt)) return ProgressStatus::progress; - auto result = run_statement(stmt, gctx, ctx, tid); + auto result = run_statement(stmt, thread); if (auto term = std::get_if(&result)) { thread->terminated = *term; @@ -354,10 +354,9 @@ run_single_thread_to_sync(GlobalContext &gctx, const ThreadID tid, * thread */ std::variant -progress_thread(GlobalContext &gctx, const ThreadID tid, - std::shared_ptr thread) { +Interpreter::progress_thread(std::shared_ptr thread) { auto no_threads = gctx.threads.size(); - auto prog_or_term = run_single_thread_to_sync(gctx, tid, thread); + auto prog_or_term = run_single_thread_to_sync(thread); bool any_progress = std::holds_alternative(prog_or_term) && @@ -369,7 +368,7 @@ progress_thread(GlobalContext &gctx, const ThreadID tid, auto new_thread = gctx.threads[i]; if (!is_syncing(*new_thread)) { verbose << "==== Thread " << i << " (spawn) ====" << std::endl; - progress_thread(gctx, i, new_thread); + progress_thread(new_thread); } } @@ -382,7 +381,7 @@ progress_thread(GlobalContext &gctx, const ThreadID tid, /* Try to evaluate all threads until a sync point or termination point */ std::variant -run_threads_to_sync(GlobalContext &gctx) { +Interpreter::run_threads_to_sync() { verbose << "-----------------------" << std::endl; bool all_completed = true; ProgressStatus any_progress = ProgressStatus::no_progress; @@ -390,7 +389,7 @@ run_threads_to_sync(GlobalContext &gctx) { verbose << "==== t" << i << " ====" << std::endl; auto thread = gctx.threads[i]; if (!thread->terminated) { - auto prog_or_term = run_single_thread_to_sync(gctx, i, thread); + auto prog_or_term = run_single_thread_to_sync(thread); if (ProgressStatus *prog = std::get_if(&prog_or_term)) { any_progress |= *prog; } else { @@ -424,10 +423,10 @@ static bool is_finished( std::variant &prog_o /* Try to evaluate all threads until they have all terminated in some way * or we have reached a stuck configuration. */ -int run(GlobalContext gctx) { +int Interpreter::run() { std::variant prog_or_term; do { - prog_or_term = run_threads_to_sync(gctx); + prog_or_term = run_threads_to_sync(); } while (!is_finished(prog_or_term)); verbose << "----------- execution complete -----------" << std::endl; @@ -480,7 +479,8 @@ int run(GlobalContext gctx) { int interpret(const Node ast, const std::filesystem::path &output_path, SyncKind sync_kind) { GlobalContext gctx(ast, make_protocol(sync_kind)); - return run(std::move(gctx)); + Interpreter interp(std::move(gctx)); + return interp.run(); } } // namespace gitmem \ No newline at end of file diff --git a/src/interpreter.hh b/src/interpreter.hh index fb26045..144002e 100644 --- a/src/interpreter.hh +++ b/src/interpreter.hh @@ -8,9 +8,31 @@ #include #include "sync_protocol.hh" #include "termination_status.hh" +#include "step_result.hh" namespace gitmem { +class Interpreter { +private: + GlobalContext gctx; + +public: + Interpreter(GlobalContext gctx): gctx(std::move(gctx)) {} + + GlobalContext& context() { return gctx; } + + // Internal functions + int run(); + + std::variant evaluate_expression(trieste::Node, std::shared_ptr); + std::variant run_statement(trieste::Node, std::shared_ptr); + + std::variant progress_thread(std::shared_ptr); + std::variant run_single_thread_to_sync(std::shared_ptr); + std::variant run_threads_to_sync(); + +}; + // Entry function int interpret(const trieste::Node, const std::filesystem::path &output_file, SyncKind sync_kind); diff --git a/src/model_checker.cc b/src/model_checker.cc index c5cb844..1c036d0 100644 --- a/src/model_checker.cc +++ b/src/model_checker.cc @@ -59,8 +59,6 @@ build_output_path(const std::filesystem::path &output_path, const size_t idx) { * for each distinct final state that led to an error. */ int model_check(const Node ast, const std::filesystem::path &output_path, SyncKind sync_kind) { - GlobalContext gctx(ast, make_protocol(sync_kind)); - auto final_contexts = std::vector>{}; auto failing_contexts = std::vector>{}; auto deadlocked_contexts = std::vector>{}; @@ -73,7 +71,11 @@ int model_check(const Node ast, const std::filesystem::path &output_path, SyncKi auto cursor = root; auto current_trace = std::vector{0}; // Start with the main thread verbose << "==== Thread " << cursor->tid_ << " ====" << std::endl; - progress_thread(gctx, cursor->tid_, gctx.threads[cursor->tid_]); + + Interpreter interp(GlobalContext(ast, make_protocol(sync_kind))); + + GlobalContext& gctx = interp.context(); + interp.progress_thread(gctx.threads[cursor->tid_]); while (!root->complete) { while (!cursor->children.empty() && !cursor->children.back()->complete) { @@ -82,7 +84,7 @@ int model_check(const Node ast, const std::filesystem::path &output_path, SyncKi current_trace.push_back(cursor->tid_); verbose << "==== Thread " << cursor->tid_ << " (replay) ====" << std::endl; - progress_thread(gctx, cursor->tid_, gctx.threads[cursor->tid_]); + interp.progress_thread(gctx.threads[cursor->tid_]); } // Try to find a thread to schedule next @@ -95,7 +97,7 @@ int model_check(const Node ast, const std::filesystem::path &output_path, SyncKi if (!thread->terminated) { // Run the thread to the next sync point verbose << "==== Thread " << i << " ====" << std::endl; - auto prog_or_term = progress_thread(gctx, i, thread); + auto prog_or_term = interp.progress_thread(thread); if (std::holds_alternative(prog_or_term)) { // Thread terminated, we can extend the trace made_progress = true; @@ -141,6 +143,7 @@ int model_check(const Node ast, const std::filesystem::path &output_path, SyncKi if (!std::any_of( final_contexts.begin(), final_contexts.end(), [&gctx](const std::shared_ptr &state) { return *state == gctx; })) { + // Here we take ownership of the global context std::shared_ptr gctxp = std::make_shared(std::move(gctx)); final_contexts.push_back(gctxp); final_traces.push_back(current_trace); @@ -159,14 +162,15 @@ int model_check(const Node ast, const std::filesystem::path &output_path, SyncKi if (cursor->complete && !root->complete) { // Reset the cursor to the root and start a new trace verbose << std::endl << "Restarting trace..." << std::endl; - gctx = GlobalContext(ast, make_protocol(sync_kind)); + interp = Interpreter(GlobalContext(ast, make_protocol(sync_kind))); + GlobalContext& gctx = interp.context(); cursor = root; current_trace.clear(); current_trace.push_back(0); // Start with the main thread again verbose << "==== Thread " << cursor->tid_ << " (replay) ====" << std::endl; - progress_thread(gctx, cursor->tid_, gctx.threads[cursor->tid_]); + interp.progress_thread(gctx.threads[cursor->tid_]); } } diff --git a/test_gitmem.py b/test_gitmem.py index 5d36153..ee6556c 100644 --- a/test_gitmem.py +++ b/test_gitmem.py @@ -33,12 +33,9 @@ def run_gitmem_test(gitmem_path, file_path, should_accept, is_branching): text=True ) if should_accept: - # Accept tests pass if exit code is 0 accepted = (result.returncode == 0) else: - # Reject tests pass only if exit code is 1 accepted = (result.returncode == 1) - except FileNotFoundError: print(f"Error: '{gitmem_path}' executable not found.") sys.exit(1) @@ -54,10 +51,25 @@ def main(): required=True, help="Path to the gitmem executable" ) + parser.add_argument( + "--linear", + action="store_true", + help="Only run linear tests" + ) + parser.add_argument( + "--branching", + action="store_true", + help="Only run branching tests" + ) args = parser.parse_args() gitmem_path = args.gitmem + run_linear = args.linear + run_branching = args.branching + + # If neither flag is specified, run both + if not run_linear and not run_branching: + run_linear = run_branching = True - # results[expectation][category][subcategory] = {"total": x, "failed": y} results = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: { "total": 0, "failed": 0 @@ -76,7 +88,11 @@ def main(): continue if category == "semantics": - subcategories = ["branching", "linear"] + subcategories = [] + if run_branching: + subcategories.append("branching") + if run_linear: + subcategories.append("linear") else: subcategories = [None] From a1f0ef13122f7c75a463c9261ff779f44f4a5dee Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Fri, 2 Jan 2026 15:12:06 +0100 Subject: [PATCH 22/58] fix bug in accessing illegal threads, and switch threads vector to deque, and use refs instead of shared ptrs --- src/debugger.cc | 4 +-- src/execution_state.cc | 15 +++++----- src/execution_state.hh | 2 +- src/interpreter.cc | 64 +++++++++++++++++++++++------------------- src/interpreter.hh | 11 +++----- src/model_checker.cc | 12 ++++---- 6 files changed, 55 insertions(+), 53 deletions(-) diff --git a/src/debugger.cc b/src/debugger.cc index c0e8295..111b940 100644 --- a/src/debugger.cc +++ b/src/debugger.cc @@ -79,8 +79,8 @@ bool step_thread(ThreadID tid, GlobalContext &gctx, std::string &msg) { return false; } - auto thread = gctx.threads[tid]; - if (auto term = thread->terminated) { + auto& thread = gctx.threads[tid]; + if (auto term = thread.terminated) { if (*term == TerminationStatus::completed) { msg = "Thread " + std::to_string(tid) + " has terminated normally"; } else { diff --git a/src/execution_state.cc b/src/execution_state.cc index 1e3224e..36e34b6 100644 --- a/src/execution_state.cc +++ b/src/execution_state.cc @@ -68,9 +68,8 @@ GlobalContext::GlobalContext(const trieste::Node &ast, ThreadContext starting_ctx(this->protocol->kind()); ThreadID main_tid = 0; - auto main_thread = std::make_shared(main_tid, std::move(starting_ctx), starting_block); - this->threads = {main_thread}; + this->threads.emplace_back(main_tid, std::move(starting_ctx), starting_block); this->locks = {}; this->cache = {}; } @@ -109,8 +108,8 @@ bool GlobalContext::operator==(const GlobalContext &other) const { for (auto &thread : threads) { auto it = std::find_if(other.threads.begin(), other.threads.end(), [&thread](auto &t) - { return t->block == thread->block; }); - if (it == other.threads.end() || !(*thread == **it)) + { return t.block == thread.block; }); + if (it == other.threads.end() || !(thread == *it)) return false; } @@ -196,11 +195,11 @@ void GlobalContext::print(std::ostream& os, bool show_all) const { bool showed_any = false; for (size_t i = 0; i < threads.size(); i++) { - auto thread = threads[i]; - if (show_all || !thread->terminated || - *threads[i]->terminated != TerminationStatus::completed) { + auto& thread = threads[i]; + if (show_all || !thread.terminated || + *threads[i].terminated != TerminationStatus::completed) { os << "---- Thread " << i << std::endl; - os << *threads[i] << std::endl; + os << threads[i] << std::endl; os << std::endl; showed_any = true; } diff --git a/src/execution_state.hh b/src/execution_state.hh index 655acf0..ce85741 100644 --- a/src/execution_state.hh +++ b/src/execution_state.hh @@ -78,7 +78,7 @@ thread_append_node(ThreadContext &ctx, std::string &&stmt); struct GlobalContext { // Execution state - std::vector> threads; + std::deque threads; std::unordered_map locks; // AST evaluation cache diff --git a/src/interpreter.cc b/src/interpreter.cc index bc9ecef..db53319 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -43,8 +43,8 @@ static bool is_syncing(Thread &thread) { * a the exceptional termination status of the thread. */ std::variant -Interpreter::evaluate_expression(trieste::Node expr, std::shared_ptr thread) { - ThreadContext& ctx = thread->ctx; +Interpreter::evaluate_expression(trieste::Node expr, Thread& thread) { + ThreadContext& ctx = thread.ctx; auto e = expr / lang::Expr; if (e == lang::Reg) { @@ -82,8 +82,7 @@ Interpreter::evaluate_expression(trieste::Node expr, std::shared_ptr thr throw std::logic_error("This code path should never be reached"); } - gctx.threads.push_back( - std::make_shared(tid, std::move(child_ctx), e / lang::Block)); + gctx.threads.emplace_back(tid, std::move(child_ctx), e / lang::Block); // thread_append_node(ctx, tid, node); return tid; @@ -112,8 +111,8 @@ Interpreter::evaluate_expression(trieste::Node expr, std::shared_ptr thr * counter (0 if waiting for some other thread) or the exceptional * termination status of the thread. */ -std::variant Interpreter::run_statement(Node stmt, std::shared_ptr thread) { - ThreadContext& ctx = thread->ctx; +std::variant Interpreter::run_statement(Node stmt, Thread& thread) { + ThreadContext& ctx = thread.ctx; auto s = stmt / lang::Stmt; if (s == lang::Nop) { @@ -193,10 +192,17 @@ std::variant Interpreter::run_statement(Node stmt, std:: } auto result = gctx.cache[expr]; + // Check if the thread ID is valid + if (result >= gctx.threads.size()) { + verbose << "Join: invalid thread ID " << result + << ". gctx.threads.size()=" << gctx.threads.size() << std::endl; + return TerminationStatus::unassigned_variable_read_exception; + } + auto &joinee = gctx.threads[result]; - if (joinee->terminated && - (*joinee->terminated == TerminationStatus::completed)) { - if (auto conflict = gctx.protocol->on_join(ctx, joinee->ctx, gctx)) { + if (joinee.terminated && + (*joinee.terminated == TerminationStatus::completed)) { + if (auto conflict = gctx.protocol->on_join(ctx, joinee.ctx, gctx)) { verbose << (**conflict) << std::endl; return TerminationStatus::datarace_exception; } else { @@ -221,7 +227,7 @@ std::variant Interpreter::run_statement(Node stmt, std:: return 0; } - lock.owner = thread->tid; + lock.owner = thread.tid; if (auto conflict = gctx.protocol->on_lock(ctx, lock, gctx)) { verbose << (**conflict) << std::endl; // using graph::Node; @@ -248,7 +254,7 @@ std::variant Interpreter::run_statement(Node stmt, std:: auto var = std::string(v->location().view()); auto &lock = gctx.locks[var]; - if (!lock.owner || (lock.owner && *lock.owner != thread->tid)) { + if (!lock.owner || (lock.owner && *lock.owner != thread.tid)) { return TerminationStatus::unlock_exception; } @@ -294,13 +300,13 @@ std::variant Interpreter::run_statement(Node stmt, std:: * whether it terminated. */ std::variant -Interpreter::run_single_thread_to_sync(std::shared_ptr thread) { - if (thread->terminated) - return *thread->terminated; +Interpreter::run_single_thread_to_sync(Thread& thread) { + if (thread.terminated) + return *thread.terminated; - auto& ctx = thread->ctx; - auto& pc = thread->pc; - Node block = thread->block; + auto& ctx = thread.ctx; + auto& pc = thread.pc; + Node block = thread.block; // Initial sync when thread starts executing if (pc == 0) @@ -318,7 +324,7 @@ Interpreter::run_single_thread_to_sync(std::shared_ptr thread) { auto result = run_statement(stmt, thread); if (auto term = std::get_if(&result)) { - thread->terminated = *term; + thread.terminated = *term; return *term; } @@ -340,11 +346,11 @@ Interpreter::run_single_thread_to_sync(std::shared_ptr thread) { // Otherwise, we truly reached the end this iteration if (auto conflict = gctx.protocol->on_end(ctx, gctx)) { verbose << (**conflict) << std::endl; - thread->terminated = TerminationStatus::datarace_exception; + thread.terminated = TerminationStatus::datarace_exception; return TerminationStatus::datarace_exception; } - thread->terminated = TerminationStatus::completed; + thread.terminated = TerminationStatus::completed; // thread_append_node(ctx); return TerminationStatus::completed; } @@ -354,7 +360,7 @@ Interpreter::run_single_thread_to_sync(std::shared_ptr thread) { * thread */ std::variant -Interpreter::progress_thread(std::shared_ptr thread) { +Interpreter::progress_thread(Thread& thread) { auto no_threads = gctx.threads.size(); auto prog_or_term = run_single_thread_to_sync(thread); @@ -365,8 +371,8 @@ Interpreter::progress_thread(std::shared_ptr thread) { for (size_t i = no_threads; i < gctx.threads.size(); ++i) { // If there are new threads, we can run them to sync as well any_progress = true; - auto new_thread = gctx.threads[i]; - if (!is_syncing(*new_thread)) { + auto& new_thread = gctx.threads[i]; + if (!is_syncing(new_thread)) { verbose << "==== Thread " << i << " (spawn) ====" << std::endl; progress_thread(new_thread); } @@ -387,19 +393,19 @@ Interpreter::run_threads_to_sync() { ProgressStatus any_progress = ProgressStatus::no_progress; for (size_t i = 0; i < gctx.threads.size(); ++i) { verbose << "==== t" << i << " ====" << std::endl; - auto thread = gctx.threads[i]; - if (!thread->terminated) { + auto& thread = gctx.threads[i]; + if (!thread.terminated) { auto prog_or_term = run_single_thread_to_sync(thread); if (ProgressStatus *prog = std::get_if(&prog_or_term)) { any_progress |= *prog; } else { // We could return termination status of any error here and stop // at the first error - thread->terminated = std::get(prog_or_term); + thread.terminated = std::get(prog_or_term); any_progress |= ProgressStatus::progress; } - all_completed &= thread->terminated.has_value(); + all_completed &= thread.terminated.has_value(); // if a thread spawns a new thread, it will end up at the end so // we will always include the new threads in the termination // criteria @@ -434,8 +440,8 @@ int Interpreter::run() { bool exception_detected = false; for (size_t i = 0; i < gctx.threads.size(); ++i) { const auto &thread = gctx.threads[i]; - if (thread->terminated) { - switch (thread->terminated.value()) { + if (thread.terminated) { + switch (thread.terminated.value()) { case TerminationStatus::completed: verbose << "Thread " << i << " terminated normally" << std::endl; break; diff --git a/src/interpreter.hh b/src/interpreter.hh index 144002e..a2eed85 100644 --- a/src/interpreter.hh +++ b/src/interpreter.hh @@ -24,11 +24,11 @@ public: // Internal functions int run(); - std::variant evaluate_expression(trieste::Node, std::shared_ptr); - std::variant run_statement(trieste::Node, std::shared_ptr); + std::variant evaluate_expression(trieste::Node, Thread&); + std::variant run_statement(trieste::Node, Thread&); - std::variant progress_thread(std::shared_ptr); - std::variant run_single_thread_to_sync(std::shared_ptr); + std::variant progress_thread(Thread&); + std::variant run_single_thread_to_sync(Thread&); std::variant run_threads_to_sync(); }; @@ -36,7 +36,4 @@ public: // Entry function int interpret(const trieste::Node, const std::filesystem::path &output_file, SyncKind sync_kind); -std::variant -progress_thread(GlobalContext &, const ThreadID, std::shared_ptr); - } // namespace gitmem \ No newline at end of file diff --git a/src/model_checker.cc b/src/model_checker.cc index 1c036d0..993c92d 100644 --- a/src/model_checker.cc +++ b/src/model_checker.cc @@ -93,8 +93,8 @@ int model_check(const Node ast, const std::filesystem::path &output_path, SyncKi size_t no_threads = gctx.threads.size(); bool made_progress = false; for (size_t i = start_idx; i < no_threads && !made_progress; ++i) { - auto thread = gctx.threads[i]; - if (!thread->terminated) { + auto& thread = gctx.threads[i]; + if (!thread.terminated) { // Run the thread to the next sync point verbose << "==== Thread " << i << " ====" << std::endl; auto prog_or_term = interp.progress_thread(thread); @@ -127,13 +127,13 @@ int model_check(const Node ast, const std::filesystem::path &output_path, SyncKi bool all_completed = std::all_of( gctx.threads.begin(), gctx.threads.end(), [](const auto &thread) { - return thread->terminated && - *thread->terminated == TerminationStatus::completed; + return thread.terminated && + *thread.terminated == TerminationStatus::completed; }); bool any_crashed = std::any_of( gctx.threads.begin(), gctx.threads.end(), [](const auto &thread) { - return thread->terminated && - *thread->terminated != TerminationStatus::completed; + return thread.terminated && + *thread.terminated != TerminationStatus::completed; }); bool is_deadlock = !all_completed && !made_progress && cursor->is_leaf(); From 7da3b59e62e29dfc031c099302c50db6ea8db3dd Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Fri, 2 Jan 2026 18:06:27 +0100 Subject: [PATCH 23/58] better termination encapsulation --- src/interpreter.cc | 18 +++++++++--------- src/interpreter.hh | 24 ++++++++++++++++++------ src/model_checker.cc | 2 +- 3 files changed, 28 insertions(+), 16 deletions(-) diff --git a/src/interpreter.cc b/src/interpreter.cc index db53319..c099f15 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -42,7 +42,7 @@ static bool is_syncing(Thread &thread) { /* Evaluating an expression either returns the result of the expression or * a the exceptional termination status of the thread. */ -std::variant +StepResult Interpreter::evaluate_expression(trieste::Node expr, Thread& thread) { ThreadContext& ctx = thread.ctx; @@ -111,7 +111,7 @@ Interpreter::evaluate_expression(trieste::Node expr, Thread& thread) { * counter (0 if waiting for some other thread) or the exceptional * termination status of the thread. */ -std::variant Interpreter::run_statement(Node stmt, Thread& thread) { +StepResult Interpreter::run_statement(Node stmt, Thread& thread) { ThreadContext& ctx = thread.ctx; auto s = stmt / lang::Stmt; @@ -299,7 +299,7 @@ std::variant Interpreter::run_statement(Node stmt, Threa * it terminates. Report whether the thread was able to progress or not, or * whether it terminated. */ -std::variant +StepResult Interpreter::run_single_thread_to_sync(Thread& thread) { if (thread.terminated) return *thread.terminated; @@ -359,7 +359,7 @@ Interpreter::run_single_thread_to_sync(Thread& thread) { * Run a thread to the next sync point, including any threads spawned by that * thread */ -std::variant +StepResult Interpreter::progress_thread(Thread& thread) { auto no_threads = gctx.threads.size(); auto prog_or_term = run_single_thread_to_sync(thread); @@ -386,7 +386,7 @@ Interpreter::progress_thread(Thread& thread) { /* Try to evaluate all threads until a sync point or termination point */ -std::variant +StepResult Interpreter::run_threads_to_sync() { verbose << "-----------------------" << std::endl; bool all_completed = true; @@ -418,19 +418,19 @@ Interpreter::run_threads_to_sync() { return any_progress; } -static bool is_finished( std::variant &prog_or_term) { +static bool is_finished(const StepResult& r) { // Either, the system is stuck and made no progress in which case there // is a deadlock (or a thread is stuck waiting for a crashed thread?) // Or, there was some termination criteria in which case we stop - return std::holds_alternative(prog_or_term) || - std::get(prog_or_term) == ProgressStatus::no_progress; + return is_terminated(r) || + std::get(r) == ProgressStatus::no_progress; } /* Try to evaluate all threads until they have all terminated in some way * or we have reached a stuck configuration. */ int Interpreter::run() { - std::variant prog_or_term; + StepResult prog_or_term; do { prog_or_term = run_threads_to_sync(); } while (!is_finished(prog_or_term)); diff --git a/src/interpreter.hh b/src/interpreter.hh index a2eed85..9b35e11 100644 --- a/src/interpreter.hh +++ b/src/interpreter.hh @@ -8,10 +8,22 @@ #include #include "sync_protocol.hh" #include "termination_status.hh" -#include "step_result.hh" namespace gitmem { +template +using StepResult = std::variant; + +template +bool is_terminated(const StepResult& r) { + return std::holds_alternative(r); +} + +template +T& value(StepResult& r) { + return std::get(r); +} + class Interpreter { private: GlobalContext gctx; @@ -24,12 +36,12 @@ public: // Internal functions int run(); - std::variant evaluate_expression(trieste::Node, Thread&); - std::variant run_statement(trieste::Node, Thread&); + StepResult evaluate_expression(trieste::Node, Thread&); + StepResult run_statement(trieste::Node, Thread&); - std::variant progress_thread(Thread&); - std::variant run_single_thread_to_sync(Thread&); - std::variant run_threads_to_sync(); + StepResult progress_thread(Thread&); + StepResult run_single_thread_to_sync(Thread&); + StepResult run_threads_to_sync(); }; diff --git a/src/model_checker.cc b/src/model_checker.cc index 993c92d..bdf56a9 100644 --- a/src/model_checker.cc +++ b/src/model_checker.cc @@ -98,7 +98,7 @@ int model_check(const Node ast, const std::filesystem::path &output_path, SyncKi // Run the thread to the next sync point verbose << "==== Thread " << i << " ====" << std::endl; auto prog_or_term = interp.progress_thread(thread); - if (std::holds_alternative(prog_or_term)) { + if (is_terminated(prog_or_term)) { // Thread terminated, we can extend the trace made_progress = true; cursor = cursor->extend(i); From be8cd482ed4488878a02ace32e9b149cd2d2bcbb Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Fri, 2 Jan 2026 20:24:02 +0100 Subject: [PATCH 24/58] restoring some interactive commands --- CMakeLists.txt | 2 +- src/debugger.cc | 352 ++++++++++++++++++++++++++++----------------- src/gitmem.cc | 3 +- src/interpreter.cc | 12 +- 4 files changed, 226 insertions(+), 143 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6e39979..6e8f9d9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -27,7 +27,7 @@ add_executable(gitmem src/passes/branching.cc src/linear/version_store.cc src/interpreter.cc - # src/debugger.cc + src/debugger.cc src/model_checker.cc src/sync_protocol.cc src/graphviz.cc diff --git a/src/debugger.cc b/src/debugger.cc index 111b940..174fcec 100644 --- a/src/debugger.cc +++ b/src/debugger.cc @@ -70,167 +70,251 @@ Command parse_command(std::string &input) { } } -/** Perform the Step command on a given thread. Error messages are assigned - * to `msg`. The return value signals whether threads should be printed - * after stepping or not. */ -bool step_thread(ThreadID tid, GlobalContext &gctx, std::string &msg) { +enum class StepKind { + Progressed, // Thread made progress + Blocked, // Thread is blocked on sync + Terminated, // Thread terminated this step + Invalid // Invalid thread id, etc. +}; + +struct StepUIResult { + StepKind kind; + std::optional termination; + std::string message; + + static StepUIResult progressed() { + return {StepKind::Progressed, std::nullopt, ""}; + } + + static StepUIResult blocked(std::string msg) { + return {StepKind::Blocked, std::nullopt, std::move(msg)}; + } + + static StepUIResult terminated(TerminationStatus t, std::string msg) { + return {StepKind::Terminated, t, std::move(msg)}; + } + + static StepUIResult invalid(std::string msg) { + return {StepKind::Invalid, std::nullopt, std::move(msg)}; + } +}; + +StepUIResult step_thread(Interpreter& interp, ThreadID tid) { + GlobalContext& gctx = interp.context(); + if (tid >= gctx.threads.size()) { - msg = "Invalid thread id: " + std::to_string(tid); - return false; + return StepUIResult::invalid( + "Invalid thread id: " + std::to_string(tid)); } auto& thread = gctx.threads[tid]; - if (auto term = thread.terminated) { - if (*term == TerminationStatus::completed) { - msg = "Thread " + std::to_string(tid) + " has terminated normally"; + + if (thread.terminated) { + if (*thread.terminated == TerminationStatus::completed) { + return StepUIResult::terminated( + *thread.terminated, + "Thread " + std::to_string(tid) + " has terminated normally"); } else { - msg = "Thread " + std::to_string(tid) + " has terminated with an error"; + return StepUIResult::terminated( + *thread.terminated, + "Thread " + std::to_string(tid) + " has terminated with an error"); } - return false; } - auto prog_or_term = progress_thread(gctx, tid, thread); - if (ProgressStatus *prog = std::get_if(&prog_or_term)) { - if (!*prog) { - auto stmt = thread->block->at(thread->pc); - msg = "Thread " + std::to_string(tid) + " is blocking on '" + - std::string(stmt->location().view()) + "'"; - return false; + auto prog_or_term = interp.progress_thread(gctx.threads[tid]); + + if (auto prog = std::get_if(&prog_or_term)) { + if (*prog == ProgressStatus::no_progress) { + auto stmt = thread.block->at(thread.pc); + return StepUIResult::blocked( + "Thread " + std::to_string(tid) + " is blocking on '" + + std::string(stmt->location().view()) + "'"); } - } else if (TerminationStatus *term = - std::get_if(&prog_or_term)) { - switch (*term) { + return StepUIResult::progressed(); + } + + auto term = std::get(prog_or_term); + + switch (term) { case TerminationStatus::completed: - msg = "Thread " + std::to_string(tid) + " terminated normally"; - return true; + return StepUIResult::terminated( + term, + "Thread " + std::to_string(tid) + " terminated normally"); + case TerminationStatus::datarace_exception: - // TODO: Say on which variable the datarace occurred. To - // do this, have pull return an optional variable that - // is in a race and have the data race exception - // remember that variable. - msg = "Thread " + std::to_string(tid) + - " encountered a data race and was terminated"; - return false; + return StepUIResult::terminated( + term, + "Thread " + std::to_string(tid) + + " encountered a data race and was terminated"); + case TerminationStatus::assertion_failure_exception: { - auto expr = thread->block->at(thread->pc) / lang::Stmt / lang::Expr; - msg = "Thread " + std::to_string(tid) + " failed assertion '" + - std::string(expr->location().view()) + "' and was terminated"; - return false; + auto expr = + thread.block->at(thread.pc) / lang::Stmt / lang::Expr; + return StepUIResult::terminated( + term, + "Thread " + std::to_string(tid) + + " failed assertion '" + + std::string(expr->location().view()) + + "' and was terminated"); } + case TerminationStatus::unassigned_variable_read_exception: - throw std::runtime_error("Thread " + std::to_string(tid) + - " read an uninitialised variable"); + return StepUIResult::terminated( + term, + "Thread " + std::to_string(tid) + + " read an uninitialised variable"); + case TerminationStatus::unlock_exception: - throw std::runtime_error("Thread " + std::to_string(tid) + - " unlocked an unlocked lock"); + return StepUIResult::terminated( + term, + "Thread " + std::to_string(tid) + + " unlocked a lock it does not own"); + default: - throw std::runtime_error("Thread " + std::to_string(tid) + - " has an unhandled termination state"); + return StepUIResult::terminated( + term, + "Thread " + std::to_string(tid) + + " terminated with an unknown error"); + } +} + +/** Print the execution graph if requested */ +void maybe_print_graph(Interpreter& interp, + bool print_graphs, + const std::filesystem::path &output_file) { + if (print_graphs) { + // gctx.print_execution_graph(output_file); + verbose << "Execution graph written to " << output_file << std::endl; } +} + +/** Step a single thread and return the StepUIResult. Also prints the message. */ +StepUIResult do_step(Interpreter &interp, + ThreadID tid, + bool print_graphs, + const std::filesystem::path &output_file) { + StepUIResult result = step_thread(interp, tid); + if (!result.message.empty()) + std::cout << result.message << std::endl; + + maybe_print_graph(interp, print_graphs, output_file); + return result; +} + +/** Reset the interpreter to a fresh state */ +void do_restart(Interpreter &interp, + const trieste::Node ast, + SyncKind sync_kind, + bool print_graphs, + const std::filesystem::path &output_file) { + interp = Interpreter(GlobalContext(ast, make_protocol(sync_kind))); + maybe_print_graph(interp, print_graphs, output_file); +} + +/** Print the list of threads and optionally all threads */ +void do_list(GlobalContext &gctx, bool show_all) { + gctx.print(std::cout, show_all); +} + +void do_finish(Interpreter& interp, bool print_graphs, const std::filesystem::path &output_file) { + if (!interp.run()) { + std::cout << "Program finished successfully" << std::endl; + } else { + std::cout << "Program terminated with an error" << std::endl; } - return true; + + maybe_print_graph(interp, print_graphs, output_file); +} + +/** Print interactive command help */ +void print_help() { + std::cout << "Commands:\n"; + std::cout << "s [tid] - Step to next sync point in thread\n"; + std::cout << "[tid] - Step to next sync point in thread\n"; + std::cout << "f - Finish the program\n"; + std::cout << "r - Restart the program\n"; + std::cout << "l - List all threads\n"; + std::cout << "g - Toggle automatic execution graph printing\n"; + std::cout << "p - Print the execution graph immediately\n"; + std::cout << "q - Quit the interpreter\n"; + std::cout << "? - Display this help message\n"; } -/** Interpret the AST in an interactive way, letting the user choose which - * thread to schedule next. */ +/** Main interactive interpreter loop */ int interpret_interactive(const trieste::Node ast, const std::filesystem::path &output_file, SyncKind sync_kind) { - GlobalContext gctx(ast, make_protocol(sync_kind)); - - size_t prev_no_threads = 1; - Command command = {Command::List}; - std::string msg = ""; - bool print_graphs = true; - // gctx.print_execution_graph(output_file); - while (command.cmd != Command::Quit) { - if (command.cmd != Command::Skip || - prev_no_threads != gctx.threads.size()) { - bool show_all = command.cmd == Command::List; - gctx.print(std::cout, show_all); - } - prev_no_threads = gctx.threads.size(); + Interpreter interp(GlobalContext(ast, make_protocol(sync_kind))); + GlobalContext &gctx = interp.context(); - if (!msg.empty()) { - std::cout << msg << std::endl; - msg.clear(); - } + size_t prev_no_threads = 1; + Command command = {Command::List}; + bool print_graphs = true; - std::cout << "> "; - std::string input; - std::getline(std::cin, input); - if (!input.empty() && - input.find_first_not_of(" \t\n\r") != std::string::npos) { - command = parse_command(input); - } + while (command.cmd != Command::Quit) { + // Print threads if new threads appeared or command is List + if (command.cmd != Command::Skip || prev_no_threads != gctx.threads.size()) { + do_list(gctx, command.cmd == Command::List); + } + prev_no_threads = gctx.threads.size(); - if (command.cmd == Command::Step) { - auto tid = command.argument; - if (!step_thread(tid, gctx, msg)) - command = {Command::Skip}; + // Read user input + std::cout << "> "; + std::string input; + std::getline(std::cin, input); + if (!input.empty() && input.find_first_not_of(" \t\n\r") != std::string::npos) + command = parse_command(input); - if (print_graphs) { - // gctx.print_execution_graph(output_file); - assert(false && "todo"); - verbose << "Execution graph written to " << output_file << std::endl; - } - } else if (command.cmd == Command::Finish) { - // Finish the program - assert(false && "fixme"); - // if (!run_threads(gctx)) - // msg = "Program finished successfully"; - // else - // msg = "Program terminated with an error"; - - if (print_graphs) { - // gctx.print_execution_graph(output_file); - assert(false && "todo"); - verbose << "Execution graph written to " << output_file << std::endl; - } - } else if (command.cmd == Command::Restart) { - // Start the program from the beginning - gctx = GlobalContext(ast, make_protocol(sync_kind)); - command = {Command::List}; - if (print_graphs) { - // gctx.print_execution_graph(output_file); - assert(false && "todo"); - verbose << "Execution graph written to " << output_file << std::endl; - } - } else if (command.cmd == Command::List) { - // Listing is a no-op - } else if (command.cmd == Command::Graph) { - // Toggle printing execution graph automatically - print_graphs = !print_graphs; - std::cout << "graphs " << (print_graphs ? "will" : "won't") - << " print automatically" << std::endl; - command = {Command::Skip}; - } else if (command.cmd == Command::Print) { - // Print the execution graph - // gctx.print_execution_graph(output_file); - assert(false && "todo"); - verbose << "Execution graph written to " << output_file << std::endl; - command = {Command::Skip}; - } else if (command.cmd == Command::Skip) { - // Skip is a no-op - } else if (command.cmd == Command::Info) { - std::cout << "Commands:" << std::endl; - std::cout << "s [tid] - Step to next sync point in thread" << std::endl; - std::cout << "[tid] - Step to next sync point in thread" << std::endl; - std::cout << "f - Finish the program" << std::endl; - std::cout << "r - Restart the program" << std::endl; - std::cout << "l - List all threads" << std::endl; - std::cout << "g - Toggle printing the execution graph at sync points" - << std::endl; - std::cout << "p - Printing the execution graph at current sync point" - << std::endl; - std::cout << "q - Quit the interpreter" << std::endl; - std::cout << "? - Display this help message" << std::endl; - command = {Command::Skip}; - } else if (command.cmd == Command::Quit) { - // Quit is a no-op + switch (command.cmd) { + case Command::Step: { + ThreadID tid = command.argument; + StepUIResult res = do_step(interp, tid, print_graphs, output_file); + if (res.kind != StepKind::Progressed) + command = {Command::Skip}; + break; + } + + case Command::Finish: + do_finish(interp, print_graphs, output_file); + break; + + case Command::Restart: + do_restart(interp, ast, sync_kind, print_graphs, output_file); + command = {Command::List}; + break; + + case Command::List: + // Already handled before reading input, no-op here + break; + + case Command::Graph: + print_graphs = !print_graphs; + std::cout << "Graphs " << (print_graphs ? "will" : "won't") + << " print automatically" << std::endl; + command = {Command::Skip}; + break; + + case Command::Print: + maybe_print_graph(interp, print_graphs, output_file); + command = {Command::Skip}; + break; + + case Command::Info: + print_help(); + command = {Command::Skip}; + break; + + case Command::Skip: + // No-op + break; + + case Command::Quit: + // No-op + break; + } } - } - return 0; + return 0; } -} // namespace gitmem + +} // namespace gitmem \ No newline at end of file diff --git a/src/gitmem.cc b/src/gitmem.cc index ad40077..6fda58a 100644 --- a/src/gitmem.cc +++ b/src/gitmem.cc @@ -72,8 +72,7 @@ int main(int argc, char **argv) { if (model_check) { exit_status = gitmem::model_check(result.ast, output_path, sync_kind); } else if (interactive) { - assert(false && "fixme"); - // exit_status = gitmem::interpret_interactive(result.ast, output_path, sync_kind); + exit_status = gitmem::interpret_interactive(result.ast, output_path, sync_kind); } else { exit_status = gitmem::interpret(result.ast, output_path, sync_kind); } diff --git a/src/interpreter.cc b/src/interpreter.cc index c099f15..fa8f217 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -42,7 +42,7 @@ static bool is_syncing(Thread &thread) { /* Evaluating an expression either returns the result of the expression or * a the exceptional termination status of the thread. */ -StepResult +std::variant Interpreter::evaluate_expression(trieste::Node expr, Thread& thread) { ThreadContext& ctx = thread.ctx; @@ -111,7 +111,7 @@ Interpreter::evaluate_expression(trieste::Node expr, Thread& thread) { * counter (0 if waiting for some other thread) or the exceptional * termination status of the thread. */ -StepResult Interpreter::run_statement(Node stmt, Thread& thread) { +std::variant Interpreter::run_statement(Node stmt, Thread& thread) { ThreadContext& ctx = thread.ctx; auto s = stmt / lang::Stmt; @@ -299,7 +299,7 @@ StepResult Interpreter::run_statement(Node stmt, Thread& thread) { * it terminates. Report whether the thread was able to progress or not, or * whether it terminated. */ -StepResult +std::variant Interpreter::run_single_thread_to_sync(Thread& thread) { if (thread.terminated) return *thread.terminated; @@ -359,7 +359,7 @@ Interpreter::run_single_thread_to_sync(Thread& thread) { * Run a thread to the next sync point, including any threads spawned by that * thread */ -StepResult +std::variant Interpreter::progress_thread(Thread& thread) { auto no_threads = gctx.threads.size(); auto prog_or_term = run_single_thread_to_sync(thread); @@ -386,7 +386,7 @@ Interpreter::progress_thread(Thread& thread) { /* Try to evaluate all threads until a sync point or termination point */ -StepResult +std::variant Interpreter::run_threads_to_sync() { verbose << "-----------------------" << std::endl; bool all_completed = true; @@ -430,7 +430,7 @@ static bool is_finished(const StepResult& r) { * or we have reached a stuck configuration. */ int Interpreter::run() { - StepResult prog_or_term; + std::variant prog_or_term; do { prog_or_term = run_threads_to_sync(); } while (!is_finished(prog_or_term)); From a0a24086a7b96a21956b2aad8767b29d2967d0b7 Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Mon, 5 Jan 2026 19:27:08 +0100 Subject: [PATCH 25/58] trying to fix graphing but curently broken --- src/execution_state.hh | 8 +- src/graph.hh | 13 +++ src/interpreter.cc | 141 ++++++++++++++++++++++++++------- src/interpreter.hh | 5 +- src/sync_protocol.cc | 6 -- src/sync_protocol.hh | 20 +---- src/thread_trace.hh | 174 +++++++++++++++++++++++++++++++++++------ 7 files changed, 279 insertions(+), 88 deletions(-) diff --git a/src/execution_state.hh b/src/execution_state.hh index ce85741..32d7756 100644 --- a/src/execution_state.hh +++ b/src/execution_state.hh @@ -67,15 +67,9 @@ struct Thread { struct Lock { std::optional owner = std::nullopt; + std::shared_ptr last_unlock_event = nullptr; }; -template -std::shared_ptr thread_append_node(ThreadContext &ctx, Args &&...args); - -template <> -std::shared_ptr -thread_append_node(ThreadContext &ctx, std::string &&stmt); - struct GlobalContext { // Execution state std::deque threads; diff --git a/src/graph.hh b/src/graph.hh index a201439..2d8e4c9 100644 --- a/src/graph.hh +++ b/src/graph.hh @@ -140,5 +140,18 @@ struct Pending : Node { Pending(const std::string statement) : statement(statement) {} void accept(Visitor *v) const override { v->visitPending(this); } }; + +struct ExecutionGraph { + std::vector> threads; + + ExecutionGraph() = default; + + ExecutionGraph(const ExecutionGraph&) = delete; + ExecutionGraph& operator=(const ExecutionGraph&) = delete; + + ExecutionGraph(ExecutionGraph&&) = default; + ExecutionGraph& operator=(ExecutionGraph&&) = default; +}; + } // namespace graph } // namespace gitmem diff --git a/src/interpreter.cc b/src/interpreter.cc index fa8f217..48f0077 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -3,7 +3,6 @@ #include #include "debug.hh" -#include "graphviz.hh" #include "interpreter.hh" #include "sync_protocol.hh" @@ -58,6 +57,7 @@ Interpreter::evaluate_expression(trieste::Node expr, Thread& thread) { } else if (e == lang::Var) { auto var = std::string(expr->location().view()); if (std::optional result = gctx.protocol->read(ctx, var)) { + thread.trace.on_read(var, *result); return *result; } else { // It is invalid to read a previously unwritten value return TerminationStatus::unassigned_variable_read_exception; @@ -74,7 +74,7 @@ Interpreter::evaluate_expression(trieste::Node expr, Thread& thread) { } return sum; } else if (e == lang::Spawn) { - ThreadID tid = gctx.threads.size(); + ThreadID child_tid = gctx.threads.size(); ThreadContext child_ctx(gctx.protocol->kind()); if (std::optional> conflict = @@ -82,10 +82,9 @@ Interpreter::evaluate_expression(trieste::Node expr, Thread& thread) { throw std::logic_error("This code path should never be reached"); } - gctx.threads.emplace_back(tid, std::move(child_ctx), e / lang::Block); - // thread_append_node(ctx, tid, node); - - return tid; + gctx.threads.emplace_back(child_tid, std::move(child_ctx), e / lang::Block); + thread.trace.on_spawn(child_tid); + return child_tid; } else if (e == lang::Eq || e == lang::Neq) { auto lhs = e / lang::Lhs; auto rhs = e / lang::Rhs; @@ -158,6 +157,7 @@ std::variant Interpreter::run_statement(Node stmt, Threa } else if (lhs == lang::Var) { gctx.protocol->write(ctx, var, *val); + thread.trace.on_write(var, *val); // // Global variable writes need to create a new commit id // // to track the history of updates @@ -167,8 +167,7 @@ std::variant Interpreter::run_statement(Node stmt, Threa // verbose << "Set global '" << lhs->location().view() << "' to " << // *val << " with id " << *(global.commit) << std::endl; - // auto node = thread_append_node(ctx, var, global.val, - // *global.commit); gctx.commit_map[*(global.commit)] = node; + // gctx.commit_map[*(global.commit)] = node; } else { throw std::runtime_error("Bad left-hand side: " + std::string(lhs->type().str())); @@ -204,9 +203,10 @@ std::variant Interpreter::run_statement(Node stmt, Threa (*joinee.terminated == TerminationStatus::completed)) { if (auto conflict = gctx.protocol->on_join(ctx, joinee.ctx, gctx)) { verbose << (**conflict) << std::endl; + thread.trace.on_join(result, std::move(*conflict)); return TerminationStatus::datarace_exception; } else { - // thread_append_node(ctx, result, joinee->ctx.tail); + thread.trace.on_join(result); } } else { @@ -228,21 +228,16 @@ std::variant Interpreter::run_statement(Node stmt, Threa } lock.owner = thread.tid; + if (auto conflict = gctx.protocol->on_lock(ctx, lock, gctx)) { verbose << (**conflict) << std::endl; - // using graph::Node; - // auto [s1, s2] = conflict->commits; - // auto sources = std::pair, - // std::shared_ptr>{gctx.commit_map[s1], gctx.commit_map[s2]}; - // auto graph_conflict = graph::Conflict(conflict->var, sources); - // thread_append_node(ctx, var, lock.last, - // graph_conflict); + thread.trace.on_lock(var, lock.last_unlock_event, std::move(*conflict)); return TerminationStatus::datarace_exception; } - // thread_append_node(ctx, var, lock.last); - + thread.trace.on_lock(var, lock.last_unlock_event); verbose << "Locked " << var << std::endl; + } else if (s == lang::Unlock) { // We can only unlock locks we previously locked. We commit any // pending updates and then copy the threads versioned globals @@ -260,14 +255,14 @@ std::variant Interpreter::run_statement(Node stmt, Threa if (auto conflict = gctx.protocol->on_unlock(ctx, lock, gctx)) { verbose << (**conflict) << std::endl; + thread.trace.on_unlock(var, std::move(*conflict)); return TerminationStatus::datarace_exception; } // lock.globals = ctx.globals; lock.owner.reset(); - // thread_append_node(ctx, var); - // lock.last = ctx.tail; + lock.last_unlock_event = thread.trace.on_unlock(var); verbose << "Unlocked " << var << std::endl; @@ -280,8 +275,7 @@ std::variant Interpreter::run_statement(Node stmt, Threa verbose << "Assertion passed: " << expr->location().view() << std::endl; } else { verbose << "Assertion failed: " << expr->location().view() << std::endl; - // thread_append_node( - // ctx, std::string(expr->location().view())); + thread.trace.on_assert_fail(std::string(expr->location().view())); return TerminationStatus::assertion_failure_exception; } } else { @@ -309,8 +303,10 @@ Interpreter::run_single_thread_to_sync(Thread& thread) { Node block = thread.block; // Initial sync when thread starts executing - if (pc == 0) + if (pc == 0) { gctx.protocol->on_start(ctx, gctx); + thread.trace.on_start(); + } bool made_progress = false; @@ -351,7 +347,7 @@ Interpreter::run_single_thread_to_sync(Thread& thread) { } thread.terminated = TerminationStatus::completed; - // thread_append_node(ctx); + thread.trace.on_end(); return TerminationStatus::completed; } @@ -439,7 +435,7 @@ int Interpreter::run() { bool exception_detected = false; for (size_t i = 0; i < gctx.threads.size(); ++i) { - const auto &thread = gctx.threads[i]; + auto &thread = gctx.threads[i]; if (thread.terminated) { switch (thread.terminated.value()) { case TerminationStatus::completed: @@ -475,7 +471,7 @@ int Interpreter::run() { } } else { exception_detected = true; - // thread_append_node(thread->ctx); + thread.trace.on_end(); verbose << "Thread " << i << " is stuck" << std::endl; } } @@ -483,10 +479,97 @@ int Interpreter::run() { return exception_detected ? 1 : 0; } +void Interpreter::print_thread_traces() { + for (size_t tid = 0; tid < gctx.threads.size(); ++tid) { + const auto& thread = gctx.threads[tid]; + std::cout << "=== Thread " << tid << " ===" << std::endl; + std::cout << thread.trace; + std::cout << "====================================\n"; + } +} + +graph::ExecutionGraph Interpreter::build_execution_graph_from_traces() { + // Map each thread ID to its last graph node in the execution graph + std::unordered_map> thread_tails; + + // create start nodes for all threads + graph::ExecutionGraph g; + for (ThreadID tid = 0; tid < gctx.threads.size(); ++tid) { + auto node = std::make_shared(tid); + g.threads.push_back(node); + thread_tails[tid] = node; + } + + assert(false && "This is currently not building the current graph"); + + // Helper lambda to convert an Event to a graph Node + auto event_to_node = [&](ThreadID tid, const std::shared_ptr& e) -> std::shared_ptr { + return std::visit([&](auto&& arg) -> std::shared_ptr { + using T = std::decay_t; + std::shared_ptr node; + + if constexpr (std::is_same_v) { + node = std::make_shared(tid); + } else if constexpr (std::is_same_v) { + node = std::make_shared(); + } else if constexpr (std::is_same_v) { + node = std::make_shared(arg.var, arg.value, tid); + } else if constexpr (std::is_same_v) { + // We could track dependencies from previous writes if desired + node = std::make_shared(arg.var, arg.value, tid, thread_tails[tid]); + } else if constexpr (std::is_same_v) { + ThreadID child_tid = arg.child_tid; + node = std::make_shared(child_tid, g.threads[child_tid]); + } else if constexpr (std::is_same_v) { + node = std::make_shared(arg.joinee_tid, thread_tails[arg.joinee_tid]); + } else if constexpr (std::is_same_v) { + node = std::make_shared(arg.lock_name, thread_tails[tid]); + } else if constexpr (std::is_same_v) { + node = std::make_shared(arg.lock_name); + } else if constexpr (std::is_same_v) { + node = std::make_shared(arg.condition); + } else { + throw std::logic_error("Unknown Event type in trace"); + } + + // Link the previous tail of this thread to this new node + if (thread_tails[tid]) + thread_tails[tid]->next = node; + + thread_tails[tid] = node; + return node; + }, e->data); + }; + + // Iterate over threads in thread ID order + for (ThreadID tid = 0; tid < gctx.threads.size(); ++tid) { + auto& thread = gctx.threads[tid]; + + // skip the first event because it is always start and we created that to begin with + for (auto it = std::next(thread.trace.begin()); it != thread.trace.end(); ++it) { + event_to_node(tid, *it); + } + if (!thread.terminated) { + trieste::Node stmt = thread.block->at(thread.pc); + thread_tails[tid]->next = std::make_shared(std::string(stmt->location().view())); + } + } + + return g; +} + int interpret(const Node ast, const std::filesystem::path &output_path, SyncKind sync_kind) { - GlobalContext gctx(ast, make_protocol(sync_kind)); - Interpreter interp(std::move(gctx)); - return interp.run(); + Interpreter interp(GlobalContext(ast, make_protocol(sync_kind))); + int result = interp.run(); + + interp.print_thread_traces(); + + auto exec_graph = interp.build_execution_graph_from_traces(); + + // graph::GraphvizPrinter gv(output_path); + // gv.visit(node.get()); + + return result; } } // namespace gitmem \ No newline at end of file diff --git a/src/interpreter.hh b/src/interpreter.hh index 9b35e11..9f924c1 100644 --- a/src/interpreter.hh +++ b/src/interpreter.hh @@ -33,7 +33,6 @@ public: GlobalContext& context() { return gctx; } - // Internal functions int run(); StepResult evaluate_expression(trieste::Node, Thread&); @@ -43,6 +42,10 @@ public: StepResult run_single_thread_to_sync(Thread&); StepResult run_threads_to_sync(); + void print_thread_traces(); + + graph::ExecutionGraph build_execution_graph_from_traces(); + }; // Entry function diff --git a/src/sync_protocol.cc b/src/sync_protocol.cc index 4287371..131b7d8 100644 --- a/src/sync_protocol.cc +++ b/src/sync_protocol.cc @@ -4,12 +4,6 @@ namespace gitmem { -template std::ostream &Conflict::print(std::ostream &os) const { - os << "conflict on " << var << " { " << versions.first << ", " - << versions.second << " }"; - return os; -} - // -------------------- // LinearSyncProtocol // -------------------- diff --git a/src/sync_protocol.hh b/src/sync_protocol.hh index 6f008dc..5ed1336 100644 --- a/src/sync_protocol.hh +++ b/src/sync_protocol.hh @@ -1,5 +1,6 @@ #pragma once +#include "conflict.hh" #include "sync_kind.hh" #include "execution_state.hh" #include "branching/version_store.hh" @@ -13,25 +14,6 @@ namespace gitmem { std::unique_ptr make_protocol(SyncKind); -struct ConflictBase { - virtual ~ConflictBase() = default; - virtual std::ostream &print(std::ostream &os) const = 0; - friend std::ostream &operator<<(std::ostream &os, - const ConflictBase &conflict) { - return conflict.print(os); - } -}; - -template struct Conflict : ConflictBase { - std::string var; - std::pair versions; - - Conflict(std::string var, std::pair versions) - : var(std::move(var)), versions(std::move(versions)) {} - - std::ostream &print(std::ostream &os) const override; -}; - using LinearConflict = Conflict; using BranchingConflict = Conflict; diff --git a/src/thread_trace.hh b/src/thread_trace.hh index 6866d77..378f7ff 100644 --- a/src/thread_trace.hh +++ b/src/thread_trace.hh @@ -1,58 +1,180 @@ #pragma once #include "thread_id.hh" -#include "graph.hh" +#include "conflict.hh" namespace gitmem { +struct Event; + +struct StartEvent {}; +struct SpawnEvent { const ThreadID child_tid; }; +struct ReadEvent { const std::string var; const size_t value; }; +struct WriteEvent { const std::string var; const size_t value; }; +struct LockEvent { std::string lock_name; std::unique_ptr maybe_conflict; std::shared_ptr last_unlock_event; }; +struct UnlockEvent { const std::string lock_name; std::unique_ptr maybe_conflict; }; +struct JoinEvent { const ThreadID joinee_tid; std::unique_ptr maybe_conflict; }; +struct AssertEvent { const std::string condition; }; + +struct EndEvent {}; + +using EventID = size_t; + +struct Event { + ThreadID tid; + EventID eid; + std::variant< + StartEvent, + SpawnEvent, + ReadEvent, + WriteEvent, + LockEvent, + UnlockEvent, + JoinEvent, + AssertEvent, + EndEvent + > data; +}; + +inline std::string event_header(const Event& e) { + std::ostringstream oss; + oss << "[tid=" << e.tid << ", eid=" << e.eid << "]"; + return oss.str(); +} + +// --- operator<< overloads for individual event types --- +inline std::ostream& operator<<(std::ostream& os, const StartEvent&) { + return os << "StartEvent"; +} + +inline std::ostream& operator<<(std::ostream& os, const SpawnEvent& e) { + return os << "SpawnEvent(child_tid=" << e.child_tid << ")"; +} + +inline std::ostream& operator<<(std::ostream& os, const ReadEvent& e) { + return os << "ReadEvent(var=\"" << e.var << "\", value=" << e.value << ")"; +} + +inline std::ostream& operator<<(std::ostream& os, const WriteEvent& e) { + return os << "WriteEvent(var=\"" << e.var << "\", value=" << e.value << ")"; +} + +inline std::ostream& operator<<(std::ostream& os, const LockEvent& e) { + os << "LockEvent(lock_name=\"" << e.lock_name << "\""; + if (e.last_unlock_event) + os << ", last unlock " << event_header(*e.last_unlock_event); + if (e.maybe_conflict) + os << ", conflict)"; + else + os << ")"; + return os; +} + +inline std::ostream& operator<<(std::ostream& os, const UnlockEvent& e) { + os << "UnlockEvent(lock_name=\"" << e.lock_name << "\""; + if (e.maybe_conflict) + os << ", conflict)"; + else + os << ")"; + return os; +} + +inline std::ostream& operator<<(std::ostream& os, const JoinEvent& e) { + os << "JoinEvent(joinee_tid=" << e.joinee_tid; + if (e.maybe_conflict) + os << ", conflict)"; + else + os << ")"; + return os; +} + +inline std::ostream& operator<<(std::ostream& os, const AssertEvent& e) { + return os << "AssertEvent(condition=\"" << e.condition << "\")"; +} + +inline std::ostream& operator<<(std::ostream& os, const EndEvent&) { + return os << "EndEvent"; +} + +// --- operator<< for the wrapper Event --- +inline std::ostream& operator<<(std::ostream& os, const Event& e) { + os << event_header(e) << " "; + std::visit([&os](auto&& arg) { os << arg; }, e.data); + return os; +} + +static EventID next_eid = 0; + struct ThreadTrace { + std::vector> trace; ThreadID tid; - std::shared_ptr head; - std::shared_ptr tail; + + auto begin() { return trace.begin(); } + auto end() { return trace.end(); } + + auto begin() const { return trace.begin(); } + auto end() const { return trace.end(); } + + explicit ThreadTrace(ThreadID tid) : tid(tid) {} private: template - void append(Args&&... args) { - assert(tail); - auto node = std::make_shared(std::forward(args)...); - tail->next = node; - tail = node; + std::shared_ptr append(Args&&... args) { + auto event = std::make_shared(tid, next_eid++, T(std::forward(args)...)); + trace.push_back(event); + return event; + } + + public: + std::shared_ptr on_start() { + return append(); } -public: - explicit ThreadTrace(ThreadID tid): tid(tid), head(nullptr), tail(nullptr) {} + std::shared_ptr on_spawn(ThreadID child_tid) { + return append(child_tid); + } - void on_start(ThreadID tid) { - assert(head == tail && head == nullptr); - head = std::make_shared(tid); - tail = head; + std::shared_ptr on_read(const std::string text, const size_t value) { + return append(std::move(text), value); } - void on_stmt(std::string text) { - append(std::move(text)); + std::shared_ptr on_write(const std::string text, const size_t value) { + return append(std::move(text), value); } - void on_lock(std::string lock, std::shared_ptr last) { - append(std::move(lock), last); + std::shared_ptr on_lock(const std::string lock_name, + std::shared_ptr last_unlock_event, + std::unique_ptr conflict = nullptr) { + return append(std::move(lock_name), std::move(conflict), last_unlock_event); } - void on_unlock(std::string lock) { - append(std::move(lock)); + std::shared_ptr on_unlock(const std::string lock_name, std::unique_ptr conflict = nullptr) { + return append(std::move(lock_name), std::move(conflict)); } - void on_join(ThreadID tid, std::shared_ptr target) { - append(tid, target); + std::shared_ptr on_join(ThreadID tid, std::unique_ptr conflict = nullptr) { + return append(tid, std::move(conflict)); } - void on_assert_fail(std::string expr) { - append(std::move(expr)); + std::shared_ptr on_assert_fail(std::string expr) { + return append(std::move(expr)); } - void on_end() { - append(); + std::shared_ptr on_end() { + return append(); } }; + +// --- operator<< for ThreadTrace --- +inline std::ostream& operator<<(std::ostream& os, const ThreadTrace& tt) { + os << "ThreadTrace[" << tt.trace.size() << " events]:\n"; + for (size_t i = 0; i < tt.trace.size(); ++i) { + os << " " << i << ": " << *(tt.trace[i]) << "\n"; + } + return os; +} + } // namespace gitmem // template From 72a6142fd6c8d30bdcfcb028d05cebed217178c7 Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Mon, 5 Jan 2026 20:35:13 +0100 Subject: [PATCH 26/58] separate protocols into different files --- CMakeLists.txt | 2 + src/branching/version_store.hh | 40 +++-- src/sync_protocol.cc | 300 +-------------------------------- src/sync_protocol.hh | 83 --------- 4 files changed, 33 insertions(+), 392 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6e8f9d9..2232aec 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,6 +26,8 @@ add_executable(gitmem src/passes/check_refs.cc src/passes/branching.cc src/linear/version_store.cc + src/linear/sync_protocol.cc + src/branching/sync_protocol.cc src/interpreter.cc src/debugger.cc src/model_checker.cc diff --git a/src/branching/version_store.hh b/src/branching/version_store.hh index e20b706..ac0660d 100644 --- a/src/branching/version_store.hh +++ b/src/branching/version_store.hh @@ -5,6 +5,7 @@ #include #include #include +#include "thread_id.hh" namespace gitmem { @@ -15,23 +16,21 @@ namespace branching { * the current commit id for the variable, and the history of commited ids. */ -using Commit = size_t; -using CommitHistory = std::vector; +using Timestamp = std::pair; +using Value = size_t; +using ObjectNumber = uint64_t; -struct Global { - size_t val; - std::optional commit; - CommitHistory history; -}; - -using Globals = std::unordered_map; - -struct Conflict { - std::string var; - std::pair commits; +struct Commit { + size_t id; + std::vector> parents; + Timestamp timestamp; + std::unordered_map changes; }; struct LocalVersionStore { + std::shared_ptr _head; + std::unordered_map _staging; + friend std::ostream& operator<<(std::ostream& os, const LocalVersionStore& store) { assert(false && "TODO"); return os; @@ -44,6 +43,21 @@ struct LocalVersionStore { }; +// using CommitHistory = std::vector; + +// struct Global { +// size_t val; +// std::optional commit; +// CommitHistory history; +// }; + +// using Globals = std::unordered_map; + +// struct Conflict { +// std::string var; +// std::pair commits; +// }; + // Join logic // commit(ctx.globals); // commit(thread->ctx.globals); diff --git a/src/sync_protocol.cc b/src/sync_protocol.cc index 131b7d8..87ee605 100644 --- a/src/sync_protocol.cc +++ b/src/sync_protocol.cc @@ -1,307 +1,15 @@ #include "sync_protocol.hh" -#include "debug.hh" -#include +#include "linear/sync_protocol.hh" +#include "branching/sync_protocol.hh" namespace gitmem { -// -------------------- -// LinearSyncProtocol -// -------------------- - -std::ostream &LinearSyncProtocol::print(std::ostream &os) const { - os << _global_store << std::endl; - return os; -} - -std::optional -LinearSyncProtocol::push(linear::LocalVersionStore &local) { - if (auto conflict = _global_store.check_conflicts(local.base_timestamp(), - local.staged_changes())) { - - // reshape the conflict - return std::make_optional( - _global_store.get_object_name(conflict->object), - std::make_pair(conflict->local_base, conflict->global_head)); - } - - linear::Timestamp new_base = _global_store.apply_changes( - local.base_timestamp(), local.staged_changes()); - - local.clear_staging(); - local.advance_base(new_base); - return std::nullopt; -} - -std::optional -LinearSyncProtocol::pull(linear::LocalVersionStore &local) { - if (auto conflict = _global_store.check_conflicts(local.base_timestamp(), - local.staged_changes())) { - - return std::make_optional( - _global_store.get_object_name(conflict->object), - std::make_pair(conflict->local_base, conflict->global_head)); - } - - local.advance_base(_global_store.current_timestamp()); - return std::nullopt; -} - -LinearSyncProtocol::~LinearSyncProtocol() = default; - -std::optional LinearSyncProtocol::read(ThreadContext &ctx, - const std::string &var) { - linear::ObjectNumber number = _global_store.get_object_number(var); - - auto& store = std::get(ctx.sync).store; - - if (auto result = store.get_staged(number)) - return result; - - std::optional value = _global_store.get_version_for_timestamp( - number, store.base_timestamp()); - if (!value) - return std::nullopt; - - // we do not need to record the staged value for correctness - // TODO: there is something about working out if a value has changed vs been - // written - - return *value; -} - -void LinearSyncProtocol::write(ThreadContext &ctx, const std::string &var, - size_t value) { - // write into the staging area of the thread - auto& store = std::get(ctx.sync).store; - store.stage(_global_store.get_object_number(var), value); -} - -std::optional> -LinearSyncProtocol::on_spawn(ThreadContext &parent, ThreadContext &child, - GlobalContext &gctx) { - // TODO: i think we can drop the globalcontext but check after branching is - // added - - // push parent to global history - auto& store = std::get(parent.sync).store; - if (auto conflict = push(store)) - return std::make_unique(std::move(*conflict)); - - // pull into the child - store = std::get(child.sync).store; - if (auto conflict = pull(store)) { - throw std::logic_error("This code path should never be reached"); - } - - return std::nullopt; -} - -std::optional> -LinearSyncProtocol::on_join(ThreadContext &joiner, ThreadContext &joinee, - GlobalContext &gctx) { - // we assume the joinee has already terminated and pushed - - // pull changes into parent - auto& store = std::get(joiner.sync).store; - if (auto conflict = pull(store)) - return std::make_unique(std::move(*conflict)); - - return std::nullopt; -} - -std::optional> -LinearSyncProtocol::on_start(ThreadContext &thread, GlobalContext &gctx) { - // pull state from global history - auto& store = std::get(thread.sync).store; - auto conflict = pull(store); - assert(!conflict && "cannot conflict from starting state"); - - return std::nullopt; -}; - -std::optional> -LinearSyncProtocol::on_end(ThreadContext &thread, GlobalContext &gctx) { - // push changes to global history - auto& store = std::get(thread.sync).store; - if (auto conflict = push(store)) - return std::make_unique(std::move(*conflict)); - - return std::nullopt; -}; - -std::optional> -LinearSyncProtocol::on_lock(ThreadContext &thread, Lock &lock, - GlobalContext &gctx) { - - auto& store = std::get(thread.sync).store; - if (auto conflict = pull(store)) - return std::make_unique(std::move(*conflict)); - - return std::nullopt; -} - -std::optional> -LinearSyncProtocol::on_unlock(ThreadContext &thread, Lock &, - GlobalContext &gctx) { - - // push changes to global history - auto& store = std::get(thread.sync).store; - if (auto conflict = push(store)) - return std::make_unique(std::move(*conflict)); - - return std::nullopt; -} - -// -------------------- -// BranchingSyncProtocol -// -------------------- - -BranchingSyncProtocol::~BranchingSyncProtocol() = default; - -std::ostream &BranchingSyncProtocol::print(std::ostream &os) const { - assert(false && "TODO"); - return os; -} - -// /* At a commit point, walk through all the versioned variables and see if -// * they have a pending commit, if so commit the value by appending to -// * the variables history. -// */ -// void commit(Globals &globals) { -// for (auto& [var, global] : globals) { -// if (global.commit) -// { -// global.history.push_back(*global.commit); -// verbose << "Committed global '" << var << "' with id " << -// *global.commit << std::endl; global.commit.reset(); -// } -// } -// } - -// /* A versioned value can be fastforwarded to another version, if one -// * version's history is a prefix of another version's history. -// * A conflict between two commit histories exists if neither history is a -// * prefix of the other. -// */ -// std::optional> has_conflict(CommitHistory& h1, -// CommitHistory& h2) -// { -// size_t length = std::min(h1.size(), h2.size()); - -// for (size_t i = 0; i < length; i++) -// { -// if (h1[i] != h2[i]) return std::pair{h1[i], h2[i]}; -// } - -// return std::nullopt; -// } - -// /* Walk through all the global versions from source and update the versions -// * in destination to be the most up-to-date version (this could come from -// * either source or destination). This means destination will now also -// * include variables it previously did not know about. -// */ -// std::optional> pull(Globals &dst, Globals &src) -// { -// for (auto& [var, global] : src) { -// if (dst.contains(var)) -// { -// auto& src_var = src[var]; -// auto& dst_var = dst[var]; -// if (auto conflict = has_conflict(src_var.history, -// dst_var.history)) -// { -// auto [s1, s2] = *conflict; -// verbose << "A data race on '" << var << "' was detected from -// commits " << s1 << " and " << s2 << std::endl; return -// Conflict(var, *conflict); -// } -// else if (src_var.history.size() > dst_var.history.size()) -// { -// verbose << "Fast-forward '" << var << "' to id " << -// src_var.val << std::endl; dst_var.val = src_var.val; -// dst_var.history = src_var.history; -// } -// } -// else -// { -// dst[var].val = src[var].val; -// dst[var].history = src[var].history; -// } -// } -// return std::nullopt; -// } - -std::optional BranchingSyncProtocol::read(ThreadContext &ctx, - const std::string &var) { - assert(false && "Todo read"); - return std::nullopt; -} - -void BranchingSyncProtocol::write(ThreadContext &ctx, const std::string &var, - size_t value) { - assert(false && "Todo write"); -} - -// Spawning is a sync point, commit local pending commits, and -// copy the global state to the spawned thread -// commit(ctx.globals); - -std::optional> -BranchingSyncProtocol::on_spawn(ThreadContext &parent, ThreadContext &child, - GlobalContext &) { - assert(false && "Todo on_spawn"); - // commit(parent.globals); - // child.globals = parent.globals; - return std::nullopt; -} - -std::optional> -BranchingSyncProtocol::on_join(ThreadContext &joiner, ThreadContext &joinee, - GlobalContext &) { - assert(false && "Todo on_join"); - // commit(joiner.globals); - // commit(joinee.globals); - // return pull(joiner.globals, joinee.globals); - return std::nullopt; -} - -std::optional> -BranchingSyncProtocol::on_start(ThreadContext &thread, GlobalContext &gctx) { - assert(false && "Todo on_start"); - return std::nullopt; -}; - -std::optional> -BranchingSyncProtocol::on_end(ThreadContext &thread, GlobalContext &gctx) { - assert(false && "Todo on_end"); - return std::nullopt; -}; - -std::optional> -BranchingSyncProtocol::on_lock(ThreadContext &thread, Lock &lock, - GlobalContext &) { - assert(false && "Todo on_lock"); - // commit(thread.globals); - // return pull(thread.globals, lock.globals); - return std::nullopt; -} - -std::optional> -BranchingSyncProtocol::on_unlock(ThreadContext &thread, Lock &lock, - GlobalContext &) { - assert(false && "Todo on_unlock"); - // commit(thread.globals); - // lock.globals = thread.globals; - return std::nullopt; -} - std::unique_ptr make_protocol(SyncKind sync_kind) { switch (sync_kind) { case SyncKind::Linear: - return std::make_unique(); + return std::make_unique(); case SyncKind::Branching: - return std::make_unique(); + return std::make_unique(); } std::unreachable(); } diff --git a/src/sync_protocol.hh b/src/sync_protocol.hh index 5ed1336..b6b0704 100644 --- a/src/sync_protocol.hh +++ b/src/sync_protocol.hh @@ -3,18 +3,13 @@ #include "conflict.hh" #include "sync_kind.hh" #include "execution_state.hh" -#include "branching/version_store.hh" -#include "linear/version_store.hh" #include #include -/* i want an on_start and on_end event i think too */ - namespace gitmem { std::unique_ptr make_protocol(SyncKind); -using LinearConflict = Conflict; using BranchingConflict = Conflict; class SyncProtocol { @@ -58,82 +53,4 @@ public: } }; -// --------------------------------- -// Concrete protocols -// --------------------------------- - -class LinearSyncProtocol final : public SyncProtocol { - linear::GlobalVersionStore _global_store; - - std::optional push(linear::LocalVersionStore &local); - std::optional pull(linear::LocalVersionStore &local); - -public: - ~LinearSyncProtocol() override; - SyncKind kind() const override { return SyncKind::Linear; }; - - std::optional read(ThreadContext &ctx, - const std::string &var) override; - - void write(ThreadContext &ctx, const std::string &var, size_t value) override; - - std::optional> - on_spawn(ThreadContext &parent, ThreadContext &child, - GlobalContext &gctx) override; - - std::optional> - on_join(ThreadContext &joiner, ThreadContext &joinee, - GlobalContext &gctx) override; - - std::optional> - on_start(ThreadContext &thread, GlobalContext &gctx) override; - - std::optional> - on_end(ThreadContext &thread, GlobalContext &gctx) override; - - std::optional> - on_lock(ThreadContext &thread, Lock &lock, GlobalContext &gctx) override; - - std::optional> - on_unlock(ThreadContext &thread, Lock &lock, GlobalContext &gctx) override; - - std::ostream &print(std::ostream &os) const override; -}; - -class BranchingSyncProtocol final : public SyncProtocol { - - // std::unordered_map> commit_nodes; - -public: - ~BranchingSyncProtocol() override; - SyncKind kind() const override { return SyncKind::Branching; }; - - std::optional read(ThreadContext &ctx, - const std::string &var) override; - - void write(ThreadContext &ctx, const std::string &var, size_t value) override; - - std::optional> - on_spawn(ThreadContext &parent, ThreadContext &child, - GlobalContext &gctx) override; - - std::optional> - on_join(ThreadContext &joiner, ThreadContext &joinee, - GlobalContext &gctx) override; - - std::optional> - on_start(ThreadContext &thread, GlobalContext &gctx) override; - - std::optional> - on_end(ThreadContext &thread, GlobalContext &gctx) override; - - std::optional> - on_lock(ThreadContext &thread, Lock &lock, GlobalContext &gctx) override; - - std::optional> - on_unlock(ThreadContext &thread, Lock &lock, GlobalContext &gctx) override; - - std::ostream &print(std::ostream &os) const override; -}; - } // namespace gitmem From 52bf50a29abd57499b222dee6deaaebb940ec0c7 Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Mon, 5 Jan 2026 20:45:58 +0100 Subject: [PATCH 27/58] adding the missing files for protocols --- src/branching/sync_protocol.cc | 154 +++++++++++++++++++++++++++++++ src/branching/sync_protocol.hh | 48 ++++++++++ src/conflict.hh | 34 +++++++ src/linear/sync_protocol.cc | 159 +++++++++++++++++++++++++++++++++ src/linear/sync_protocol.hh | 55 ++++++++++++ 5 files changed, 450 insertions(+) create mode 100644 src/branching/sync_protocol.cc create mode 100644 src/branching/sync_protocol.hh create mode 100644 src/conflict.hh create mode 100644 src/linear/sync_protocol.cc create mode 100644 src/linear/sync_protocol.hh diff --git a/src/branching/sync_protocol.cc b/src/branching/sync_protocol.cc new file mode 100644 index 0000000..7481287 --- /dev/null +++ b/src/branching/sync_protocol.cc @@ -0,0 +1,154 @@ +#include "branching/sync_protocol.hh" + +namespace gitmem { + +namespace branching { + +// -------------------- +// BranchingSyncProtocol +// -------------------- + +BranchingSyncProtocol::~BranchingSyncProtocol() = default; + +std::ostream &BranchingSyncProtocol::print(std::ostream &os) const { + assert(false && "TODO"); + return os; +} + +// /* At a commit point, walk through all the versioned variables and see if +// * they have a pending commit, if so commit the value by appending to +// * the variables history. +// */ +// void commit(Globals &globals) { +// for (auto& [var, global] : globals) { +// if (global.commit) +// { +// global.history.push_back(*global.commit); +// verbose << "Committed global '" << var << "' with id " << +// *global.commit << std::endl; global.commit.reset(); +// } +// } +// } + +// /* A versioned value can be fastforwarded to another version, if one +// * version's history is a prefix of another version's history. +// * A conflict between two commit histories exists if neither history is a +// * prefix of the other. +// */ +// std::optional> has_conflict(CommitHistory& h1, +// CommitHistory& h2) +// { +// size_t length = std::min(h1.size(), h2.size()); + +// for (size_t i = 0; i < length; i++) +// { +// if (h1[i] != h2[i]) return std::pair{h1[i], h2[i]}; +// } + +// return std::nullopt; +// } + +// /* Walk through all the global versions from source and update the versions +// * in destination to be the most up-to-date version (this could come from +// * either source or destination). This means destination will now also +// * include variables it previously did not know about. +// */ +// std::optional> pull(Globals &dst, Globals &src) +// { +// for (auto& [var, global] : src) { +// if (dst.contains(var)) +// { +// auto& src_var = src[var]; +// auto& dst_var = dst[var]; +// if (auto conflict = has_conflict(src_var.history, +// dst_var.history)) +// { +// auto [s1, s2] = *conflict; +// verbose << "A data race on '" << var << "' was detected from +// commits " << s1 << " and " << s2 << std::endl; return +// Conflict(var, *conflict); +// } +// else if (src_var.history.size() > dst_var.history.size()) +// { +// verbose << "Fast-forward '" << var << "' to id " << +// src_var.val << std::endl; dst_var.val = src_var.val; +// dst_var.history = src_var.history; +// } +// } +// else +// { +// dst[var].val = src[var].val; +// dst[var].history = src[var].history; +// } +// } +// return std::nullopt; +// } + +std::optional BranchingSyncProtocol::read(ThreadContext &ctx, + const std::string &var) { + assert(false && "Todo read"); + return std::nullopt; +} + +void BranchingSyncProtocol::write(ThreadContext &ctx, const std::string &var, + size_t value) { + // auto& store = std::get(ctx.sync).store; + // store.stage(_global_store.get_object_number(var), value); +} + +// Spawning is a sync point, commit local pending commits, and +// copy the global state to the spawned thread +// commit(ctx.globals); + +std::optional> +BranchingSyncProtocol::on_spawn(ThreadContext &parent, ThreadContext &child, + GlobalContext &) { + assert(false && "Todo on_spawn"); + // commit(parent.globals); + // child.globals = parent.globals; + return std::nullopt; +} + +std::optional> +BranchingSyncProtocol::on_join(ThreadContext &joiner, ThreadContext &joinee, + GlobalContext &) { + assert(false && "Todo on_join"); + // commit(joiner.globals); + // commit(joinee.globals); + // return pull(joiner.globals, joinee.globals); + return std::nullopt; +} + +std::optional> +BranchingSyncProtocol::on_start(ThreadContext &thread, GlobalContext &gctx) { + assert(false && "Todo on_start"); + return std::nullopt; +}; + +std::optional> +BranchingSyncProtocol::on_end(ThreadContext &thread, GlobalContext &gctx) { + assert(false && "Todo on_end"); + return std::nullopt; +}; + +std::optional> +BranchingSyncProtocol::on_lock(ThreadContext &thread, Lock &lock, + GlobalContext &) { + assert(false && "Todo on_lock"); + // commit(thread.globals); + // return pull(thread.globals, lock.globals); + return std::nullopt; +} + +std::optional> +BranchingSyncProtocol::on_unlock(ThreadContext &thread, Lock &lock, + GlobalContext &) { + assert(false && "Todo on_unlock"); + // commit(thread.globals); + // lock.globals = thread.globals; + return std::nullopt; +} + +} // end branching + +} // end gitmem \ No newline at end of file diff --git a/src/branching/sync_protocol.hh b/src/branching/sync_protocol.hh new file mode 100644 index 0000000..4ec1307 --- /dev/null +++ b/src/branching/sync_protocol.hh @@ -0,0 +1,48 @@ +#pragma once + +#include "../sync_protocol.hh" + +namespace gitmem { + +namespace branching { + +class BranchingSyncProtocol final : public SyncProtocol { + // branching::GlobalContext _global_context; + + // std::unordered_map> commit_nodes; + +public: + ~BranchingSyncProtocol() override; + SyncKind kind() const override { return SyncKind::Branching; }; + + std::optional read(ThreadContext &ctx, + const std::string &var) override; + + void write(ThreadContext &ctx, const std::string &var, size_t value) override; + + std::optional> + on_spawn(ThreadContext &parent, ThreadContext &child, + GlobalContext &gctx) override; + + std::optional> + on_join(ThreadContext &joiner, ThreadContext &joinee, + GlobalContext &gctx) override; + + std::optional> + on_start(ThreadContext &thread, GlobalContext &gctx) override; + + std::optional> + on_end(ThreadContext &thread, GlobalContext &gctx) override; + + std::optional> + on_lock(ThreadContext &thread, Lock &lock, GlobalContext &gctx) override; + + std::optional> + on_unlock(ThreadContext &thread, Lock &lock, GlobalContext &gctx) override; + + std::ostream &print(std::ostream &os) const override; +}; + +} // end branching + +} // end gitmem \ No newline at end of file diff --git a/src/conflict.hh b/src/conflict.hh new file mode 100644 index 0000000..c523080 --- /dev/null +++ b/src/conflict.hh @@ -0,0 +1,34 @@ +#pragma once + +#include + +namespace gitmem { + +struct ConflictBase { + virtual ~ConflictBase() = default; + virtual std::ostream &print(std::ostream &os) const = 0; + friend std::ostream &operator<<(std::ostream &os, + const ConflictBase &conflict) { + return conflict.print(os); + } +}; + +template struct Conflict : ConflictBase { + std::string var; + std::pair versions; + + Conflict(std::string var, std::pair versions) + : var(std::move(var)), versions(std::move(versions)) {} + + std::ostream &print(std::ostream &os) const override; +}; + +template +std::ostream &Conflict::print(std::ostream &os) const { + os << "conflict on " << var << " { " << versions.first << ", " + << versions.second << " }"; + return os; +} + + +} \ No newline at end of file diff --git a/src/linear/sync_protocol.cc b/src/linear/sync_protocol.cc new file mode 100644 index 0000000..ed8ba7a --- /dev/null +++ b/src/linear/sync_protocol.cc @@ -0,0 +1,159 @@ +#include "linear/sync_protocol.hh" +#include "debug.hh" +#include + +namespace gitmem { + +namespace linear { + +// -------------------- +// LinearSyncProtocol +// -------------------- + +std::ostream &LinearSyncProtocol::print(std::ostream &os) const { + os << _global_store << std::endl; + return os; +} + +std::optional +LinearSyncProtocol::push(linear::LocalVersionStore &local) { + if (auto conflict = _global_store.check_conflicts(local.base_timestamp(), + local.staged_changes())) { + + // reshape the conflict + return std::make_optional( + _global_store.get_object_name(conflict->object), + std::make_pair(conflict->local_base, conflict->global_head)); + } + + linear::Timestamp new_base = _global_store.apply_changes( + local.base_timestamp(), local.staged_changes()); + + local.clear_staging(); + local.advance_base(new_base); + return std::nullopt; +} + +std::optional +LinearSyncProtocol::pull(linear::LocalVersionStore &local) { + if (auto conflict = _global_store.check_conflicts(local.base_timestamp(), + local.staged_changes())) { + + return std::make_optional( + _global_store.get_object_name(conflict->object), + std::make_pair(conflict->local_base, conflict->global_head)); + } + + local.advance_base(_global_store.current_timestamp()); + return std::nullopt; +} + +LinearSyncProtocol::~LinearSyncProtocol() = default; + +std::optional LinearSyncProtocol::read(ThreadContext &ctx, + const std::string &var) { + linear::ObjectNumber number = _global_store.get_object_number(var); + + auto& store = std::get(ctx.sync).store; + + if (auto result = store.get_staged(number)) + return result; + + std::optional value = _global_store.get_version_for_timestamp( + number, store.base_timestamp()); + if (!value) + return std::nullopt; + + // we do not need to record the staged value for correctness + // TODO: there is something about working out if a value has changed vs been + // written + + return *value; +} + +void LinearSyncProtocol::write(ThreadContext &ctx, const std::string &var, + size_t value) { + // write into the staging area of the thread + auto& store = std::get(ctx.sync).store; + store.stage(_global_store.get_object_number(var), value); +} + +std::optional> +LinearSyncProtocol::on_spawn(ThreadContext &parent, ThreadContext &child, + GlobalContext &gctx) { + // TODO: i think we can drop the globalcontext but check after branching is + // added + + // push parent to global history + auto& store = std::get(parent.sync).store; + if (auto conflict = push(store)) + return std::make_unique(std::move(*conflict)); + + // pull into the child + store = std::get(child.sync).store; + if (auto conflict = pull(store)) { + throw std::logic_error("This code path should never be reached"); + } + + return std::nullopt; +} + +std::optional> +LinearSyncProtocol::on_join(ThreadContext &joiner, ThreadContext &joinee, + GlobalContext &gctx) { + // we assume the joinee has already terminated and pushed + + // pull changes into parent + auto& store = std::get(joiner.sync).store; + if (auto conflict = pull(store)) + return std::make_unique(std::move(*conflict)); + + return std::nullopt; +} + +std::optional> +LinearSyncProtocol::on_start(ThreadContext &thread, GlobalContext &gctx) { + // pull state from global history + auto& store = std::get(thread.sync).store; + auto conflict = pull(store); + assert(!conflict && "cannot conflict from starting state"); + + return std::nullopt; +}; + +std::optional> +LinearSyncProtocol::on_end(ThreadContext &thread, GlobalContext &gctx) { + // push changes to global history + auto& store = std::get(thread.sync).store; + if (auto conflict = push(store)) + return std::make_unique(std::move(*conflict)); + + return std::nullopt; +}; + +std::optional> +LinearSyncProtocol::on_lock(ThreadContext &thread, Lock &lock, + GlobalContext &gctx) { + + auto& store = std::get(thread.sync).store; + if (auto conflict = pull(store)) + return std::make_unique(std::move(*conflict)); + + return std::nullopt; +} + +std::optional> +LinearSyncProtocol::on_unlock(ThreadContext &thread, Lock &, + GlobalContext &gctx) { + + // push changes to global history + auto& store = std::get(thread.sync).store; + if (auto conflict = push(store)) + return std::make_unique(std::move(*conflict)); + + return std::nullopt; +} + +} // namespace linear + +} // namespace gitmem diff --git a/src/linear/sync_protocol.hh b/src/linear/sync_protocol.hh new file mode 100644 index 0000000..5ed47f4 --- /dev/null +++ b/src/linear/sync_protocol.hh @@ -0,0 +1,55 @@ +#pragma once + +#include "../sync_protocol.hh" +#include "conflict.hh" +#include "sync_kind.hh" +#include "execution_state.hh" +#include "linear/version_store.hh" + +namespace gitmem { + +using LinearConflict = Conflict; + +namespace linear { + +class LinearSyncProtocol final : public SyncProtocol { + linear::GlobalVersionStore _global_store; + + std::optional push(linear::LocalVersionStore &local); + std::optional pull(linear::LocalVersionStore &local); + +public: + ~LinearSyncProtocol() override; + SyncKind kind() const override { return SyncKind::Linear; }; + + std::optional read(ThreadContext &ctx, + const std::string &var) override; + + void write(ThreadContext &ctx, const std::string &var, size_t value) override; + + std::optional> + on_spawn(ThreadContext &parent, ThreadContext &child, + GlobalContext &gctx) override; + + std::optional> + on_join(ThreadContext &joiner, ThreadContext &joinee, + GlobalContext &gctx) override; + + std::optional> + on_start(ThreadContext &thread, GlobalContext &gctx) override; + + std::optional> + on_end(ThreadContext &thread, GlobalContext &gctx) override; + + std::optional> + on_lock(ThreadContext &thread, Lock &lock, GlobalContext &gctx) override; + + std::optional> + on_unlock(ThreadContext &thread, Lock &lock, GlobalContext &gctx) override; + + std::ostream &print(std::ostream &os) const override; +}; + +} // namespace linear + +} // namespace gitmem \ No newline at end of file From df8a986fdad7d6ec1b9fdd990bee0372244df68d Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Tue, 6 Jan 2026 11:11:51 +0100 Subject: [PATCH 28/58] working on branching protocol --- CMakeLists.txt | 3 +- src/branching/sync_protocol.cc | 31 +++++-- src/branching/sync_protocol.hh | 3 +- src/branching/version_store.cc | 164 +++++++++++++++++++++++++++++++++ src/branching/version_store.hh | 85 ++++++++++++++--- src/execution_state.cc | 7 +- src/execution_state.hh | 2 +- src/interpreter.cc | 4 +- src/linear/sync_protocol.cc | 8 +- src/linear/sync_protocol.hh | 6 +- 10 files changed, 274 insertions(+), 39 deletions(-) create mode 100644 src/branching/version_store.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 2232aec..88ab504 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,9 +25,10 @@ add_executable(gitmem src/passes/statements.cc src/passes/check_refs.cc src/passes/branching.cc - src/linear/version_store.cc src/linear/sync_protocol.cc + src/linear/version_store.cc src/branching/sync_protocol.cc + src/branching/version_store.cc src/interpreter.cc src/debugger.cc src/model_checker.cc diff --git a/src/branching/sync_protocol.cc b/src/branching/sync_protocol.cc index 7481287..a32a947 100644 --- a/src/branching/sync_protocol.cc +++ b/src/branching/sync_protocol.cc @@ -86,14 +86,21 @@ std::ostream &BranchingSyncProtocol::print(std::ostream &os) const { std::optional BranchingSyncProtocol::read(ThreadContext &ctx, const std::string &var) { - assert(false && "Todo read"); - return std::nullopt; + ObjectNumber number = _global_store.get_object_number(var); + + auto& store = std::get(ctx.sync).store; + + if (auto result = store.get_staged(number)) + return result; + + // look in commit history + return store.get_committed(number); } void BranchingSyncProtocol::write(ThreadContext &ctx, const std::string &var, size_t value) { - // auto& store = std::get(ctx.sync).store; - // store.stage(_global_store.get_object_number(var), value); + auto& store = std::get(ctx.sync).store; + store.stage(_global_store.get_object_number(var), value); } // Spawning is a sync point, commit local pending commits, and @@ -103,9 +110,13 @@ void BranchingSyncProtocol::write(ThreadContext &ctx, const std::string &var, std::optional> BranchingSyncProtocol::on_spawn(ThreadContext &parent, ThreadContext &child, GlobalContext &) { - assert(false && "Todo on_spawn"); - // commit(parent.globals); - // child.globals = parent.globals; + auto& parent_store = std::get(parent.sync).store; + parent_store.commit_staging(); + + auto& child_store = std::get(child.sync).store; + child_store.adopt_history(parent_store.exported_head()); + + // a conflict cannot occur here return std::nullopt; } @@ -121,13 +132,15 @@ BranchingSyncProtocol::on_join(ThreadContext &joiner, ThreadContext &joinee, std::optional> BranchingSyncProtocol::on_start(ThreadContext &thread, GlobalContext &gctx) { - assert(false && "Todo on_start"); + // nothing to do, the thread will have inhereted the parent commit on spawn return std::nullopt; }; std::optional> BranchingSyncProtocol::on_end(ThreadContext &thread, GlobalContext &gctx) { - assert(false && "Todo on_end"); + auto& store = std::get(thread.sync).store; + store.commit_staging(); + return std::nullopt; }; diff --git a/src/branching/sync_protocol.hh b/src/branching/sync_protocol.hh index 4ec1307..3d17696 100644 --- a/src/branching/sync_protocol.hh +++ b/src/branching/sync_protocol.hh @@ -1,13 +1,14 @@ #pragma once #include "../sync_protocol.hh" +#include "version_store.hh" namespace gitmem { namespace branching { class BranchingSyncProtocol final : public SyncProtocol { - // branching::GlobalContext _global_context; + GlobalVersionStore _global_store; // std::unordered_map> commit_nodes; diff --git a/src/branching/version_store.cc b/src/branching/version_store.cc new file mode 100644 index 0000000..3537acf --- /dev/null +++ b/src/branching/version_store.cc @@ -0,0 +1,164 @@ +#include "version_store.hh" +#include +#include + +namespace gitmem { + +namespace branching { + +void LocalVersionStore::stage(ObjectNumber obj, Value value) { + staging[obj] = value; +} + +void LocalVersionStore::commit_staging() { + // No-op commit does nothing + if (staging.empty()) { + return; + } + + auto new_commit = std::make_shared(base_timestamp++, std::move(staging)); + staging.clear(); + + if (head) + new_commit->parents.push_back(head); + + head = new_commit; +} + +std::optional LocalVersionStore::get_staged(ObjectNumber obj) const { + auto it = staging.find(obj); + return it != staging.end() ? std::make_optional(it->second) : std::nullopt; +} + +std::optional get_committed_recursive( + std::shared_ptr commit, + ObjectNumber number, + std::unordered_set>& visited) +{ + if (!commit || !visited.insert(commit).second) return std::nullopt; + + // check if commit explicitly wrote the variable + auto it = commit->changes.find(number); + if (it != commit->changes.end()) return it->second; + + std::optional found; + for (const auto& parent : commit->parents) { + auto val = get_committed_recursive(parent, number, visited); + if (!val.has_value()) continue; + + if (!found.has_value()) + found = val; // first value found + else if (found.value() != val.value()) + assert(false && "Conflict detected on read should have been detected earlier"); // lazy conflict + } + + return found; +} + +std::optional LocalVersionStore::get_committed(ObjectNumber number) const { + // for now assume early resolution of conflicts + std::unordered_set> visited; + return get_committed_recursive(head, number, visited); +} + +std::ostream& operator<<(std::ostream& os, const LocalVersionStore& store) { + os << "LocalVersionStore{" + << "base=" << store.base_timestamp + << ", head="; + + if (store.head) + os << store.head->id; + else + os << "null"; + + os << ", staged={"; + + bool first = true; + for (const auto& [obj, val] : store.staging) { + if (!first) os << ", "; + first = false; + os << obj << "->" << val; + } + + os << "}}"; + return os; +} + +// Late resolution + +// ReadResult get_committed_recursive( +// std::shared_ptr commit, +// ObjectNumber number, +// std::unordered_set>& visited) { +// if (!commit || !visited.insert(commit).second) +// return { ReadKind::NotFound, std::nullopt }; + +// // Explicit write dominates +// auto it = commit->changes.find(number); +// if (it != commit->changes.end()) +// return { ReadKind::Value, it->second }; + +// std::optional found; + +// for (const auto& parent : commit->parents) { +// auto res = get_committed_recursive(parent, number, visited); + +// if (res.kind == ReadKind::Conflict) +// return res; + +// if (res.kind == ReadKind::Value) { +// if (!found) +// found = res.value; +// else if (*found != *res.value) +// return { ReadKind::Conflict, std::nullopt }; +// } +// } + +// if (found) +// return { ReadKind::Value, *found }; + +// return { ReadKind::NotFound, std::nullopt }; +// } + +// ReadResult LocalVersionStore::get_committed(ObjectNumber number) const { +// // for now allow for late resolution of read conflict + +// std::unordered_set> visited; +// return get_committed_recursive(head, number, visited); +// } + +bool LocalVersionStore::operator==(const LocalVersionStore& other) const { + return base_timestamp == other.base_timestamp && + head == other.head && + staging == other.staging; +} + +ObjectNumber GlobalVersionStore::get_object_number(std::string var) { + auto it = _object_numbers.find(var); + if (it != _object_numbers.end()) { + return it->second; + } else { + ObjectNumber number = _next_object++; + _object_numbers[var] = number; + return number; + } +} + +std::string GlobalVersionStore::get_object_name(ObjectNumber find) { + for (const auto &[name, number] : _object_numbers) { + if (number == find) + return name; + } + assert(false && "failed to find object name for object number"); + return ""; +} + + +std::ostream& operator<<(std::ostream& os, const GlobalVersionStore& store) { + os << "GlobalVersionStore(next_object=" << store._next_object << ")" << std::endl; + return os; +} + +} // branching + +} // gitmem \ No newline at end of file diff --git a/src/branching/version_store.hh b/src/branching/version_store.hh index ac0660d..ba42a4a 100644 --- a/src/branching/version_store.hh +++ b/src/branching/version_store.hh @@ -5,6 +5,7 @@ #include #include #include +#include #include "thread_id.hh" namespace gitmem { @@ -16,31 +17,85 @@ namespace branching { * the current commit id for the variable, and the history of commited ids. */ -using Timestamp = std::pair; +struct Timestamp { + ThreadID thread; + size_t counter; + + auto operator<=>(const Timestamp &) const = default; + + // pre-increment + Timestamp& operator++() { + ++counter; + return *this; + } + + // post-increment + Timestamp operator++(int) { + Timestamp old = *this; + ++(*this); + return old; + } + + friend std::ostream &operator<<(std::ostream &os, + const Timestamp &ts) { + os << ts.thread << ":" << ts.counter; + return os; + } +}; + using Value = size_t; using ObjectNumber = uint64_t; struct Commit { - size_t id; - std::vector> parents; - Timestamp timestamp; + Timestamp id; std::unordered_map changes; + std::vector> parents; }; -struct LocalVersionStore { - std::shared_ptr _head; - std::unordered_map _staging; +// Initial plumbing for fail late +// enum class ReadKind { +// NotFound, +// Value, +// Conflict +// }; - friend std::ostream& operator<<(std::ostream& os, const LocalVersionStore& store) { - assert(false && "TODO"); - return os; - } +// struct ReadResult { +// ReadKind kind; +// std::optional value; // only valid if kind == Value +// }; - bool operator==(const LocalVersionStore& other) const { - assert(false && "TODO"); - return false; - } +class LocalVersionStore { + Timestamp base_timestamp; + std::shared_ptr head; + std::unordered_map staging; + +public: + LocalVersionStore(ThreadID tid): base_timestamp(tid, 0) {} + + void stage(ObjectNumber obj, Value value); + void commit_staging(); + + std::optional get_staged(ObjectNumber obj) const; + std::optional get_committed(ObjectNumber number) const; + + std::shared_ptr exported_head() const { return head; }; + void adopt_history(std::shared_ptr new_head) { head = new_head; }; + + friend std::ostream& operator<<(std::ostream& os, const LocalVersionStore& store); + bool operator==(const LocalVersionStore& other) const; +}; + +class GlobalVersionStore { + ObjectNumber _next_object{0}; + std::unordered_map _object_numbers; + +public: + + ObjectNumber get_object_number(std::string); + std::string get_object_name(ObjectNumber); + + friend std::ostream& operator<<(std::ostream&, const GlobalVersionStore&); }; // using CommitHistory = std::vector; diff --git a/src/execution_state.cc b/src/execution_state.cc index 36e34b6..fb60c8b 100644 --- a/src/execution_state.cc +++ b/src/execution_state.cc @@ -5,13 +5,13 @@ namespace gitmem { -ThreadContext::ThreadContext(SyncKind sync_kind) { +ThreadContext::ThreadContext(ThreadID tid, SyncKind sync_kind) { switch (sync_kind) { case SyncKind::Linear: sync.emplace(); break; case SyncKind::Branching: - sync.emplace(); + sync.emplace(tid); break; } } @@ -65,10 +65,11 @@ GlobalContext::GlobalContext(const trieste::Node &ast, std::unique_ptr protocol) : protocol(std::move(protocol)) { trieste::Node starting_block = ast / lang::File / lang::Block; - ThreadContext starting_ctx(this->protocol->kind()); ThreadID main_tid = 0; + ThreadContext starting_ctx(main_tid, this->protocol->kind()); + this->threads.emplace_back(main_tid, std::move(starting_ctx), starting_block); this->locks = {}; this->cache = {}; diff --git a/src/execution_state.hh b/src/execution_state.hh index 32d7756..1f35ce4 100644 --- a/src/execution_state.hh +++ b/src/execution_state.hh @@ -36,7 +36,7 @@ struct ThreadContext { ThreadContext(ThreadContext&&) = default; ThreadContext& operator=(ThreadContext&&) = default; - ThreadContext(SyncKind sync_kind); + ThreadContext(ThreadID tid, SyncKind sync_kind); bool operator==(const ThreadContext &other) const; diff --git a/src/interpreter.cc b/src/interpreter.cc index 48f0077..faca3e3 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -75,7 +75,7 @@ Interpreter::evaluate_expression(trieste::Node expr, Thread& thread) { return sum; } else if (e == lang::Spawn) { ThreadID child_tid = gctx.threads.size(); - ThreadContext child_ctx(gctx.protocol->kind()); + ThreadContext child_ctx(child_tid, gctx.protocol->kind()); if (std::optional> conflict = gctx.protocol->on_spawn(ctx, child_ctx, gctx)) { @@ -564,7 +564,7 @@ int interpret(const Node ast, const std::filesystem::path &output_path, SyncKind interp.print_thread_traces(); - auto exec_graph = interp.build_execution_graph_from_traces(); + // auto exec_graph = interp.build_execution_graph_from_traces(); // graph::GraphvizPrinter gv(output_path); // gv.visit(node.get()); diff --git a/src/linear/sync_protocol.cc b/src/linear/sync_protocol.cc index ed8ba7a..e2cf897 100644 --- a/src/linear/sync_protocol.cc +++ b/src/linear/sync_protocol.cc @@ -16,7 +16,7 @@ std::ostream &LinearSyncProtocol::print(std::ostream &os) const { } std::optional -LinearSyncProtocol::push(linear::LocalVersionStore &local) { +LinearSyncProtocol::push(LocalVersionStore &local) { if (auto conflict = _global_store.check_conflicts(local.base_timestamp(), local.staged_changes())) { @@ -26,7 +26,7 @@ LinearSyncProtocol::push(linear::LocalVersionStore &local) { std::make_pair(conflict->local_base, conflict->global_head)); } - linear::Timestamp new_base = _global_store.apply_changes( + Timestamp new_base = _global_store.apply_changes( local.base_timestamp(), local.staged_changes()); local.clear_staging(); @@ -35,7 +35,7 @@ LinearSyncProtocol::push(linear::LocalVersionStore &local) { } std::optional -LinearSyncProtocol::pull(linear::LocalVersionStore &local) { +LinearSyncProtocol::pull(LocalVersionStore &local) { if (auto conflict = _global_store.check_conflicts(local.base_timestamp(), local.staged_changes())) { @@ -52,7 +52,7 @@ LinearSyncProtocol::~LinearSyncProtocol() = default; std::optional LinearSyncProtocol::read(ThreadContext &ctx, const std::string &var) { - linear::ObjectNumber number = _global_store.get_object_number(var); + ObjectNumber number = _global_store.get_object_number(var); auto& store = std::get(ctx.sync).store; diff --git a/src/linear/sync_protocol.hh b/src/linear/sync_protocol.hh index 5ed47f4..59cfea9 100644 --- a/src/linear/sync_protocol.hh +++ b/src/linear/sync_protocol.hh @@ -13,10 +13,10 @@ using LinearConflict = Conflict; namespace linear { class LinearSyncProtocol final : public SyncProtocol { - linear::GlobalVersionStore _global_store; + GlobalVersionStore _global_store; - std::optional push(linear::LocalVersionStore &local); - std::optional pull(linear::LocalVersionStore &local); + std::optional push(LocalVersionStore &local); + std::optional pull(LocalVersionStore &local); public: ~LinearSyncProtocol() override; From 7af4f626d8f14d665411d455d441e065cc75032b Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Tue, 6 Jan 2026 15:32:22 +0100 Subject: [PATCH 29/58] trying to improve the DAG stuff by caching last writer --- src/branching/sync_protocol.cc | 94 ++-------- src/branching/sync_protocol.hh | 2 + src/branching/version_store.cc | 308 ++++++++++++++++++++++++++------- src/branching/version_store.hh | 19 +- src/conflict.hh | 3 +- src/sync_protocol.hh | 2 - 6 files changed, 282 insertions(+), 146 deletions(-) diff --git a/src/branching/sync_protocol.cc b/src/branching/sync_protocol.cc index a32a947..84f31ff 100644 --- a/src/branching/sync_protocol.cc +++ b/src/branching/sync_protocol.cc @@ -11,79 +11,10 @@ namespace branching { BranchingSyncProtocol::~BranchingSyncProtocol() = default; std::ostream &BranchingSyncProtocol::print(std::ostream &os) const { - assert(false && "TODO"); + os << _global_store << std::endl; return os; } -// /* At a commit point, walk through all the versioned variables and see if -// * they have a pending commit, if so commit the value by appending to -// * the variables history. -// */ -// void commit(Globals &globals) { -// for (auto& [var, global] : globals) { -// if (global.commit) -// { -// global.history.push_back(*global.commit); -// verbose << "Committed global '" << var << "' with id " << -// *global.commit << std::endl; global.commit.reset(); -// } -// } -// } - -// /* A versioned value can be fastforwarded to another version, if one -// * version's history is a prefix of another version's history. -// * A conflict between two commit histories exists if neither history is a -// * prefix of the other. -// */ -// std::optional> has_conflict(CommitHistory& h1, -// CommitHistory& h2) -// { -// size_t length = std::min(h1.size(), h2.size()); - -// for (size_t i = 0; i < length; i++) -// { -// if (h1[i] != h2[i]) return std::pair{h1[i], h2[i]}; -// } - -// return std::nullopt; -// } - -// /* Walk through all the global versions from source and update the versions -// * in destination to be the most up-to-date version (this could come from -// * either source or destination). This means destination will now also -// * include variables it previously did not know about. -// */ -// std::optional> pull(Globals &dst, Globals &src) -// { -// for (auto& [var, global] : src) { -// if (dst.contains(var)) -// { -// auto& src_var = src[var]; -// auto& dst_var = dst[var]; -// if (auto conflict = has_conflict(src_var.history, -// dst_var.history)) -// { -// auto [s1, s2] = *conflict; -// verbose << "A data race on '" << var << "' was detected from -// commits " << s1 << " and " << s2 << std::endl; return -// Conflict(var, *conflict); -// } -// else if (src_var.history.size() > dst_var.history.size()) -// { -// verbose << "Fast-forward '" << var << "' to id " << -// src_var.val << std::endl; dst_var.val = src_var.val; -// dst_var.history = src_var.history; -// } -// } -// else -// { -// dst[var].val = src[var].val; -// dst[var].history = src[var].history; -// } -// } -// return std::nullopt; -// } - std::optional BranchingSyncProtocol::read(ThreadContext &ctx, const std::string &var) { ObjectNumber number = _global_store.get_object_number(var); @@ -103,10 +34,6 @@ void BranchingSyncProtocol::write(ThreadContext &ctx, const std::string &var, store.stage(_global_store.get_object_number(var), value); } -// Spawning is a sync point, commit local pending commits, and -// copy the global state to the spawned thread -// commit(ctx.globals); - std::optional> BranchingSyncProtocol::on_spawn(ThreadContext &parent, ThreadContext &child, GlobalContext &) { @@ -114,7 +41,7 @@ BranchingSyncProtocol::on_spawn(ThreadContext &parent, ThreadContext &child, parent_store.commit_staging(); auto& child_store = std::get(child.sync).store; - child_store.adopt_history(parent_store.exported_head()); + child_store.adopt_history(parent_store); // a conflict cannot occur here return std::nullopt; @@ -123,10 +50,19 @@ BranchingSyncProtocol::on_spawn(ThreadContext &parent, ThreadContext &child, std::optional> BranchingSyncProtocol::on_join(ThreadContext &joiner, ThreadContext &joinee, GlobalContext &) { - assert(false && "Todo on_join"); - // commit(joiner.globals); - // commit(joinee.globals); - // return pull(joiner.globals, joinee.globals); + auto& joiner_store = std::get(joiner.sync).store; + auto& joinee_store = std::get(joinee.sync).store; + + joiner_store.commit_staging(); + assert(joinee_store.has_commited() && "joinee has staged changes"); + + std::optional conflict = joiner_store.merge_with(joinee_store); + if (conflict) { + return std::make_unique( + _global_store.get_object_name(conflict->obj), + std::make_pair(conflict->timestamp_a, conflict->timestamp_b)); + } + return std::nullopt; } diff --git a/src/branching/sync_protocol.hh b/src/branching/sync_protocol.hh index 3d17696..b232684 100644 --- a/src/branching/sync_protocol.hh +++ b/src/branching/sync_protocol.hh @@ -5,6 +5,8 @@ namespace gitmem { +using BranchingConflict = Conflict; + namespace branching { class BranchingSyncProtocol final : public SyncProtocol { diff --git a/src/branching/version_store.cc b/src/branching/version_store.cc index 3537acf..16fa25c 100644 --- a/src/branching/version_store.cc +++ b/src/branching/version_store.cc @@ -6,6 +6,51 @@ namespace gitmem { namespace branching { +// Helper: recursive print with 2-space indentation and cycle protection +void print_commit_recursive(std::ostream& os, + const std::shared_ptr& commit, + std::unordered_set& visited, + int depth = 0) +{ + if (!commit) return; + if (!visited.insert(commit.get()).second) { + os << std::string(depth * 2, ' ') << "(already printed commit " << commit->id << ")\n"; + return; + } + + os << std::string(depth * 2, ' ') << "Commit " << commit->id << " {\n"; + + // Print changes + for (const auto& [obj, val] : commit->changes) { + os << std::string((depth + 1) * 2, ' ') << obj << " -> " << val << "\n"; + } + + // Print parents + if (!commit->parents.empty()) { + os << std::string((depth + 1) * 2, ' ') << "Parents: "; + for (size_t i = 0; i < commit->parents.size(); ++i) { + os << commit->parents[i]->id; + if (i + 1 < commit->parents.size()) os << ", "; + } + os << "\n"; + } + + os << std::string(depth * 2, ' ') << "}\n"; + + // Recursively print parents + for (auto& parent : commit->parents) { + print_commit_recursive(os, parent, visited, depth + 1); + } +} + +// operator<< for Commit +std::ostream& operator<<(std::ostream& os, const Commit& commit) { + std::unordered_set visited; + // Wrap the commit in a shared_ptr to reuse the recursive helper + print_commit_recursive(os, std::make_shared(commit), visited); + return os; +} + void LocalVersionStore::stage(ObjectNumber obj, Value value) { staging[obj] = value; } @@ -16,12 +61,22 @@ void LocalVersionStore::commit_staging() { return; } + // Create the new commit with the staged changes auto new_commit = std::make_shared(base_timestamp++, std::move(staging)); + + // Update last_writer for each staged variable + for (const auto& [obj, _] : new_commit->changes) { + last_writer[obj] = new_commit; + } + + // Clear staging staging.clear(); + // Set parent to previous head if it exists if (head) new_commit->parents.push_back(head); + // Update head head = new_commit; } @@ -30,35 +85,207 @@ std::optional LocalVersionStore::get_staged(ObjectNumber obj) const { return it != staging.end() ? std::make_optional(it->second) : std::nullopt; } -std::optional get_committed_recursive( - std::shared_ptr commit, - ObjectNumber number, - std::unordered_set>& visited) +std::optional> +get_committed_recursive( + const std::shared_ptr& commit, + ObjectNumber number, + std::unordered_set>& visited) { + + if (!commit || !visited.insert(commit).second) + return std::nullopt; + + std::optional> found; + + // Recurse into all parents first + for (auto& parent : commit->parents) { + auto parent_commit = get_committed_recursive(parent, number, visited); + if (parent_commit) { + if (!found.has_value()) + found = parent_commit; + else if (found.value()->changes.at(number) != parent_commit.value()->changes.at(number)) + assert(false && "Conflict detected (should be impossible in conflict-free DAG)"); + } + } + + // If this commit wrote the variable, it overrides any parent + if (commit->changes.contains(number)) + return commit; + + return found; +} + +std::shared_ptr +find_lowest_common_ancestor(std::shared_ptr a, + std::shared_ptr b) { - if (!commit || !visited.insert(commit).second) return std::nullopt; + std::unordered_set> ancestors_a; + std::queue> q; + + // Collect all ancestors of 'a' + q.push(a); + while (!q.empty()) { + auto c = q.front(); q.pop(); + if (!c) continue; + if (!ancestors_a.insert(c).second) continue; + for (auto& p : c->parents) q.push(p); + } + + // Walk ancestors of 'b' until we find a common one + std::unordered_set> visited; + q.push(b); + while (!q.empty()) { + auto c = q.front(); q.pop(); + if (!c) continue; + if (!visited.insert(c).second) continue; + if (ancestors_a.count(c)) return c; // first common ancestor + for (auto& p : c->parents) q.push(p); + } + + return nullptr; // no common ancestor (shouldn’t happen) +} - // check if commit explicitly wrote the variable - auto it = commit->changes.find(number); +std::optional traverse_to_lca( + const std::shared_ptr& commit, + ObjectNumber var, + const std::shared_ptr& lca) +{ + if (!commit || commit == lca) return std::nullopt; + + auto it = commit->changes.find(var); + if (it != commit->changes.end()) return it->second; + + if (commit->parents.size() > 1) + assert(false && "should never encounter multi-parent commit before LCA"); + + return traverse_to_lca(commit->parents[0], var, lca); +} + +std::optional get_committed_recursive( + const std::shared_ptr& commit, + ObjectNumber var) { + if (!commit) return std::nullopt; + + // 1. If this commit wrote the variable, return it + auto it = commit->changes.find(var); if (it != commit->changes.end()) return it->second; - std::optional found; - for (const auto& parent : commit->parents) { - auto val = get_committed_recursive(parent, number, visited); - if (!val.has_value()) continue; + // 2. If single parent, recurse + if (commit->parents.size() == 1) + return get_committed_recursive(commit->parents[0], var); - if (!found.has_value()) - found = val; // first value found - else if (found.value() != val.value()) - assert(false && "Conflict detected on read should have been detected earlier"); // lazy conflict - } + // 3. Merge commit + assert(commit->parents.size() == 2); // merge commit - return found; + auto& p1 = commit->parents[0]; + auto& p2 = commit->parents[1]; + + auto lca = find_lowest_common_ancestor(p1, p2); + + // Explore both paths from merge commit to LCA + std::optional v1 = traverse_to_lca(p1, var, lca); + std::optional v2 = traverse_to_lca(p2, var, lca); + + assert(!v1 || !v2 || v1 == v2); // conflict-free invariant + + if (v1) return v1; // found in one of the merge branches + if (v2) return v2; + + // 4. Not found yet → continue recursively from the LCA downward + return get_committed_recursive(lca, var); } std::optional LocalVersionStore::get_committed(ObjectNumber number) const { - // for now assume early resolution of conflicts + if (auto it = last_writer.find(number); it != last_writer.end()) + return it->second->changes.at(number); + + return std::nullopt; +} + +void LocalVersionStore::adopt_history(const LocalVersionStore& other) { + // Inherit the DAG head + head = other.head; + + // Inherit the last_writer cache so the child sees all latest commits + last_writer = other.last_writer; +} + +bool traverse_until_lca( + const std::shared_ptr& commit, + const std::shared_ptr& lca, + std::unordered_map>& out_map, + std::unordered_set>& visited) +{ + if (!commit || commit == lca || !visited.insert(commit).second) + return true; + + for (const auto& [obj, _] : commit->changes) { + // first write seen dominates + if (out_map.find(obj) == out_map.end()) + out_map[obj] = commit; + } + + for (auto& parent : commit->parents) { + if (!traverse_until_lca(parent, lca, out_map, visited)) + return false; + } + + return true; +} + +std::optional LocalVersionStore::merge_with(const LocalVersionStore& other) { + assert(staging.empty()); + assert(other.staging.empty()); + + // trivial case: same history + if (head == other.head) + return std::nullopt; + + // Create merge commit (no changes itself) + auto merge_commit = std::make_shared( + Commit{ + .id = base_timestamp++, + .parents = {head, other.head}, + .changes = {} // merge commit does not write anything + } + ); + + // Find lowest common ancestor of the two heads + auto lca = find_lowest_common_ancestor(head, other.head); + + // Collect all writes after LCA for each branch + std::unordered_map> branch_a, branch_b; std::unordered_set> visited; - return get_committed_recursive(head, number, visited); + + traverse_until_lca(head, lca, branch_a, visited); + visited.clear(); + traverse_until_lca(other.head, lca, branch_b, visited); + + // 1. Eager conflict detection + for (const auto& [obj, commit_a] : branch_a) { + auto it = branch_b.find(obj); + if (it != branch_b.end() && it->second != commit_a) { + return Conflict{ + .obj = obj, + .timestamp_a = commit_a->id, + .timestamp_b = it->second->id + }; + } + } + + // 2. Update thread-local last_writer incrementally + // Only overwrite variables that were touched along either branch after LCA + for (const auto& [obj, commit] : branch_a) + last_writer[obj] = commit; + + for (const auto& [obj, commit] : branch_b) + last_writer[obj] = commit; + + // 3. Variables not touched in either branch remain unchanged (from before LCA) + + // 4. Update head + head = merge_commit; + + return std::nullopt; } std::ostream& operator<<(std::ostream& os, const LocalVersionStore& store) { @@ -84,49 +311,6 @@ std::ostream& operator<<(std::ostream& os, const LocalVersionStore& store) { return os; } -// Late resolution - -// ReadResult get_committed_recursive( -// std::shared_ptr commit, -// ObjectNumber number, -// std::unordered_set>& visited) { -// if (!commit || !visited.insert(commit).second) -// return { ReadKind::NotFound, std::nullopt }; - -// // Explicit write dominates -// auto it = commit->changes.find(number); -// if (it != commit->changes.end()) -// return { ReadKind::Value, it->second }; - -// std::optional found; - -// for (const auto& parent : commit->parents) { -// auto res = get_committed_recursive(parent, number, visited); - -// if (res.kind == ReadKind::Conflict) -// return res; - -// if (res.kind == ReadKind::Value) { -// if (!found) -// found = res.value; -// else if (*found != *res.value) -// return { ReadKind::Conflict, std::nullopt }; -// } -// } - -// if (found) -// return { ReadKind::Value, *found }; - -// return { ReadKind::NotFound, std::nullopt }; -// } - -// ReadResult LocalVersionStore::get_committed(ObjectNumber number) const { -// // for now allow for late resolution of read conflict - -// std::unordered_set> visited; -// return get_committed_recursive(head, number, visited); -// } - bool LocalVersionStore::operator==(const LocalVersionStore& other) const { return base_timestamp == other.base_timestamp && head == other.head && diff --git a/src/branching/version_store.hh b/src/branching/version_store.hh index ba42a4a..71bccf4 100644 --- a/src/branching/version_store.hh +++ b/src/branching/version_store.hh @@ -52,6 +52,15 @@ struct Commit { std::vector> parents; }; +// operator<< for Commit +std::ostream& operator<<(std::ostream& os, const Commit& commit); + +struct Conflict { + ObjectNumber obj; + Timestamp timestamp_a; + Timestamp timestamp_b; +}; + // Initial plumbing for fail late // enum class ReadKind { // NotFound, @@ -70,20 +79,26 @@ class LocalVersionStore { std::shared_ptr head; std::unordered_map staging; + std::unordered_map> last_writer; // cached + public: LocalVersionStore(ThreadID tid): base_timestamp(tid, 0) {} void stage(ObjectNumber obj, Value value); void commit_staging(); + bool has_commited() { return staging.empty(); } + std::optional get_staged(ObjectNumber obj) const; std::optional get_committed(ObjectNumber number) const; - std::shared_ptr exported_head() const { return head; }; - void adopt_history(std::shared_ptr new_head) { head = new_head; }; + void adopt_history(const LocalVersionStore& other); + std::optional merge_with(const LocalVersionStore& other); friend std::ostream& operator<<(std::ostream& os, const LocalVersionStore& store); bool operator==(const LocalVersionStore& other) const; + + void dump() { if (head) std::cout << *head << "\n============\n" << std::endl; } }; class GlobalVersionStore { diff --git a/src/conflict.hh b/src/conflict.hh index c523080..f78de41 100644 --- a/src/conflict.hh +++ b/src/conflict.hh @@ -13,7 +13,8 @@ struct ConflictBase { } }; -template struct Conflict : ConflictBase { +template +struct Conflict : ConflictBase { std::string var; std::pair versions; diff --git a/src/sync_protocol.hh b/src/sync_protocol.hh index b6b0704..82bfc69 100644 --- a/src/sync_protocol.hh +++ b/src/sync_protocol.hh @@ -10,8 +10,6 @@ namespace gitmem { std::unique_ptr make_protocol(SyncKind); -using BranchingConflict = Conflict; - class SyncProtocol { public: virtual ~SyncProtocol() = default; From e605777c8530d60fa194e0c4dceaeb6381796954 Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Wed, 7 Jan 2026 11:43:30 +0100 Subject: [PATCH 30/58] a lot of graph traversal things to work out merge conflicts and value resolution in branching --- .../branching/join_pulled_variable.gm | 1 + src/branching/version_store.cc | 148 +++++++++++------- src/branching/version_store.hh | 6 + src/interpreter.cc | 1 + src/linear/version_store.hh | 1 + src/model_checker.cc | 7 + src/thread_trace.hh | 14 +- 7 files changed, 122 insertions(+), 56 deletions(-) diff --git a/examples/accept/semantics/branching/join_pulled_variable.gm b/examples/accept/semantics/branching/join_pulled_variable.gm index ae481e8..7054065 100644 --- a/examples/accept/semantics/branching/join_pulled_variable.gm +++ b/examples/accept/semantics/branching/join_pulled_variable.gm @@ -9,6 +9,7 @@ t = spawn { }; assert(x == 2); }; +assert (x == 0); join t; assert (x == 2); join t2; diff --git a/src/branching/version_store.cc b/src/branching/version_store.cc index 16fa25c..249e572 100644 --- a/src/branching/version_store.cc +++ b/src/branching/version_store.cc @@ -1,6 +1,7 @@ #include "version_store.hh" #include #include +#include "debug.hh" namespace gitmem { @@ -118,81 +119,90 @@ std::shared_ptr find_lowest_common_ancestor(std::shared_ptr a, std::shared_ptr b) { + if (!a || !b) return nullptr; + + // Step 1: collect all ancestors of 'a' std::unordered_set> ancestors_a; std::queue> q; - - // Collect all ancestors of 'a' q.push(a); + while (!q.empty()) { auto c = q.front(); q.pop(); if (!c) continue; - if (!ancestors_a.insert(c).second) continue; - for (auto& p : c->parents) q.push(p); + if (!ancestors_a.insert(c).second) continue; // already visited + + for (auto& p : c->parents) + q.push(p); } - // Walk ancestors of 'b' until we find a common one - std::unordered_set> visited; + // Step 2: BFS from 'b' to find first common ancestor + std::unordered_set> visited_b; q.push(b); + while (!q.empty()) { auto c = q.front(); q.pop(); if (!c) continue; - if (!visited.insert(c).second) continue; - if (ancestors_a.count(c)) return c; // first common ancestor - for (auto& p : c->parents) q.push(p); + if (!visited_b.insert(c).second) continue; + + if (ancestors_a.count(c)) + return c; // first common ancestor seen + + for (auto& p : c->parents) + q.push(p); } - return nullptr; // no common ancestor (shouldn’t happen) + return nullptr; // disjoint histories (shouldn’t happen) } -std::optional traverse_to_lca( - const std::shared_ptr& commit, - ObjectNumber var, - const std::shared_ptr& lca) -{ - if (!commit || commit == lca) return std::nullopt; +// std::optional traverse_to_lca( +// const std::shared_ptr& commit, +// ObjectNumber var, +// const std::shared_ptr& lca) +// { +// if (!commit || commit == lca) return std::nullopt; - auto it = commit->changes.find(var); - if (it != commit->changes.end()) return it->second; +// auto it = commit->changes.find(var); +// if (it != commit->changes.end()) return it->second; - if (commit->parents.size() > 1) - assert(false && "should never encounter multi-parent commit before LCA"); +// if (commit->parents.size() > 1) +// assert(false && "should never encounter multi-parent commit before LCA"); - return traverse_to_lca(commit->parents[0], var, lca); -} +// return traverse_to_lca(commit->parents[0], var, lca); +// } -std::optional get_committed_recursive( - const std::shared_ptr& commit, - ObjectNumber var) { - if (!commit) return std::nullopt; +// std::optional get_committed_recursive( +// const std::shared_ptr& commit, +// ObjectNumber var) { +// if (!commit) return std::nullopt; - // 1. If this commit wrote the variable, return it - auto it = commit->changes.find(var); - if (it != commit->changes.end()) return it->second; +// // 1. If this commit wrote the variable, return it +// auto it = commit->changes.find(var); +// if (it != commit->changes.end()) return it->second; - // 2. If single parent, recurse - if (commit->parents.size() == 1) - return get_committed_recursive(commit->parents[0], var); +// // 2. If single parent, recurse +// if (commit->parents.size() == 1) +// return get_committed_recursive(commit->parents[0], var); - // 3. Merge commit - assert(commit->parents.size() == 2); // merge commit +// // 3. Merge commit +// assert(commit->parents.size() == 2); // merge commit - auto& p1 = commit->parents[0]; - auto& p2 = commit->parents[1]; +// auto& p1 = commit->parents[0]; +// auto& p2 = commit->parents[1]; - auto lca = find_lowest_common_ancestor(p1, p2); +// auto lca = find_lowest_common_ancestor(p1, p2); - // Explore both paths from merge commit to LCA - std::optional v1 = traverse_to_lca(p1, var, lca); - std::optional v2 = traverse_to_lca(p2, var, lca); +// // Explore both paths from merge commit to LCA +// std::optional v1 = traverse_to_lca(p1, var, lca); +// std::optional v2 = traverse_to_lca(p2, var, lca); - assert(!v1 || !v2 || v1 == v2); // conflict-free invariant +// assert(!v1 || !v2 || v1 == v2); // conflict-free invariant - if (v1) return v1; // found in one of the merge branches - if (v2) return v2; +// if (v1) return v1; // found in one of the merge branches +// if (v2) return v2; - // 4. Not found yet → continue recursively from the LCA downward - return get_committed_recursive(lca, var); -} +// // 4. Not found yet → continue recursively from the LCA downward +// return get_committed_recursive(lca, var); +// } std::optional LocalVersionStore::get_committed(ObjectNumber number) const { if (auto it = last_writer.find(number); it != last_writer.end()) @@ -209,15 +219,45 @@ void LocalVersionStore::adopt_history(const LocalVersionStore& other) { last_writer = other.last_writer; } -bool traverse_until_lca( +bool can_reach_lca( const std::shared_ptr& commit, const std::shared_ptr& lca, - std::unordered_map>& out_map, - std::unordered_set>& visited) + std::unordered_map, bool>& memo) +{ + if (!commit) + return false; + + if (commit == lca) + return true; + + auto it = memo.find(commit); + if (it != memo.end()) + return it->second; + + for (const auto& parent : commit->parents) { + if (can_reach_lca(parent, lca, memo)) { + memo[commit] = true; + return true; + } + } + + memo[commit] = false; + return false; +} + +bool traverse_until_lca( + const std::shared_ptr& commit, + const std::shared_ptr& lca, + std::unordered_map>& out_map, + std::unordered_set>& visited, + std::unordered_map, bool>& reach_memo) { if (!commit || commit == lca || !visited.insert(commit).second) return true; + if (!can_reach_lca(commit, lca, reach_memo)) + return true; + for (const auto& [obj, _] : commit->changes) { // first write seen dominates if (out_map.find(obj) == out_map.end()) @@ -225,7 +265,7 @@ bool traverse_until_lca( } for (auto& parent : commit->parents) { - if (!traverse_until_lca(parent, lca, out_map, visited)) + if (!traverse_until_lca(parent, lca, out_map, visited, reach_memo)) return false; } @@ -250,15 +290,17 @@ std::optional LocalVersionStore::merge_with(const LocalVersionStore& o ); // Find lowest common ancestor of the two heads - auto lca = find_lowest_common_ancestor(head, other.head); + std::shared_ptr lca = find_lowest_common_ancestor(head, other.head); + verbose << "found lca of " << head->id << " and " << other.head->id << " to be " << lca->id << std::endl; // Collect all writes after LCA for each branch std::unordered_map> branch_a, branch_b; std::unordered_set> visited; - traverse_until_lca(head, lca, branch_a, visited); + std::unordered_map, bool> reach_memo; + traverse_until_lca(head, lca, branch_a, visited, reach_memo); visited.clear(); - traverse_until_lca(other.head, lca, branch_b, visited); + traverse_until_lca(other.head, lca, branch_b, visited, reach_memo); // 1. Eager conflict detection for (const auto& [obj, commit_a] : branch_a) { diff --git a/src/branching/version_store.hh b/src/branching/version_store.hh index 71bccf4..ed5816c 100644 --- a/src/branching/version_store.hh +++ b/src/branching/version_store.hh @@ -61,6 +61,12 @@ struct Conflict { Timestamp timestamp_b; }; +inline std::ostream& operator<<(std::ostream& os, const Conflict& c) { + return os << "Conflict{obj=" << c.obj + << ", timestamp_a=" << c.timestamp_a + << ", timestamp_b=" << c.timestamp_b << "}"; +} + // Initial plumbing for fail late // enum class ReadKind { // NotFound, diff --git a/src/interpreter.cc b/src/interpreter.cc index faca3e3..9f8fc27 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -273,6 +273,7 @@ std::variant Interpreter::run_statement(Node stmt, Threa if (size_t *result = std::get_if(&result_or_term)) { if (*result) { verbose << "Assertion passed: " << expr->location().view() << std::endl; + thread.trace.on_assert_pass(std::string(expr->location().view())); } else { verbose << "Assertion failed: " << expr->location().view() << std::endl; thread.trace.on_assert_fail(std::string(expr->location().view())); diff --git a/src/linear/version_store.hh b/src/linear/version_store.hh index c6d05c6..41b0911 100644 --- a/src/linear/version_store.hh +++ b/src/linear/version_store.hh @@ -5,6 +5,7 @@ #include #include #include +#include namespace gitmem { diff --git a/src/model_checker.cc b/src/model_checker.cc index bdf56a9..ab9367c 100644 --- a/src/model_checker.cc +++ b/src/model_checker.cc @@ -186,6 +186,13 @@ int model_check(const Node ast, const std::filesystem::path &output_path, SyncKi for (const auto &ctx : failing_contexts) { auto path = build_output_path(output_path, idx++); + + for (size_t tid = 0; tid < ctx->threads.size(); ++tid) { + const auto& thread = ctx->threads[tid]; + std::cout << "=== Thread " << tid << " ===" << std::endl; + std::cout << thread.trace; + std::cout << "====================================\n"; + } // ctx->print_execution_graph(path); } } diff --git a/src/thread_trace.hh b/src/thread_trace.hh index 378f7ff..ab854f7 100644 --- a/src/thread_trace.hh +++ b/src/thread_trace.hh @@ -14,7 +14,7 @@ struct WriteEvent { const std::string var; const size_t value; }; struct LockEvent { std::string lock_name; std::unique_ptr maybe_conflict; std::shared_ptr last_unlock_event; }; struct UnlockEvent { const std::string lock_name; std::unique_ptr maybe_conflict; }; struct JoinEvent { const ThreadID joinee_tid; std::unique_ptr maybe_conflict; }; -struct AssertEvent { const std::string condition; }; +struct AssertEvent { const std::string condition; bool pass; }; struct EndEvent {}; @@ -89,7 +89,7 @@ inline std::ostream& operator<<(std::ostream& os, const JoinEvent& e) { } inline std::ostream& operator<<(std::ostream& os, const AssertEvent& e) { - return os << "AssertEvent(condition=\"" << e.condition << "\")"; + return os << "AssertEvent(condition=\"" << e.condition << "\", " << (e.pass ? "pass" : "fail") << ")"; } inline std::ostream& operator<<(std::ostream& os, const EndEvent&) { @@ -156,8 +156,16 @@ private: return append(tid, std::move(conflict)); } + std::shared_ptr on_assert(std::string expr, bool pass) { + return append(std::move(expr), pass); + } + + std::shared_ptr on_assert_pass(std::string expr) { + return append(std::move(expr), true); + } + std::shared_ptr on_assert_fail(std::string expr) { - return append(std::move(expr)); + return append(std::move(expr), false); } std::shared_ptr on_end() { From 9920dd215b5ff3612842434d1d323983ca599e7b Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Wed, 7 Jan 2026 14:16:00 +0100 Subject: [PATCH 31/58] build eager conflict detection with branching semantics --- src/branching/sync_protocol.cc | 37 +++++++++++++++++++++++++++------- src/branching/version_store.cc | 14 ++++++------- src/branching/version_store.hh | 6 +++--- src/execution_state.hh | 7 +++++++ 4 files changed, 47 insertions(+), 17 deletions(-) diff --git a/src/branching/sync_protocol.cc b/src/branching/sync_protocol.cc index 84f31ff..42abb98 100644 --- a/src/branching/sync_protocol.cc +++ b/src/branching/sync_protocol.cc @@ -56,7 +56,7 @@ BranchingSyncProtocol::on_join(ThreadContext &joiner, ThreadContext &joinee, joiner_store.commit_staging(); assert(joinee_store.has_commited() && "joinee has staged changes"); - std::optional conflict = joiner_store.merge_with(joinee_store); + std::optional conflict = joiner_store.merge_with_commit(joinee_store.get_head()); if (conflict) { return std::make_unique( _global_store.get_object_name(conflict->obj), @@ -83,18 +83,41 @@ BranchingSyncProtocol::on_end(ThreadContext &thread, GlobalContext &gctx) { std::optional> BranchingSyncProtocol::on_lock(ThreadContext &thread, Lock &lock, GlobalContext &) { - assert(false && "Todo on_lock"); - // commit(thread.globals); - // return pull(thread.globals, lock.globals); + auto& thread_store = std::get(thread.sync).store; + thread_store.commit_staging(); + + std::shared_ptr lock_commit = lock.branching.commit; + + if (lock_commit != nullptr) { + std::optional conflict = thread_store.merge_with_commit(lock_commit); + if (conflict) { + return std::make_unique( + _global_store.get_object_name(conflict->obj), + std::make_pair(conflict->timestamp_a, conflict->timestamp_b)); + } + } + + lock.branching.commit = thread_store.get_head(); + return std::nullopt; } std::optional> BranchingSyncProtocol::on_unlock(ThreadContext &thread, Lock &lock, GlobalContext &) { - assert(false && "Todo on_unlock"); - // commit(thread.globals); - // lock.globals = thread.globals; + auto& thread_store = std::get(thread.sync).store; + thread_store.commit_staging(); + + std::shared_ptr lock_commit = lock.branching.commit; + + // we don't need to check for conflicts + if (lock_commit != nullptr) { + std::optional conflict = thread_store.merge_with_commit(lock_commit); + assert (!conflict); + } + + lock.branching.commit = thread_store.get_head(); + return std::nullopt; } diff --git a/src/branching/version_store.cc b/src/branching/version_store.cc index 249e572..1366830 100644 --- a/src/branching/version_store.cc +++ b/src/branching/version_store.cc @@ -272,26 +272,26 @@ bool traverse_until_lca( return true; } -std::optional LocalVersionStore::merge_with(const LocalVersionStore& other) { +std::optional LocalVersionStore::merge_with_commit(const std::shared_ptr& commit) { assert(staging.empty()); - assert(other.staging.empty()); + assert(commit != nullptr); // trivial case: same history - if (head == other.head) + if (head == commit) return std::nullopt; // Create merge commit (no changes itself) auto merge_commit = std::make_shared( Commit{ .id = base_timestamp++, - .parents = {head, other.head}, + .parents = {head, commit}, .changes = {} // merge commit does not write anything } ); // Find lowest common ancestor of the two heads - std::shared_ptr lca = find_lowest_common_ancestor(head, other.head); - verbose << "found lca of " << head->id << " and " << other.head->id << " to be " << lca->id << std::endl; + std::shared_ptr lca = find_lowest_common_ancestor(head, commit); + verbose << "found lca of " << head->id << " and " << commit->id << " to be " << lca->id << std::endl; // Collect all writes after LCA for each branch std::unordered_map> branch_a, branch_b; @@ -300,7 +300,7 @@ std::optional LocalVersionStore::merge_with(const LocalVersionStore& o std::unordered_map, bool> reach_memo; traverse_until_lca(head, lca, branch_a, visited, reach_memo); visited.clear(); - traverse_until_lca(other.head, lca, branch_b, visited, reach_memo); + traverse_until_lca(commit, lca, branch_b, visited, reach_memo); // 1. Eager conflict detection for (const auto& [obj, commit_a] : branch_a) { diff --git a/src/branching/version_store.hh b/src/branching/version_store.hh index ed5816c..14c9364 100644 --- a/src/branching/version_store.hh +++ b/src/branching/version_store.hh @@ -95,16 +95,16 @@ public: bool has_commited() { return staging.empty(); } + std::shared_ptr get_head() const { return head; } + std::optional get_staged(ObjectNumber obj) const; std::optional get_committed(ObjectNumber number) const; void adopt_history(const LocalVersionStore& other); - std::optional merge_with(const LocalVersionStore& other); + std::optional merge_with_commit(const std::shared_ptr& other_head); friend std::ostream& operator<<(std::ostream& os, const LocalVersionStore& store); bool operator==(const LocalVersionStore& other) const; - - void dump() { if (head) std::cout << *head << "\n============\n" << std::endl; } }; class GlobalVersionStore { diff --git a/src/execution_state.hh b/src/execution_state.hh index 1f35ce4..a854673 100644 --- a/src/execution_state.hh +++ b/src/execution_state.hh @@ -68,6 +68,13 @@ struct Thread { struct Lock { std::optional owner = std::nullopt; std::shared_ptr last_unlock_event = nullptr; + + // Branching-specific data + struct BranchingData { + std::shared_ptr commit; + }; + + BranchingData branching; }; struct GlobalContext { From b399b5b88599cddbf30d42e0328bc9872f7e088f Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Wed, 7 Jan 2026 17:29:02 +0100 Subject: [PATCH 32/58] moving protocol thread and lock sync state into the respective protocol directories and abstracting it out of the general thread state --- CMakeLists.txt | 2 +- src/branching/sync_protocol.cc | 126 --------------------------------- src/branching/sync_protocol.hh | 51 ------------- src/branching/version_store.hh | 50 +++++-------- src/execution_state.cc | 73 ++++++------------- src/execution_state.hh | 25 +++---- src/gitmem.cc | 17 +++-- src/interpreter.cc | 6 +- src/linear/sync_protocol.cc | 22 +++--- src/linear/sync_protocol.hh | 8 +++ src/linear/version_store.hh | 16 ++++- src/sync_kind.hh | 5 +- src/sync_protocol.cc | 10 +-- src/sync_protocol.hh | 3 + test_gitmem.py | 98 ++++++++++++++----------- 15 files changed, 170 insertions(+), 342 deletions(-) delete mode 100644 src/branching/sync_protocol.cc delete mode 100644 src/branching/sync_protocol.hh diff --git a/CMakeLists.txt b/CMakeLists.txt index 88ab504..fd994a1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -27,7 +27,7 @@ add_executable(gitmem src/passes/branching.cc src/linear/sync_protocol.cc src/linear/version_store.cc - src/branching/sync_protocol.cc + src/branching/base_sync_protocol.cc src/branching/version_store.cc src/interpreter.cc src/debugger.cc diff --git a/src/branching/sync_protocol.cc b/src/branching/sync_protocol.cc deleted file mode 100644 index 42abb98..0000000 --- a/src/branching/sync_protocol.cc +++ /dev/null @@ -1,126 +0,0 @@ -#include "branching/sync_protocol.hh" - -namespace gitmem { - -namespace branching { - -// -------------------- -// BranchingSyncProtocol -// -------------------- - -BranchingSyncProtocol::~BranchingSyncProtocol() = default; - -std::ostream &BranchingSyncProtocol::print(std::ostream &os) const { - os << _global_store << std::endl; - return os; -} - -std::optional BranchingSyncProtocol::read(ThreadContext &ctx, - const std::string &var) { - ObjectNumber number = _global_store.get_object_number(var); - - auto& store = std::get(ctx.sync).store; - - if (auto result = store.get_staged(number)) - return result; - - // look in commit history - return store.get_committed(number); -} - -void BranchingSyncProtocol::write(ThreadContext &ctx, const std::string &var, - size_t value) { - auto& store = std::get(ctx.sync).store; - store.stage(_global_store.get_object_number(var), value); -} - -std::optional> -BranchingSyncProtocol::on_spawn(ThreadContext &parent, ThreadContext &child, - GlobalContext &) { - auto& parent_store = std::get(parent.sync).store; - parent_store.commit_staging(); - - auto& child_store = std::get(child.sync).store; - child_store.adopt_history(parent_store); - - // a conflict cannot occur here - return std::nullopt; -} - -std::optional> -BranchingSyncProtocol::on_join(ThreadContext &joiner, ThreadContext &joinee, - GlobalContext &) { - auto& joiner_store = std::get(joiner.sync).store; - auto& joinee_store = std::get(joinee.sync).store; - - joiner_store.commit_staging(); - assert(joinee_store.has_commited() && "joinee has staged changes"); - - std::optional conflict = joiner_store.merge_with_commit(joinee_store.get_head()); - if (conflict) { - return std::make_unique( - _global_store.get_object_name(conflict->obj), - std::make_pair(conflict->timestamp_a, conflict->timestamp_b)); - } - - return std::nullopt; -} - -std::optional> -BranchingSyncProtocol::on_start(ThreadContext &thread, GlobalContext &gctx) { - // nothing to do, the thread will have inhereted the parent commit on spawn - return std::nullopt; -}; - -std::optional> -BranchingSyncProtocol::on_end(ThreadContext &thread, GlobalContext &gctx) { - auto& store = std::get(thread.sync).store; - store.commit_staging(); - - return std::nullopt; -}; - -std::optional> -BranchingSyncProtocol::on_lock(ThreadContext &thread, Lock &lock, - GlobalContext &) { - auto& thread_store = std::get(thread.sync).store; - thread_store.commit_staging(); - - std::shared_ptr lock_commit = lock.branching.commit; - - if (lock_commit != nullptr) { - std::optional conflict = thread_store.merge_with_commit(lock_commit); - if (conflict) { - return std::make_unique( - _global_store.get_object_name(conflict->obj), - std::make_pair(conflict->timestamp_a, conflict->timestamp_b)); - } - } - - lock.branching.commit = thread_store.get_head(); - - return std::nullopt; -} - -std::optional> -BranchingSyncProtocol::on_unlock(ThreadContext &thread, Lock &lock, - GlobalContext &) { - auto& thread_store = std::get(thread.sync).store; - thread_store.commit_staging(); - - std::shared_ptr lock_commit = lock.branching.commit; - - // we don't need to check for conflicts - if (lock_commit != nullptr) { - std::optional conflict = thread_store.merge_with_commit(lock_commit); - assert (!conflict); - } - - lock.branching.commit = thread_store.get_head(); - - return std::nullopt; -} - -} // end branching - -} // end gitmem \ No newline at end of file diff --git a/src/branching/sync_protocol.hh b/src/branching/sync_protocol.hh deleted file mode 100644 index b232684..0000000 --- a/src/branching/sync_protocol.hh +++ /dev/null @@ -1,51 +0,0 @@ -#pragma once - -#include "../sync_protocol.hh" -#include "version_store.hh" - -namespace gitmem { - -using BranchingConflict = Conflict; - -namespace branching { - -class BranchingSyncProtocol final : public SyncProtocol { - GlobalVersionStore _global_store; - - // std::unordered_map> commit_nodes; - -public: - ~BranchingSyncProtocol() override; - SyncKind kind() const override { return SyncKind::Branching; }; - - std::optional read(ThreadContext &ctx, - const std::string &var) override; - - void write(ThreadContext &ctx, const std::string &var, size_t value) override; - - std::optional> - on_spawn(ThreadContext &parent, ThreadContext &child, - GlobalContext &gctx) override; - - std::optional> - on_join(ThreadContext &joiner, ThreadContext &joinee, - GlobalContext &gctx) override; - - std::optional> - on_start(ThreadContext &thread, GlobalContext &gctx) override; - - std::optional> - on_end(ThreadContext &thread, GlobalContext &gctx) override; - - std::optional> - on_lock(ThreadContext &thread, Lock &lock, GlobalContext &gctx) override; - - std::optional> - on_unlock(ThreadContext &thread, Lock &lock, GlobalContext &gctx) override; - - std::ostream &print(std::ostream &os) const override; -}; - -} // end branching - -} // end gitmem \ No newline at end of file diff --git a/src/branching/version_store.hh b/src/branching/version_store.hh index 14c9364..e4d25ae 100644 --- a/src/branching/version_store.hh +++ b/src/branching/version_store.hh @@ -7,6 +7,7 @@ #include #include #include "thread_id.hh" +#include "sync_state.hh" namespace gitmem { @@ -79,8 +80,7 @@ inline std::ostream& operator<<(std::ostream& os, const Conflict& c) { // std::optional value; // only valid if kind == Value // }; - -class LocalVersionStore { +class LocalVersionStore : public ThreadSyncState { Timestamp base_timestamp; std::shared_ptr head; std::unordered_map staging; @@ -88,6 +88,8 @@ class LocalVersionStore { std::unordered_map> last_writer; // cached public: + ~LocalVersionStore() = default; + LocalVersionStore(ThreadID tid): base_timestamp(tid, 0) {} void stage(ObjectNumber obj, Value value); @@ -104,7 +106,18 @@ public: std::optional merge_with_commit(const std::shared_ptr& other_head); friend std::ostream& operator<<(std::ostream& os, const LocalVersionStore& store); + std::ostream &print(std::ostream &os) const override { + os << dynamic_cast(this); + return os; + } + bool operator==(const LocalVersionStore& other) const; + bool operator==(const ThreadSyncState& other) const override { + auto* o = dynamic_cast(&other); + if (!o) + return false; + return *this == *o; + } }; class GlobalVersionStore { @@ -119,35 +132,10 @@ public: friend std::ostream& operator<<(std::ostream&, const GlobalVersionStore&); }; -// using CommitHistory = std::vector; - -// struct Global { -// size_t val; -// std::optional commit; -// CommitHistory history; -// }; - -// using Globals = std::unordered_map; - -// struct Conflict { -// std::string var; -// std::pair commits; -// }; - -// Join logic -// commit(ctx.globals); -// commit(thread->ctx.globals); -// verbose << "Pulling from thread " << result << std::endl; -// if(auto conflict = pull(ctx.globals, thread->ctx.globals)) -// { -// using graph::Node; -// auto [s1, s2] = conflict->commits; -// auto sources = std::pair, -// std::shared_ptr>{gctx.commit_map[s1], gctx.commit_map[s2]}; auto -// graph_conflict = graph::Conflict(conflict->var, sources); -// thread_append_node(ctx, result, thread->ctx.tail, -// graph_conflict); return TerminationStatus::datarace_exception; -// } +class LockState : public LockSyncState { +public: + std::shared_ptr commit; +}; } // namespace branching diff --git a/src/execution_state.cc b/src/execution_state.cc index fb60c8b..aeaa943 100644 --- a/src/execution_state.cc +++ b/src/execution_state.cc @@ -5,15 +5,8 @@ namespace gitmem { -ThreadContext::ThreadContext(ThreadID tid, SyncKind sync_kind) { - switch (sync_kind) { - case SyncKind::Linear: - sync.emplace(); - break; - case SyncKind::Branching: - sync.emplace(tid); - break; - } +ThreadContext::ThreadContext(ThreadID tid, std::unique_ptr& protocol) { + sync = protocol->make_thread_state(tid); } bool ThreadContext::operator==(const ThreadContext &other) const { @@ -22,38 +15,9 @@ bool ThreadContext::operator==(const ThreadContext &other) const { // ignore the graph node, we're not interested in that - if (sync.index() != other.sync.index()) - return false; - - return std::visit([&](const auto& a, const auto& b) -> bool { - using A = std::decay_t; - using B = std::decay_t; - - if constexpr (std::is_same_v && - std::is_same_v) { - return true; - } else if constexpr (std::is_same_v) { - return a.store == b.store; - } else { - return false; // unreachable due to index check - } - }, sync, other.sync); + return *sync == *other.sync; } -// An old comment on equals -// Globals have a history that we don't care about, so we only -// compare values -// if (ctx.globals.size() != other.ctx.globals.size()) -// return false; -// for (const auto &[var, global] : ctx.globals) -// { -// if (!other.ctx.globals.contains(var) || -// ctx.globals.at(var).val != other.ctx.globals.at(var).val) -// { -// return false; -// } -// } - bool Thread::operator==(const Thread &other) const { return ctx == other.ctx && block == other.block && @@ -68,15 +32,30 @@ GlobalContext::GlobalContext(const trieste::Node &ast, ThreadID main_tid = 0; - ThreadContext starting_ctx(main_tid, this->protocol->kind()); + ThreadContext starting_ctx(main_tid, this->protocol); this->threads.emplace_back(main_tid, std::move(starting_ctx), starting_block); - this->locks = {}; - this->cache = {}; } GlobalContext::~GlobalContext() = default; +Lock& GlobalContext::get_lock(std::string lock) { + auto it = locks.find(lock); + if (it != locks.end()) + return it->second; + + auto [new_it, inserted] = locks.emplace( + lock, + Lock{ + .owner = std::nullopt, + .last_unlock_event = nullptr, + .sync = protocol->make_lock_state() + } + ); + + return new_it->second; +} + // void GlobalContext::print_execution_graph( // const std::filesystem::path &output_path) const { // return; // FIXME @@ -164,15 +143,7 @@ std::ostream& operator<<(std::ostream& os, const ThreadContext& ctx) { os << "}"; //, tail=" << ctx.tail; - std::visit([&](const auto& data) { - using T = std::decay_t; - - if constexpr (std::is_same_v) { - os << ", sync=linear{" << data.store << "}"; - } else if constexpr (std::is_same_v) { - os << ", sync=branching{" << data.store << "}"; - } - }, ctx.sync); + os << *(ctx.sync); os << "}"; return os; diff --git a/src/execution_state.hh b/src/execution_state.hh index a854673..5c4f0c1 100644 --- a/src/execution_state.hh +++ b/src/execution_state.hh @@ -17,18 +17,12 @@ namespace gitmem { class SyncProtocol; +class ThreadSyncState; struct ThreadContext { std::unordered_map locals; - struct LinearData { - linear::LocalVersionStore store; - }; - struct BranchingData { - branching::LocalVersionStore store; - }; - - std::variant sync; + std::unique_ptr sync; ThreadContext(const ThreadContext&) = delete; ThreadContext& operator=(const ThreadContext&) = delete; @@ -36,7 +30,7 @@ struct ThreadContext { ThreadContext(ThreadContext&&) = default; ThreadContext& operator=(ThreadContext&&) = default; - ThreadContext(ThreadID tid, SyncKind sync_kind); + ThreadContext(ThreadID tid, std::unique_ptr&); bool operator==(const ThreadContext &other) const; @@ -51,7 +45,7 @@ struct Thread { size_t pc = 0; std::optional terminated = std::nullopt; - Thread(ThreadID tid, ThreadContext&& ctx, trieste::Node block): + Thread(ThreadID tid, ThreadContext ctx, trieste::Node block): tid(tid), ctx(std::move(ctx)), trace(tid), block(block) {}; Thread(const Thread&) = delete; @@ -68,19 +62,16 @@ struct Thread { struct Lock { std::optional owner = std::nullopt; std::shared_ptr last_unlock_event = nullptr; - - // Branching-specific data - struct BranchingData { - std::shared_ptr commit; - }; - - BranchingData branching; + std::unique_ptr sync; }; struct GlobalContext { // Execution state std::deque threads; +private: std::unordered_map locks; +public: + Lock& get_lock(std::string); // AST evaluation cache lang::NodeMap cache; diff --git a/src/gitmem.cc b/src/gitmem.cc index 6fda58a..bdffbfb 100644 --- a/src/gitmem.cc +++ b/src/gitmem.cc @@ -5,7 +5,7 @@ #include "model_checker.hh" #include "debugger.hh" #include "lang.hh" -#include "sync_protocol.hh" +#include "sync_kind.hh" int main(int argc, char **argv) { using namespace trieste; @@ -32,9 +32,17 @@ int main(int argc, char **argv) { app.add_flag("-e,--explore", model_check, "Explore all possible execution paths."); - bool branching = false; - app.add_flag("-b,--branching", branching, - "Using branching semantics."); + auto string_to_sync = CLI::Transformer(std::map { + {"linear", gitmem::SyncKind::Linear}, + {"branching-eager", gitmem::SyncKind::BranchingEager}, + {"branching-lazy", gitmem::SyncKind::BranchingLazy} + }); + string_to_sync.description("linear,branching-eager,branching-lazy"); + + gitmem::SyncKind sync_kind = gitmem::SyncKind::Linear; + app.add_option("--sync", sync_kind, "Select a sync protocol for execution (default: linear)") + ->transform(string_to_sync) + ->type_name("SYNC_KIND"); try { app.parse(argc, argv); @@ -67,7 +75,6 @@ int main(int argc, char **argv) { gitmem::verbose << "Output will be written to " << output_path << std::endl; int exit_status; - gitmem::SyncKind sync_kind = branching ? gitmem::SyncKind::Branching : gitmem::SyncKind::Linear; wf::push_back(gitmem::lang::wf); if (model_check) { exit_status = gitmem::model_check(result.ast, output_path, sync_kind); diff --git a/src/interpreter.cc b/src/interpreter.cc index 9f8fc27..ba453b0 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -75,7 +75,7 @@ Interpreter::evaluate_expression(trieste::Node expr, Thread& thread) { return sum; } else if (e == lang::Spawn) { ThreadID child_tid = gctx.threads.size(); - ThreadContext child_ctx(child_tid, gctx.protocol->kind()); + ThreadContext child_ctx(child_tid, gctx.protocol); if (std::optional> conflict = gctx.protocol->on_spawn(ctx, child_ctx, gctx)) { @@ -220,7 +220,7 @@ std::variant Interpreter::run_statement(Node stmt, Threa auto v = s / lang::Var; auto var = std::string(v->location().view()); - Lock &lock = gctx.locks[var]; + Lock& lock = gctx.get_lock(var); if (lock.owner) { verbose << "Waiting for lock " << var << " owned by " << lock.owner.value() << std::endl; @@ -248,7 +248,7 @@ std::variant Interpreter::run_statement(Node stmt, Threa auto v = s / lang::Var; auto var = std::string(v->location().view()); - auto &lock = gctx.locks[var]; + Lock& lock = gctx.get_lock(var); if (!lock.owner || (lock.owner && *lock.owner != thread.tid)) { return TerminationStatus::unlock_exception; } diff --git a/src/linear/sync_protocol.cc b/src/linear/sync_protocol.cc index e2cf897..f26ac5c 100644 --- a/src/linear/sync_protocol.cc +++ b/src/linear/sync_protocol.cc @@ -6,6 +6,10 @@ namespace gitmem { namespace linear { +LocalVersionStore& get_store(ThreadContext& ctx) { + return static_cast(*ctx.sync); +} + // -------------------- // LinearSyncProtocol // -------------------- @@ -54,7 +58,7 @@ std::optional LinearSyncProtocol::read(ThreadContext &ctx, const std::string &var) { ObjectNumber number = _global_store.get_object_number(var); - auto& store = std::get(ctx.sync).store; + auto& store = get_store(ctx); if (auto result = store.get_staged(number)) return result; @@ -74,7 +78,7 @@ std::optional LinearSyncProtocol::read(ThreadContext &ctx, void LinearSyncProtocol::write(ThreadContext &ctx, const std::string &var, size_t value) { // write into the staging area of the thread - auto& store = std::get(ctx.sync).store; + auto& store = get_store(ctx); store.stage(_global_store.get_object_number(var), value); } @@ -85,12 +89,12 @@ LinearSyncProtocol::on_spawn(ThreadContext &parent, ThreadContext &child, // added // push parent to global history - auto& store = std::get(parent.sync).store; + auto& store = get_store(parent); if (auto conflict = push(store)) return std::make_unique(std::move(*conflict)); // pull into the child - store = std::get(child.sync).store; + store = get_store(child); if (auto conflict = pull(store)) { throw std::logic_error("This code path should never be reached"); } @@ -104,7 +108,7 @@ LinearSyncProtocol::on_join(ThreadContext &joiner, ThreadContext &joinee, // we assume the joinee has already terminated and pushed // pull changes into parent - auto& store = std::get(joiner.sync).store; + auto& store = get_store(joiner); if (auto conflict = pull(store)) return std::make_unique(std::move(*conflict)); @@ -114,7 +118,7 @@ LinearSyncProtocol::on_join(ThreadContext &joiner, ThreadContext &joinee, std::optional> LinearSyncProtocol::on_start(ThreadContext &thread, GlobalContext &gctx) { // pull state from global history - auto& store = std::get(thread.sync).store; + auto& store = get_store(thread); auto conflict = pull(store); assert(!conflict && "cannot conflict from starting state"); @@ -124,7 +128,7 @@ LinearSyncProtocol::on_start(ThreadContext &thread, GlobalContext &gctx) { std::optional> LinearSyncProtocol::on_end(ThreadContext &thread, GlobalContext &gctx) { // push changes to global history - auto& store = std::get(thread.sync).store; + auto& store = get_store(thread); if (auto conflict = push(store)) return std::make_unique(std::move(*conflict)); @@ -135,7 +139,7 @@ std::optional> LinearSyncProtocol::on_lock(ThreadContext &thread, Lock &lock, GlobalContext &gctx) { - auto& store = std::get(thread.sync).store; + auto& store = get_store(thread); if (auto conflict = pull(store)) return std::make_unique(std::move(*conflict)); @@ -147,7 +151,7 @@ LinearSyncProtocol::on_unlock(ThreadContext &thread, Lock &, GlobalContext &gctx) { // push changes to global history - auto& store = std::get(thread.sync).store; + auto& store = get_store(thread); if (auto conflict = push(store)) return std::make_unique(std::move(*conflict)); diff --git a/src/linear/sync_protocol.hh b/src/linear/sync_protocol.hh index 59cfea9..5bb9480 100644 --- a/src/linear/sync_protocol.hh +++ b/src/linear/sync_protocol.hh @@ -48,6 +48,14 @@ public: on_unlock(ThreadContext &thread, Lock &lock, GlobalContext &gctx) override; std::ostream &print(std::ostream &os) const override; + + std::unique_ptr make_thread_state(ThreadID tid) const override { + return std::make_unique(); + } + + std::unique_ptr make_lock_state() const override { + return nullptr; + } }; } // namespace linear diff --git a/src/linear/version_store.hh b/src/linear/version_store.hh index 41b0911..e8e9bda 100644 --- a/src/linear/version_store.hh +++ b/src/linear/version_store.hh @@ -6,6 +6,7 @@ #include #include #include +#include "sync_state.hh" namespace gitmem { @@ -81,11 +82,13 @@ struct Conflict { // LocalVersionStore // ----------------------------- -class LocalVersionStore { +class LocalVersionStore : public ThreadSyncState { Timestamp _base_timestamp{}; std::unordered_map _staging; public: + ~LocalVersionStore() = default; + Timestamp base_timestamp() const { return _base_timestamp; } const auto &staged_changes() const { return _staging; } @@ -96,7 +99,18 @@ public: bool operator==(const LocalVersionStore& other) const; + bool operator==(const ThreadSyncState& other) const override { + auto* o = dynamic_cast(&other); + if (!o) + return false; + return *this == *o; + } + friend std::ostream& operator<<(std::ostream&, const LocalVersionStore&); + std::ostream &print(std::ostream &os) const override { + os << dynamic_cast(this); + return os; + } }; // ----------------------------- diff --git a/src/sync_kind.hh b/src/sync_kind.hh index 303df8e..a5d2867 100644 --- a/src/sync_kind.hh +++ b/src/sync_kind.hh @@ -1,10 +1,11 @@ #pragma once namespace gitmem { - + enum class SyncKind { Linear, - Branching + BranchingEager, + BranchingLazy }; } \ No newline at end of file diff --git a/src/sync_protocol.cc b/src/sync_protocol.cc index 87ee605..ff3f7c0 100644 --- a/src/sync_protocol.cc +++ b/src/sync_protocol.cc @@ -1,15 +1,15 @@ #include "sync_protocol.hh" #include "linear/sync_protocol.hh" -#include "branching/sync_protocol.hh" +#include "branching/eager_sync_protocol.hh" +#include "branching/lazy_sync_protocol.hh" namespace gitmem { std::unique_ptr make_protocol(SyncKind sync_kind) { switch (sync_kind) { - case SyncKind::Linear: - return std::make_unique(); - case SyncKind::Branching: - return std::make_unique(); + case SyncKind::Linear: return std::make_unique(); + case SyncKind::BranchingEager: return std::make_unique(); + case SyncKind::BranchingLazy: return std::make_unique(); } std::unreachable(); } diff --git a/src/sync_protocol.hh b/src/sync_protocol.hh index 82bfc69..77101cf 100644 --- a/src/sync_protocol.hh +++ b/src/sync_protocol.hh @@ -2,6 +2,7 @@ #include "conflict.hh" #include "sync_kind.hh" +#include "sync_state.hh" #include "execution_state.hh" #include #include @@ -14,6 +15,8 @@ class SyncProtocol { public: virtual ~SyncProtocol() = default; virtual SyncKind kind() const = 0; + virtual std::unique_ptr make_thread_state(ThreadID tid) const = 0; + virtual std::unique_ptr make_lock_state() const = 0; // Read a shared variable into the thread context virtual std::optional read(ThreadContext &ctx, diff --git a/test_gitmem.py b/test_gitmem.py index ee6556c..f686a5e 100644 --- a/test_gitmem.py +++ b/test_gitmem.py @@ -6,6 +6,12 @@ EXAMPLES_DIR = "examples" +SYNC_KINDS = { + "linear": "linear", + "branching-eager": "branching-eager", + "branching-lazy": "branching-lazy", +} + def supports_color(): return sys.stdout.isatty() and os.getenv("NO_COLOR") is None @@ -20,11 +26,14 @@ def green(text): def red(text): return color(text, "31") -def run_gitmem_test(gitmem_path, file_path, should_accept, is_branching): - cmd = [gitmem_path, file_path, "-e", "-o", "/dev/null"] - - if is_branching: - cmd.insert(2, "-b") +def run_gitmem_test(gitmem_path, file_path, should_accept, sync_kind): + cmd = [ + gitmem_path, + file_path, + "--sync", sync_kind, + "-e", + "-o", "/dev/null" + ] try: result = subprocess.run( @@ -41,7 +50,7 @@ def run_gitmem_test(gitmem_path, file_path, should_accept, is_branching): sys.exit(1) status = green("PASS") if accepted else red("FAIL") - print(f"[{status}] {file_path} (exit code: {result.returncode})") + print(f"[{status}] {file_path} [{sync_kind}] (exit code: {result.returncode})") return accepted def main(): @@ -54,21 +63,34 @@ def main(): parser.add_argument( "--linear", action="store_true", - help="Only run linear tests" + help="Only run linear sync tests" + ) + parser.add_argument( + "--branching-eager", + action="store_true", + help="Only run branching-eager tests" ) parser.add_argument( - "--branching", + "--branching-lazy", action="store_true", - help="Only run branching tests" + help="Only run branching-lazy tests" ) + args = parser.parse_args() gitmem_path = args.gitmem - run_linear = args.linear - run_branching = args.branching - # If neither flag is specified, run both - if not run_linear and not run_branching: - run_linear = run_branching = True + selected_syncs = [] + + if args.linear: + selected_syncs.append("linear") + if args.branching_eager: + selected_syncs.append("branching-eager") + if args.branching_lazy: + selected_syncs.append("branching-lazy") + + # If none specified, run all + if not selected_syncs: + selected_syncs = list(SYNC_KINDS.values()) results = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: { "total": 0, @@ -87,54 +109,50 @@ def main(): if not os.path.isdir(base_dir): continue - if category == "semantics": - subcategories = [] - if run_branching: - subcategories.append("branching") - if run_linear: - subcategories.append("linear") - else: - subcategories = [None] - - for subcategory in subcategories: - if subcategory: - test_dir = os.path.join(base_dir, subcategory) - if not os.path.isdir(test_dir): - continue - is_branching = (subcategory == "branching") + for sync_kind in selected_syncs: + # syntax tests are sync-agnostic → only run once + if category == "syntax" and sync_kind != "linear": + continue + + if category == "semantics": + if sync_kind == "linear": + test_dir = os.path.join(base_dir, "linear") + else: + test_dir = os.path.join(base_dir, "branching") else: test_dir = base_dir - is_branching = False + + if not os.path.isdir(test_dir): + continue for root, _, files in os.walk(test_dir): for file in files: file_path = os.path.join(root, file) total_tests += 1 - results[expectation][category][subcategory]["total"] += 1 + results[expectation][category][sync_kind]["total"] += 1 passed = run_gitmem_test( gitmem_path, file_path, should_accept, - is_branching + sync_kind ) if not passed: failed_tests += 1 - results[expectation][category][subcategory]["failed"] += 1 - failing_tests.append(file_path) + results[expectation][category][sync_kind]["failed"] += 1 + failing_tests.append((file_path, sync_kind)) print("\nDetailed Summary:") for expectation, categories in results.items(): print(f"\n{expectation.upper()}:") - for category, subcats in categories.items(): + for category, syncs in categories.items(): print(f" {category}:") - for subcategory, stats in subcats.items(): - label = subcategory if subcategory else "all" + for sync_kind, stats in syncs.items(): passed = stats["total"] - stats["failed"] print( - f" {label}: " + f" {sync_kind}: " f"{passed}/{stats['total']} passed " f"({stats['failed']} failed)" ) @@ -146,8 +164,8 @@ def main(): if failing_tests: print("\nFailing tests:") - for path in failing_tests: - print(f" {red(path)}") + for path, sync in failing_tests: + print(f" {red(path)} [{sync}]") if failed_tests > 0: sys.exit(1) From 597d09a0a6ed747c01d0687d76c7676d7a82be69 Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Wed, 7 Jan 2026 17:29:26 +0100 Subject: [PATCH 33/58] adding missing sync state files --- src/branching/base_sync_protocol.cc | 136 +++++++++++++++++++++++++++ src/branching/base_sync_protocol.hh | 57 +++++++++++ src/branching/eager_sync_protocol.hh | 16 ++++ src/branching/lazy_sync_protocol.hh | 16 ++++ src/sync_state.hh | 22 +++++ 5 files changed, 247 insertions(+) create mode 100644 src/branching/base_sync_protocol.cc create mode 100644 src/branching/base_sync_protocol.hh create mode 100644 src/branching/eager_sync_protocol.hh create mode 100644 src/branching/lazy_sync_protocol.hh create mode 100644 src/sync_state.hh diff --git a/src/branching/base_sync_protocol.cc b/src/branching/base_sync_protocol.cc new file mode 100644 index 0000000..c5563d1 --- /dev/null +++ b/src/branching/base_sync_protocol.cc @@ -0,0 +1,136 @@ +#include "branching/base_sync_protocol.hh" + +namespace gitmem { + +namespace branching { + +LocalVersionStore& get_store(ThreadContext& ctx) { + return static_cast(*ctx.sync); +} + +LockState& get_store(Lock& ctx) { + return static_cast(*ctx.sync); +} + +// -------------------- +// BranchingSyncProtocolBase +// -------------------- + +BranchingSyncProtocolBase::~BranchingSyncProtocolBase() = default; + +std::ostream &BranchingSyncProtocolBase::print(std::ostream &os) const { + os << _global_store << std::endl; + return os; +} + +std::optional BranchingSyncProtocolBase::read(ThreadContext &ctx, + const std::string &var) { + ObjectNumber number = _global_store.get_object_number(var); + + auto& store = get_store(ctx); + + if (auto result = store.get_staged(number)) + return result; + + // look in commit history + return store.get_committed(number); +} + +void BranchingSyncProtocolBase::write(ThreadContext &ctx, const std::string &var, + size_t value) { + auto& store = get_store(ctx); + store.stage(_global_store.get_object_number(var), value); +} + +std::optional> +BranchingSyncProtocolBase::on_spawn(ThreadContext &parent, ThreadContext &child, + GlobalContext &) { + auto& parent_store = get_store(parent); + parent_store.commit_staging(); + + auto& child_store = get_store(child); + child_store.adopt_history(parent_store); + + // a conflict cannot occur here + return std::nullopt; +} + +std::optional> +BranchingSyncProtocolBase::on_join(ThreadContext &joiner, ThreadContext &joinee, + GlobalContext &) { + auto& joiner_store = get_store(joiner); + auto& joinee_store = get_store(joinee); + + joiner_store.commit_staging(); + assert(joinee_store.has_commited() && "joinee has staged changes"); + + std::optional conflict = joiner_store.merge_with_commit(joinee_store.get_head()); + if (conflict) { + return std::make_unique( + _global_store.get_object_name(conflict->obj), + std::make_pair(conflict->timestamp_a, conflict->timestamp_b)); + } + + return std::nullopt; +} + +std::optional> +BranchingSyncProtocolBase::on_start(ThreadContext &thread, GlobalContext &gctx) { + // nothing to do, the thread will have inhereted the parent commit on spawn + return std::nullopt; +}; + +std::optional> +BranchingSyncProtocolBase::on_end(ThreadContext &thread, GlobalContext &gctx) { + auto& store = get_store(thread); + store.commit_staging(); + + return std::nullopt; +}; + +std::optional> +BranchingSyncProtocolBase::on_lock(ThreadContext &thread, Lock &lock, + GlobalContext &) { + auto& store = get_store(thread); + store.commit_staging(); + + LockState& lock_state = get_store(lock); + std::shared_ptr lock_commit = lock_state.commit; + + if (lock_commit != nullptr) { + std::optional conflict = store.merge_with_commit(lock_commit); + if (conflict) { + return std::make_unique( + _global_store.get_object_name(conflict->obj), + std::make_pair(conflict->timestamp_a, conflict->timestamp_b)); + } + } + + lock_state.commit = store.get_head(); + + return std::nullopt; +} + +std::optional> +BranchingSyncProtocolBase::on_unlock(ThreadContext &thread, Lock &lock, + GlobalContext &) { + auto& store = get_store(thread); + store.commit_staging(); + + LockState& lock_state = get_store(lock); + std::shared_ptr lock_commit = lock_state.commit; + + // we don't need to check for conflicts + if (lock_commit != nullptr) { + std::optional conflict = store.merge_with_commit(lock_commit); + assert (!conflict); + } + + lock_state.commit = store.get_head(); + + return std::nullopt; +} + +} // end branching + +} // end gitmem \ No newline at end of file diff --git a/src/branching/base_sync_protocol.hh b/src/branching/base_sync_protocol.hh new file mode 100644 index 0000000..b48b4ba --- /dev/null +++ b/src/branching/base_sync_protocol.hh @@ -0,0 +1,57 @@ +#pragma once + +#include "../sync_protocol.hh" +#include "version_store.hh" + +namespace gitmem { + +using BranchingConflict = Conflict; + +namespace branching { + +class BranchingSyncProtocolBase : public SyncProtocol { +protected: + GlobalVersionStore _global_store; + +public: + ~BranchingSyncProtocolBase() override; + + std::optional read(ThreadContext &ctx, + const std::string &var) override; + + void write(ThreadContext &ctx, const std::string &var, size_t value) override; + + std::optional> + on_spawn(ThreadContext &parent, ThreadContext &child, + GlobalContext &gctx) override; + + std::optional> + on_join(ThreadContext &joiner, ThreadContext &joinee, + GlobalContext &gctx) override; + + std::optional> + on_start(ThreadContext &thread, GlobalContext &gctx) override; + + std::optional> + on_end(ThreadContext &thread, GlobalContext &gctx) override; + + std::optional> + on_lock(ThreadContext &thread, Lock &lock, GlobalContext &gctx) override; + + std::optional> + on_unlock(ThreadContext &thread, Lock &lock, GlobalContext &gctx) override; + + std::ostream &print(std::ostream &os) const override; + + std::unique_ptr make_thread_state(ThreadID tid) const override { + return std::make_unique(tid); + } + + std::unique_ptr make_lock_state() const override { + return std::make_unique(); + } +}; + +} // end branching + +} // end gitmem \ No newline at end of file diff --git a/src/branching/eager_sync_protocol.hh b/src/branching/eager_sync_protocol.hh new file mode 100644 index 0000000..a50b909 --- /dev/null +++ b/src/branching/eager_sync_protocol.hh @@ -0,0 +1,16 @@ +#pragma once + +#include "base_sync_protocol.hh" + +namespace gitmem { + +namespace branching { + +class BranchingEagerSyncProtocol final : public BranchingSyncProtocolBase { +public: + SyncKind kind() const override { return SyncKind::BranchingEager; }; +}; + +} + +} \ No newline at end of file diff --git a/src/branching/lazy_sync_protocol.hh b/src/branching/lazy_sync_protocol.hh new file mode 100644 index 0000000..3b0a9dc --- /dev/null +++ b/src/branching/lazy_sync_protocol.hh @@ -0,0 +1,16 @@ +#pragma once + +#include "base_sync_protocol.hh" + +namespace gitmem { + +namespace branching { + +class BranchingLazySyncProtocol final : public BranchingSyncProtocolBase { +public: + SyncKind kind() const override { return SyncKind::BranchingLazy; }; +}; + +} + +} \ No newline at end of file diff --git a/src/sync_state.hh b/src/sync_state.hh new file mode 100644 index 0000000..2960120 --- /dev/null +++ b/src/sync_state.hh @@ -0,0 +1,22 @@ +#pragma once + +#include + +namespace gitmem { + +class ThreadSyncState { +public: + virtual ~ThreadSyncState() = default; + + virtual bool operator==(const ThreadSyncState& other) const = 0; + + virtual std::ostream &print(std::ostream &os) const = 0; + friend std::ostream &operator<<(std::ostream &os, + const ThreadSyncState &state) { + return state.print(os); + } +}; + +class LockSyncState {}; + +} \ No newline at end of file From 60e753dafee1af0649299227b010f6524f065f49 Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Wed, 7 Jan 2026 18:22:23 +0100 Subject: [PATCH 34/58] building initial plumbing to allow reads to fail lazily --- src/branching/base_sync_protocol.cc | 9 ++++++--- src/branching/base_sync_protocol.hh | 3 +-- src/branching/version_store.hh | 12 +---------- src/interpreter.cc | 31 ++++++++++++++++++++++++----- src/linear/sync_protocol.cc | 10 +++++----- src/linear/sync_protocol.hh | 3 +-- src/read_result.hh | 11 ++++++++++ src/sync_protocol.hh | 4 ++-- src/thread_trace.hh | 13 ++++++++---- 9 files changed, 62 insertions(+), 34 deletions(-) create mode 100644 src/read_result.hh diff --git a/src/branching/base_sync_protocol.cc b/src/branching/base_sync_protocol.cc index c5563d1..f60ea93 100644 --- a/src/branching/base_sync_protocol.cc +++ b/src/branching/base_sync_protocol.cc @@ -23,17 +23,20 @@ std::ostream &BranchingSyncProtocolBase::print(std::ostream &os) const { return os; } -std::optional BranchingSyncProtocolBase::read(ThreadContext &ctx, +ReadResult BranchingSyncProtocolBase::read(ThreadContext &ctx, const std::string &var) { ObjectNumber number = _global_store.get_object_number(var); auto& store = get_store(ctx); if (auto result = store.get_staged(number)) - return result; + return *result; // look in commit history - return store.get_committed(number); + if (auto result = store.get_committed(number)) + return *result; + + return std::monostate{}; } void BranchingSyncProtocolBase::write(ThreadContext &ctx, const std::string &var, diff --git a/src/branching/base_sync_protocol.hh b/src/branching/base_sync_protocol.hh index b48b4ba..ff24943 100644 --- a/src/branching/base_sync_protocol.hh +++ b/src/branching/base_sync_protocol.hh @@ -16,8 +16,7 @@ protected: public: ~BranchingSyncProtocolBase() override; - std::optional read(ThreadContext &ctx, - const std::string &var) override; + ReadResult read(ThreadContext &ctx, const std::string &var) override; void write(ThreadContext &ctx, const std::string &var, size_t value) override; diff --git a/src/branching/version_store.hh b/src/branching/version_store.hh index e4d25ae..42215e2 100644 --- a/src/branching/version_store.hh +++ b/src/branching/version_store.hh @@ -8,6 +8,7 @@ #include #include "thread_id.hh" #include "sync_state.hh" +#include "read_result.hh" namespace gitmem { @@ -44,7 +45,6 @@ struct Timestamp { } }; -using Value = size_t; using ObjectNumber = uint64_t; struct Commit { @@ -69,16 +69,6 @@ inline std::ostream& operator<<(std::ostream& os, const Conflict& c) { } // Initial plumbing for fail late -// enum class ReadKind { -// NotFound, -// Value, -// Conflict -// }; - -// struct ReadResult { -// ReadKind kind; -// std::optional value; // only valid if kind == Value -// }; class LocalVersionStore : public ThreadSyncState { Timestamp base_timestamp; diff --git a/src/interpreter.cc b/src/interpreter.cc index ba453b0..4ec36fc 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -38,6 +38,15 @@ static bool is_syncing(Thread &thread) { ((thread.pc >= thread.block->size()) || is_syncing(thread.block->at(thread.pc))); } +// Helper to combine multiple lambdas for std::visit +template +struct overloaded : Ts... { + using Ts::operator()...; +}; + +// deduction guide +template overloaded(Ts...) -> overloaded; + /* Evaluating an expression either returns the result of the expression or * a the exceptional termination status of the thread. */ @@ -56,12 +65,24 @@ Interpreter::evaluate_expression(trieste::Node expr, Thread& thread) { } } else if (e == lang::Var) { auto var = std::string(expr->location().view()); - if (std::optional result = gctx.protocol->read(ctx, var)) { - thread.trace.on_read(var, *result); - return *result; - } else { // It is invalid to read a previously unwritten value - return TerminationStatus::unassigned_variable_read_exception; + + auto result = gctx.protocol->read(ctx, var); + + return std::visit(overloaded{ + [&](std::monostate) -> std::variant { + // invalid: reading a variable that hasn't been written + return TerminationStatus::unassigned_variable_read_exception; + }, + [&](Value value) -> std::variant { + // normal read + thread.trace.on_read(var, value); + return value; + }, + [&](std::unique_ptr& conflict) -> std::variant { + verbose << (*conflict) << std::endl; + return TerminationStatus::datarace_exception; } +}, result); } else if (e == lang::Const) { return size_t(std::stoi(std::string(e->location().view()))); } else if (e == lang::Add) { diff --git a/src/linear/sync_protocol.cc b/src/linear/sync_protocol.cc index f26ac5c..50ce0a7 100644 --- a/src/linear/sync_protocol.cc +++ b/src/linear/sync_protocol.cc @@ -54,25 +54,25 @@ LinearSyncProtocol::pull(LocalVersionStore &local) { LinearSyncProtocol::~LinearSyncProtocol() = default; -std::optional LinearSyncProtocol::read(ThreadContext &ctx, +ReadResult LinearSyncProtocol::read(ThreadContext &ctx, const std::string &var) { ObjectNumber number = _global_store.get_object_number(var); auto& store = get_store(ctx); if (auto result = store.get_staged(number)) - return result; + return *result; std::optional value = _global_store.get_version_for_timestamp( number, store.base_timestamp()); - if (!value) - return std::nullopt; + if (value) + return *value; // we do not need to record the staged value for correctness // TODO: there is something about working out if a value has changed vs been // written - return *value; + return std::monostate{}; } void LinearSyncProtocol::write(ThreadContext &ctx, const std::string &var, diff --git a/src/linear/sync_protocol.hh b/src/linear/sync_protocol.hh index 5bb9480..265b78b 100644 --- a/src/linear/sync_protocol.hh +++ b/src/linear/sync_protocol.hh @@ -22,8 +22,7 @@ public: ~LinearSyncProtocol() override; SyncKind kind() const override { return SyncKind::Linear; }; - std::optional read(ThreadContext &ctx, - const std::string &var) override; + ReadResult read(ThreadContext &ctx, const std::string &var) override; void write(ThreadContext &ctx, const std::string &var, size_t value) override; diff --git a/src/read_result.hh b/src/read_result.hh new file mode 100644 index 0000000..3c1bc1a --- /dev/null +++ b/src/read_result.hh @@ -0,0 +1,11 @@ +#pragma once + +#include +#include "conflict.hh" + +namespace gitmem { + +using Value = size_t; +using ReadResult = std::variant>; + +} \ No newline at end of file diff --git a/src/sync_protocol.hh b/src/sync_protocol.hh index 77101cf..052cfc0 100644 --- a/src/sync_protocol.hh +++ b/src/sync_protocol.hh @@ -4,6 +4,7 @@ #include "sync_kind.hh" #include "sync_state.hh" #include "execution_state.hh" +#include "read_result.hh" #include #include @@ -19,8 +20,7 @@ public: virtual std::unique_ptr make_lock_state() const = 0; // Read a shared variable into the thread context - virtual std::optional read(ThreadContext &ctx, - const std::string &var) = 0; + virtual ReadResult read(ThreadContext &ctx, const std::string &var) = 0; // Write a shared variable (staged, not committed) virtual void write(ThreadContext &ctx, const std::string &var, diff --git a/src/thread_trace.hh b/src/thread_trace.hh index ab854f7..bdd6625 100644 --- a/src/thread_trace.hh +++ b/src/thread_trace.hh @@ -9,7 +9,7 @@ struct Event; struct StartEvent {}; struct SpawnEvent { const ThreadID child_tid; }; -struct ReadEvent { const std::string var; const size_t value; }; +struct ReadEvent { const std::string var; const size_t value; std::unique_ptr maybe_conflict; }; struct WriteEvent { const std::string var; const size_t value; }; struct LockEvent { std::string lock_name; std::unique_ptr maybe_conflict; std::shared_ptr last_unlock_event; }; struct UnlockEvent { const std::string lock_name; std::unique_ptr maybe_conflict; }; @@ -52,7 +52,12 @@ inline std::ostream& operator<<(std::ostream& os, const SpawnEvent& e) { } inline std::ostream& operator<<(std::ostream& os, const ReadEvent& e) { - return os << "ReadEvent(var=\"" << e.var << "\", value=" << e.value << ")"; + os << "ReadEvent(var=\"" << e.var << "\", value=" << e.value; + if (e.maybe_conflict) + os << ", conflict)"; + else + os << ")"; + return os; } inline std::ostream& operator<<(std::ostream& os, const WriteEvent& e) { @@ -134,8 +139,8 @@ private: return append(child_tid); } - std::shared_ptr on_read(const std::string text, const size_t value) { - return append(std::move(text), value); + std::shared_ptr on_read(const std::string text, const size_t value, std::unique_ptr conflict = nullptr) { + return append(std::move(text), value, std::move(conflict)); } std::shared_ptr on_write(const std::string text, const size_t value) { From 1cbd0f606636a39faa64f714bac88a12e590f3e3 Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Wed, 7 Jan 2026 19:54:56 +0100 Subject: [PATCH 35/58] removing the unused global context in the sync protocols --- src/branching/base_sync_protocol.cc | 18 +++++++----------- src/branching/base_sync_protocol.hh | 14 ++++++-------- src/interpreter.cc | 12 ++++++------ src/linear/sync_protocol.cc | 17 ++++++----------- src/linear/sync_protocol.hh | 14 ++++++-------- src/sync_protocol.hh | 14 ++++++-------- 6 files changed, 37 insertions(+), 52 deletions(-) diff --git a/src/branching/base_sync_protocol.cc b/src/branching/base_sync_protocol.cc index f60ea93..7b26266 100644 --- a/src/branching/base_sync_protocol.cc +++ b/src/branching/base_sync_protocol.cc @@ -24,7 +24,7 @@ std::ostream &BranchingSyncProtocolBase::print(std::ostream &os) const { } ReadResult BranchingSyncProtocolBase::read(ThreadContext &ctx, - const std::string &var) { + const std::string &var) { ObjectNumber number = _global_store.get_object_number(var); auto& store = get_store(ctx); @@ -46,8 +46,7 @@ void BranchingSyncProtocolBase::write(ThreadContext &ctx, const std::string &var } std::optional> -BranchingSyncProtocolBase::on_spawn(ThreadContext &parent, ThreadContext &child, - GlobalContext &) { +BranchingSyncProtocolBase::on_spawn(ThreadContext &parent, ThreadContext &child) { auto& parent_store = get_store(parent); parent_store.commit_staging(); @@ -59,8 +58,7 @@ BranchingSyncProtocolBase::on_spawn(ThreadContext &parent, ThreadContext &child, } std::optional> -BranchingSyncProtocolBase::on_join(ThreadContext &joiner, ThreadContext &joinee, - GlobalContext &) { +BranchingSyncProtocolBase::on_join(ThreadContext &joiner, ThreadContext &joinee) { auto& joiner_store = get_store(joiner); auto& joinee_store = get_store(joinee); @@ -78,13 +76,13 @@ BranchingSyncProtocolBase::on_join(ThreadContext &joiner, ThreadContext &joinee, } std::optional> -BranchingSyncProtocolBase::on_start(ThreadContext &thread, GlobalContext &gctx) { +BranchingSyncProtocolBase::on_start(ThreadContext &thread) { // nothing to do, the thread will have inhereted the parent commit on spawn return std::nullopt; }; std::optional> -BranchingSyncProtocolBase::on_end(ThreadContext &thread, GlobalContext &gctx) { +BranchingSyncProtocolBase::on_end(ThreadContext &thread) { auto& store = get_store(thread); store.commit_staging(); @@ -92,8 +90,7 @@ BranchingSyncProtocolBase::on_end(ThreadContext &thread, GlobalContext &gctx) { }; std::optional> -BranchingSyncProtocolBase::on_lock(ThreadContext &thread, Lock &lock, - GlobalContext &) { +BranchingSyncProtocolBase::on_lock(ThreadContext &thread, Lock &lock) { auto& store = get_store(thread); store.commit_staging(); @@ -115,8 +112,7 @@ BranchingSyncProtocolBase::on_lock(ThreadContext &thread, Lock &lock, } std::optional> -BranchingSyncProtocolBase::on_unlock(ThreadContext &thread, Lock &lock, - GlobalContext &) { +BranchingSyncProtocolBase::on_unlock(ThreadContext &thread, Lock &lock) { auto& store = get_store(thread); store.commit_staging(); diff --git a/src/branching/base_sync_protocol.hh b/src/branching/base_sync_protocol.hh index ff24943..c67c66d 100644 --- a/src/branching/base_sync_protocol.hh +++ b/src/branching/base_sync_protocol.hh @@ -21,24 +21,22 @@ public: void write(ThreadContext &ctx, const std::string &var, size_t value) override; std::optional> - on_spawn(ThreadContext &parent, ThreadContext &child, - GlobalContext &gctx) override; + on_spawn(ThreadContext &parent, ThreadContext &child) override; std::optional> - on_join(ThreadContext &joiner, ThreadContext &joinee, - GlobalContext &gctx) override; + on_join(ThreadContext &joiner, ThreadContext &joinee) override; std::optional> - on_start(ThreadContext &thread, GlobalContext &gctx) override; + on_start(ThreadContext &thread) override; std::optional> - on_end(ThreadContext &thread, GlobalContext &gctx) override; + on_end(ThreadContext &thread) override; std::optional> - on_lock(ThreadContext &thread, Lock &lock, GlobalContext &gctx) override; + on_lock(ThreadContext &thread, Lock &lock) override; std::optional> - on_unlock(ThreadContext &thread, Lock &lock, GlobalContext &gctx) override; + on_unlock(ThreadContext &thread, Lock &lock) override; std::ostream &print(std::ostream &os) const override; diff --git a/src/interpreter.cc b/src/interpreter.cc index 4ec36fc..cafab3d 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -99,7 +99,7 @@ Interpreter::evaluate_expression(trieste::Node expr, Thread& thread) { ThreadContext child_ctx(child_tid, gctx.protocol); if (std::optional> conflict = - gctx.protocol->on_spawn(ctx, child_ctx, gctx)) { + gctx.protocol->on_spawn(ctx, child_ctx)) { throw std::logic_error("This code path should never be reached"); } @@ -222,7 +222,7 @@ std::variant Interpreter::run_statement(Node stmt, Threa auto &joinee = gctx.threads[result]; if (joinee.terminated && (*joinee.terminated == TerminationStatus::completed)) { - if (auto conflict = gctx.protocol->on_join(ctx, joinee.ctx, gctx)) { + if (auto conflict = gctx.protocol->on_join(ctx, joinee.ctx)) { verbose << (**conflict) << std::endl; thread.trace.on_join(result, std::move(*conflict)); return TerminationStatus::datarace_exception; @@ -250,7 +250,7 @@ std::variant Interpreter::run_statement(Node stmt, Threa lock.owner = thread.tid; - if (auto conflict = gctx.protocol->on_lock(ctx, lock, gctx)) { + if (auto conflict = gctx.protocol->on_lock(ctx, lock)) { verbose << (**conflict) << std::endl; thread.trace.on_lock(var, lock.last_unlock_event, std::move(*conflict)); return TerminationStatus::datarace_exception; @@ -274,7 +274,7 @@ std::variant Interpreter::run_statement(Node stmt, Threa return TerminationStatus::unlock_exception; } - if (auto conflict = gctx.protocol->on_unlock(ctx, lock, gctx)) { + if (auto conflict = gctx.protocol->on_unlock(ctx, lock)) { verbose << (**conflict) << std::endl; thread.trace.on_unlock(var, std::move(*conflict)); return TerminationStatus::datarace_exception; @@ -326,7 +326,7 @@ Interpreter::run_single_thread_to_sync(Thread& thread) { // Initial sync when thread starts executing if (pc == 0) { - gctx.protocol->on_start(ctx, gctx); + gctx.protocol->on_start(ctx); thread.trace.on_start(); } @@ -362,7 +362,7 @@ Interpreter::run_single_thread_to_sync(Thread& thread) { return ProgressStatus::progress; // Otherwise, we truly reached the end this iteration - if (auto conflict = gctx.protocol->on_end(ctx, gctx)) { + if (auto conflict = gctx.protocol->on_end(ctx)) { verbose << (**conflict) << std::endl; thread.terminated = TerminationStatus::datarace_exception; return TerminationStatus::datarace_exception; diff --git a/src/linear/sync_protocol.cc b/src/linear/sync_protocol.cc index 50ce0a7..ca7c2f9 100644 --- a/src/linear/sync_protocol.cc +++ b/src/linear/sync_protocol.cc @@ -83,8 +83,7 @@ void LinearSyncProtocol::write(ThreadContext &ctx, const std::string &var, } std::optional> -LinearSyncProtocol::on_spawn(ThreadContext &parent, ThreadContext &child, - GlobalContext &gctx) { +LinearSyncProtocol::on_spawn(ThreadContext &parent, ThreadContext &child) { // TODO: i think we can drop the globalcontext but check after branching is // added @@ -103,8 +102,7 @@ LinearSyncProtocol::on_spawn(ThreadContext &parent, ThreadContext &child, } std::optional> -LinearSyncProtocol::on_join(ThreadContext &joiner, ThreadContext &joinee, - GlobalContext &gctx) { +LinearSyncProtocol::on_join(ThreadContext &joiner, ThreadContext &joinee) { // we assume the joinee has already terminated and pushed // pull changes into parent @@ -116,7 +114,7 @@ LinearSyncProtocol::on_join(ThreadContext &joiner, ThreadContext &joinee, } std::optional> -LinearSyncProtocol::on_start(ThreadContext &thread, GlobalContext &gctx) { +LinearSyncProtocol::on_start(ThreadContext &thread) { // pull state from global history auto& store = get_store(thread); auto conflict = pull(store); @@ -126,7 +124,7 @@ LinearSyncProtocol::on_start(ThreadContext &thread, GlobalContext &gctx) { }; std::optional> -LinearSyncProtocol::on_end(ThreadContext &thread, GlobalContext &gctx) { +LinearSyncProtocol::on_end(ThreadContext &thread) { // push changes to global history auto& store = get_store(thread); if (auto conflict = push(store)) @@ -136,8 +134,7 @@ LinearSyncProtocol::on_end(ThreadContext &thread, GlobalContext &gctx) { }; std::optional> -LinearSyncProtocol::on_lock(ThreadContext &thread, Lock &lock, - GlobalContext &gctx) { +LinearSyncProtocol::on_lock(ThreadContext &thread, Lock &lock) { auto& store = get_store(thread); if (auto conflict = pull(store)) @@ -147,9 +144,7 @@ LinearSyncProtocol::on_lock(ThreadContext &thread, Lock &lock, } std::optional> -LinearSyncProtocol::on_unlock(ThreadContext &thread, Lock &, - GlobalContext &gctx) { - +LinearSyncProtocol::on_unlock(ThreadContext &thread, Lock &) { // push changes to global history auto& store = get_store(thread); if (auto conflict = push(store)) diff --git a/src/linear/sync_protocol.hh b/src/linear/sync_protocol.hh index 265b78b..b85aa83 100644 --- a/src/linear/sync_protocol.hh +++ b/src/linear/sync_protocol.hh @@ -27,24 +27,22 @@ public: void write(ThreadContext &ctx, const std::string &var, size_t value) override; std::optional> - on_spawn(ThreadContext &parent, ThreadContext &child, - GlobalContext &gctx) override; + on_spawn(ThreadContext &parent, ThreadContext &child) override; std::optional> - on_join(ThreadContext &joiner, ThreadContext &joinee, - GlobalContext &gctx) override; + on_join(ThreadContext &joiner, ThreadContext &joinee) override; std::optional> - on_start(ThreadContext &thread, GlobalContext &gctx) override; + on_start(ThreadContext &thread) override; std::optional> - on_end(ThreadContext &thread, GlobalContext &gctx) override; + on_end(ThreadContext &thread) override; std::optional> - on_lock(ThreadContext &thread, Lock &lock, GlobalContext &gctx) override; + on_lock(ThreadContext &thread, Lock &lock) override; std::optional> - on_unlock(ThreadContext &thread, Lock &lock, GlobalContext &gctx) override; + on_unlock(ThreadContext &thread, Lock &lock) override; std::ostream &print(std::ostream &os) const override; diff --git a/src/sync_protocol.hh b/src/sync_protocol.hh index 052cfc0..98de8f2 100644 --- a/src/sync_protocol.hh +++ b/src/sync_protocol.hh @@ -27,24 +27,22 @@ public: size_t value) = 0; virtual std::optional> - on_spawn(ThreadContext &parent, ThreadContext &child, - GlobalContext &gctx) = 0; + on_spawn(ThreadContext &parent, ThreadContext &child) = 0; virtual std::optional> - on_join(ThreadContext &joiner, ThreadContext &joinee, - GlobalContext &gctx) = 0; + on_join(ThreadContext &joiner, ThreadContext &joinee) = 0; virtual std::optional> - on_start(ThreadContext &thread, GlobalContext &gctx) = 0; + on_start(ThreadContext &thread) = 0; virtual std::optional> - on_end(ThreadContext &thread, GlobalContext &gctx) = 0; + on_end(ThreadContext &thread) = 0; virtual std::optional> - on_lock(ThreadContext &thread, Lock &lock, GlobalContext &gctx) = 0; + on_lock(ThreadContext &thread, Lock &lock) = 0; virtual std::optional> - on_unlock(ThreadContext &thread, Lock &lock, GlobalContext &gctx) = 0; + on_unlock(ThreadContext &thread, Lock &lock) = 0; virtual std::ostream &print(std::ostream &os) const = 0; From 83ace8d37875851348fbc95a956d0f0740855629 Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Wed, 7 Jan 2026 21:14:24 +0100 Subject: [PATCH 36/58] separating branching into lazy and eager conflict detection protocols, lazy is unfinished --- CMakeLists.txt | 4 +- src/branching/base_sync_protocol.cc | 2 +- src/branching/base_sync_protocol.hh | 6 +- src/branching/base_version_store.cc | 154 +++++++ ...version_store.hh => base_version_store.hh} | 13 +- src/branching/eager/sync_protocol.hh | 23 ++ src/branching/eager/version_store.cc | 169 ++++++++ src/branching/eager/version_store.hh | 22 + src/branching/eager_sync_protocol.hh | 16 - src/branching/lazy/sync_protocol.hh | 23 ++ src/branching/lazy/version_store.cc | 101 +++++ src/branching/lazy/version_store.hh | 22 + src/branching/lazy_sync_protocol.hh | 16 - src/branching/version_store.cc | 390 ------------------ src/execution_state.hh | 4 +- src/sync_protocol.cc | 6 +- 16 files changed, 526 insertions(+), 445 deletions(-) create mode 100644 src/branching/base_version_store.cc rename src/branching/{version_store.hh => base_version_store.hh} (86%) create mode 100644 src/branching/eager/sync_protocol.hh create mode 100644 src/branching/eager/version_store.cc create mode 100644 src/branching/eager/version_store.hh delete mode 100644 src/branching/eager_sync_protocol.hh create mode 100644 src/branching/lazy/sync_protocol.hh create mode 100644 src/branching/lazy/version_store.cc create mode 100644 src/branching/lazy/version_store.hh delete mode 100644 src/branching/lazy_sync_protocol.hh delete mode 100644 src/branching/version_store.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index fd994a1..4a417ac 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -28,7 +28,9 @@ add_executable(gitmem src/linear/sync_protocol.cc src/linear/version_store.cc src/branching/base_sync_protocol.cc - src/branching/version_store.cc + src/branching/base_version_store.cc + src/branching/lazy/version_store.cc + src/branching/eager/version_store.cc src/interpreter.cc src/debugger.cc src/model_checker.cc diff --git a/src/branching/base_sync_protocol.cc b/src/branching/base_sync_protocol.cc index 7b26266..536956a 100644 --- a/src/branching/base_sync_protocol.cc +++ b/src/branching/base_sync_protocol.cc @@ -1,4 +1,4 @@ -#include "branching/base_sync_protocol.hh" +#include "base_sync_protocol.hh" namespace gitmem { diff --git a/src/branching/base_sync_protocol.hh b/src/branching/base_sync_protocol.hh index c67c66d..bc53087 100644 --- a/src/branching/base_sync_protocol.hh +++ b/src/branching/base_sync_protocol.hh @@ -1,7 +1,7 @@ #pragma once #include "../sync_protocol.hh" -#include "version_store.hh" +#include "base_version_store.hh" namespace gitmem { @@ -40,10 +40,6 @@ public: std::ostream &print(std::ostream &os) const override; - std::unique_ptr make_thread_state(ThreadID tid) const override { - return std::make_unique(tid); - } - std::unique_ptr make_lock_state() const override { return std::make_unique(); } diff --git a/src/branching/base_version_store.cc b/src/branching/base_version_store.cc new file mode 100644 index 0000000..fb14b42 --- /dev/null +++ b/src/branching/base_version_store.cc @@ -0,0 +1,154 @@ +#include "base_version_store.hh" +#include +#include +#include "debug.hh" + +namespace gitmem { + +namespace branching { + +// Helper: recursive print with 2-space indentation and cycle protection +void print_commit_recursive(std::ostream& os, + const std::shared_ptr& commit, + std::unordered_set& visited, + int depth = 0) +{ + if (!commit) return; + if (!visited.insert(commit.get()).second) { + os << std::string(depth * 2, ' ') << "(already printed commit " << commit->id << ")\n"; + return; + } + + os << std::string(depth * 2, ' ') << "Commit " << commit->id << " {\n"; + + // Print changes + for (const auto& [obj, val] : commit->changes) { + os << std::string((depth + 1) * 2, ' ') << obj << " -> " << val << "\n"; + } + + // Print parents + if (!commit->parents.empty()) { + os << std::string((depth + 1) * 2, ' ') << "Parents: "; + for (size_t i = 0; i < commit->parents.size(); ++i) { + os << commit->parents[i]->id; + if (i + 1 < commit->parents.size()) os << ", "; + } + os << "\n"; + } + + os << std::string(depth * 2, ' ') << "}\n"; + + // Recursively print parents + for (auto& parent : commit->parents) { + print_commit_recursive(os, parent, visited, depth + 1); + } +} + +// operator<< for Commit +std::ostream& operator<<(std::ostream& os, const Commit& commit) { + std::unordered_set visited; + // Wrap the commit in a shared_ptr to reuse the recursive helper + print_commit_recursive(os, std::make_shared(commit), visited); + return os; +} + +void LocalVersionStore::stage(ObjectNumber obj, Value value) { + staging[obj] = value; +} + +void LocalVersionStore::commit_staging() { + // No-op commit does nothing + if (staging.empty()) { + return; + } + + // Create the new commit with the staged changes + auto new_commit = std::make_shared(base_timestamp++, std::move(staging)); + + // Update last_writer for each staged variable + for (const auto& [obj, _] : new_commit->changes) { + last_writer[obj] = new_commit; + } + + // Clear staging + staging.clear(); + + // Set parent to previous head if it exists + if (head) + new_commit->parents.push_back(head); + + // Update head + head = new_commit; +} + +std::optional LocalVersionStore::get_staged(ObjectNumber obj) const { + auto it = staging.find(obj); + return it != staging.end() ? std::make_optional(it->second) : std::nullopt; +} + +void LocalVersionStore::adopt_history(const LocalVersionStore& other) { + // Inherit the DAG head + head = other.head; + + // Inherit the last_writer cache so the child sees all latest commits + last_writer = other.last_writer; +} + +std::ostream& operator<<(std::ostream& os, const LocalVersionStore& store) { + os << "LocalVersionStore{" + << "base=" << store.base_timestamp + << ", head="; + + if (store.head) + os << store.head->id; + else + os << "null"; + + os << ", staged={"; + + bool first = true; + for (const auto& [obj, val] : store.staging) { + if (!first) os << ", "; + first = false; + os << obj << "->" << val; + } + + os << "}}"; + return os; +} + +bool LocalVersionStore::operator==(const LocalVersionStore& other) const { + return base_timestamp == other.base_timestamp && + head == other.head && + staging == other.staging; +} + +ObjectNumber GlobalVersionStore::get_object_number(std::string var) { + auto it = _object_numbers.find(var); + if (it != _object_numbers.end()) { + return it->second; + } else { + ObjectNumber number = _next_object++; + _object_numbers[var] = number; + return number; + } +} + +std::string GlobalVersionStore::get_object_name(ObjectNumber find) { + for (const auto &[name, number] : _object_numbers) { + if (number == find) + return name; + } + assert(false && "failed to find object name for object number"); + return ""; +} + + +std::ostream& operator<<(std::ostream& os, const GlobalVersionStore& store) { + os << "GlobalVersionStore(next_object=" << store._next_object << ")" << std::endl; + return os; +} + +} // branching + +} // gitmem \ No newline at end of file diff --git a/src/branching/version_store.hh b/src/branching/base_version_store.hh similarity index 86% rename from src/branching/version_store.hh rename to src/branching/base_version_store.hh index 42215e2..00da798 100644 --- a/src/branching/version_store.hh +++ b/src/branching/base_version_store.hh @@ -14,11 +14,6 @@ namespace gitmem { namespace branching { -/* A 'Global' is a structure to capture the current synchronising objects - * representation of a global variable. The structure is the current value, - * the current commit id for the variable, and the history of commited ids. - */ - struct Timestamp { ThreadID thread; size_t counter; @@ -53,7 +48,6 @@ struct Commit { std::vector> parents; }; -// operator<< for Commit std::ostream& operator<<(std::ostream& os, const Commit& commit); struct Conflict { @@ -68,9 +62,8 @@ inline std::ostream& operator<<(std::ostream& os, const Conflict& c) { << ", timestamp_b=" << c.timestamp_b << "}"; } -// Initial plumbing for fail late - class LocalVersionStore : public ThreadSyncState { +protected: Timestamp base_timestamp; std::shared_ptr head; std::unordered_map staging; @@ -90,10 +83,10 @@ public: std::shared_ptr get_head() const { return head; } std::optional get_staged(ObjectNumber obj) const; - std::optional get_committed(ObjectNumber number) const; + virtual std::optional get_committed(ObjectNumber number) const = 0; void adopt_history(const LocalVersionStore& other); - std::optional merge_with_commit(const std::shared_ptr& other_head); + virtual std::optional merge_with_commit(const std::shared_ptr& other_head) = 0; friend std::ostream& operator<<(std::ostream& os, const LocalVersionStore& store); std::ostream &print(std::ostream &os) const override { diff --git a/src/branching/eager/sync_protocol.hh b/src/branching/eager/sync_protocol.hh new file mode 100644 index 0000000..fa5b26e --- /dev/null +++ b/src/branching/eager/sync_protocol.hh @@ -0,0 +1,23 @@ +#pragma once + +#include "branching/base_sync_protocol.hh" +#include "branching/eager/version_store.hh" + +namespace gitmem { + +namespace branching { + +class BranchingEagerSyncProtocol final : public BranchingSyncProtocolBase { +public: + ~BranchingEagerSyncProtocol() = default; + + SyncKind kind() const override { return SyncKind::BranchingEager; }; + + std::unique_ptr make_thread_state(ThreadID tid) const override { + return std::make_unique(tid); + } +}; + +} + +} \ No newline at end of file diff --git a/src/branching/eager/version_store.cc b/src/branching/eager/version_store.cc new file mode 100644 index 0000000..3f4818a --- /dev/null +++ b/src/branching/eager/version_store.cc @@ -0,0 +1,169 @@ +#include "branching/eager/version_store.hh" +#include "debug.hh" + +#include + +namespace gitmem { + +namespace branching { + +bool can_reach_lca( + const std::shared_ptr& commit, + const std::shared_ptr& lca, + std::unordered_map, bool>& memo) +{ + if (!commit) + return false; + + if (commit == lca) + return true; + + auto it = memo.find(commit); + if (it != memo.end()) + return it->second; + + for (const auto& parent : commit->parents) { + if (can_reach_lca(parent, lca, memo)) { + memo[commit] = true; + return true; + } + } + + memo[commit] = false; + return false; +} + +bool traverse_until_lca( + const std::shared_ptr& commit, + const std::shared_ptr& lca, + std::unordered_map>& out_map, + std::unordered_set>& visited, + std::unordered_map, bool>& reach_memo) +{ + if (!commit || commit == lca || !visited.insert(commit).second) + return true; + + if (!can_reach_lca(commit, lca, reach_memo)) + return true; + + for (const auto& [obj, _] : commit->changes) { + // first write seen dominates + if (out_map.find(obj) == out_map.end()) + out_map[obj] = commit; + } + + for (auto& parent : commit->parents) { + if (!traverse_until_lca(parent, lca, out_map, visited, reach_memo)) + return false; + } + + return true; +} + +std::shared_ptr +find_lowest_common_ancestor(std::shared_ptr a, + std::shared_ptr b) +{ + if (!a || !b) return nullptr; + + // Step 1: collect all ancestors of 'a' + std::unordered_set> ancestors_a; + std::queue> q; + q.push(a); + + while (!q.empty()) { + auto c = q.front(); q.pop(); + if (!c) continue; + if (!ancestors_a.insert(c).second) continue; // already visited + + for (auto& p : c->parents) + q.push(p); + } + + // Step 2: BFS from 'b' to find first common ancestor + std::unordered_set> visited_b; + q.push(b); + + while (!q.empty()) { + auto c = q.front(); q.pop(); + if (!c) continue; + if (!visited_b.insert(c).second) continue; + + if (ancestors_a.count(c)) + return c; // first common ancestor seen + + for (auto& p : c->parents) + q.push(p); + } + + return nullptr; // disjoint histories (shouldn’t happen) +} + +std::optional EagerLocalVersionStore::merge_with_commit(const std::shared_ptr& commit) { + assert(staging.empty()); + assert(commit != nullptr); + + // trivial case: same history + if (head == commit) + return std::nullopt; + + // Create merge commit (no changes itself) + auto merge_commit = std::make_shared( + Commit{ + .id = base_timestamp++, + .parents = {head, commit}, + .changes = {} // merge commit does not write anything + } + ); + + // Find lowest common ancestor of the two heads + std::shared_ptr lca = find_lowest_common_ancestor(head, commit); + verbose << "found lca of " << head->id << " and " << commit->id << " to be " << lca->id << std::endl; + + // Collect all writes after LCA for each branch + std::unordered_map> branch_a, branch_b; + std::unordered_set> visited; + + std::unordered_map, bool> reach_memo; + traverse_until_lca(head, lca, branch_a, visited, reach_memo); + visited.clear(); + traverse_until_lca(commit, lca, branch_b, visited, reach_memo); + + // 1. Eager conflict detection + for (const auto& [obj, commit_a] : branch_a) { + auto it = branch_b.find(obj); + if (it != branch_b.end() && it->second != commit_a) { + return Conflict{ + .obj = obj, + .timestamp_a = commit_a->id, + .timestamp_b = it->second->id + }; + } + } + + // 2. Update thread-local last_writer incrementally + // Only overwrite variables that were touched along either branch after LCA + for (const auto& [obj, commit] : branch_a) + last_writer[obj] = commit; + + for (const auto& [obj, commit] : branch_b) + last_writer[obj] = commit; + + // 3. Variables not touched in either branch remain unchanged (from before LCA) + + // 4. Update head + head = merge_commit; + + return std::nullopt; +} + +std::optional EagerLocalVersionStore::get_committed(ObjectNumber number) const { + if (auto it = last_writer.find(number); it != last_writer.end()) + return it->second->changes.at(number); + + return std::nullopt; +} + +} // end branching + +} // end gitmem \ No newline at end of file diff --git a/src/branching/eager/version_store.hh b/src/branching/eager/version_store.hh new file mode 100644 index 0000000..f1f5923 --- /dev/null +++ b/src/branching/eager/version_store.hh @@ -0,0 +1,22 @@ +#pragma once + +#include "branching/base_version_store.hh" + +namespace gitmem { + +namespace branching { + +class EagerLocalVersionStore : public LocalVersionStore { +public: + ~EagerLocalVersionStore() = default; + + EagerLocalVersionStore(ThreadID tid) : LocalVersionStore(tid) {} + + std::optional merge_with_commit(const std::shared_ptr&) override; + std::optional get_committed(ObjectNumber number) const override; + +}; + +} + +} \ No newline at end of file diff --git a/src/branching/eager_sync_protocol.hh b/src/branching/eager_sync_protocol.hh deleted file mode 100644 index a50b909..0000000 --- a/src/branching/eager_sync_protocol.hh +++ /dev/null @@ -1,16 +0,0 @@ -#pragma once - -#include "base_sync_protocol.hh" - -namespace gitmem { - -namespace branching { - -class BranchingEagerSyncProtocol final : public BranchingSyncProtocolBase { -public: - SyncKind kind() const override { return SyncKind::BranchingEager; }; -}; - -} - -} \ No newline at end of file diff --git a/src/branching/lazy/sync_protocol.hh b/src/branching/lazy/sync_protocol.hh new file mode 100644 index 0000000..d30a5c5 --- /dev/null +++ b/src/branching/lazy/sync_protocol.hh @@ -0,0 +1,23 @@ +#pragma once + +#include "branching/base_sync_protocol.hh" +#include "branching/lazy/version_store.hh" + +namespace gitmem { + +namespace branching { + +class BranchingLazySyncProtocol final : public BranchingSyncProtocolBase { +public: + ~BranchingLazySyncProtocol() = default; + + SyncKind kind() const override { return SyncKind::BranchingLazy; }; + + std::unique_ptr make_thread_state(ThreadID tid) const override { + return std::make_unique(tid); + } +}; + +} + +} \ No newline at end of file diff --git a/src/branching/lazy/version_store.cc b/src/branching/lazy/version_store.cc new file mode 100644 index 0000000..0911c97 --- /dev/null +++ b/src/branching/lazy/version_store.cc @@ -0,0 +1,101 @@ +#include "branching/lazy/version_store.hh" +#include "debug.hh" + +namespace gitmem { + +namespace branching { + +// std::optional traverse_to_lca( +// const std::shared_ptr& commit, +// ObjectNumber var, +// const std::shared_ptr& lca) +// { +// if (!commit || commit == lca) return std::nullopt; + +// auto it = commit->changes.find(var); +// if (it != commit->changes.end()) return it->second; + +// if (commit->parents.size() > 1) +// assert(false && "should never encounter multi-parent commit before LCA"); + +// return traverse_to_lca(commit->parents[0], var, lca); +// } + +// std::optional get_committed_recursive( +// const std::shared_ptr& commit, +// ObjectNumber var) { +// if (!commit) return std::nullopt; + +// // 1. If this commit wrote the variable, return it +// auto it = commit->changes.find(var); +// if (it != commit->changes.end()) return it->second; + +// // 2. If single parent, recurse +// if (commit->parents.size() == 1) +// return get_committed_recursive(commit->parents[0], var); + +// // 3. Merge commit +// assert(commit->parents.size() == 2); // merge commit + +// auto& p1 = commit->parents[0]; +// auto& p2 = commit->parents[1]; + +// auto lca = find_lowest_common_ancestor(p1, p2); + +// // Explore both paths from merge commit to LCA +// std::optional v1 = traverse_to_lca(p1, var, lca); +// std::optional v2 = traverse_to_lca(p2, var, lca); + +// assert(!v1 || !v2 || v1 == v2); // conflict-free invariant + +// if (v1) return v1; // found in one of the merge branches +// if (v2) return v2; + +// // 4. Not found yet → continue recursively from the LCA downward +// return get_committed_recursive(lca, var); +// } + +// std::optional> +// get_committed_recursive( +// const std::shared_ptr& commit, +// ObjectNumber number, +// std::unordered_set>& visited) { + +// if (!commit || !visited.insert(commit).second) +// return std::nullopt; + +// std::optional> found; + +// // Recurse into all parents first +// for (auto& parent : commit->parents) { +// auto parent_commit = get_committed_recursive(parent, number, visited); +// if (parent_commit) { +// if (!found.has_value()) +// found = parent_commit; +// else if (found.value()->changes.at(number) != parent_commit.value()->changes.at(number)) +// assert(false && "Conflict detected (should be impossible in conflict-free DAG)"); +// } +// } + +// // If this commit wrote the variable, it overrides any parent +// if (commit->changes.contains(number)) +// return commit; + +// return found; +// } + + +std::optional LazyLocalVersionStore::merge_with_commit(const std::shared_ptr&) { + assert(false && "todo"); + return std::nullopt; + +} + +std::optional LazyLocalVersionStore::get_committed(ObjectNumber number) const { + assert(false && "todo"); + return std::nullopt; +} + +} + +} \ No newline at end of file diff --git a/src/branching/lazy/version_store.hh b/src/branching/lazy/version_store.hh new file mode 100644 index 0000000..45598a7 --- /dev/null +++ b/src/branching/lazy/version_store.hh @@ -0,0 +1,22 @@ +#pragma once + +#include "branching/base_version_store.hh" + +namespace gitmem { + +namespace branching { + +class LazyLocalVersionStore : public LocalVersionStore { +public: + ~LazyLocalVersionStore() = default; + + LazyLocalVersionStore(ThreadID tid) : LocalVersionStore(tid) {} + + std::optional merge_with_commit(const std::shared_ptr&) override; + std::optional get_committed(ObjectNumber number) const override; + +}; + +} + +} \ No newline at end of file diff --git a/src/branching/lazy_sync_protocol.hh b/src/branching/lazy_sync_protocol.hh deleted file mode 100644 index 3b0a9dc..0000000 --- a/src/branching/lazy_sync_protocol.hh +++ /dev/null @@ -1,16 +0,0 @@ -#pragma once - -#include "base_sync_protocol.hh" - -namespace gitmem { - -namespace branching { - -class BranchingLazySyncProtocol final : public BranchingSyncProtocolBase { -public: - SyncKind kind() const override { return SyncKind::BranchingLazy; }; -}; - -} - -} \ No newline at end of file diff --git a/src/branching/version_store.cc b/src/branching/version_store.cc deleted file mode 100644 index 1366830..0000000 --- a/src/branching/version_store.cc +++ /dev/null @@ -1,390 +0,0 @@ -#include "version_store.hh" -#include -#include -#include "debug.hh" - -namespace gitmem { - -namespace branching { - -// Helper: recursive print with 2-space indentation and cycle protection -void print_commit_recursive(std::ostream& os, - const std::shared_ptr& commit, - std::unordered_set& visited, - int depth = 0) -{ - if (!commit) return; - if (!visited.insert(commit.get()).second) { - os << std::string(depth * 2, ' ') << "(already printed commit " << commit->id << ")\n"; - return; - } - - os << std::string(depth * 2, ' ') << "Commit " << commit->id << " {\n"; - - // Print changes - for (const auto& [obj, val] : commit->changes) { - os << std::string((depth + 1) * 2, ' ') << obj << " -> " << val << "\n"; - } - - // Print parents - if (!commit->parents.empty()) { - os << std::string((depth + 1) * 2, ' ') << "Parents: "; - for (size_t i = 0; i < commit->parents.size(); ++i) { - os << commit->parents[i]->id; - if (i + 1 < commit->parents.size()) os << ", "; - } - os << "\n"; - } - - os << std::string(depth * 2, ' ') << "}\n"; - - // Recursively print parents - for (auto& parent : commit->parents) { - print_commit_recursive(os, parent, visited, depth + 1); - } -} - -// operator<< for Commit -std::ostream& operator<<(std::ostream& os, const Commit& commit) { - std::unordered_set visited; - // Wrap the commit in a shared_ptr to reuse the recursive helper - print_commit_recursive(os, std::make_shared(commit), visited); - return os; -} - -void LocalVersionStore::stage(ObjectNumber obj, Value value) { - staging[obj] = value; -} - -void LocalVersionStore::commit_staging() { - // No-op commit does nothing - if (staging.empty()) { - return; - } - - // Create the new commit with the staged changes - auto new_commit = std::make_shared(base_timestamp++, std::move(staging)); - - // Update last_writer for each staged variable - for (const auto& [obj, _] : new_commit->changes) { - last_writer[obj] = new_commit; - } - - // Clear staging - staging.clear(); - - // Set parent to previous head if it exists - if (head) - new_commit->parents.push_back(head); - - // Update head - head = new_commit; -} - -std::optional LocalVersionStore::get_staged(ObjectNumber obj) const { - auto it = staging.find(obj); - return it != staging.end() ? std::make_optional(it->second) : std::nullopt; -} - -std::optional> -get_committed_recursive( - const std::shared_ptr& commit, - ObjectNumber number, - std::unordered_set>& visited) { - - if (!commit || !visited.insert(commit).second) - return std::nullopt; - - std::optional> found; - - // Recurse into all parents first - for (auto& parent : commit->parents) { - auto parent_commit = get_committed_recursive(parent, number, visited); - if (parent_commit) { - if (!found.has_value()) - found = parent_commit; - else if (found.value()->changes.at(number) != parent_commit.value()->changes.at(number)) - assert(false && "Conflict detected (should be impossible in conflict-free DAG)"); - } - } - - // If this commit wrote the variable, it overrides any parent - if (commit->changes.contains(number)) - return commit; - - return found; -} - -std::shared_ptr -find_lowest_common_ancestor(std::shared_ptr a, - std::shared_ptr b) -{ - if (!a || !b) return nullptr; - - // Step 1: collect all ancestors of 'a' - std::unordered_set> ancestors_a; - std::queue> q; - q.push(a); - - while (!q.empty()) { - auto c = q.front(); q.pop(); - if (!c) continue; - if (!ancestors_a.insert(c).second) continue; // already visited - - for (auto& p : c->parents) - q.push(p); - } - - // Step 2: BFS from 'b' to find first common ancestor - std::unordered_set> visited_b; - q.push(b); - - while (!q.empty()) { - auto c = q.front(); q.pop(); - if (!c) continue; - if (!visited_b.insert(c).second) continue; - - if (ancestors_a.count(c)) - return c; // first common ancestor seen - - for (auto& p : c->parents) - q.push(p); - } - - return nullptr; // disjoint histories (shouldn’t happen) -} - -// std::optional traverse_to_lca( -// const std::shared_ptr& commit, -// ObjectNumber var, -// const std::shared_ptr& lca) -// { -// if (!commit || commit == lca) return std::nullopt; - -// auto it = commit->changes.find(var); -// if (it != commit->changes.end()) return it->second; - -// if (commit->parents.size() > 1) -// assert(false && "should never encounter multi-parent commit before LCA"); - -// return traverse_to_lca(commit->parents[0], var, lca); -// } - -// std::optional get_committed_recursive( -// const std::shared_ptr& commit, -// ObjectNumber var) { -// if (!commit) return std::nullopt; - -// // 1. If this commit wrote the variable, return it -// auto it = commit->changes.find(var); -// if (it != commit->changes.end()) return it->second; - -// // 2. If single parent, recurse -// if (commit->parents.size() == 1) -// return get_committed_recursive(commit->parents[0], var); - -// // 3. Merge commit -// assert(commit->parents.size() == 2); // merge commit - -// auto& p1 = commit->parents[0]; -// auto& p2 = commit->parents[1]; - -// auto lca = find_lowest_common_ancestor(p1, p2); - -// // Explore both paths from merge commit to LCA -// std::optional v1 = traverse_to_lca(p1, var, lca); -// std::optional v2 = traverse_to_lca(p2, var, lca); - -// assert(!v1 || !v2 || v1 == v2); // conflict-free invariant - -// if (v1) return v1; // found in one of the merge branches -// if (v2) return v2; - -// // 4. Not found yet → continue recursively from the LCA downward -// return get_committed_recursive(lca, var); -// } - -std::optional LocalVersionStore::get_committed(ObjectNumber number) const { - if (auto it = last_writer.find(number); it != last_writer.end()) - return it->second->changes.at(number); - - return std::nullopt; -} - -void LocalVersionStore::adopt_history(const LocalVersionStore& other) { - // Inherit the DAG head - head = other.head; - - // Inherit the last_writer cache so the child sees all latest commits - last_writer = other.last_writer; -} - -bool can_reach_lca( - const std::shared_ptr& commit, - const std::shared_ptr& lca, - std::unordered_map, bool>& memo) -{ - if (!commit) - return false; - - if (commit == lca) - return true; - - auto it = memo.find(commit); - if (it != memo.end()) - return it->second; - - for (const auto& parent : commit->parents) { - if (can_reach_lca(parent, lca, memo)) { - memo[commit] = true; - return true; - } - } - - memo[commit] = false; - return false; -} - -bool traverse_until_lca( - const std::shared_ptr& commit, - const std::shared_ptr& lca, - std::unordered_map>& out_map, - std::unordered_set>& visited, - std::unordered_map, bool>& reach_memo) -{ - if (!commit || commit == lca || !visited.insert(commit).second) - return true; - - if (!can_reach_lca(commit, lca, reach_memo)) - return true; - - for (const auto& [obj, _] : commit->changes) { - // first write seen dominates - if (out_map.find(obj) == out_map.end()) - out_map[obj] = commit; - } - - for (auto& parent : commit->parents) { - if (!traverse_until_lca(parent, lca, out_map, visited, reach_memo)) - return false; - } - - return true; -} - -std::optional LocalVersionStore::merge_with_commit(const std::shared_ptr& commit) { - assert(staging.empty()); - assert(commit != nullptr); - - // trivial case: same history - if (head == commit) - return std::nullopt; - - // Create merge commit (no changes itself) - auto merge_commit = std::make_shared( - Commit{ - .id = base_timestamp++, - .parents = {head, commit}, - .changes = {} // merge commit does not write anything - } - ); - - // Find lowest common ancestor of the two heads - std::shared_ptr lca = find_lowest_common_ancestor(head, commit); - verbose << "found lca of " << head->id << " and " << commit->id << " to be " << lca->id << std::endl; - - // Collect all writes after LCA for each branch - std::unordered_map> branch_a, branch_b; - std::unordered_set> visited; - - std::unordered_map, bool> reach_memo; - traverse_until_lca(head, lca, branch_a, visited, reach_memo); - visited.clear(); - traverse_until_lca(commit, lca, branch_b, visited, reach_memo); - - // 1. Eager conflict detection - for (const auto& [obj, commit_a] : branch_a) { - auto it = branch_b.find(obj); - if (it != branch_b.end() && it->second != commit_a) { - return Conflict{ - .obj = obj, - .timestamp_a = commit_a->id, - .timestamp_b = it->second->id - }; - } - } - - // 2. Update thread-local last_writer incrementally - // Only overwrite variables that were touched along either branch after LCA - for (const auto& [obj, commit] : branch_a) - last_writer[obj] = commit; - - for (const auto& [obj, commit] : branch_b) - last_writer[obj] = commit; - - // 3. Variables not touched in either branch remain unchanged (from before LCA) - - // 4. Update head - head = merge_commit; - - return std::nullopt; -} - -std::ostream& operator<<(std::ostream& os, const LocalVersionStore& store) { - os << "LocalVersionStore{" - << "base=" << store.base_timestamp - << ", head="; - - if (store.head) - os << store.head->id; - else - os << "null"; - - os << ", staged={"; - - bool first = true; - for (const auto& [obj, val] : store.staging) { - if (!first) os << ", "; - first = false; - os << obj << "->" << val; - } - - os << "}}"; - return os; -} - -bool LocalVersionStore::operator==(const LocalVersionStore& other) const { - return base_timestamp == other.base_timestamp && - head == other.head && - staging == other.staging; -} - -ObjectNumber GlobalVersionStore::get_object_number(std::string var) { - auto it = _object_numbers.find(var); - if (it != _object_numbers.end()) { - return it->second; - } else { - ObjectNumber number = _next_object++; - _object_numbers[var] = number; - return number; - } -} - -std::string GlobalVersionStore::get_object_name(ObjectNumber find) { - for (const auto &[name, number] : _object_numbers) { - if (number == find) - return name; - } - assert(false && "failed to find object name for object number"); - return ""; -} - - -std::ostream& operator<<(std::ostream& os, const GlobalVersionStore& store) { - os << "GlobalVersionStore(next_object=" << store._next_object << ")" << std::endl; - return os; -} - -} // branching - -} // gitmem \ No newline at end of file diff --git a/src/execution_state.hh b/src/execution_state.hh index 5c4f0c1..bd00c23 100644 --- a/src/execution_state.hh +++ b/src/execution_state.hh @@ -7,8 +7,7 @@ #include "lang.hh" #include "sync_kind.hh" -#include "linear/version_store.hh" -#include "branching/version_store.hh" +#include "sync_state.hh" #include "graphviz.hh" #include "termination_status.hh" #include "thread_trace.hh" @@ -17,7 +16,6 @@ namespace gitmem { class SyncProtocol; -class ThreadSyncState; struct ThreadContext { std::unordered_map locals; diff --git a/src/sync_protocol.cc b/src/sync_protocol.cc index ff3f7c0..d23716f 100644 --- a/src/sync_protocol.cc +++ b/src/sync_protocol.cc @@ -1,7 +1,7 @@ #include "sync_protocol.hh" #include "linear/sync_protocol.hh" -#include "branching/eager_sync_protocol.hh" -#include "branching/lazy_sync_protocol.hh" +#include "branching/eager/sync_protocol.hh" +#include "branching/lazy/sync_protocol.hh" namespace gitmem { @@ -9,7 +9,7 @@ std::unique_ptr make_protocol(SyncKind sync_kind) { switch (sync_kind) { case SyncKind::Linear: return std::make_unique(); case SyncKind::BranchingEager: return std::make_unique(); - case SyncKind::BranchingLazy: return std::make_unique(); + case SyncKind::BranchingLazy: return std::make_unique(); } std::unreachable(); } From 11399ff578cfb097aaf935502a91e1f454ff0a5c Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Thu, 8 Jan 2026 13:09:17 +0100 Subject: [PATCH 37/58] making termination type richer and report more consitently --- src/branching/base_sync_protocol.cc | 25 +++----- src/branching/base_sync_protocol.hh | 12 ++-- src/branching/base_version_store.cc | 7 ++- src/branching/base_version_store.hh | 9 ++- src/branching/eager/version_store.cc | 4 +- src/branching/eager/version_store.hh | 2 +- src/branching/lazy/version_store.cc | 70 +++++++++++++++++++-- src/branching/lazy/version_store.hh | 26 +++++++- src/conflict.hh | 12 ++++ src/debugger.cc | 84 +++++++++---------------- src/execution_state.cc | 7 +-- src/execution_state.hh | 2 + src/interpreter.cc | 91 ++++++++++++---------------- src/interpreter.hh | 2 + src/linear/sync_protocol.cc | 22 +++---- src/linear/sync_protocol.hh | 12 ++-- src/linear/version_store.hh | 3 +- src/model_checker.cc | 7 +-- src/overloaded.hh | 8 +++ src/read_result.hh | 2 +- src/sync_protocol.hh | 12 ++-- src/termination_status.hh | 90 +++++++++++++++++++++++++-- src/thread_trace.hh | 22 +++---- 23 files changed, 334 insertions(+), 197 deletions(-) create mode 100644 src/overloaded.hh diff --git a/src/branching/base_sync_protocol.cc b/src/branching/base_sync_protocol.cc index 536956a..f9515a6 100644 --- a/src/branching/base_sync_protocol.cc +++ b/src/branching/base_sync_protocol.cc @@ -29,14 +29,7 @@ ReadResult BranchingSyncProtocolBase::read(ThreadContext &ctx, auto& store = get_store(ctx); - if (auto result = store.get_staged(number)) - return *result; - - // look in commit history - if (auto result = store.get_committed(number)) - return *result; - - return std::monostate{}; + return store.read(number); } void BranchingSyncProtocolBase::write(ThreadContext &ctx, const std::string &var, @@ -45,7 +38,7 @@ void BranchingSyncProtocolBase::write(ThreadContext &ctx, const std::string &var store.stage(_global_store.get_object_number(var), value); } -std::optional> +std::optional> BranchingSyncProtocolBase::on_spawn(ThreadContext &parent, ThreadContext &child) { auto& parent_store = get_store(parent); parent_store.commit_staging(); @@ -57,7 +50,7 @@ BranchingSyncProtocolBase::on_spawn(ThreadContext &parent, ThreadContext &child) return std::nullopt; } -std::optional> +std::optional> BranchingSyncProtocolBase::on_join(ThreadContext &joiner, ThreadContext &joinee) { auto& joiner_store = get_store(joiner); auto& joinee_store = get_store(joinee); @@ -67,7 +60,7 @@ BranchingSyncProtocolBase::on_join(ThreadContext &joiner, ThreadContext &joinee) std::optional conflict = joiner_store.merge_with_commit(joinee_store.get_head()); if (conflict) { - return std::make_unique( + return std::make_shared( _global_store.get_object_name(conflict->obj), std::make_pair(conflict->timestamp_a, conflict->timestamp_b)); } @@ -75,13 +68,13 @@ BranchingSyncProtocolBase::on_join(ThreadContext &joiner, ThreadContext &joinee) return std::nullopt; } -std::optional> +std::optional> BranchingSyncProtocolBase::on_start(ThreadContext &thread) { // nothing to do, the thread will have inhereted the parent commit on spawn return std::nullopt; }; -std::optional> +std::optional> BranchingSyncProtocolBase::on_end(ThreadContext &thread) { auto& store = get_store(thread); store.commit_staging(); @@ -89,7 +82,7 @@ BranchingSyncProtocolBase::on_end(ThreadContext &thread) { return std::nullopt; }; -std::optional> +std::optional> BranchingSyncProtocolBase::on_lock(ThreadContext &thread, Lock &lock) { auto& store = get_store(thread); store.commit_staging(); @@ -100,7 +93,7 @@ BranchingSyncProtocolBase::on_lock(ThreadContext &thread, Lock &lock) { if (lock_commit != nullptr) { std::optional conflict = store.merge_with_commit(lock_commit); if (conflict) { - return std::make_unique( + return std::make_shared( _global_store.get_object_name(conflict->obj), std::make_pair(conflict->timestamp_a, conflict->timestamp_b)); } @@ -111,7 +104,7 @@ BranchingSyncProtocolBase::on_lock(ThreadContext &thread, Lock &lock) { return std::nullopt; } -std::optional> +std::optional> BranchingSyncProtocolBase::on_unlock(ThreadContext &thread, Lock &lock) { auto& store = get_store(thread); store.commit_staging(); diff --git a/src/branching/base_sync_protocol.hh b/src/branching/base_sync_protocol.hh index bc53087..2a28208 100644 --- a/src/branching/base_sync_protocol.hh +++ b/src/branching/base_sync_protocol.hh @@ -20,22 +20,22 @@ public: void write(ThreadContext &ctx, const std::string &var, size_t value) override; - std::optional> + std::optional> on_spawn(ThreadContext &parent, ThreadContext &child) override; - std::optional> + std::optional> on_join(ThreadContext &joiner, ThreadContext &joinee) override; - std::optional> + std::optional> on_start(ThreadContext &thread) override; - std::optional> + std::optional> on_end(ThreadContext &thread) override; - std::optional> + std::optional> on_lock(ThreadContext &thread, Lock &lock) override; - std::optional> + std::optional> on_unlock(ThreadContext &thread, Lock &lock) override; std::ostream &print(std::ostream &os) const override; diff --git a/src/branching/base_version_store.cc b/src/branching/base_version_store.cc index fb14b42..2deb6c9 100644 --- a/src/branching/base_version_store.cc +++ b/src/branching/base_version_store.cc @@ -81,9 +81,12 @@ void LocalVersionStore::commit_staging() { head = new_commit; } -std::optional LocalVersionStore::get_staged(ObjectNumber obj) const { +ReadResult LocalVersionStore::read(ObjectNumber obj) const { auto it = staging.find(obj); - return it != staging.end() ? std::make_optional(it->second) : std::nullopt; + if (it != staging.end()) + return it->second; + + return get_committed(obj); } void LocalVersionStore::adopt_history(const LocalVersionStore& other) { diff --git a/src/branching/base_version_store.hh b/src/branching/base_version_store.hh index 00da798..7f6606f 100644 --- a/src/branching/base_version_store.hh +++ b/src/branching/base_version_store.hh @@ -82,15 +82,18 @@ public: std::shared_ptr get_head() const { return head; } - std::optional get_staged(ObjectNumber obj) const; - virtual std::optional get_committed(ObjectNumber number) const = 0; +private: + virtual ReadResult get_committed(ObjectNumber number) const = 0; + +public: + ReadResult read(ObjectNumber number) const; void adopt_history(const LocalVersionStore& other); virtual std::optional merge_with_commit(const std::shared_ptr& other_head) = 0; friend std::ostream& operator<<(std::ostream& os, const LocalVersionStore& store); std::ostream &print(std::ostream &os) const override { - os << dynamic_cast(this); + os << *dynamic_cast(this); return os; } diff --git a/src/branching/eager/version_store.cc b/src/branching/eager/version_store.cc index 3f4818a..0e3771f 100644 --- a/src/branching/eager/version_store.cc +++ b/src/branching/eager/version_store.cc @@ -157,11 +157,11 @@ std::optional EagerLocalVersionStore::merge_with_commit(const std::sha return std::nullopt; } -std::optional EagerLocalVersionStore::get_committed(ObjectNumber number) const { +ReadResult EagerLocalVersionStore::get_committed(ObjectNumber number) const { if (auto it = last_writer.find(number); it != last_writer.end()) return it->second->changes.at(number); - return std::nullopt; + return std::monostate{}; } } // end branching diff --git a/src/branching/eager/version_store.hh b/src/branching/eager/version_store.hh index f1f5923..074d608 100644 --- a/src/branching/eager/version_store.hh +++ b/src/branching/eager/version_store.hh @@ -13,7 +13,7 @@ public: EagerLocalVersionStore(ThreadID tid) : LocalVersionStore(tid) {} std::optional merge_with_commit(const std::shared_ptr&) override; - std::optional get_committed(ObjectNumber number) const override; + ReadResult get_committed(ObjectNumber number) const override; }; diff --git a/src/branching/lazy/version_store.cc b/src/branching/lazy/version_store.cc index 0911c97..b5256af 100644 --- a/src/branching/lazy/version_store.cc +++ b/src/branching/lazy/version_store.cc @@ -1,5 +1,7 @@ #include "branching/lazy/version_store.hh" #include "debug.hh" +#include +#include namespace gitmem { @@ -85,17 +87,73 @@ namespace branching { // } -std::optional LazyLocalVersionStore::merge_with_commit(const std::shared_ptr&) { - assert(false && "todo"); - return std::nullopt; +std::optional LazyLocalVersionStore::merge_with_commit(const std::shared_ptr& commit) { + assert(staging.empty()); + assert(commit != nullptr); -} + // trivial case: same history + if (head == commit) + return std::nullopt; + + // Create merge commit (no changes itself) + auto merge_commit = std::make_shared( + Commit{ + .id = base_timestamp++, + .parents = {head, commit}, + .changes = {} // merge commit does not write anything + } + ); + + // don't check for conflicts, we do that when later read a variable + head = merge_commit; -std::optional LazyLocalVersionStore::get_committed(ObjectNumber number) const { - assert(false && "todo"); return std::nullopt; } +ReadResult LazyLocalVersionStore::get_committed(ObjectNumber number) const { + std::unordered_set visited; + std::unordered_set> writers; + + std::function&)> dfs = + [&](const std::shared_ptr& c) { + if (!c || writers.size() > 1) + return; + + if (!visited.insert(c.get()).second) + return; + + // If this commit writes 'number', this path is resolved + auto it = c->changes.find(number); + if (it != c->changes.end()) { + writers.insert(c); + return; + } + + // Otherwise, explore *all* parents + for (const auto& p : c->parents) + dfs(p); + }; + + dfs(head); + + if (writers.empty()) + return std::monostate{}; + + if (writers.size() == 1) { + auto writer = *writers.begin(); + return writer->changes.at(number); + } + + // Conflict: multiple distinct writers + auto it = writers.begin(); + auto a = (*it++)->id; + auto b = (*it)->id; + + return std::unique_ptr( + new ReadConflict(number, std::make_pair(a, b)) + ); +} + } } \ No newline at end of file diff --git a/src/branching/lazy/version_store.hh b/src/branching/lazy/version_store.hh index 45598a7..c652d33 100644 --- a/src/branching/lazy/version_store.hh +++ b/src/branching/lazy/version_store.hh @@ -6,6 +6,30 @@ namespace gitmem { namespace branching { +struct ReadConflict : ConflictBase { + ObjectNumber obj; + std::pair versions; + + ReadConflict(ObjectNumber obj, std::pair versions): + obj(obj), versions(std::move(versions)) {} + + ~ReadConflict() = default; + std::ostream &print(std::ostream &os) const override { + return os; + }; + + bool operator==(const ReadConflict &other) const { + return obj == other.obj && versions == other.versions; + } + + bool operator==(const ConflictBase& other) const override { + auto* o = dynamic_cast(&other); + if (!o) + return false; + return *this == *o; + } +}; + class LazyLocalVersionStore : public LocalVersionStore { public: ~LazyLocalVersionStore() = default; @@ -13,7 +37,7 @@ public: LazyLocalVersionStore(ThreadID tid) : LocalVersionStore(tid) {} std::optional merge_with_commit(const std::shared_ptr&) override; - std::optional get_committed(ObjectNumber number) const override; + ReadResult get_committed(ObjectNumber number) const override; }; diff --git a/src/conflict.hh b/src/conflict.hh index f78de41..9c1e49d 100644 --- a/src/conflict.hh +++ b/src/conflict.hh @@ -11,6 +11,7 @@ struct ConflictBase { const ConflictBase &conflict) { return conflict.print(os); } + virtual bool operator==(const ConflictBase &other) const = 0; }; template @@ -22,6 +23,17 @@ struct Conflict : ConflictBase { : var(std::move(var)), versions(std::move(versions)) {} std::ostream &print(std::ostream &os) const override; + + bool operator==(const Conflict &other) const { + return var == other.var && versions == other.versions; + } + + bool operator==(const ConflictBase& other) const override { + auto* o = dynamic_cast(&other); + if (!o) + return false; + return *this == *o; + } }; template diff --git a/src/debugger.cc b/src/debugger.cc index 174fcec..cb05b81 100644 --- a/src/debugger.cc +++ b/src/debugger.cc @@ -3,6 +3,7 @@ #include "debug.hh" #include "debugger.hh" #include "interpreter.hh" +#include "overloaded.hh" namespace gitmem { /** A command that can be parsed by the debugger. Some commands store a @@ -80,23 +81,29 @@ enum class StepKind { struct StepUIResult { StepKind kind; std::optional termination; - std::string message; + std::optional message; static StepUIResult progressed() { - return {StepKind::Progressed, std::nullopt, ""}; + return {StepKind::Progressed, std::nullopt, std::nullopt}; } static StepUIResult blocked(std::string msg) { return {StepKind::Blocked, std::nullopt, std::move(msg)}; } - static StepUIResult terminated(TerminationStatus t, std::string msg) { - return {StepKind::Terminated, t, std::move(msg)}; + static StepUIResult terminated(TerminationStatus t) { + return {StepKind::Terminated, t, std::nullopt}; } static StepUIResult invalid(std::string msg) { return {StepKind::Invalid, std::nullopt, std::move(msg)}; } + + bool has_message() { return message.has_value(); } + std::string& get_message() { return *message; } + + bool has_terminated() { return termination.has_value(); } + TerminationStatus& get_termination() { return *termination; } }; StepUIResult step_thread(Interpreter& interp, ThreadID tid) { @@ -110,15 +117,7 @@ StepUIResult step_thread(Interpreter& interp, ThreadID tid) { auto& thread = gctx.threads[tid]; if (thread.terminated) { - if (*thread.terminated == TerminationStatus::completed) { - return StepUIResult::terminated( - *thread.terminated, - "Thread " + std::to_string(tid) + " has terminated normally"); - } else { - return StepUIResult::terminated( - *thread.terminated, - "Thread " + std::to_string(tid) + " has terminated with an error"); - } + StepUIResult::terminated(*thread.terminated); } auto prog_or_term = interp.progress_thread(gctx.threads[tid]); @@ -134,48 +133,7 @@ StepUIResult step_thread(Interpreter& interp, ThreadID tid) { } auto term = std::get(prog_or_term); - - switch (term) { - case TerminationStatus::completed: - return StepUIResult::terminated( - term, - "Thread " + std::to_string(tid) + " terminated normally"); - - case TerminationStatus::datarace_exception: - return StepUIResult::terminated( - term, - "Thread " + std::to_string(tid) + - " encountered a data race and was terminated"); - - case TerminationStatus::assertion_failure_exception: { - auto expr = - thread.block->at(thread.pc) / lang::Stmt / lang::Expr; - return StepUIResult::terminated( - term, - "Thread " + std::to_string(tid) + - " failed assertion '" + - std::string(expr->location().view()) + - "' and was terminated"); - } - - case TerminationStatus::unassigned_variable_read_exception: - return StepUIResult::terminated( - term, - "Thread " + std::to_string(tid) + - " read an uninitialised variable"); - - case TerminationStatus::unlock_exception: - return StepUIResult::terminated( - term, - "Thread " + std::to_string(tid) + - " unlocked a lock it does not own"); - - default: - return StepUIResult::terminated( - term, - "Thread " + std::to_string(tid) + - " terminated with an unknown error"); - } + return StepUIResult::terminated(term); } /** Print the execution graph if requested */ @@ -194,8 +152,20 @@ StepUIResult do_step(Interpreter &interp, bool print_graphs, const std::filesystem::path &output_file) { StepUIResult result = step_thread(interp, tid); - if (!result.message.empty()) - std::cout << result.message << std::endl; + if (result.has_message()) + std::cout << result.get_message() << std::endl; + if (result.has_terminated()) { + std::cout << "Thread " << tid << ": "; + std::visit( + overloaded{ + [&](const auto &t) { + // Any non-completed termination is exceptional + std::cout << t << std::endl; + } + }, + result.get_termination() + ); + } maybe_print_graph(interp, print_graphs, output_file); return result; diff --git a/src/execution_state.cc b/src/execution_state.cc index aeaa943..9780c14 100644 --- a/src/execution_state.cc +++ b/src/execution_state.cc @@ -141,7 +141,7 @@ std::ostream& operator<<(std::ostream& os, const ThreadContext& ctx) { os << k << "=" << v; } - os << "}"; //, tail=" << ctx.tail; + os << "}, "; //, tail=" << ctx.tail; os << *(ctx.sync); @@ -157,9 +157,6 @@ void show_lock(const std::string &lock_name, const struct Lock &lock) { std::cout << ""; } std::cout << std::endl; - // for (auto &[var, global] : lock.globals) { - // show_global(var, global); - // } } void GlobalContext::print(std::ostream& os, bool show_all) const { @@ -169,7 +166,7 @@ void GlobalContext::print(std::ostream& os, bool show_all) const { for (size_t i = 0; i < threads.size(); i++) { auto& thread = threads[i]; if (show_all || !thread.terminated || - *threads[i].terminated != TerminationStatus::completed) { + !std::holds_alternative(*thread.terminated)) { os << "---- Thread " << i << std::endl; os << threads[i] << std::endl; os << std::endl; diff --git a/src/execution_state.hh b/src/execution_state.hh index bd00c23..550f627 100644 --- a/src/execution_state.hh +++ b/src/execution_state.hh @@ -35,6 +35,8 @@ struct ThreadContext { friend std::ostream& operator<<(std::ostream&, const ThreadContext&); }; +using termination::TerminationStatus; + struct Thread { ThreadID tid; ThreadContext ctx; diff --git a/src/interpreter.cc b/src/interpreter.cc index cafab3d..8822a37 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -5,6 +5,7 @@ #include "debug.hh" #include "interpreter.hh" #include "sync_protocol.hh" +#include "overloaded.hh" namespace gitmem { @@ -61,7 +62,7 @@ Interpreter::evaluate_expression(trieste::Node expr, Thread& thread) { if (ctx.locals.contains(var)) { return ctx.locals[var]; } else { - return TerminationStatus::unassigned_variable_read_exception; + return termination::UnassignedRead(var); } } else if (e == lang::Var) { auto var = std::string(expr->location().view()); @@ -71,16 +72,16 @@ Interpreter::evaluate_expression(trieste::Node expr, Thread& thread) { return std::visit(overloaded{ [&](std::monostate) -> std::variant { // invalid: reading a variable that hasn't been written - return TerminationStatus::unassigned_variable_read_exception; + return termination::UnassignedRead(var); }, [&](Value value) -> std::variant { // normal read thread.trace.on_read(var, value); return value; }, - [&](std::unique_ptr& conflict) -> std::variant { + [&](std::shared_ptr& conflict) -> std::variant { verbose << (*conflict) << std::endl; - return TerminationStatus::datarace_exception; + return termination::DataRace(conflict); } }, result); } else if (e == lang::Const) { @@ -98,7 +99,7 @@ Interpreter::evaluate_expression(trieste::Node expr, Thread& thread) { ThreadID child_tid = gctx.threads.size(); ThreadContext child_ctx(child_tid, gctx.protocol); - if (std::optional> conflict = + if (std::optional> conflict = gctx.protocol->on_spawn(ctx, child_ctx)) { throw std::logic_error("This code path should never be reached"); } @@ -216,16 +217,16 @@ std::variant Interpreter::run_statement(Node stmt, Threa if (result >= gctx.threads.size()) { verbose << "Join: invalid thread ID " << result << ". gctx.threads.size()=" << gctx.threads.size() << std::endl; - return TerminationStatus::unassigned_variable_read_exception; + return termination::UnassignedRead(std::to_string(result)); } auto &joinee = gctx.threads[result]; if (joinee.terminated && - (*joinee.terminated == TerminationStatus::completed)) { + std::holds_alternative(*joinee.terminated)) { if (auto conflict = gctx.protocol->on_join(ctx, joinee.ctx)) { verbose << (**conflict) << std::endl; - thread.trace.on_join(result, std::move(*conflict)); - return TerminationStatus::datarace_exception; + thread.trace.on_join(result, *conflict); + return termination::DataRace(*conflict); } else { thread.trace.on_join(result); } @@ -252,8 +253,8 @@ std::variant Interpreter::run_statement(Node stmt, Threa if (auto conflict = gctx.protocol->on_lock(ctx, lock)) { verbose << (**conflict) << std::endl; - thread.trace.on_lock(var, lock.last_unlock_event, std::move(*conflict)); - return TerminationStatus::datarace_exception; + thread.trace.on_lock(var, lock.last_unlock_event, *conflict); + return termination::DataRace(*conflict); } thread.trace.on_lock(var, lock.last_unlock_event); @@ -271,13 +272,13 @@ std::variant Interpreter::run_statement(Node stmt, Threa Lock& lock = gctx.get_lock(var); if (!lock.owner || (lock.owner && *lock.owner != thread.tid)) { - return TerminationStatus::unlock_exception; + return termination::UnlockError(var); } if (auto conflict = gctx.protocol->on_unlock(ctx, lock)) { verbose << (**conflict) << std::endl; - thread.trace.on_unlock(var, std::move(*conflict)); - return TerminationStatus::datarace_exception; + thread.trace.on_unlock(var, *conflict); + return termination::DataRace(*conflict); } // lock.globals = ctx.globals; @@ -298,7 +299,7 @@ std::variant Interpreter::run_statement(Node stmt, Threa } else { verbose << "Assertion failed: " << expr->location().view() << std::endl; thread.trace.on_assert_fail(std::string(expr->location().view())); - return TerminationStatus::assertion_failure_exception; + return termination::AssertionFailure(std::string(expr->location().view())); } } else { return std::get(result_or_term); @@ -364,13 +365,14 @@ Interpreter::run_single_thread_to_sync(Thread& thread) { // Otherwise, we truly reached the end this iteration if (auto conflict = gctx.protocol->on_end(ctx)) { verbose << (**conflict) << std::endl; - thread.terminated = TerminationStatus::datarace_exception; - return TerminationStatus::datarace_exception; + TerminationStatus term = termination::DataRace(*conflict); + thread.terminated = term; + return term; } - thread.terminated = TerminationStatus::completed; + thread.terminated = termination::Completed(); thread.trace.on_end(); - return TerminationStatus::completed; + return termination::Completed(); } /** @@ -431,7 +433,7 @@ Interpreter::run_threads_to_sync() { } if (all_completed) - return TerminationStatus::completed; + return termination::Completed(); return any_progress; } @@ -458,39 +460,24 @@ int Interpreter::run() { bool exception_detected = false; for (size_t i = 0; i < gctx.threads.size(); ++i) { auto &thread = gctx.threads[i]; - if (thread.terminated) { - switch (thread.terminated.value()) { - case TerminationStatus::completed: - verbose << "Thread " << i << " terminated normally" << std::endl; - break; - - case TerminationStatus::unlock_exception: - verbose << "Thread " << i << " unlocked a lock it does not own" - << std::endl; - exception_detected = true; - break; - - case TerminationStatus::datarace_exception: - verbose << "Thread " << i << " encountered a data-race" << std::endl; - exception_detected = true; - break; - - case TerminationStatus::assertion_failure_exception: - verbose << "Thread " << i << " failed an assertion" << std::endl; - exception_detected = true; - break; - case TerminationStatus::unassigned_variable_read_exception: - verbose << "Thread " << i << " read an uninitialised value" - << std::endl; - exception_detected = true; - break; - - default: - verbose << "Thread " << i << " has an unhandled termination state" - << std::endl; - break; - } + if (thread.terminated) { + verbose << "Thread " << i << ": "; + + std::visit( + overloaded{ + [&](const termination::Completed &t) { + verbose << t << std::endl; + }, + + [&](const auto &t) { + // Any non-completed termination is exceptional + verbose << t << std::endl; + exception_detected = true; + } + }, + *thread.terminated + ); } else { exception_detected = true; thread.trace.on_end(); diff --git a/src/interpreter.hh b/src/interpreter.hh index 9f924c1..d9bd2e7 100644 --- a/src/interpreter.hh +++ b/src/interpreter.hh @@ -11,6 +11,8 @@ namespace gitmem { +using termination::TerminationStatus; + template using StepResult = std::variant; diff --git a/src/linear/sync_protocol.cc b/src/linear/sync_protocol.cc index ca7c2f9..519fbf8 100644 --- a/src/linear/sync_protocol.cc +++ b/src/linear/sync_protocol.cc @@ -82,7 +82,7 @@ void LinearSyncProtocol::write(ThreadContext &ctx, const std::string &var, store.stage(_global_store.get_object_number(var), value); } -std::optional> +std::optional> LinearSyncProtocol::on_spawn(ThreadContext &parent, ThreadContext &child) { // TODO: i think we can drop the globalcontext but check after branching is // added @@ -90,7 +90,7 @@ LinearSyncProtocol::on_spawn(ThreadContext &parent, ThreadContext &child) { // push parent to global history auto& store = get_store(parent); if (auto conflict = push(store)) - return std::make_unique(std::move(*conflict)); + return std::make_shared(std::move(*conflict)); // pull into the child store = get_store(child); @@ -101,19 +101,19 @@ LinearSyncProtocol::on_spawn(ThreadContext &parent, ThreadContext &child) { return std::nullopt; } -std::optional> +std::optional> LinearSyncProtocol::on_join(ThreadContext &joiner, ThreadContext &joinee) { // we assume the joinee has already terminated and pushed // pull changes into parent auto& store = get_store(joiner); if (auto conflict = pull(store)) - return std::make_unique(std::move(*conflict)); + return std::make_shared(std::move(*conflict)); return std::nullopt; } -std::optional> +std::optional> LinearSyncProtocol::on_start(ThreadContext &thread) { // pull state from global history auto& store = get_store(thread); @@ -123,32 +123,32 @@ LinearSyncProtocol::on_start(ThreadContext &thread) { return std::nullopt; }; -std::optional> +std::optional> LinearSyncProtocol::on_end(ThreadContext &thread) { // push changes to global history auto& store = get_store(thread); if (auto conflict = push(store)) - return std::make_unique(std::move(*conflict)); + return std::make_shared(std::move(*conflict)); return std::nullopt; }; -std::optional> +std::optional> LinearSyncProtocol::on_lock(ThreadContext &thread, Lock &lock) { auto& store = get_store(thread); if (auto conflict = pull(store)) - return std::make_unique(std::move(*conflict)); + return std::make_shared(std::move(*conflict)); return std::nullopt; } -std::optional> +std::optional> LinearSyncProtocol::on_unlock(ThreadContext &thread, Lock &) { // push changes to global history auto& store = get_store(thread); if (auto conflict = push(store)) - return std::make_unique(std::move(*conflict)); + return std::make_shared(std::move(*conflict)); return std::nullopt; } diff --git a/src/linear/sync_protocol.hh b/src/linear/sync_protocol.hh index b85aa83..d8c2af6 100644 --- a/src/linear/sync_protocol.hh +++ b/src/linear/sync_protocol.hh @@ -26,22 +26,22 @@ public: void write(ThreadContext &ctx, const std::string &var, size_t value) override; - std::optional> + std::optional> on_spawn(ThreadContext &parent, ThreadContext &child) override; - std::optional> + std::optional> on_join(ThreadContext &joiner, ThreadContext &joinee) override; - std::optional> + std::optional> on_start(ThreadContext &thread) override; - std::optional> + std::optional> on_end(ThreadContext &thread) override; - std::optional> + std::optional> on_lock(ThreadContext &thread, Lock &lock) override; - std::optional> + std::optional> on_unlock(ThreadContext &thread, Lock &lock) override; std::ostream &print(std::ostream &os) const override; diff --git a/src/linear/version_store.hh b/src/linear/version_store.hh index e8e9bda..a5f6f39 100644 --- a/src/linear/version_store.hh +++ b/src/linear/version_store.hh @@ -107,8 +107,9 @@ public: } friend std::ostream& operator<<(std::ostream&, const LocalVersionStore&); + std::ostream &print(std::ostream &os) const override { - os << dynamic_cast(this); + os << *dynamic_cast(this); return os; } }; diff --git a/src/model_checker.cc b/src/model_checker.cc index ab9367c..dd1abef 100644 --- a/src/model_checker.cc +++ b/src/model_checker.cc @@ -103,8 +103,7 @@ int model_check(const Node ast, const std::filesystem::path &output_path, SyncKi made_progress = true; cursor = cursor->extend(i); current_trace.push_back(i); - if (std::get(prog_or_term) != - TerminationStatus::completed) { + if (!std::holds_alternative(std::get(prog_or_term))) { // Thread terminated with an error, we can stop here verbose << "Thread " << i << " terminated with an error" << std::endl; @@ -128,12 +127,12 @@ int model_check(const Node ast, const std::filesystem::path &output_path, SyncKi bool all_completed = std::all_of( gctx.threads.begin(), gctx.threads.end(), [](const auto &thread) { return thread.terminated && - *thread.terminated == TerminationStatus::completed; + std::holds_alternative(*thread.terminated); }); bool any_crashed = std::any_of( gctx.threads.begin(), gctx.threads.end(), [](const auto &thread) { return thread.terminated && - *thread.terminated != TerminationStatus::completed; + !std::holds_alternative(*thread.terminated); }); bool is_deadlock = !all_completed && !made_progress && cursor->is_leaf(); diff --git a/src/overloaded.hh b/src/overloaded.hh new file mode 100644 index 0000000..689c013 --- /dev/null +++ b/src/overloaded.hh @@ -0,0 +1,8 @@ +#pragma once + +template +struct overloaded : Ts... { + using Ts::operator()...; +}; +template +overloaded(Ts...) -> overloaded; \ No newline at end of file diff --git a/src/read_result.hh b/src/read_result.hh index 3c1bc1a..d32cac5 100644 --- a/src/read_result.hh +++ b/src/read_result.hh @@ -6,6 +6,6 @@ namespace gitmem { using Value = size_t; -using ReadResult = std::variant>; +using ReadResult = std::variant>; } \ No newline at end of file diff --git a/src/sync_protocol.hh b/src/sync_protocol.hh index 98de8f2..684dcf0 100644 --- a/src/sync_protocol.hh +++ b/src/sync_protocol.hh @@ -26,22 +26,22 @@ public: virtual void write(ThreadContext &ctx, const std::string &var, size_t value) = 0; - virtual std::optional> + virtual std::optional> on_spawn(ThreadContext &parent, ThreadContext &child) = 0; - virtual std::optional> + virtual std::optional> on_join(ThreadContext &joiner, ThreadContext &joinee) = 0; - virtual std::optional> + virtual std::optional> on_start(ThreadContext &thread) = 0; - virtual std::optional> + virtual std::optional> on_end(ThreadContext &thread) = 0; - virtual std::optional> + virtual std::optional> on_lock(ThreadContext &thread, Lock &lock) = 0; - virtual std::optional> + virtual std::optional> on_unlock(ThreadContext &thread, Lock &lock) = 0; diff --git a/src/termination_status.hh b/src/termination_status.hh index 5417b74..eafa3dc 100644 --- a/src/termination_status.hh +++ b/src/termination_status.hh @@ -1,9 +1,87 @@ #pragma once -enum class TerminationStatus { - completed, - datarace_exception, - unlock_exception, - assertion_failure_exception, - unassigned_variable_read_exception, +#include "thread_id.hh" +#include "conflict.hh" + +#include + +namespace gitmem { + +namespace termination { + +struct Completed { + friend std::ostream& operator<<(std::ostream& os, const Completed&) { + os << "Completed successfully\n"; + return os; + } +}; + +struct DataRace { + std::shared_ptr conflict; + friend std::ostream& operator<<(std::ostream& os, const DataRace& r) { + os << "Data race occurred: '" << *r.conflict << "\n"; + return os; + } }; + +struct UnlockError { + std::string lock; + friend std::ostream& operator<<(std::ostream& os, const UnlockError& e) { + os << "Attempted to unlock '" << e.lock << "' without ownership\n"; + return os; + } +}; + +struct AssertionFailure { + std::string expression; + friend std::ostream& operator<<(std::ostream& os, const AssertionFailure& a) { + os << "Assertion failed: " << a.expression << "\n"; + return os; + } +}; + +struct UnassignedRead { + std::string variable; + friend std::ostream& operator<<(std::ostream& os, const UnassignedRead& u) { + os << "Read of unassigned variable '" << u.variable << "'\n"; + return os; + } +}; + +inline bool operator==(const Completed&, const Completed&) { return true; } + +inline bool operator==(const DataRace& a, const DataRace& b) { + return *a.conflict == *b.conflict; +} + +inline bool operator==(const UnlockError& a, const UnlockError& b) { + return a.lock == b.lock; +} + +inline bool operator==(const AssertionFailure& a, const AssertionFailure& b) { + return a.expression == b.expression; +} + +inline bool operator==(const UnassignedRead& a, const UnassignedRead& b) { + return a.variable == b.variable; +} + +// Optional: != operators for convenience +inline bool operator!=(const Completed& a, const Completed& b) { return !(a == b); } +inline bool operator!=(const DataRace& a, const DataRace& b) { return !(a == b); } +inline bool operator!=(const UnlockError& a, const UnlockError& b) { return !(a == b); } +inline bool operator!=(const AssertionFailure& a, const AssertionFailure& b) { return !(a == b); } +inline bool operator!=(const UnassignedRead& a, const UnassignedRead& b) { return !(a == b); } + +using TerminationStatus = + std::variant< + Completed, + DataRace, + UnlockError, + AssertionFailure, + UnassignedRead + >; + +} // end termination + +} // end gitmem \ No newline at end of file diff --git a/src/thread_trace.hh b/src/thread_trace.hh index bdd6625..0dfe830 100644 --- a/src/thread_trace.hh +++ b/src/thread_trace.hh @@ -9,11 +9,11 @@ struct Event; struct StartEvent {}; struct SpawnEvent { const ThreadID child_tid; }; -struct ReadEvent { const std::string var; const size_t value; std::unique_ptr maybe_conflict; }; +struct ReadEvent { const std::string var; const size_t value; std::shared_ptr maybe_conflict; }; struct WriteEvent { const std::string var; const size_t value; }; -struct LockEvent { std::string lock_name; std::unique_ptr maybe_conflict; std::shared_ptr last_unlock_event; }; -struct UnlockEvent { const std::string lock_name; std::unique_ptr maybe_conflict; }; -struct JoinEvent { const ThreadID joinee_tid; std::unique_ptr maybe_conflict; }; +struct LockEvent { std::string lock_name; std::shared_ptr maybe_conflict; std::shared_ptr last_unlock_event; }; +struct UnlockEvent { const std::string lock_name; std::shared_ptr maybe_conflict; }; +struct JoinEvent { const ThreadID joinee_tid; std::shared_ptr maybe_conflict; }; struct AssertEvent { const std::string condition; bool pass; }; struct EndEvent {}; @@ -139,8 +139,8 @@ private: return append(child_tid); } - std::shared_ptr on_read(const std::string text, const size_t value, std::unique_ptr conflict = nullptr) { - return append(std::move(text), value, std::move(conflict)); + std::shared_ptr on_read(const std::string text, const size_t value, std::shared_ptr conflict = nullptr) { + return append(std::move(text), value, conflict); } std::shared_ptr on_write(const std::string text, const size_t value) { @@ -149,16 +149,16 @@ private: std::shared_ptr on_lock(const std::string lock_name, std::shared_ptr last_unlock_event, - std::unique_ptr conflict = nullptr) { + std::shared_ptr conflict = nullptr) { return append(std::move(lock_name), std::move(conflict), last_unlock_event); } - std::shared_ptr on_unlock(const std::string lock_name, std::unique_ptr conflict = nullptr) { - return append(std::move(lock_name), std::move(conflict)); + std::shared_ptr on_unlock(const std::string lock_name, std::shared_ptr conflict = nullptr) { + return append(std::move(lock_name), conflict); } - std::shared_ptr on_join(ThreadID tid, std::unique_ptr conflict = nullptr) { - return append(tid, std::move(conflict)); + std::shared_ptr on_join(ThreadID tid, std::shared_ptr conflict = nullptr) { + return append(tid, conflict); } std::shared_ptr on_assert(std::string expr, bool pass) { From 7d2a63b15489bd46cb78f4754f41766c00a69b6c Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Thu, 8 Jan 2026 13:50:15 +0100 Subject: [PATCH 38/58] making read result restrucutre into the expected output type --- src/branching/base_sync_protocol.cc | 13 ++++++++++++- src/branching/base_version_store.cc | 2 +- src/branching/base_version_store.hh | 6 ++++-- src/branching/eager/version_store.cc | 2 +- src/branching/eager/version_store.hh | 2 +- src/branching/lazy/version_store.cc | 6 ++---- src/branching/lazy/version_store.hh | 26 +------------------------- src/termination_status.hh | 14 +++++++++----- 8 files changed, 31 insertions(+), 40 deletions(-) diff --git a/src/branching/base_sync_protocol.cc b/src/branching/base_sync_protocol.cc index f9515a6..3681bba 100644 --- a/src/branching/base_sync_protocol.cc +++ b/src/branching/base_sync_protocol.cc @@ -1,4 +1,5 @@ #include "base_sync_protocol.hh" +#include "overloaded.hh" namespace gitmem { @@ -29,7 +30,17 @@ ReadResult BranchingSyncProtocolBase::read(ThreadContext &ctx, auto& store = get_store(ctx); - return store.read(number); + // convert the branching read result into a regular read result + return std::visit(overloaded{ + [](std::monostate) -> ReadResult { return std::monostate{}; }, + [](const Value& v) -> ReadResult { return v; }, + [&](const Conflict& c) -> ReadResult { + return std::make_shared( + _global_store.get_object_name(c.obj), std::pair{c.timestamp_a, c.timestamp_b} + ); + }, + + }, store.read(number)); } void BranchingSyncProtocolBase::write(ThreadContext &ctx, const std::string &var, diff --git a/src/branching/base_version_store.cc b/src/branching/base_version_store.cc index 2deb6c9..62f4762 100644 --- a/src/branching/base_version_store.cc +++ b/src/branching/base_version_store.cc @@ -81,7 +81,7 @@ void LocalVersionStore::commit_staging() { head = new_commit; } -ReadResult LocalVersionStore::read(ObjectNumber obj) const { +BranchingReadResult LocalVersionStore::read(ObjectNumber obj) const { auto it = staging.find(obj); if (it != staging.end()) return it->second; diff --git a/src/branching/base_version_store.hh b/src/branching/base_version_store.hh index 7f6606f..71905e2 100644 --- a/src/branching/base_version_store.hh +++ b/src/branching/base_version_store.hh @@ -56,6 +56,8 @@ struct Conflict { Timestamp timestamp_b; }; +using BranchingReadResult = std::variant; + inline std::ostream& operator<<(std::ostream& os, const Conflict& c) { return os << "Conflict{obj=" << c.obj << ", timestamp_a=" << c.timestamp_a @@ -83,10 +85,10 @@ public: std::shared_ptr get_head() const { return head; } private: - virtual ReadResult get_committed(ObjectNumber number) const = 0; + virtual BranchingReadResult get_committed(ObjectNumber number) const = 0; public: - ReadResult read(ObjectNumber number) const; + BranchingReadResult read(ObjectNumber number) const; void adopt_history(const LocalVersionStore& other); virtual std::optional merge_with_commit(const std::shared_ptr& other_head) = 0; diff --git a/src/branching/eager/version_store.cc b/src/branching/eager/version_store.cc index 0e3771f..8447dba 100644 --- a/src/branching/eager/version_store.cc +++ b/src/branching/eager/version_store.cc @@ -157,7 +157,7 @@ std::optional EagerLocalVersionStore::merge_with_commit(const std::sha return std::nullopt; } -ReadResult EagerLocalVersionStore::get_committed(ObjectNumber number) const { +BranchingReadResult EagerLocalVersionStore::get_committed(ObjectNumber number) const { if (auto it = last_writer.find(number); it != last_writer.end()) return it->second->changes.at(number); diff --git a/src/branching/eager/version_store.hh b/src/branching/eager/version_store.hh index 074d608..2920eb0 100644 --- a/src/branching/eager/version_store.hh +++ b/src/branching/eager/version_store.hh @@ -13,7 +13,7 @@ public: EagerLocalVersionStore(ThreadID tid) : LocalVersionStore(tid) {} std::optional merge_with_commit(const std::shared_ptr&) override; - ReadResult get_committed(ObjectNumber number) const override; + BranchingReadResult get_committed(ObjectNumber number) const override; }; diff --git a/src/branching/lazy/version_store.cc b/src/branching/lazy/version_store.cc index b5256af..f01f3f7 100644 --- a/src/branching/lazy/version_store.cc +++ b/src/branching/lazy/version_store.cc @@ -110,7 +110,7 @@ std::optional LazyLocalVersionStore::merge_with_commit(const std::shar return std::nullopt; } -ReadResult LazyLocalVersionStore::get_committed(ObjectNumber number) const { +BranchingReadResult LazyLocalVersionStore::get_committed(ObjectNumber number) const { std::unordered_set visited; std::unordered_set> writers; @@ -149,9 +149,7 @@ ReadResult LazyLocalVersionStore::get_committed(ObjectNumber number) const { auto a = (*it++)->id; auto b = (*it)->id; - return std::unique_ptr( - new ReadConflict(number, std::make_pair(a, b)) - ); + return Conflict(number, a, b); } } diff --git a/src/branching/lazy/version_store.hh b/src/branching/lazy/version_store.hh index c652d33..881305f 100644 --- a/src/branching/lazy/version_store.hh +++ b/src/branching/lazy/version_store.hh @@ -6,30 +6,6 @@ namespace gitmem { namespace branching { -struct ReadConflict : ConflictBase { - ObjectNumber obj; - std::pair versions; - - ReadConflict(ObjectNumber obj, std::pair versions): - obj(obj), versions(std::move(versions)) {} - - ~ReadConflict() = default; - std::ostream &print(std::ostream &os) const override { - return os; - }; - - bool operator==(const ReadConflict &other) const { - return obj == other.obj && versions == other.versions; - } - - bool operator==(const ConflictBase& other) const override { - auto* o = dynamic_cast(&other); - if (!o) - return false; - return *this == *o; - } -}; - class LazyLocalVersionStore : public LocalVersionStore { public: ~LazyLocalVersionStore() = default; @@ -37,7 +13,7 @@ public: LazyLocalVersionStore(ThreadID tid) : LocalVersionStore(tid) {} std::optional merge_with_commit(const std::shared_ptr&) override; - ReadResult get_committed(ObjectNumber number) const override; + BranchingReadResult get_committed(ObjectNumber number) const override; }; diff --git a/src/termination_status.hh b/src/termination_status.hh index eafa3dc..a86323c 100644 --- a/src/termination_status.hh +++ b/src/termination_status.hh @@ -11,15 +11,19 @@ namespace termination { struct Completed { friend std::ostream& operator<<(std::ostream& os, const Completed&) { - os << "Completed successfully\n"; + os << "Completed successfully"; return os; } }; struct DataRace { std::shared_ptr conflict; + + explicit DataRace(std::shared_ptr conflict): conflict(conflict) {} + friend std::ostream& operator<<(std::ostream& os, const DataRace& r) { - os << "Data race occurred: '" << *r.conflict << "\n"; + assert(r.conflict != nullptr); + os << "Data race occurred: " << *r.conflict; return os; } }; @@ -27,7 +31,7 @@ struct DataRace { struct UnlockError { std::string lock; friend std::ostream& operator<<(std::ostream& os, const UnlockError& e) { - os << "Attempted to unlock '" << e.lock << "' without ownership\n"; + os << "Attempted to unlock '" << e.lock << "' without ownership"; return os; } }; @@ -35,7 +39,7 @@ struct UnlockError { struct AssertionFailure { std::string expression; friend std::ostream& operator<<(std::ostream& os, const AssertionFailure& a) { - os << "Assertion failed: " << a.expression << "\n"; + os << "Assertion failed: " << a.expression; return os; } }; @@ -43,7 +47,7 @@ struct AssertionFailure { struct UnassignedRead { std::string variable; friend std::ostream& operator<<(std::ostream& os, const UnassignedRead& u) { - os << "Read of unassigned variable '" << u.variable << "'\n"; + os << "Read of unassigned variable '" << u.variable; return os; } }; From 1d7dd342b5fe5d935a2ed0f773f2912aa4197fad Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Thu, 8 Jan 2026 14:32:12 +0100 Subject: [PATCH 39/58] preliminary version of lazy branching, i think the two failing tests are failing as a result of no read of the conflict variables --- src/branching/base_sync_protocol.cc | 1 - src/branching/base_version_store.cc | 22 +++++++++ src/branching/base_version_store.hh | 2 + src/branching/eager/version_store.cc | 28 +----------- src/branching/lazy/version_store.cc | 67 ++++++++++++++++------------ 5 files changed, 64 insertions(+), 56 deletions(-) diff --git a/src/branching/base_sync_protocol.cc b/src/branching/base_sync_protocol.cc index 3681bba..b5f5737 100644 --- a/src/branching/base_sync_protocol.cc +++ b/src/branching/base_sync_protocol.cc @@ -27,7 +27,6 @@ std::ostream &BranchingSyncProtocolBase::print(std::ostream &os) const { ReadResult BranchingSyncProtocolBase::read(ThreadContext &ctx, const std::string &var) { ObjectNumber number = _global_store.get_object_number(var); - auto& store = get_store(ctx); // convert the branching read result into a regular read result diff --git a/src/branching/base_version_store.cc b/src/branching/base_version_store.cc index 62f4762..dde7ff2 100644 --- a/src/branching/base_version_store.cc +++ b/src/branching/base_version_store.cc @@ -152,6 +152,28 @@ std::ostream& operator<<(std::ostream& os, const GlobalVersionStore& store) { return os; } +bool can_reach(const std::shared_ptr& commit, const std::shared_ptr& other, std::unordered_map, bool>& memo) { + if (!commit) + return false; + + if (commit == other) + return true; + + auto it = memo.find(commit); + if (it != memo.end()) + return it->second; + + for (const auto& parent : commit->parents) { + if (can_reach(parent, other, memo)) { + memo[commit] = true; + return true; + } + } + + memo[commit] = false; + return false; +} + } // branching } // gitmem \ No newline at end of file diff --git a/src/branching/base_version_store.hh b/src/branching/base_version_store.hh index 71905e2..f3c6f58 100644 --- a/src/branching/base_version_store.hh +++ b/src/branching/base_version_store.hh @@ -48,6 +48,8 @@ struct Commit { std::vector> parents; }; +bool can_reach(const std::shared_ptr& commit, const std::shared_ptr& lca, std::unordered_map, bool>& memo); + std::ostream& operator<<(std::ostream& os, const Commit& commit); struct Conflict { diff --git a/src/branching/eager/version_store.cc b/src/branching/eager/version_store.cc index 8447dba..71e758c 100644 --- a/src/branching/eager/version_store.cc +++ b/src/branching/eager/version_store.cc @@ -7,32 +7,6 @@ namespace gitmem { namespace branching { -bool can_reach_lca( - const std::shared_ptr& commit, - const std::shared_ptr& lca, - std::unordered_map, bool>& memo) -{ - if (!commit) - return false; - - if (commit == lca) - return true; - - auto it = memo.find(commit); - if (it != memo.end()) - return it->second; - - for (const auto& parent : commit->parents) { - if (can_reach_lca(parent, lca, memo)) { - memo[commit] = true; - return true; - } - } - - memo[commit] = false; - return false; -} - bool traverse_until_lca( const std::shared_ptr& commit, const std::shared_ptr& lca, @@ -43,7 +17,7 @@ bool traverse_until_lca( if (!commit || commit == lca || !visited.insert(commit).second) return true; - if (!can_reach_lca(commit, lca, reach_memo)) + if (!can_reach(commit, lca, reach_memo)) return true; for (const auto& [obj, _] : commit->changes) { diff --git a/src/branching/lazy/version_store.cc b/src/branching/lazy/version_store.cc index f01f3f7..935babb 100644 --- a/src/branching/lazy/version_store.cc +++ b/src/branching/lazy/version_store.cc @@ -107,51 +107,62 @@ std::optional LazyLocalVersionStore::merge_with_commit(const std::shar // don't check for conflicts, we do that when later read a variable head = merge_commit; + // whenever we merge, we loose all the information about the last writer + last_writer.clear(); + return std::nullopt; } +// Thought, if we merge two paths that conflict on a variable, but we never read it +// and just right to it, is that okay ? +// if (auto it = last_writer.find(number); it != last_writer.end()) { +// return it->second->changes.at(number); +// } + BranchingReadResult LazyLocalVersionStore::get_committed(ObjectNumber number) const { - std::unordered_set visited; - std::unordered_set> writers; - std::function&)> dfs = - [&](const std::shared_ptr& c) { - if (!c || writers.size() > 1) - return; - if (!visited.insert(c.get()).second) - return; + std::vector> writers; + std::unordered_map, bool> reach_memo; + + std::function)> dfs; + dfs = [&](std::shared_ptr c) { + if (!c) return; - // If this commit writes 'number', this path is resolved - auto it = c->changes.find(number); - if (it != c->changes.end()) { - writers.insert(c); + // If we've already found a writer that is an ancestor of c, skip + for (auto it = writers.begin(); it != writers.end(); ) { + if (can_reach(c, *it, reach_memo)) { + // existing writer is ancestor of this commit, remove it + it = writers.erase(it); + } else if (can_reach(*it, c, reach_memo)) { + // this commit is ancestor of existing writer, ignore this path return; + } else { + ++it; } + } - // Otherwise, explore *all* parents - for (const auto& p : c->parents) - dfs(p); - }; - - dfs(head); + if (c->changes.contains(number)) { + writers.push_back(c); + return; + } - if (writers.empty()) - return std::monostate{}; + for (auto& p : c->parents) + dfs(p); + }; - if (writers.size() == 1) { - auto writer = *writers.begin(); - return writer->changes.at(number); - } + dfs(head); - // Conflict: multiple distinct writers - auto it = writers.begin(); - auto a = (*it++)->id; - auto b = (*it)->id; + if (writers.empty()) return std::monostate{}; + if (writers.size() == 1) return writers[0]->changes.at(number); + // conflict + auto a = writers[0]->id; + auto b = writers[1]->id; return Conflict(number, a, b); } + } } \ No newline at end of file From 8851e639e2c927f6a2a08c858dbb2babeb54e945 Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Fri, 9 Jan 2026 15:15:40 +0100 Subject: [PATCH 40/58] threading through machinery to write the revision graph --- .../semantics/branching/conditional_race.gm | 4 ++ .../semantics/branching/join_datarace.gm | 3 +- src/branching/base_sync_protocol.cc | 15 ++++ src/branching/base_sync_protocol.hh | 2 + src/branching/base_version_store.cc | 53 ++++++++++++++ src/branching/base_version_store.hh | 9 +++ src/branching/lazy/version_store.cc | 22 +++++- src/interpreter.cc | 69 +++++++++++-------- src/linear/sync_protocol.cc | 9 ++- src/linear/sync_protocol.hh | 2 + src/model_checker.cc | 6 +- src/sync_protocol.hh | 3 + src/thread_trace.hh | 1 - 13 files changed, 160 insertions(+), 38 deletions(-) diff --git a/examples/reject/semantics/branching/conditional_race.gm b/examples/reject/semantics/branching/conditional_race.gm index 6dd0a36..2a86c78 100644 --- a/examples/reject/semantics/branching/conditional_race.gm +++ b/examples/reject/semantics/branching/conditional_race.gm @@ -1,3 +1,6 @@ +// The racy schedule 0 1 1 1 2 2 2 0 0 (thread 1 gets the lock first) +// thread 1 r and flag inside a lock, and then mutates x outside of a lock +// thread 2 sees the set flag inside of a lock, and also mutates x outside of a lock x = 0; y = 0; flag = 0; @@ -30,3 +33,4 @@ $t2 = spawn { join $t1; join $t2; assert (x != y); +assert(flag == flag); \ No newline at end of file diff --git a/examples/reject/semantics/branching/join_datarace.gm b/examples/reject/semantics/branching/join_datarace.gm index 5b14916..59ea95b 100644 --- a/examples/reject/semantics/branching/join_datarace.gm +++ b/examples/reject/semantics/branching/join_datarace.gm @@ -7,4 +7,5 @@ $t = spawn { assert(x == 2); x = 4; assert(x == 4); -join $t; \ No newline at end of file +join $t; +assert(x == x); \ No newline at end of file diff --git a/src/branching/base_sync_protocol.cc b/src/branching/base_sync_protocol.cc index b5f5737..d40cfe3 100644 --- a/src/branching/base_sync_protocol.cc +++ b/src/branching/base_sync_protocol.cc @@ -133,6 +133,21 @@ BranchingSyncProtocolBase::on_unlock(ThreadContext &thread, Lock &lock) { return std::nullopt; } +std::string BranchingSyncProtocolBase::build_revision_graph_dot( + const std::vector& thread_states) const { + + std::vector> heads; + + for (const ThreadSyncState* state_ptr : thread_states) { + const auto* local_store = dynamic_cast(state_ptr); + if (local_store && local_store->get_head()) { + heads.push_back(local_store->get_head()); + } + } + + return build_commit_graph_dot(heads); +} + } // end branching } // end gitmem \ No newline at end of file diff --git a/src/branching/base_sync_protocol.hh b/src/branching/base_sync_protocol.hh index 2a28208..8880fe9 100644 --- a/src/branching/base_sync_protocol.hh +++ b/src/branching/base_sync_protocol.hh @@ -40,6 +40,8 @@ public: std::ostream &print(std::ostream &os) const override; + std::string build_revision_graph_dot(const std::vector& thread_states) const override; + std::unique_ptr make_lock_state() const override { return std::make_unique(); } diff --git a/src/branching/base_version_store.cc b/src/branching/base_version_store.cc index dde7ff2..be820e7 100644 --- a/src/branching/base_version_store.cc +++ b/src/branching/base_version_store.cc @@ -1,6 +1,8 @@ #include "base_version_store.hh" #include #include +#include +#include #include "debug.hh" namespace gitmem { @@ -174,6 +176,57 @@ bool can_reach(const std::shared_ptr& commit, const std::shared_pt return false; } +std::string build_commit_graph_dot(const std::vector>& leaves) { + std::ostringstream dot; + dot << "digraph CommitGraph {\n"; + dot << " rankdir=BT;\n"; // bottom (leaves) → top (roots) + dot << " node [shape=box];\n"; + + std::unordered_set visited; + std::stack> stack; + + for (const auto& leaf : leaves) + if (leaf) stack.push(leaf); + + while (!stack.empty()) { + auto commit = stack.top(); + stack.pop(); + + if (!commit || !visited.insert(commit.get()).second) + continue; + + const std::string cid = to_string(commit->id); + + // Build label with commit ID and changes + std::ostringstream label; + label << cid; + if (!commit->changes.empty()) { + label << "\\n"; + bool first = true; + for (const auto& [obj, val] : commit->changes) { + if (!first) label << "\\n"; + first = false; + label << obj << "→" << val; + } + } + + // Emit node with label + dot << " \"" << cid << "\" [label=\"" << label.str() << "\"];\n"; + + // Emit edges to parents + for (const auto& parent : commit->parents) { + if (!parent) continue; + + const std::string pid = to_string(parent->id); + dot << " \"" << cid << "\" -> \"" << pid << "\";\n"; + stack.push(parent); + } + } + + dot << "}\n"; + return dot.str(); +} + } // branching } // gitmem \ No newline at end of file diff --git a/src/branching/base_version_store.hh b/src/branching/base_version_store.hh index f3c6f58..1848180 100644 --- a/src/branching/base_version_store.hh +++ b/src/branching/base_version_store.hh @@ -6,6 +6,7 @@ #include #include #include +#include #include "thread_id.hh" #include "sync_state.hh" #include "read_result.hh" @@ -40,6 +41,12 @@ struct Timestamp { } }; +inline std::string to_string(const Timestamp& ts) { + std::ostringstream ss; + ss << ts; + return ss.str(); +} + using ObjectNumber = uint64_t; struct Commit { @@ -48,6 +55,8 @@ struct Commit { std::vector> parents; }; +std::string build_commit_graph_dot(const std::vector>& leaves); + bool can_reach(const std::shared_ptr& commit, const std::shared_ptr& lca, std::unordered_map, bool>& memo); std::ostream& operator<<(std::ostream& os, const Commit& commit); diff --git a/src/branching/lazy/version_store.cc b/src/branching/lazy/version_store.cc index 935babb..bc30d06 100644 --- a/src/branching/lazy/version_store.cc +++ b/src/branching/lazy/version_store.cc @@ -108,7 +108,7 @@ std::optional LazyLocalVersionStore::merge_with_commit(const std::shar head = merge_commit; // whenever we merge, we loose all the information about the last writer - last_writer.clear(); + // last_writer.clear(); return std::nullopt; } @@ -120,18 +120,29 @@ std::optional LazyLocalVersionStore::merge_with_commit(const std::shar // } BranchingReadResult LazyLocalVersionStore::get_committed(ObjectNumber number) const { - + // Thought, if we merge two paths that conflict on a variable, but we never read it + // if (auto it = last_writer.find(number); it != last_writer.end()) { + // return it->second->changes.at(number); + // } std::vector> writers; std::unordered_map, bool> reach_memo; + // std::cout << "Get committed for " << number << std::endl; + std::function)> dfs; dfs = [&](std::shared_ptr c) { if (!c) return; + // std::cout << c->id << " changes: {"; + // for (const auto& [k, v] : c->changes) { + // std::cout << k << "->" << v << ","; + // } + // std::cout << "}" << std::endl; // If we've already found a writer that is an ancestor of c, skip for (auto it = writers.begin(); it != writers.end(); ) { if (can_reach(c, *it, reach_memo)) { + // std::cout << c->id << " can reach " << (*it)->id << std::endl; // existing writer is ancestor of this commit, remove it it = writers.erase(it); } else if (can_reach(*it, c, reach_memo)) { @@ -153,8 +164,13 @@ BranchingReadResult LazyLocalVersionStore::get_committed(ObjectNumber number) co dfs(head); + // std::cout << "=====================================" << std::endl; + if (writers.empty()) return std::monostate{}; - if (writers.size() == 1) return writers[0]->changes.at(number); + if (writers.size() == 1) { + // last_writer[number] = writers[0]; + return writers[0]->changes.at(number); + } // conflict auto a = writers[0]->id; diff --git a/src/interpreter.cc b/src/interpreter.cc index 8822a37..20c5263 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -1,6 +1,7 @@ #include #include #include +#include #include "debug.hh" #include "interpreter.hh" @@ -39,15 +40,6 @@ static bool is_syncing(Thread &thread) { ((thread.pc >= thread.block->size()) || is_syncing(thread.block->at(thread.pc))); } -// Helper to combine multiple lambdas for std::visit -template -struct overloaded : Ts... { - using Ts::operator()...; -}; - -// deduction guide -template overloaded(Ts...) -> overloaded; - /* Evaluating an expression either returns the result of the expression or * a the exceptional termination status of the thread. */ @@ -70,20 +62,20 @@ Interpreter::evaluate_expression(trieste::Node expr, Thread& thread) { auto result = gctx.protocol->read(ctx, var); return std::visit(overloaded{ - [&](std::monostate) -> std::variant { - // invalid: reading a variable that hasn't been written - return termination::UnassignedRead(var); - }, - [&](Value value) -> std::variant { - // normal read - thread.trace.on_read(var, value); - return value; - }, - [&](std::shared_ptr& conflict) -> std::variant { - verbose << (*conflict) << std::endl; - return termination::DataRace(conflict); - } -}, result); + [&](std::monostate) -> std::variant { + // invalid: reading a variable that hasn't been written + return termination::UnassignedRead(var); + }, + [&](Value value) -> std::variant { + // normal read + thread.trace.on_read(var, value); + return value; + }, + [&](std::shared_ptr& conflict) -> std::variant { + verbose << (*conflict) << std::endl; + return termination::DataRace(conflict); + } + }, result); } else if (e == lang::Const) { return size_t(std::stoi(std::string(e->location().view()))); } else if (e == lang::Add) { @@ -485,15 +477,17 @@ int Interpreter::run() { } } + verbose << *gctx.protocol << std::endl; + return exception_detected ? 1 : 0; } void Interpreter::print_thread_traces() { for (size_t tid = 0; tid < gctx.threads.size(); ++tid) { const auto& thread = gctx.threads[tid]; - std::cout << "=== Thread " << tid << " ===" << std::endl; - std::cout << thread.trace; - std::cout << "====================================\n"; + verbose << "=== Thread " << tid << " ===" << std::endl; + verbose << thread.trace; + verbose << "====================================\n"; } } @@ -573,8 +567,29 @@ int interpret(const Node ast, const std::filesystem::path &output_path, SyncKind interp.print_thread_traces(); - // auto exec_graph = interp.build_execution_graph_from_traces(); + // Build and output revision graph - collect const raw pointers + std::vector thread_state_ptrs; + for (const auto& thread : interp.context().threads) { + thread_state_ptrs.push_back(thread.ctx.sync.get()); + } + std::string dot = interp.context().protocol->build_revision_graph_dot(thread_state_ptrs); + if (!dot.empty()) { + verbose << "=== Revision Graph ===" << std::endl; + verbose << dot << std::endl; + + // Write to file + auto dot_file = "revision_graph.dot"; + std::ofstream out(dot_file); + if (out) { + out << dot; + verbose << "Revision graph written to " << dot_file << std::endl; + } else { + verbose << "Failed to write revision graph to " << dot_file << std::endl; + } + } + + // auto exec_graph = interp.build_execution_graph_from_traces(); // graph::GraphvizPrinter gv(output_path); // gv.visit(node.get()); diff --git a/src/linear/sync_protocol.cc b/src/linear/sync_protocol.cc index 519fbf8..5d29e86 100644 --- a/src/linear/sync_protocol.cc +++ b/src/linear/sync_protocol.cc @@ -19,6 +19,12 @@ std::ostream &LinearSyncProtocol::print(std::ostream &os) const { return os; } +std::string LinearSyncProtocol::build_revision_graph_dot( + const std::vector& thread_states) const { + // Linear protocol doesn't have a commit graph structure + return ""; +} + std::optional LinearSyncProtocol::push(LocalVersionStore &local) { if (auto conflict = _global_store.check_conflicts(local.base_timestamp(), @@ -84,9 +90,6 @@ void LinearSyncProtocol::write(ThreadContext &ctx, const std::string &var, std::optional> LinearSyncProtocol::on_spawn(ThreadContext &parent, ThreadContext &child) { - // TODO: i think we can drop the globalcontext but check after branching is - // added - // push parent to global history auto& store = get_store(parent); if (auto conflict = push(store)) diff --git a/src/linear/sync_protocol.hh b/src/linear/sync_protocol.hh index d8c2af6..333c33d 100644 --- a/src/linear/sync_protocol.hh +++ b/src/linear/sync_protocol.hh @@ -46,6 +46,8 @@ public: std::ostream &print(std::ostream &os) const override; + std::string build_revision_graph_dot(const std::vector& thread_states) const override; + std::unique_ptr make_thread_state(ThreadID tid) const override { return std::make_unique(); } diff --git a/src/model_checker.cc b/src/model_checker.cc index dd1abef..c26166a 100644 --- a/src/model_checker.cc +++ b/src/model_checker.cc @@ -188,9 +188,9 @@ int model_check(const Node ast, const std::filesystem::path &output_path, SyncKi for (size_t tid = 0; tid < ctx->threads.size(); ++tid) { const auto& thread = ctx->threads[tid]; - std::cout << "=== Thread " << tid << " ===" << std::endl; - std::cout << thread.trace; - std::cout << "====================================\n"; + verbose << "=== Thread " << tid << " ===" << std::endl; + verbose << thread.trace; + verbose << "====================================\n"; } // ctx->print_execution_graph(path); } diff --git a/src/sync_protocol.hh b/src/sync_protocol.hh index 684dcf0..f3960f3 100644 --- a/src/sync_protocol.hh +++ b/src/sync_protocol.hh @@ -44,12 +44,15 @@ public: virtual std::optional> on_unlock(ThreadContext &thread, Lock &lock) = 0; + virtual std::string build_revision_graph_dot(const std::vector& thread_states) const = 0; virtual std::ostream &print(std::ostream &os) const = 0; friend std::ostream &operator<<(std::ostream &os, const SyncProtocol &protocol) { return protocol.print(os); } + + }; } // namespace gitmem diff --git a/src/thread_trace.hh b/src/thread_trace.hh index 0dfe830..6321a5c 100644 --- a/src/thread_trace.hh +++ b/src/thread_trace.hh @@ -178,7 +178,6 @@ private: } }; - // --- operator<< for ThreadTrace --- inline std::ostream& operator<<(std::ostream& os, const ThreadTrace& tt) { os << "ThreadTrace[" << tt.trace.size() << " events]:\n"; From 0d7bb0f88833b629f844d23dbc61d9ee7928a8ac Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Fri, 9 Jan 2026 16:19:38 +0100 Subject: [PATCH 41/58] adding a print out for revision history in linear, and better using the timestamp --- src/branching/base_sync_protocol.cc | 12 ++--- src/branching/base_version_store.cc | 31 ++--------- src/branching/base_version_store.hh | 23 +++----- src/branching/eager/version_store.cc | 10 ++-- src/branching/eager/version_store.hh | 2 +- src/branching/lazy/version_store.cc | 16 +++--- src/branching/lazy/version_store.hh | 2 +- src/linear/sync_protocol.cc | 71 +++++++++++++++++++------ src/linear/sync_protocol.hh | 2 +- src/linear/version_store.cc | 67 ++++++++--------------- src/linear/version_store.hh | 79 +++++++++++----------------- 11 files changed, 141 insertions(+), 174 deletions(-) diff --git a/src/branching/base_sync_protocol.cc b/src/branching/base_sync_protocol.cc index d40cfe3..04945b0 100644 --- a/src/branching/base_sync_protocol.cc +++ b/src/branching/base_sync_protocol.cc @@ -26,26 +26,24 @@ std::ostream &BranchingSyncProtocolBase::print(std::ostream &os) const { ReadResult BranchingSyncProtocolBase::read(ThreadContext &ctx, const std::string &var) { - ObjectNumber number = _global_store.get_object_number(var); auto& store = get_store(ctx); - // convert the branching read result into a regular read result return std::visit(overloaded{ [](std::monostate) -> ReadResult { return std::monostate{}; }, [](const Value& v) -> ReadResult { return v; }, [&](const Conflict& c) -> ReadResult { return std::make_shared( - _global_store.get_object_name(c.obj), std::pair{c.timestamp_a, c.timestamp_b} + c.obj, std::pair{c.timestamp_a, c.timestamp_b} ); }, - }, store.read(number)); + }, store.read(var)); } void BranchingSyncProtocolBase::write(ThreadContext &ctx, const std::string &var, size_t value) { auto& store = get_store(ctx); - store.stage(_global_store.get_object_number(var), value); + store.stage(var, value); } std::optional> @@ -71,7 +69,7 @@ BranchingSyncProtocolBase::on_join(ThreadContext &joiner, ThreadContext &joinee) std::optional conflict = joiner_store.merge_with_commit(joinee_store.get_head()); if (conflict) { return std::make_shared( - _global_store.get_object_name(conflict->obj), + conflict->obj, std::make_pair(conflict->timestamp_a, conflict->timestamp_b)); } @@ -104,7 +102,7 @@ BranchingSyncProtocolBase::on_lock(ThreadContext &thread, Lock &lock) { std::optional conflict = store.merge_with_commit(lock_commit); if (conflict) { return std::make_shared( - _global_store.get_object_name(conflict->obj), + conflict->obj, std::make_pair(conflict->timestamp_a, conflict->timestamp_b)); } } diff --git a/src/branching/base_version_store.cc b/src/branching/base_version_store.cc index be820e7..b1a9ec2 100644 --- a/src/branching/base_version_store.cc +++ b/src/branching/base_version_store.cc @@ -54,7 +54,7 @@ std::ostream& operator<<(std::ostream& os, const Commit& commit) { return os; } -void LocalVersionStore::stage(ObjectNumber obj, Value value) { +void LocalVersionStore::stage(std::string obj, Value value) { staging[obj] = value; } @@ -83,12 +83,12 @@ void LocalVersionStore::commit_staging() { head = new_commit; } -BranchingReadResult LocalVersionStore::read(ObjectNumber obj) const { - auto it = staging.find(obj); +BranchingReadResult LocalVersionStore::read(std::string var) const { + auto it = staging.find(var); if (it != staging.end()) return it->second; - return get_committed(obj); + return get_committed(var); } void LocalVersionStore::adopt_history(const LocalVersionStore& other) { @@ -128,29 +128,8 @@ bool LocalVersionStore::operator==(const LocalVersionStore& other) const { staging == other.staging; } -ObjectNumber GlobalVersionStore::get_object_number(std::string var) { - auto it = _object_numbers.find(var); - if (it != _object_numbers.end()) { - return it->second; - } else { - ObjectNumber number = _next_object++; - _object_numbers[var] = number; - return number; - } -} - -std::string GlobalVersionStore::get_object_name(ObjectNumber find) { - for (const auto &[name, number] : _object_numbers) { - if (number == find) - return name; - } - assert(false && "failed to find object name for object number"); - return ""; -} - - std::ostream& operator<<(std::ostream& os, const GlobalVersionStore& store) { - os << "GlobalVersionStore(next_object=" << store._next_object << ")" << std::endl; + os << "GlobalVersionStore()" << std::endl; return os; } diff --git a/src/branching/base_version_store.hh b/src/branching/base_version_store.hh index 1848180..36db0f8 100644 --- a/src/branching/base_version_store.hh +++ b/src/branching/base_version_store.hh @@ -47,11 +47,9 @@ inline std::string to_string(const Timestamp& ts) { return ss.str(); } -using ObjectNumber = uint64_t; - struct Commit { Timestamp id; - std::unordered_map changes; + std::unordered_map changes; std::vector> parents; }; @@ -62,7 +60,7 @@ bool can_reach(const std::shared_ptr& commit, const std::shared_pt std::ostream& operator<<(std::ostream& os, const Commit& commit); struct Conflict { - ObjectNumber obj; + std::string obj; Timestamp timestamp_a; Timestamp timestamp_b; }; @@ -79,16 +77,16 @@ class LocalVersionStore : public ThreadSyncState { protected: Timestamp base_timestamp; std::shared_ptr head; - std::unordered_map staging; + std::unordered_map staging; - std::unordered_map> last_writer; // cached + std::unordered_map> last_writer; // cached public: ~LocalVersionStore() = default; LocalVersionStore(ThreadID tid): base_timestamp(tid, 0) {} - void stage(ObjectNumber obj, Value value); + void stage(std::string obj, Value value); void commit_staging(); bool has_commited() { return staging.empty(); } @@ -96,10 +94,10 @@ public: std::shared_ptr get_head() const { return head; } private: - virtual BranchingReadResult get_committed(ObjectNumber number) const = 0; + virtual BranchingReadResult get_committed(std::string var) const = 0; public: - BranchingReadResult read(ObjectNumber number) const; + BranchingReadResult read(std::string var) const; void adopt_history(const LocalVersionStore& other); virtual std::optional merge_with_commit(const std::shared_ptr& other_head) = 0; @@ -120,14 +118,7 @@ public: }; class GlobalVersionStore { - ObjectNumber _next_object{0}; - std::unordered_map _object_numbers; - public: - - ObjectNumber get_object_number(std::string); - std::string get_object_name(ObjectNumber); - friend std::ostream& operator<<(std::ostream&, const GlobalVersionStore&); }; diff --git a/src/branching/eager/version_store.cc b/src/branching/eager/version_store.cc index 71e758c..131f6d9 100644 --- a/src/branching/eager/version_store.cc +++ b/src/branching/eager/version_store.cc @@ -10,7 +10,7 @@ namespace branching { bool traverse_until_lca( const std::shared_ptr& commit, const std::shared_ptr& lca, - std::unordered_map>& out_map, + std::unordered_map>& out_map, std::unordered_set>& visited, std::unordered_map, bool>& reach_memo) { @@ -95,7 +95,7 @@ std::optional EagerLocalVersionStore::merge_with_commit(const std::sha verbose << "found lca of " << head->id << " and " << commit->id << " to be " << lca->id << std::endl; // Collect all writes after LCA for each branch - std::unordered_map> branch_a, branch_b; + std::unordered_map> branch_a, branch_b; std::unordered_set> visited; std::unordered_map, bool> reach_memo; @@ -131,9 +131,9 @@ std::optional EagerLocalVersionStore::merge_with_commit(const std::sha return std::nullopt; } -BranchingReadResult EagerLocalVersionStore::get_committed(ObjectNumber number) const { - if (auto it = last_writer.find(number); it != last_writer.end()) - return it->second->changes.at(number); +BranchingReadResult EagerLocalVersionStore::get_committed(std::string var) const { + if (auto it = last_writer.find(var); it != last_writer.end()) + return it->second->changes.at(var); return std::monostate{}; } diff --git a/src/branching/eager/version_store.hh b/src/branching/eager/version_store.hh index 2920eb0..3376788 100644 --- a/src/branching/eager/version_store.hh +++ b/src/branching/eager/version_store.hh @@ -13,7 +13,7 @@ public: EagerLocalVersionStore(ThreadID tid) : LocalVersionStore(tid) {} std::optional merge_with_commit(const std::shared_ptr&) override; - BranchingReadResult get_committed(ObjectNumber number) const override; + BranchingReadResult get_committed(std::string var) const override; }; diff --git a/src/branching/lazy/version_store.cc b/src/branching/lazy/version_store.cc index bc30d06..1aafb34 100644 --- a/src/branching/lazy/version_store.cc +++ b/src/branching/lazy/version_store.cc @@ -119,16 +119,16 @@ std::optional LazyLocalVersionStore::merge_with_commit(const std::shar // return it->second->changes.at(number); // } -BranchingReadResult LazyLocalVersionStore::get_committed(ObjectNumber number) const { +BranchingReadResult LazyLocalVersionStore::get_committed(std::string var) const { // Thought, if we merge two paths that conflict on a variable, but we never read it - // if (auto it = last_writer.find(number); it != last_writer.end()) { - // return it->second->changes.at(number); + // if (auto it = last_writer.find(var); it != last_writer.end()) { + // return it->second->changes.at(var); // } std::vector> writers; std::unordered_map, bool> reach_memo; - // std::cout << "Get committed for " << number << std::endl; + // std::cout << "Get committed for " << var << std::endl; std::function)> dfs; dfs = [&](std::shared_ptr c) { @@ -153,7 +153,7 @@ BranchingReadResult LazyLocalVersionStore::get_committed(ObjectNumber number) co } } - if (c->changes.contains(number)) { + if (c->changes.contains(var)) { writers.push_back(c); return; } @@ -168,14 +168,14 @@ BranchingReadResult LazyLocalVersionStore::get_committed(ObjectNumber number) co if (writers.empty()) return std::monostate{}; if (writers.size() == 1) { - // last_writer[number] = writers[0]; - return writers[0]->changes.at(number); + // last_writer[var] = writers[0]; + return writers[0]->changes.at(var); } // conflict auto a = writers[0]->id; auto b = writers[1]->id; - return Conflict(number, a, b); + return Conflict(var, a, b); } diff --git a/src/branching/lazy/version_store.hh b/src/branching/lazy/version_store.hh index 881305f..9254b86 100644 --- a/src/branching/lazy/version_store.hh +++ b/src/branching/lazy/version_store.hh @@ -13,7 +13,7 @@ public: LazyLocalVersionStore(ThreadID tid) : LocalVersionStore(tid) {} std::optional merge_with_commit(const std::shared_ptr&) override; - BranchingReadResult get_committed(ObjectNumber number) const override; + BranchingReadResult get_committed(std::string var) const override; }; diff --git a/src/linear/sync_protocol.cc b/src/linear/sync_protocol.cc index 5d29e86..7057c2c 100644 --- a/src/linear/sync_protocol.cc +++ b/src/linear/sync_protocol.cc @@ -1,6 +1,8 @@ #include "linear/sync_protocol.hh" #include "debug.hh" #include +#include +#include namespace gitmem { @@ -21,23 +23,64 @@ std::ostream &LinearSyncProtocol::print(std::ostream &os) const { std::string LinearSyncProtocol::build_revision_graph_dot( const std::vector& thread_states) const { - // Linear protocol doesn't have a commit graph structure - return ""; + + std::ostringstream dot; + dot << "digraph LinearHistory {\n"; + dot << " rankdir=BT;\n"; + dot << " node [shape=box];\n"; + + const auto& history = _global_store.get_history(); + + if (history.empty()) { + dot << "}\n"; + return dot.str(); + } + + // Create a subgraph for each variable showing its version history + for (const auto& [obj_name, versions] : history) { + dot << " subgraph cluster_" << obj_name << " {\n"; + dot << " label=\"" << obj_name << "\";\n"; + dot << " style=dashed;\n"; + + // Create nodes for each version + for (size_t i = 0; i < versions.size(); ++i) { + const auto& version = versions[i]; + std::ostringstream node_id; + node_id << obj_name << "_v" << i; + + std::ostringstream label; + label << version.timestamp() << "\\n" << obj_name << "=" << version.value(); + + dot << " \"" << node_id.str() << "\" [label=\"" << label.str() << "\"];\n"; + } + + // Create edges between consecutive versions + for (size_t i = 1; i < versions.size(); ++i) { + std::ostringstream prev_id, curr_id; + prev_id << obj_name << "_v" << (i - 1); + curr_id << obj_name << "_v" << i; + dot << " \"" << curr_id.str() << "\" -> \"" << prev_id.str() << "\";\n"; + } + + dot << " }\n"; + } + + dot << "}\n"; + return dot.str(); } std::optional LinearSyncProtocol::push(LocalVersionStore &local) { - if (auto conflict = _global_store.check_conflicts(local.base_timestamp(), + if (auto conflict = _global_store.check_conflicts(local.timestamp(), local.staged_changes())) { - // reshape the conflict return std::make_optional( - _global_store.get_object_name(conflict->object), + conflict->object, std::make_pair(conflict->local_base, conflict->global_head)); } - Timestamp new_base = _global_store.apply_changes( - local.base_timestamp(), local.staged_changes()); + uint64_t new_base = _global_store.apply_changes( + local.thread(), local.timestamp(), local.staged_changes()); local.clear_staging(); local.advance_base(new_base); @@ -46,15 +89,15 @@ LinearSyncProtocol::push(LocalVersionStore &local) { std::optional LinearSyncProtocol::pull(LocalVersionStore &local) { - if (auto conflict = _global_store.check_conflicts(local.base_timestamp(), + if (auto conflict = _global_store.check_conflicts(local.timestamp(), local.staged_changes())) { return std::make_optional( - _global_store.get_object_name(conflict->object), + conflict->object, std::make_pair(conflict->local_base, conflict->global_head)); } - local.advance_base(_global_store.current_timestamp()); + local.advance_base( _global_store.current_counter()); return std::nullopt; } @@ -62,15 +105,13 @@ LinearSyncProtocol::~LinearSyncProtocol() = default; ReadResult LinearSyncProtocol::read(ThreadContext &ctx, const std::string &var) { - ObjectNumber number = _global_store.get_object_number(var); - auto& store = get_store(ctx); - if (auto result = store.get_staged(number)) + if (auto result = store.get_staged(var)) return *result; std::optional value = _global_store.get_version_for_timestamp( - number, store.base_timestamp()); + var, store.timestamp()); if (value) return *value; @@ -85,7 +126,7 @@ void LinearSyncProtocol::write(ThreadContext &ctx, const std::string &var, size_t value) { // write into the staging area of the thread auto& store = get_store(ctx); - store.stage(_global_store.get_object_number(var), value); + store.stage(var, value); } std::optional> diff --git a/src/linear/sync_protocol.hh b/src/linear/sync_protocol.hh index 333c33d..32e823c 100644 --- a/src/linear/sync_protocol.hh +++ b/src/linear/sync_protocol.hh @@ -49,7 +49,7 @@ public: std::string build_revision_graph_dot(const std::vector& thread_states) const override; std::unique_ptr make_thread_state(ThreadID tid) const override { - return std::make_unique(); + return std::make_unique(tid); } std::unique_ptr make_lock_state() const override { diff --git a/src/linear/version_store.cc b/src/linear/version_store.cc index a30e150..e572be5 100644 --- a/src/linear/version_store.cc +++ b/src/linear/version_store.cc @@ -12,27 +12,27 @@ namespace linear { // LocalVersionStore // ----------------------------- -void LocalVersionStore::stage(ObjectNumber obj, Value value) { +void LocalVersionStore::stage(std::string obj, Value value) { _staging[obj] = value; } void LocalVersionStore::clear_staging() { _staging.clear(); } -void LocalVersionStore::advance_base(Timestamp ts) { _base_timestamp = ts; } +void LocalVersionStore::advance_base(uint64_t ts) { _timestamp = ts; } -std::optional LocalVersionStore::get_staged(ObjectNumber obj) { +std::optional LocalVersionStore::get_staged(std::string obj) { auto it = _staging.find(obj); return it != _staging.end() ? std::make_optional(it->second) : std::nullopt; } bool LocalVersionStore::operator==(const LocalVersionStore& other) const { - return _base_timestamp == other._base_timestamp && + return _timestamp == other._timestamp && _staging == other._staging; } std::ostream& operator<<(std::ostream& os, const LocalVersionStore& store) { os << "LocalVersionStore{" - << "base=" << store._base_timestamp + << "base=" << store._timestamp << ", staged={"; bool first = true; @@ -50,29 +50,9 @@ std::ostream& operator<<(std::ostream& os, const LocalVersionStore& store) { // GlobalVersionStore // ----------------------------- -ObjectNumber GlobalVersionStore::get_object_number(std::string var) { - auto it = _object_numbers.find(var); - if (it != _object_numbers.end()) { - return it->second; - } else { - ObjectNumber number = _next_object++; - _object_numbers[var] = number; - return number; - } -} - -std::string GlobalVersionStore::get_object_name(ObjectNumber find) { - for (const auto &[name, number] : _object_numbers) { - if (number == find) - return name; - } - assert(false && "failed to find object name for object number"); - return ""; -} - std::optional -GlobalVersionStore::get_version_for_timestamp(ObjectNumber obj, - Timestamp ts) const { +GlobalVersionStore::get_version_for_timestamp(std::string obj, + uint64_t ts) const { const auto it = _history.find(obj); if (it == _history.end()) @@ -81,7 +61,7 @@ GlobalVersionStore::get_version_for_timestamp(ObjectNumber obj, const VersionHistory &history = it->second; for (VersionHistory::const_reverse_iterator riter = history.rbegin(); riter != history.rend(); ++riter) { - if (riter->timestamp() <= ts) + if (riter->timestamp().counter <= ts) return riter->value(); } @@ -89,8 +69,8 @@ GlobalVersionStore::get_version_for_timestamp(ObjectNumber obj, } std::optional GlobalVersionStore::check_conflicts( - Timestamp base, - const std::unordered_map &changes) const { + uint64_t base, + const std::unordered_map &changes) const { for (const auto &[obj, _] : changes) { auto it = _history.find(obj); if (it == _history.end()) { @@ -98,7 +78,7 @@ std::optional GlobalVersionStore::check_conflicts( } const Version &latest = it->second.back(); - if (latest.timestamp() > base) { + if (latest.timestamp().counter > base) { return Conflict{ .object = obj, .local_base = base, .global_head = latest.timestamp()}; } @@ -106,32 +86,27 @@ std::optional GlobalVersionStore::check_conflicts( return std::nullopt; } -Timestamp GlobalVersionStore::apply_changes( - Timestamp base, const std::unordered_map &changes) { +uint64_t GlobalVersionStore::apply_changes( + ThreadID tid, uint64_t base, const std::unordered_map &changes) { if (auto conflict = check_conflicts(base, changes)) { throw std::logic_error("apply_changes called with conflicts"); } - Timestamp new_ts = ++_timestamp; + // Increment the global counter and create new timestamp with thread info from base + Timestamp new_ts{tid, ++_counter}; + for (const auto &[obj, value] : changes) { _history[obj].emplace_back(new_ts, value); } - _timestamp = new_ts; - return new_ts; + return _counter; } std::ostream& operator<<(std::ostream& os, const GlobalVersionStore& store) { - os << "GlobalVersionStore(timestamp=" << store._timestamp - << ", next_object=" << store._next_object << ")\n"; - - for (const auto& [obj_num, history] : store._history) { - os << " Object " << obj_num; - auto it = std::find_if(store._object_numbers.begin(), store._object_numbers.end(), - [&](const auto& pair){ return pair.second == obj_num; }); - if (it != store._object_numbers.end()) - os << " (" << it->first << ")"; - os << ":\n"; + os << "GlobalVersionStore(counter=" << store._counter << ")\n"; + + for (const auto& [obj_name, history] : store._history) { + os << " Object " << obj_name << ":\n"; for (const auto& version : history) { os << " [" << version.timestamp() << "] = " << version.value() << "\n"; diff --git a/src/linear/version_store.hh b/src/linear/version_store.hh index a5f6f39..d6dbbe9 100644 --- a/src/linear/version_store.hh +++ b/src/linear/version_store.hh @@ -16,40 +16,19 @@ namespace linear { // Timestamp // ----------------------------- -class LargeCounter { - uint64_t _epoch{0}; - uint64_t _counter{0}; - -public: - auto operator<=>(const LargeCounter &) const = default; - - LargeCounter &operator++() { - if (_counter == UINT64_MAX) { - _counter = 0; - assert(_epoch != UINT64_MAX && "timestamp overflow"); - ++_epoch; - } else { - ++_counter; - } - return *this; - } +struct Timestamp { + size_t thread{0}; + uint64_t counter{0}; - LargeCounter operator++(int) { - LargeCounter old = *this; - ++(*this); - return old; - } + auto operator<=>(const Timestamp &) const = default; - friend std::ostream &operator<<(std::ostream &os, - const LargeCounter &counter) { - os << counter._epoch << ":" << counter._counter; + friend std::ostream &operator<<(std::ostream &os, const Timestamp &ts) { + os << "t" << ts.thread << ":" << ts.counter; return os; } }; -using Timestamp = LargeCounter; using Value = size_t; -using ObjectNumber = uint64_t; // ----------------------------- // Version @@ -73,7 +52,7 @@ using VersionHistory = std::vector; // ----------------------------- struct Conflict { - ObjectNumber object; + std::string object; Timestamp local_base; Timestamp global_head; }; @@ -83,19 +62,24 @@ struct Conflict { // ----------------------------- class LocalVersionStore : public ThreadSyncState { - Timestamp _base_timestamp{}; - std::unordered_map _staging; + ThreadID tid; + uint64_t _timestamp; + std::unordered_map _staging; public: ~LocalVersionStore() = default; - Timestamp base_timestamp() const { return _base_timestamp; } + LocalVersionStore(ThreadID tid) : tid(tid), _timestamp(0) {} + + ThreadID thread() const { return tid; } + + uint64_t timestamp() const { return _timestamp; } const auto &staged_changes() const { return _staging; } - void stage(ObjectNumber obj, Value value); + void stage(std::string obj, Value value); void clear_staging(); - void advance_base(Timestamp ts); - std::optional get_staged(ObjectNumber obj); + void advance_base(uint64_t ts); + std::optional get_staged(std::string obj); bool operator==(const LocalVersionStore& other) const; @@ -119,28 +103,27 @@ public: // ----------------------------- class GlobalVersionStore { - Timestamp _timestamp{}; - ObjectNumber _next_object{0}; - std::unordered_map _history; - std::unordered_map _object_numbers; + uint64_t _counter{0}; + std::unordered_map _history; public: - Timestamp current_timestamp() const { return _timestamp; } - - ObjectNumber get_object_number(std::string); - std::string get_object_name(ObjectNumber); + uint64_t current_counter() const { return _counter; } - std::optional get_version_for_timestamp(ObjectNumber, Timestamp) const; + std::optional get_version_for_timestamp(std::string, uint64_t) const; std::optional - check_conflicts(Timestamp base, - const std::unordered_map &changes) const; + check_conflicts(uint64_t base, + const std::unordered_map &changes) const; - Timestamp - apply_changes(Timestamp base, - const std::unordered_map &changes); + uint64_t + apply_changes(ThreadID tid, uint64_t base, + const std::unordered_map &changes); friend std::ostream& operator<<(std::ostream&, const GlobalVersionStore&); + + std::unordered_map get_history() const { + return _history; + } }; } // namespace linear From f6e9550414a454f9a76dd8504d71209033c28676 Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Fri, 9 Jan 2026 16:34:17 +0100 Subject: [PATCH 42/58] better graph for hierachical model --- .../semantics/branching/conditional_race.gm | 3 +- src/branching/base_version_store.cc | 62 ++++++++++++++----- 2 files changed, 48 insertions(+), 17 deletions(-) diff --git a/examples/reject/semantics/branching/conditional_race.gm b/examples/reject/semantics/branching/conditional_race.gm index 2a86c78..427be52 100644 --- a/examples/reject/semantics/branching/conditional_race.gm +++ b/examples/reject/semantics/branching/conditional_race.gm @@ -32,5 +32,4 @@ $t2 = spawn { }; join $t1; join $t2; -assert (x != y); -assert(flag == flag); \ No newline at end of file +assert (x != y); \ No newline at end of file diff --git a/src/branching/base_version_store.cc b/src/branching/base_version_store.cc index b1a9ec2..ba715a4 100644 --- a/src/branching/base_version_store.cc +++ b/src/branching/base_version_store.cc @@ -4,6 +4,7 @@ #include #include #include "debug.hh" +#include namespace gitmem { @@ -158,12 +159,14 @@ bool can_reach(const std::shared_ptr& commit, const std::shared_pt std::string build_commit_graph_dot(const std::vector>& leaves) { std::ostringstream dot; dot << "digraph CommitGraph {\n"; - dot << " rankdir=BT;\n"; // bottom (leaves) → top (roots) + dot << " rankdir=BT;\n"; dot << " node [shape=box];\n"; std::unordered_set visited; + std::unordered_map>> commits_by_thread; std::stack> stack; + // First pass: collect all commits and organize by thread for (const auto& leaf : leaves) if (leaf) stack.push(leaf); @@ -174,25 +177,54 @@ std::string build_commit_graph_dot(const std::vectorid); + commits_by_thread[commit->id.thread].push_back(commit); + + for (const auto& parent : commit->parents) { + if (parent) stack.push(parent); + } + } - // Build label with commit ID and changes - std::ostringstream label; - label << cid; - if (!commit->changes.empty()) { - label << "\\n"; - bool first = true; - for (const auto& [obj, val] : commit->changes) { - if (!first) label << "\\n"; - first = false; - label << obj << "→" << val; + // Create subgraph clusters for each thread + for (const auto& [thread_id, commits] : commits_by_thread) { + dot << " subgraph cluster_" << thread_id << " {\n"; + dot << " label=\"Thread " << thread_id << "\";\n"; + dot << " style=dashed;\n"; + + for (const auto& commit : commits) { + const std::string cid = to_string(commit->id); + + std::ostringstream label; + label << cid; + if (!commit->changes.empty()) { + label << "\\n"; + bool first = true; + for (const auto& [obj, val] : commit->changes) { + if (!first) label << "\\n"; + first = false; + label << obj << "→" << val; + } } + + dot << " \"" << cid << "\" [label=\"" << label.str() << "\"];\n"; } - // Emit node with label - dot << " \"" << cid << "\" [label=\"" << label.str() << "\"];\n"; + dot << " }\n"; + } + + // Draw edges (outside clusters so they can cross boundaries) + visited.clear(); + for (const auto& leaf : leaves) + if (leaf) stack.push(leaf); + + while (!stack.empty()) { + auto commit = stack.top(); + stack.pop(); + + if (!commit || !visited.insert(commit.get()).second) + continue; + + const std::string cid = to_string(commit->id); - // Emit edges to parents for (const auto& parent : commit->parents) { if (!parent) continue; From 8f8bc0228bbfbc3d00b9349f1eb9b8aa27c51b87 Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Fri, 9 Jan 2026 16:55:27 +0100 Subject: [PATCH 43/58] fixed the lazy test suites --- src/branching/lazy/version_store.cc | 51 +++++++++++++++-------------- src/interpreter.cc | 3 -- 2 files changed, 26 insertions(+), 28 deletions(-) diff --git a/src/branching/lazy/version_store.cc b/src/branching/lazy/version_store.cc index 1aafb34..46905f5 100644 --- a/src/branching/lazy/version_store.cc +++ b/src/branching/lazy/version_store.cc @@ -120,39 +120,40 @@ std::optional LazyLocalVersionStore::merge_with_commit(const std::shar // } BranchingReadResult LazyLocalVersionStore::get_committed(std::string var) const { - // Thought, if we merge two paths that conflict on a variable, but we never read it - // if (auto it = last_writer.find(var); it != last_writer.end()) { - // return it->second->changes.at(var); - // } - std::vector> writers; - std::unordered_map, bool> reach_memo; - // std::cout << "Get committed for " << var << std::endl; + std::cout << "Get committed for " << var << std::endl; std::function)> dfs; dfs = [&](std::shared_ptr c) { if (!c) return; - // std::cout << c->id << " changes: {"; - // for (const auto& [k, v] : c->changes) { - // std::cout << k << "->" << v << ","; - // } - // std::cout << "}" << std::endl; - - // If we've already found a writer that is an ancestor of c, skip - for (auto it = writers.begin(); it != writers.end(); ) { - if (can_reach(c, *it, reach_memo)) { - // std::cout << c->id << " can reach " << (*it)->id << std::endl; - // existing writer is ancestor of this commit, remove it - it = writers.erase(it); - } else if (can_reach(*it, c, reach_memo)) { - // this commit is ancestor of existing writer, ignore this path - return; - } else { - ++it; + std::cout << c->id << " changes: {"; + for (const auto& [k, v] : c->changes) { + std::cout << k << "->" << v << ","; + } + std::cout << "}" << std::endl; + + // Check if c is an ancestor of any existing writer + { + std::unordered_map, bool> reach_memo; + for (const auto& writer : writers) { + if (can_reach(writer, c, reach_memo)) { + // c is ancestor of existing writer, ignore this path + return; + } } } + // Remove any existing writers that are ancestors of c + { + std::unordered_map, bool> reach_memo; + writers.erase( + std::remove_if(writers.begin(), writers.end(), + [&](const auto& writer) { return can_reach(c, writer, reach_memo); }), + writers.end() + ); + } + if (c->changes.contains(var)) { writers.push_back(c); return; @@ -164,7 +165,7 @@ BranchingReadResult LazyLocalVersionStore::get_committed(std::string var) const dfs(head); - // std::cout << "=====================================" << std::endl; + std::cout << "=====================================" << std::endl; if (writers.empty()) return std::monostate{}; if (writers.size() == 1) { diff --git a/src/interpreter.cc b/src/interpreter.cc index 20c5263..9128970 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -575,9 +575,6 @@ int interpret(const Node ast, const std::filesystem::path &output_path, SyncKind std::string dot = interp.context().protocol->build_revision_graph_dot(thread_state_ptrs); if (!dot.empty()) { - verbose << "=== Revision Graph ===" << std::endl; - verbose << dot << std::endl; - // Write to file auto dot_file = "revision_graph.dot"; std::ofstream out(dot_file); From 992696d8e07d29b00948df6fba69a7b5f055fa74 Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Fri, 9 Jan 2026 17:14:11 +0100 Subject: [PATCH 44/58] add printing of commit graphs on debug steps --- src/branching/lazy/version_store.cc | 9 ------- src/debugger.cc | 4 +++ src/interpreter.cc | 40 ++++++++++++++++------------- src/interpreter.hh | 2 ++ 4 files changed, 28 insertions(+), 27 deletions(-) diff --git a/src/branching/lazy/version_store.cc b/src/branching/lazy/version_store.cc index 46905f5..8d013a4 100644 --- a/src/branching/lazy/version_store.cc +++ b/src/branching/lazy/version_store.cc @@ -122,16 +122,9 @@ std::optional LazyLocalVersionStore::merge_with_commit(const std::shar BranchingReadResult LazyLocalVersionStore::get_committed(std::string var) const { std::vector> writers; - std::cout << "Get committed for " << var << std::endl; - std::function)> dfs; dfs = [&](std::shared_ptr c) { if (!c) return; - std::cout << c->id << " changes: {"; - for (const auto& [k, v] : c->changes) { - std::cout << k << "->" << v << ","; - } - std::cout << "}" << std::endl; // Check if c is an ancestor of any existing writer { @@ -165,8 +158,6 @@ BranchingReadResult LazyLocalVersionStore::get_committed(std::string var) const dfs(head); - std::cout << "=====================================" << std::endl; - if (writers.empty()) return std::monostate{}; if (writers.size() == 1) { // last_writer[var] = writers[0]; diff --git a/src/debugger.cc b/src/debugger.cc index cb05b81..4e43b0f 100644 --- a/src/debugger.cc +++ b/src/debugger.cc @@ -142,6 +142,7 @@ void maybe_print_graph(Interpreter& interp, const std::filesystem::path &output_file) { if (print_graphs) { // gctx.print_execution_graph(output_file); + interp.build_and_print_revision_graph(output_file); verbose << "Execution graph written to " << output_file << std::endl; } } @@ -221,6 +222,9 @@ int interpret_interactive(const trieste::Node ast, Command command = {Command::List}; bool print_graphs = true; + // clear the graph at the start + maybe_print_graph(interp, print_graphs, output_file); + while (command.cmd != Command::Quit) { // Print threads if new threads appeared or command is List if (command.cmd != Command::Skip || prev_no_threads != gctx.threads.size()) { diff --git a/src/interpreter.cc b/src/interpreter.cc index 9128970..2dc5b4c 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -491,6 +491,27 @@ void Interpreter::print_thread_traces() { } } +void Interpreter::build_and_print_revision_graph(const std::filesystem::path& output_path) { + // Build revision graph - collect const raw pointers + std::vector thread_state_ptrs; + for (const auto& thread : gctx.threads) { + thread_state_ptrs.push_back(thread.ctx.sync.get()); + } + + std::string dot = gctx.protocol->build_revision_graph_dot(thread_state_ptrs); + if (!dot.empty()) { + // Write to file + auto dot_file = output_path.parent_path() / "revision_graph.dot"; + std::ofstream out(dot_file); + if (out) { + out << dot; + verbose << "Revision graph written to " << dot_file << std::endl; + } else { + verbose << "Failed to write revision graph to " << dot_file << std::endl; + } + } +} + graph::ExecutionGraph Interpreter::build_execution_graph_from_traces() { // Map each thread ID to its last graph node in the execution graph std::unordered_map> thread_tails; @@ -567,24 +588,7 @@ int interpret(const Node ast, const std::filesystem::path &output_path, SyncKind interp.print_thread_traces(); - // Build and output revision graph - collect const raw pointers - std::vector thread_state_ptrs; - for (const auto& thread : interp.context().threads) { - thread_state_ptrs.push_back(thread.ctx.sync.get()); - } - - std::string dot = interp.context().protocol->build_revision_graph_dot(thread_state_ptrs); - if (!dot.empty()) { - // Write to file - auto dot_file = "revision_graph.dot"; - std::ofstream out(dot_file); - if (out) { - out << dot; - verbose << "Revision graph written to " << dot_file << std::endl; - } else { - verbose << "Failed to write revision graph to " << dot_file << std::endl; - } - } + interp.build_and_print_revision_graph(output_path); // auto exec_graph = interp.build_execution_graph_from_traces(); // graph::GraphvizPrinter gv(output_path); diff --git a/src/interpreter.hh b/src/interpreter.hh index d9bd2e7..9428c04 100644 --- a/src/interpreter.hh +++ b/src/interpreter.hh @@ -46,6 +46,8 @@ public: void print_thread_traces(); + void build_and_print_revision_graph(const std::filesystem::path& output_path); + graph::ExecutionGraph build_execution_graph_from_traces(); }; From 6415705e45fa7b50977c5b41b49707818fedcaa6 Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Sun, 11 Jan 2026 20:24:44 +0100 Subject: [PATCH 45/58] optional richer revisions graphs, graph printing seems to be working and interactive graphing too --- src/branching/base_sync_protocol.cc | 28 +++- src/branching/base_sync_protocol.hh | 33 ++++ src/branching/base_version_store.cc | 6 +- src/branching/base_version_store.hh | 4 +- src/branching/eager/sync_protocol.hh | 5 +- src/branching/eager/version_store.cc | 2 +- src/branching/eager/version_store.hh | 2 +- src/branching/lazy/sync_protocol.hh | 5 +- src/branching/lazy/version_store.cc | 80 ---------- src/branching/lazy/version_store.hh | 2 +- src/debug.hh | 7 +- src/debugger.cc | 6 +- src/gitmem.cc | 13 +- src/graph.hh | 3 +- src/graphviz.cc | 2 +- src/interpreter.cc | 221 ++++++++++++++++----------- src/interpreter.hh | 8 +- src/model_checker.cc | 22 +-- src/sync_protocol.cc | 25 ++- 19 files changed, 261 insertions(+), 213 deletions(-) diff --git a/src/branching/base_sync_protocol.cc b/src/branching/base_sync_protocol.cc index 04945b0..4d8fa41 100644 --- a/src/branching/base_sync_protocol.cc +++ b/src/branching/base_sync_protocol.cc @@ -1,5 +1,7 @@ #include "base_sync_protocol.hh" #include "overloaded.hh" +#include "branching/eager/sync_protocol.hh" +#include "branching/lazy/sync_protocol.hh" namespace gitmem { @@ -20,7 +22,7 @@ LockState& get_store(Lock& ctx) { BranchingSyncProtocolBase::~BranchingSyncProtocolBase() = default; std::ostream &BranchingSyncProtocolBase::print(std::ostream &os) const { - os << _global_store << std::endl; + os << _global_store; return os; } @@ -105,9 +107,9 @@ BranchingSyncProtocolBase::on_lock(ThreadContext &thread, Lock &lock) { conflict->obj, std::make_pair(conflict->timestamp_a, conflict->timestamp_b)); } - } - lock_state.commit = store.get_head(); + lock_state.commit = store.get_head(); + } return std::nullopt; } @@ -120,11 +122,8 @@ BranchingSyncProtocolBase::on_unlock(ThreadContext &thread, Lock &lock) { LockState& lock_state = get_store(lock); std::shared_ptr lock_commit = lock_state.commit; - // we don't need to check for conflicts - if (lock_commit != nullptr) { - std::optional conflict = store.merge_with_commit(lock_commit); - assert (!conflict); - } + // we know that the last committer was this thread, so no need to merge + // this sort of mixes protocol logic and lock state, i am unsure if this is ideal lock_state.commit = store.get_head(); @@ -146,6 +145,19 @@ std::string BranchingSyncProtocolBase::build_revision_graph_dot( return build_commit_graph_dot(heads); } +std::unique_ptr BranchingSyncProtocolBuilder::build() const { + switch (kind) { + case SyncKind::BranchingEager: + return std::make_unique(verbose); + + case SyncKind::BranchingLazy: + return std::make_unique(verbose); + + default: + throw std::runtime_error("Invalid sync kind for branching protocol"); + } +} + } // end branching } // end gitmem \ No newline at end of file diff --git a/src/branching/base_sync_protocol.hh b/src/branching/base_sync_protocol.hh index 8880fe9..c8417e3 100644 --- a/src/branching/base_sync_protocol.hh +++ b/src/branching/base_sync_protocol.hh @@ -12,6 +12,9 @@ namespace branching { class BranchingSyncProtocolBase : public SyncProtocol { protected: GlobalVersionStore _global_store; + bool verbose; + + explicit BranchingSyncProtocolBase(bool verbose) : verbose(verbose) {} public: ~BranchingSyncProtocolBase() override; @@ -47,6 +50,36 @@ public: } }; +// Builder for creating branching sync protocols +class BranchingSyncProtocolBuilder { +private: + SyncKind kind = SyncKind::BranchingLazy; + bool verbose = false; + +public: + BranchingSyncProtocolBuilder& with_kind(SyncKind k) { + kind = k; + return *this; + } + + BranchingSyncProtocolBuilder& eager() { + kind = SyncKind::BranchingEager; + return *this; + } + + BranchingSyncProtocolBuilder& lazy() { + kind = SyncKind::BranchingLazy; + return *this; + } + + BranchingSyncProtocolBuilder& with_verbose_commits(bool v = true) { + verbose = v; + return *this; + } + + std::unique_ptr build() const; +}; + } // end branching } // end gitmem \ No newline at end of file diff --git a/src/branching/base_version_store.cc b/src/branching/base_version_store.cc index ba715a4..51ff7fd 100644 --- a/src/branching/base_version_store.cc +++ b/src/branching/base_version_store.cc @@ -60,8 +60,8 @@ void LocalVersionStore::stage(std::string obj, Value value) { } void LocalVersionStore::commit_staging() { - // No-op commit does nothing - if (staging.empty()) { + // No-op commit does nothing unless verbose mode + if (staging.empty() && !verbose) { return; } @@ -130,7 +130,7 @@ bool LocalVersionStore::operator==(const LocalVersionStore& other) const { } std::ostream& operator<<(std::ostream& os, const GlobalVersionStore& store) { - os << "GlobalVersionStore()" << std::endl; + os << "GlobalVersionStore()"; return os; } diff --git a/src/branching/base_version_store.hh b/src/branching/base_version_store.hh index 36db0f8..18724c3 100644 --- a/src/branching/base_version_store.hh +++ b/src/branching/base_version_store.hh @@ -81,10 +81,12 @@ protected: std::unordered_map> last_writer; // cached + bool verbose; + public: ~LocalVersionStore() = default; - LocalVersionStore(ThreadID tid): base_timestamp(tid, 0) {} + LocalVersionStore(ThreadID tid, bool verbose = false): base_timestamp(tid, 0), verbose(verbose) {} void stage(std::string obj, Value value); void commit_staging(); diff --git a/src/branching/eager/sync_protocol.hh b/src/branching/eager/sync_protocol.hh index fa5b26e..b5fcbff 100644 --- a/src/branching/eager/sync_protocol.hh +++ b/src/branching/eager/sync_protocol.hh @@ -9,12 +9,15 @@ namespace branching { class BranchingEagerSyncProtocol final : public BranchingSyncProtocolBase { public: + explicit BranchingEagerSyncProtocol(bool verbose = false) + : BranchingSyncProtocolBase(verbose) {} + ~BranchingEagerSyncProtocol() = default; SyncKind kind() const override { return SyncKind::BranchingEager; }; std::unique_ptr make_thread_state(ThreadID tid) const override { - return std::make_unique(tid); + return std::make_unique(tid, verbose); } }; diff --git a/src/branching/eager/version_store.cc b/src/branching/eager/version_store.cc index 131f6d9..74cd9d8 100644 --- a/src/branching/eager/version_store.cc +++ b/src/branching/eager/version_store.cc @@ -92,7 +92,7 @@ std::optional EagerLocalVersionStore::merge_with_commit(const std::sha // Find lowest common ancestor of the two heads std::shared_ptr lca = find_lowest_common_ancestor(head, commit); - verbose << "found lca of " << head->id << " and " << commit->id << " to be " << lca->id << std::endl; + verbose::out << "found lca of " << head->id << " and " << commit->id << " to be " << lca->id << std::endl; // Collect all writes after LCA for each branch std::unordered_map> branch_a, branch_b; diff --git a/src/branching/eager/version_store.hh b/src/branching/eager/version_store.hh index 3376788..cc92be9 100644 --- a/src/branching/eager/version_store.hh +++ b/src/branching/eager/version_store.hh @@ -10,7 +10,7 @@ class EagerLocalVersionStore : public LocalVersionStore { public: ~EagerLocalVersionStore() = default; - EagerLocalVersionStore(ThreadID tid) : LocalVersionStore(tid) {} + EagerLocalVersionStore(ThreadID tid, bool verbose) : LocalVersionStore(tid, verbose) {} std::optional merge_with_commit(const std::shared_ptr&) override; BranchingReadResult get_committed(std::string var) const override; diff --git a/src/branching/lazy/sync_protocol.hh b/src/branching/lazy/sync_protocol.hh index d30a5c5..06f56c6 100644 --- a/src/branching/lazy/sync_protocol.hh +++ b/src/branching/lazy/sync_protocol.hh @@ -9,12 +9,15 @@ namespace branching { class BranchingLazySyncProtocol final : public BranchingSyncProtocolBase { public: + explicit BranchingLazySyncProtocol(bool verbose = false) + : BranchingSyncProtocolBase(verbose) {} + ~BranchingLazySyncProtocol() = default; SyncKind kind() const override { return SyncKind::BranchingLazy; }; std::unique_ptr make_thread_state(ThreadID tid) const override { - return std::make_unique(tid); + return std::make_unique(tid, verbose); } }; diff --git a/src/branching/lazy/version_store.cc b/src/branching/lazy/version_store.cc index 8d013a4..5bdc6df 100644 --- a/src/branching/lazy/version_store.cc +++ b/src/branching/lazy/version_store.cc @@ -7,86 +7,6 @@ namespace gitmem { namespace branching { -// std::optional traverse_to_lca( -// const std::shared_ptr& commit, -// ObjectNumber var, -// const std::shared_ptr& lca) -// { -// if (!commit || commit == lca) return std::nullopt; - -// auto it = commit->changes.find(var); -// if (it != commit->changes.end()) return it->second; - -// if (commit->parents.size() > 1) -// assert(false && "should never encounter multi-parent commit before LCA"); - -// return traverse_to_lca(commit->parents[0], var, lca); -// } - -// std::optional get_committed_recursive( -// const std::shared_ptr& commit, -// ObjectNumber var) { -// if (!commit) return std::nullopt; - -// // 1. If this commit wrote the variable, return it -// auto it = commit->changes.find(var); -// if (it != commit->changes.end()) return it->second; - -// // 2. If single parent, recurse -// if (commit->parents.size() == 1) -// return get_committed_recursive(commit->parents[0], var); - -// // 3. Merge commit -// assert(commit->parents.size() == 2); // merge commit - -// auto& p1 = commit->parents[0]; -// auto& p2 = commit->parents[1]; - -// auto lca = find_lowest_common_ancestor(p1, p2); - -// // Explore both paths from merge commit to LCA -// std::optional v1 = traverse_to_lca(p1, var, lca); -// std::optional v2 = traverse_to_lca(p2, var, lca); - -// assert(!v1 || !v2 || v1 == v2); // conflict-free invariant - -// if (v1) return v1; // found in one of the merge branches -// if (v2) return v2; - -// // 4. Not found yet → continue recursively from the LCA downward -// return get_committed_recursive(lca, var); -// } - -// std::optional> -// get_committed_recursive( -// const std::shared_ptr& commit, -// ObjectNumber number, -// std::unordered_set>& visited) { - -// if (!commit || !visited.insert(commit).second) -// return std::nullopt; - -// std::optional> found; - -// // Recurse into all parents first -// for (auto& parent : commit->parents) { -// auto parent_commit = get_committed_recursive(parent, number, visited); -// if (parent_commit) { -// if (!found.has_value()) -// found = parent_commit; -// else if (found.value()->changes.at(number) != parent_commit.value()->changes.at(number)) -// assert(false && "Conflict detected (should be impossible in conflict-free DAG)"); -// } -// } - -// // If this commit wrote the variable, it overrides any parent -// if (commit->changes.contains(number)) -// return commit; - -// return found; -// } - - std::optional LazyLocalVersionStore::merge_with_commit(const std::shared_ptr& commit) { assert(staging.empty()); assert(commit != nullptr); diff --git a/src/branching/lazy/version_store.hh b/src/branching/lazy/version_store.hh index 9254b86..13b881d 100644 --- a/src/branching/lazy/version_store.hh +++ b/src/branching/lazy/version_store.hh @@ -10,7 +10,7 @@ class LazyLocalVersionStore : public LocalVersionStore { public: ~LazyLocalVersionStore() = default; - LazyLocalVersionStore(ThreadID tid) : LocalVersionStore(tid) {} + LazyLocalVersionStore(ThreadID tid, bool verbose) : LocalVersionStore(tid, verbose) {} std::optional merge_with_commit(const std::shared_ptr&) override; BranchingReadResult get_committed(std::string var) const override; diff --git a/src/debug.hh b/src/debug.hh index e963461..78310f8 100644 --- a/src/debug.hh +++ b/src/debug.hh @@ -4,9 +4,12 @@ namespace gitmem { +namespace verbose { + /* For debug printing */ inline struct Verbose { bool enabled = false; + bool include_empty_commits = false; template const Verbose &operator<<(const T &msg) const { if (enabled) @@ -19,6 +22,8 @@ inline struct Verbose { std::cout << manip; return *this; } -} verbose; +} out; + +} // namespace verbose } // namespace gitmem \ No newline at end of file diff --git a/src/debugger.cc b/src/debugger.cc index 4e43b0f..6b1f835 100644 --- a/src/debugger.cc +++ b/src/debugger.cc @@ -141,9 +141,9 @@ void maybe_print_graph(Interpreter& interp, bool print_graphs, const std::filesystem::path &output_file) { if (print_graphs) { - // gctx.print_execution_graph(output_file); - interp.build_and_print_revision_graph(output_file); - verbose << "Execution graph written to " << output_file << std::endl; + interp.print_revision_graph(output_file); + interp.print_execution_graph(output_file); + verbose::out << "Execution graph written to " << output_file << std::endl; } } diff --git a/src/gitmem.cc b/src/gitmem.cc index bdffbfb..360c620 100644 --- a/src/gitmem.cc +++ b/src/gitmem.cc @@ -23,6 +23,10 @@ int main(int argc, char **argv) { app.add_flag("-v,--verbose", verbose, "Enable verbose output from the interpreter."); + bool include_empty_commits = false; + app.add_flag("--include-empty-commits", include_empty_commits, + "Include empty commits in branching protocol output."); + // TODO: These should probably be subcommands bool interactive = false; app.add_flag("-i,--interactive", interactive, @@ -51,9 +55,10 @@ int main(int argc, char **argv) { } try { - gitmem::verbose.enabled = verbose; + gitmem::verbose::out.enabled = verbose; + gitmem::verbose::out.include_empty_commits = include_empty_commits; - gitmem::verbose << "Reading file " << input_path << std::endl; + gitmem::verbose::out << "Reading file " << input_path << std::endl; if (!std::filesystem::exists(input_path)) { std::cerr << "Input file does not exist: " << input_path << std::endl; return 1; @@ -72,7 +77,7 @@ int main(int argc, char **argv) { if (output_path.empty()) output_path = input_path.stem().replace_extension(".dot"); - gitmem::verbose << "Output will be written to " << output_path << std::endl; + gitmem::verbose::out << "Output will be written to " << output_path << std::endl; int exit_status; wf::push_back(gitmem::lang::wf); @@ -85,7 +90,7 @@ int main(int argc, char **argv) { } wf::pop_front(); - gitmem::verbose << "Execution finished with exit status " << exit_status + gitmem::verbose::out << "Execution finished with exit status " << exit_status << std::endl; return exit_status; } catch (const std::exception &e) { diff --git a/src/graph.hh b/src/graph.hh index 2d8e4c9..1d94b0b 100644 --- a/src/graph.hh +++ b/src/graph.hh @@ -142,9 +142,10 @@ struct Pending : Node { }; struct ExecutionGraph { + std::shared_ptr entry; std::vector> threads; - ExecutionGraph() = default; + ExecutionGraph(std::shared_ptr entry) : entry(entry) {} ExecutionGraph(const ExecutionGraph&) = delete; ExecutionGraph& operator=(const ExecutionGraph&) = delete; diff --git a/src/graphviz.cc b/src/graphviz.cc index ee6a109..4768661 100644 --- a/src/graphviz.cc +++ b/src/graphviz.cc @@ -71,7 +71,7 @@ GraphvizPrinter::GraphvizPrinter(std::string filename) noexcept { void GraphvizPrinter::visit(const Node *n) { file << "digraph G {" << std::endl; - n->accept(this); + if (n) n->accept(this); file << "}" << std::endl; } diff --git a/src/interpreter.cc b/src/interpreter.cc index 2dc5b4c..0b87016 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -72,7 +72,7 @@ Interpreter::evaluate_expression(trieste::Node expr, Thread& thread) { return value; }, [&](std::shared_ptr& conflict) -> std::variant { - verbose << (*conflict) << std::endl; + verbose::out << (*conflict) << std::endl; return termination::DataRace(conflict); } }, result); @@ -130,7 +130,7 @@ std::variant Interpreter::run_statement(Node stmt, Threa auto s = stmt / lang::Stmt; if (s == lang::Nop) { - verbose << "Nop" << std::endl; + verbose::out << "Nop" << std::endl; } else if (s == lang::Jump) { @@ -164,7 +164,7 @@ std::variant Interpreter::run_statement(Node stmt, Threa if (lhs == lang::Reg) { // Local variables can be re-assigned whenever - verbose << "Set register '" << lhs->location().view() << "' to " << *val + verbose::out << "Set register '" << lhs->location().view() << "' to " << *val << std::endl; ctx.locals[var] = *val; @@ -178,7 +178,7 @@ std::variant Interpreter::run_statement(Node stmt, Threa // auto &global = ctx.globals[var]; // global.val = *val; // global.commit = gctx.uuid++; - // verbose << "Set global '" << lhs->location().view() << "' to " << + // verbose::out << "Set global '" << lhs->location().view() << "' to " << // *val << " with id " << *(global.commit) << std::endl; // gctx.commit_map[*(global.commit)] = node; @@ -207,7 +207,7 @@ std::variant Interpreter::run_statement(Node stmt, Threa auto result = gctx.cache[expr]; // Check if the thread ID is valid if (result >= gctx.threads.size()) { - verbose << "Join: invalid thread ID " << result + verbose::out << "Join: invalid thread ID " << result << ". gctx.threads.size()=" << gctx.threads.size() << std::endl; return termination::UnassignedRead(std::to_string(result)); } @@ -216,7 +216,7 @@ std::variant Interpreter::run_statement(Node stmt, Threa if (joinee.terminated && std::holds_alternative(*joinee.terminated)) { if (auto conflict = gctx.protocol->on_join(ctx, joinee.ctx)) { - verbose << (**conflict) << std::endl; + verbose::out << (**conflict) << std::endl; thread.trace.on_join(result, *conflict); return termination::DataRace(*conflict); } else { @@ -224,7 +224,7 @@ std::variant Interpreter::run_statement(Node stmt, Threa } } else { - verbose << "Waiting on thread " << result << std::endl; + verbose::out << "Waiting on thread " << result << std::endl; return 0; } } else if (s == lang::Lock) { @@ -236,7 +236,7 @@ std::variant Interpreter::run_statement(Node stmt, Threa Lock& lock = gctx.get_lock(var); if (lock.owner) { - verbose << "Waiting for lock " << var << " owned by " + verbose::out << "Waiting for lock " << var << " owned by " << lock.owner.value() << std::endl; return 0; } @@ -244,13 +244,13 @@ std::variant Interpreter::run_statement(Node stmt, Threa lock.owner = thread.tid; if (auto conflict = gctx.protocol->on_lock(ctx, lock)) { - verbose << (**conflict) << std::endl; + verbose::out << (**conflict) << std::endl; thread.trace.on_lock(var, lock.last_unlock_event, *conflict); return termination::DataRace(*conflict); } thread.trace.on_lock(var, lock.last_unlock_event); - verbose << "Locked " << var << std::endl; + verbose::out << "Locked " << var << std::endl; } else if (s == lang::Unlock) { // We can only unlock locks we previously locked. We commit any @@ -268,7 +268,7 @@ std::variant Interpreter::run_statement(Node stmt, Threa } if (auto conflict = gctx.protocol->on_unlock(ctx, lock)) { - verbose << (**conflict) << std::endl; + verbose::out << (**conflict) << std::endl; thread.trace.on_unlock(var, *conflict); return termination::DataRace(*conflict); } @@ -278,7 +278,7 @@ std::variant Interpreter::run_statement(Node stmt, Threa lock.last_unlock_event = thread.trace.on_unlock(var); - verbose << "Unlocked " << var << std::endl; + verbose::out << "Unlocked " << var << std::endl; } else if (s == lang::Assert) { @@ -286,10 +286,10 @@ std::variant Interpreter::run_statement(Node stmt, Threa auto result_or_term = evaluate_expression(expr, thread); if (size_t *result = std::get_if(&result_or_term)) { if (*result) { - verbose << "Assertion passed: " << expr->location().view() << std::endl; + verbose::out << "Assertion passed: " << expr->location().view() << std::endl; thread.trace.on_assert_pass(std::string(expr->location().view())); } else { - verbose << "Assertion failed: " << expr->location().view() << std::endl; + verbose::out << "Assertion failed: " << expr->location().view() << std::endl; thread.trace.on_assert_fail(std::string(expr->location().view())); return termination::AssertionFailure(std::string(expr->location().view())); } @@ -356,7 +356,7 @@ Interpreter::run_single_thread_to_sync(Thread& thread) { // Otherwise, we truly reached the end this iteration if (auto conflict = gctx.protocol->on_end(ctx)) { - verbose << (**conflict) << std::endl; + verbose::out << (**conflict) << std::endl; TerminationStatus term = termination::DataRace(*conflict); thread.terminated = term; return term; @@ -385,7 +385,7 @@ Interpreter::progress_thread(Thread& thread) { any_progress = true; auto& new_thread = gctx.threads[i]; if (!is_syncing(new_thread)) { - verbose << "==== Thread " << i << " (spawn) ====" << std::endl; + verbose::out << "==== Thread " << i << " (spawn) ====" << std::endl; progress_thread(new_thread); } } @@ -400,11 +400,11 @@ Interpreter::progress_thread(Thread& thread) { */ std::variant Interpreter::run_threads_to_sync() { - verbose << "-----------------------" << std::endl; + verbose::out << "-----------------------" << std::endl; bool all_completed = true; ProgressStatus any_progress = ProgressStatus::no_progress; for (size_t i = 0; i < gctx.threads.size(); ++i) { - verbose << "==== t" << i << " ====" << std::endl; + verbose::out << "==== t" << i << " ====" << std::endl; auto& thread = gctx.threads[i]; if (!thread.terminated) { auto prog_or_term = run_single_thread_to_sync(thread); @@ -447,24 +447,24 @@ int Interpreter::run() { prog_or_term = run_threads_to_sync(); } while (!is_finished(prog_or_term)); - verbose << "----------- execution complete -----------" << std::endl; + verbose::out << "----------- execution complete -----------" << std::endl; bool exception_detected = false; for (size_t i = 0; i < gctx.threads.size(); ++i) { auto &thread = gctx.threads[i]; if (thread.terminated) { - verbose << "Thread " << i << ": "; + verbose::out << "Thread " << i << ": "; std::visit( overloaded{ [&](const termination::Completed &t) { - verbose << t << std::endl; + verbose::out << t << std::endl; }, [&](const auto &t) { // Any non-completed termination is exceptional - verbose << t << std::endl; + verbose::out << t << std::endl; exception_detected = true; } }, @@ -473,11 +473,13 @@ int Interpreter::run() { } else { exception_detected = true; thread.trace.on_end(); - verbose << "Thread " << i << " is stuck" << std::endl; + verbose::out << "Thread " << i << " is stuck" << std::endl; } } - verbose << *gctx.protocol << std::endl; + verbose::out << "------------------------------------------" << std::endl; + + print_thread_traces(); return exception_detected ? 1 : 0; } @@ -485,13 +487,13 @@ int Interpreter::run() { void Interpreter::print_thread_traces() { for (size_t tid = 0; tid < gctx.threads.size(); ++tid) { const auto& thread = gctx.threads[tid]; - verbose << "=== Thread " << tid << " ===" << std::endl; - verbose << thread.trace; - verbose << "====================================\n"; + verbose::out << "=== Thread " << tid << " ===" << std::endl; + verbose::out << thread.trace; + verbose::out << "====================================\n"; } } -void Interpreter::build_and_print_revision_graph(const std::filesystem::path& output_path) { +void Interpreter::print_revision_graph(const std::filesystem::path& output_path) { // Build revision graph - collect const raw pointers std::vector thread_state_ptrs; for (const auto& thread : gctx.threads) { @@ -501,98 +503,145 @@ void Interpreter::build_and_print_revision_graph(const std::filesystem::path& ou std::string dot = gctx.protocol->build_revision_graph_dot(thread_state_ptrs); if (!dot.empty()) { // Write to file - auto dot_file = output_path.parent_path() / "revision_graph.dot"; + auto dot_file = output_path.parent_path() / (output_path.stem().string() + "_revision_graph.dot"); std::ofstream out(dot_file); if (out) { out << dot; - verbose << "Revision graph written to " << dot_file << std::endl; + verbose::out << "Revision graph written to " << dot_file << std::endl; } else { - verbose << "Failed to write revision graph to " << dot_file << std::endl; + verbose::out << "Failed to write revision graph to " << dot_file << std::endl; } } } graph::ExecutionGraph Interpreter::build_execution_graph_from_traces() { - // Map each thread ID to its last graph node in the execution graph + // Map each thread ID to its last graph node in the execution graph (per-thread program order) std::unordered_map> thread_tails; - // create start nodes for all threads - graph::ExecutionGraph g; + // Track the last unlock event for each lock (for lock->unlock edges) + std::unordered_map> last_unlock_per_lock; + + // Track write events per variable (for read->write edges) + std::unordered_map> last_write_per_var; + + // Track join nodes that need fixing up after all threads are processed + std::vector> joins_to_fix; + + // Create Start nodes for all threads + std::vector> thread_starts; + thread_starts.reserve(gctx.threads.size()); + for (ThreadID tid = 0; tid < gctx.threads.size(); ++tid) { auto node = std::make_shared(tid); - g.threads.push_back(node); + thread_starts.push_back(node); thread_tails[tid] = node; } - assert(false && "This is currently not building the current graph"); - - // Helper lambda to convert an Event to a graph Node - auto event_to_node = [&](ThreadID tid, const std::shared_ptr& e) -> std::shared_ptr { - return std::visit([&](auto&& arg) -> std::shared_ptr { - using T = std::decay_t; - std::shared_ptr node; - - if constexpr (std::is_same_v) { - node = std::make_shared(tid); - } else if constexpr (std::is_same_v) { - node = std::make_shared(); - } else if constexpr (std::is_same_v) { - node = std::make_shared(arg.var, arg.value, tid); - } else if constexpr (std::is_same_v) { - // We could track dependencies from previous writes if desired - node = std::make_shared(arg.var, arg.value, tid, thread_tails[tid]); - } else if constexpr (std::is_same_v) { - ThreadID child_tid = arg.child_tid; - node = std::make_shared(child_tid, g.threads[child_tid]); - } else if constexpr (std::is_same_v) { - node = std::make_shared(arg.joinee_tid, thread_tails[arg.joinee_tid]); - } else if constexpr (std::is_same_v) { - node = std::make_shared(arg.lock_name, thread_tails[tid]); - } else if constexpr (std::is_same_v) { - node = std::make_shared(arg.lock_name); - } else if constexpr (std::is_same_v) { - node = std::make_shared(arg.condition); - } else { - throw std::logic_error("Unknown Event type in trace"); - } - - // Link the previous tail of this thread to this new node - if (thread_tails[tid]) - thread_tails[tid]->next = node; + // The entry point is thread 0's start + graph::ExecutionGraph g(thread_starts[0]); + g.threads = std::move(thread_starts); - thread_tails[tid] = node; - return node; - }, e->data); + // Helper to link a node in program order for its thread + auto link_in_program_order = [&](ThreadID tid, std::shared_ptr node) { + if (thread_tails[tid]) { + thread_tails[tid]->next = node; + } + thread_tails[tid] = node; }; - // Iterate over threads in thread ID order + // Process events from all threads for (ThreadID tid = 0; tid < gctx.threads.size(); ++tid) { auto& thread = gctx.threads[tid]; - // skip the first event because it is always start and we created that to begin with - for (auto it = std::next(thread.trace.begin()); it != thread.trace.end(); ++it) { - event_to_node(tid, *it); + // Process all events from the trace + for (const auto& event : thread.trace) { + if (!event) { + // Safety check: skip null events (shouldn't happen) + continue; + } + + std::visit(overloaded{ + [&](const StartEvent&) { + // Skip: Start nodes already created above + }, + [&](const EndEvent&) { + auto node = std::make_shared(); + link_in_program_order(tid, node); + }, + [&](const WriteEvent& arg) { + auto node = std::make_shared(arg.var, arg.value, tid); + last_write_per_var[arg.var] = node; + link_in_program_order(tid, node); + }, + [&](const ReadEvent& arg) { + // Link to the write that produced this value + auto source = last_write_per_var.contains(arg.var) + ? last_write_per_var[arg.var] + : nullptr; + auto node = std::make_shared(arg.var, arg.value, tid, source); + link_in_program_order(tid, node); + }, + [&](const SpawnEvent& arg) { + // Link to the child thread's start node + auto node = std::make_shared(arg.child_tid, g.threads[arg.child_tid]); + link_in_program_order(tid, node); + }, + [&](const JoinEvent& arg) { + // Create join node with nullptr joinee for now - will fix up later + auto node = std::make_shared(arg.joinee_tid, nullptr); + joins_to_fix.push_back(node); + link_in_program_order(tid, node); + }, + [&](const LockEvent& arg) { + // Link to the last unlock of this lock + auto ordered_after = last_unlock_per_lock.contains(arg.lock_name) + ? last_unlock_per_lock[arg.lock_name] + : nullptr; + auto node = std::make_shared(arg.lock_name, ordered_after); + link_in_program_order(tid, node); + }, + [&](const UnlockEvent& arg) { + auto node = std::make_shared(arg.lock_name); + last_unlock_per_lock[arg.lock_name] = node; + link_in_program_order(tid, node); + }, + [&](const AssertEvent& arg) { + auto node = std::make_shared(arg.condition); + link_in_program_order(tid, node); + } + }, event->data); } - if (!thread.terminated) { + + // Add pending node if thread hasn't terminated and PC is valid + if (!thread.terminated && thread.pc < thread.block->size()) { trieste::Node stmt = thread.block->at(thread.pc); - thread_tails[tid]->next = std::make_shared(std::string(stmt->location().view())); + auto pending = std::make_shared(std::string(stmt->location().view())); + link_in_program_order(tid, pending); } } + // Fix up join nodes to point to the actual end of the joined threads + for (auto& join_node : joins_to_fix) { + ThreadID joinee_tid = join_node->tid; + // thread_tails[joinee_tid] now points to the end (or pending) of that thread + const_cast&>(join_node->joinee) = thread_tails[joinee_tid]; + } + return g; } +void Interpreter::print_execution_graph(const std::filesystem::path& output_path) { + auto exec_graph = build_execution_graph_from_traces(); + graph::GraphvizPrinter gv(output_path); + gv.visit(exec_graph.entry.get()); +} + + int interpret(const Node ast, const std::filesystem::path &output_path, SyncKind sync_kind) { Interpreter interp(GlobalContext(ast, make_protocol(sync_kind))); int result = interp.run(); - interp.print_thread_traces(); - - interp.build_and_print_revision_graph(output_path); - - // auto exec_graph = interp.build_execution_graph_from_traces(); - // graph::GraphvizPrinter gv(output_path); - // gv.visit(node.get()); + interp.print_revision_graph(output_path); return result; } diff --git a/src/interpreter.hh b/src/interpreter.hh index 9428c04..57cc74f 100644 --- a/src/interpreter.hh +++ b/src/interpreter.hh @@ -30,6 +30,8 @@ class Interpreter { private: GlobalContext gctx; + graph::ExecutionGraph build_execution_graph_from_traces(); + public: Interpreter(GlobalContext gctx): gctx(std::move(gctx)) {} @@ -46,10 +48,8 @@ public: void print_thread_traces(); - void build_and_print_revision_graph(const std::filesystem::path& output_path); - - graph::ExecutionGraph build_execution_graph_from_traces(); - + void print_revision_graph(const std::filesystem::path& output_path); + void print_execution_graph(const std::filesystem::path& output_path); }; // Entry function diff --git a/src/model_checker.cc b/src/model_checker.cc index c26166a..4cff887 100644 --- a/src/model_checker.cc +++ b/src/model_checker.cc @@ -70,7 +70,7 @@ int model_check(const Node ast, const std::filesystem::path &output_path, SyncKi const auto root = std::make_shared(0); auto cursor = root; auto current_trace = std::vector{0}; // Start with the main thread - verbose << "==== Thread " << cursor->tid_ << " ====" << std::endl; + verbose::out << "==== Thread " << cursor->tid_ << " ====" << std::endl; Interpreter interp(GlobalContext(ast, make_protocol(sync_kind))); @@ -82,7 +82,7 @@ int model_check(const Node ast, const std::filesystem::path &output_path, SyncKi // We have a child that is not complete, we can extend that trace cursor = cursor->children.back(); current_trace.push_back(cursor->tid_); - verbose << "==== Thread " << cursor->tid_ + verbose::out << "==== Thread " << cursor->tid_ << " (replay) ====" << std::endl; interp.progress_thread(gctx.threads[cursor->tid_]); } @@ -96,7 +96,7 @@ int model_check(const Node ast, const std::filesystem::path &output_path, SyncKi auto& thread = gctx.threads[i]; if (!thread.terminated) { // Run the thread to the next sync point - verbose << "==== Thread " << i << " ====" << std::endl; + verbose::out << "==== Thread " << i << " ====" << std::endl; auto prog_or_term = interp.progress_thread(thread); if (is_terminated(prog_or_term)) { // Thread terminated, we can extend the trace @@ -105,7 +105,7 @@ int model_check(const Node ast, const std::filesystem::path &output_path, SyncKi current_trace.push_back(i); if (!std::holds_alternative(std::get(prog_or_term))) { // Thread terminated with an error, we can stop here - verbose << "Thread " << i << " terminated with an error" + verbose::out << "Thread " << i << " terminated with an error" << std::endl; cursor->complete = true; } @@ -160,22 +160,22 @@ int model_check(const Node ast, const std::filesystem::path &output_path, SyncKi if (cursor->complete && !root->complete) { // Reset the cursor to the root and start a new trace - verbose << std::endl << "Restarting trace..." << std::endl; + verbose::out << std::endl << "Restarting trace..." << std::endl; interp = Interpreter(GlobalContext(ast, make_protocol(sync_kind))); GlobalContext& gctx = interp.context(); cursor = root; current_trace.clear(); current_trace.push_back(0); // Start with the main thread again - verbose << "==== Thread " << cursor->tid_ + verbose::out << "==== Thread " << cursor->tid_ << " (replay) ====" << std::endl; interp.progress_thread(gctx.threads[cursor->tid_]); } } - verbose << "Found a total of " << final_traces.size() + verbose::out << "Found a total of " << final_traces.size() << " trace(s) with distinct final states:" << std::endl; - print_traces(verbose, final_traces); + print_traces(verbose::out, final_traces); size_t idx = 0; if (!failing_traces.empty()) { @@ -188,9 +188,9 @@ int model_check(const Node ast, const std::filesystem::path &output_path, SyncKi for (size_t tid = 0; tid < ctx->threads.size(); ++tid) { const auto& thread = ctx->threads[tid]; - verbose << "=== Thread " << tid << " ===" << std::endl; - verbose << thread.trace; - verbose << "====================================\n"; + verbose::out << "=== Thread " << tid << " ===" << std::endl; + verbose::out << thread.trace; + verbose::out << "====================================\n"; } // ctx->print_execution_graph(path); } diff --git a/src/sync_protocol.cc b/src/sync_protocol.cc index d23716f..a854ce0 100644 --- a/src/sync_protocol.cc +++ b/src/sync_protocol.cc @@ -1,15 +1,30 @@ #include "sync_protocol.hh" #include "linear/sync_protocol.hh" -#include "branching/eager/sync_protocol.hh" -#include "branching/lazy/sync_protocol.hh" +#include "branching/base_sync_protocol.hh" +#include "debug.hh" namespace gitmem { std::unique_ptr make_protocol(SyncKind sync_kind) { switch (sync_kind) { - case SyncKind::Linear: return std::make_unique(); - case SyncKind::BranchingEager: return std::make_unique(); - case SyncKind::BranchingLazy: return std::make_unique(); + case SyncKind::Linear: + return std::make_unique(); + + case SyncKind::BranchingEager: { + auto builder = branching::BranchingSyncProtocolBuilder().eager(); + if (verbose::out.include_empty_commits) { + builder.with_verbose_commits(); + } + return builder.build(); + } + + case SyncKind::BranchingLazy: { + auto builder = branching::BranchingSyncProtocolBuilder().lazy(); + if (verbose::out.include_empty_commits) { + builder.with_verbose_commits(); + } + return builder.build(); + } } std::unreachable(); } From 15f63eac5e2eb50b01f8b2ee093ccf3633e298d4 Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Sun, 11 Jan 2026 20:29:19 +0100 Subject: [PATCH 46/58] making pending ends clearer --- src/interpreter.cc | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/interpreter.cc b/src/interpreter.cc index 0b87016..d899307 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -612,11 +612,18 @@ graph::ExecutionGraph Interpreter::build_execution_graph_from_traces() { }, event->data); } - // Add pending node if thread hasn't terminated and PC is valid - if (!thread.terminated && thread.pc < thread.block->size()) { - trieste::Node stmt = thread.block->at(thread.pc); - auto pending = std::make_shared(std::string(stmt->location().view())); - link_in_program_order(tid, pending); + // Add pending node if thread hasn't terminated + if (!thread.terminated) { + if (thread.pc < thread.block->size()) { + // Thread is stuck waiting at a specific statement + trieste::Node stmt = thread.block->at(thread.pc); + auto pending = std::make_shared(std::string(stmt->location().view())); + link_in_program_order(tid, pending); + } else { + // Thread has finished all statements but hasn't terminated yet + auto pending = std::make_shared("..."); + link_in_program_order(tid, pending); + } } } From ebbae5b9df3e92bc0ae9aa7d67071c35f8a9a797 Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Sun, 11 Jan 2026 20:46:23 +0100 Subject: [PATCH 47/58] optional richer revisions graphs, graph printing seems to be working and interactive graphing too --- src/graph.hh | 7 +++++++ src/graphviz.cc | 10 ++++++++-- src/interpreter.cc | 49 +++++++++++++++++++++++++++++++++++++++------- 3 files changed, 57 insertions(+), 9 deletions(-) diff --git a/src/graph.hh b/src/graph.hh index 1d94b0b..37226c1 100644 --- a/src/graph.hh +++ b/src/graph.hh @@ -29,7 +29,14 @@ struct Pending; struct Conflict { std::string var; + // Optional: if we can determine the conflicting nodes, store them here + // Otherwise these can be nullptr and we just mark the node red std::pair, std::shared_ptr> sources; + + // Constructor that allows creating conflicts without sources + Conflict(std::string v) : var(std::move(v)), sources{nullptr, nullptr} {} + Conflict(std::string v, std::pair, std::shared_ptr> s) + : var(std::move(v)), sources(std::move(s)) {} }; struct Visitor { diff --git a/src/graphviz.cc b/src/graphviz.cc index 4768661..217bcb6 100644 --- a/src/graphviz.cc +++ b/src/graphviz.cc @@ -60,9 +60,15 @@ void GraphvizPrinter::emitShape(const Node *n, const std::string &shape) { void GraphvizPrinter::emitConflict(const Node *n, const Conflict &conflict) { emitFillColor(n, "red"); // emitShape(n, "doubleoctagon"); + + // Only draw conflict edges if we have actual source nodes auto [s1, s2] = conflict.sources; - emitConflictEdge(n, s1.get()); - emitConflictEdge(n, s2.get()); + if (s1) { + emitConflictEdge(n, s1.get()); + } + if (s2) { + emitConflictEdge(n, s2.get()); + } } GraphvizPrinter::GraphvizPrinter(std::string filename) noexcept { diff --git a/src/interpreter.cc b/src/interpreter.cc index d899307..156fcf8 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -527,6 +527,9 @@ graph::ExecutionGraph Interpreter::build_execution_graph_from_traces() { // Track join nodes that need fixing up after all threads are processed std::vector> joins_to_fix; + // Map from trace events to graph nodes + std::unordered_map, std::shared_ptr> event_to_node; + // Create Start nodes for all threads std::vector> thread_starts; thread_starts.reserve(gctx.threads.size()); @@ -567,11 +570,13 @@ graph::ExecutionGraph Interpreter::build_execution_graph_from_traces() { [&](const EndEvent&) { auto node = std::make_shared(); link_in_program_order(tid, node); + event_to_node[event] = node; }, [&](const WriteEvent& arg) { auto node = std::make_shared(arg.var, arg.value, tid); last_write_per_var[arg.var] = node; link_in_program_order(tid, node); + event_to_node[event] = node; }, [&](const ReadEvent& arg) { // Link to the write that produced this value @@ -580,34 +585,64 @@ graph::ExecutionGraph Interpreter::build_execution_graph_from_traces() { : nullptr; auto node = std::make_shared(arg.var, arg.value, tid, source); link_in_program_order(tid, node); + event_to_node[event] = node; + + // Mark conflict if present (full conflict details would require more work) + if (arg.maybe_conflict) { + // For now, just mark the node red - proper conflict edges would need + // extracting source nodes from ConflictBase + // TODO: Extract and visualize conflict sources + } }, [&](const SpawnEvent& arg) { // Link to the child thread's start node auto node = std::make_shared(arg.child_tid, g.threads[arg.child_tid]); link_in_program_order(tid, node); + event_to_node[event] = node; }, [&](const JoinEvent& arg) { - // Create join node with nullptr joinee for now - will fix up later - auto node = std::make_shared(arg.joinee_tid, nullptr); + // Create join node - will fix up joinee pointer later + std::optional conflict; + if (arg.maybe_conflict) { + // Just mark as conflicting - version IDs don't map directly to nodes + conflict = graph::Conflict(""); // empty var name for joins + } + auto node = std::make_shared(arg.joinee_tid, nullptr, conflict); joins_to_fix.push_back(node); link_in_program_order(tid, node); + event_to_node[event] = node; }, [&](const LockEvent& arg) { - // Link to the last unlock of this lock - auto ordered_after = last_unlock_per_lock.contains(arg.lock_name) - ? last_unlock_per_lock[arg.lock_name] - : nullptr; - auto node = std::make_shared(arg.lock_name, ordered_after); + // Link to the last unlock event using the event-to-node mapping + std::shared_ptr ordered_after = nullptr; + if (arg.last_unlock_event && event_to_node.contains(arg.last_unlock_event)) { + ordered_after = event_to_node[arg.last_unlock_event]; + } + + std::optional conflict; + if (arg.maybe_conflict) { + // Mark as conflicting with the lock name + conflict = graph::Conflict(arg.lock_name); + } + auto node = std::make_shared(arg.lock_name, ordered_after, conflict); link_in_program_order(tid, node); + event_to_node[event] = node; }, [&](const UnlockEvent& arg) { auto node = std::make_shared(arg.lock_name); last_unlock_per_lock[arg.lock_name] = node; link_in_program_order(tid, node); + event_to_node[event] = node; + + // Mark conflict if present + if (arg.maybe_conflict) { + // TODO: Visualize unlock conflicts + } }, [&](const AssertEvent& arg) { auto node = std::make_shared(arg.condition); link_in_program_order(tid, node); + event_to_node[event] = node; } }, event->data); } From fca0560b6978825135a952c849b0efa02a715778 Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Mon, 12 Jan 2026 10:46:02 +0100 Subject: [PATCH 48/58] infrastructure to have protocol builders rather than synckind everywhere --- CMakeLists.txt | 1 - src/branching/base_sync_protocol.cc | 13 +++------ src/branching/base_sync_protocol.hh | 22 +++++++-------- src/branching/eager/sync_protocol.hh | 6 +++-- src/branching/lazy/sync_protocol.hh | 6 +++-- src/debug.hh | 1 - src/debugger.cc | 9 +++---- src/debugger.hh | 2 +- src/execution_state.hh | 1 - src/gitmem.cc | 40 +++++++++++++++++----------- src/graph.hh | 19 ++++++++++--- src/graphviz.cc | 24 ++++++++++++++--- src/interpreter.cc | 26 ++++++++++-------- src/interpreter.hh | 3 ++- src/linear/sync_protocol.hh | 13 +++++++-- src/model_checker.cc | 10 ++++--- src/model_checker.hh | 2 +- src/sync_kind.hh | 11 -------- src/sync_protocol.cc | 32 ---------------------- src/sync_protocol.hh | 11 ++++---- src/thread_trace.hh | 22 +++++++++------ 21 files changed, 143 insertions(+), 131 deletions(-) delete mode 100644 src/sync_kind.hh delete mode 100644 src/sync_protocol.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 4a417ac..9056cf5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -34,7 +34,6 @@ add_executable(gitmem src/interpreter.cc src/debugger.cc src/model_checker.cc - src/sync_protocol.cc src/graphviz.cc ) diff --git a/src/branching/base_sync_protocol.cc b/src/branching/base_sync_protocol.cc index 4d8fa41..ed83349 100644 --- a/src/branching/base_sync_protocol.cc +++ b/src/branching/base_sync_protocol.cc @@ -146,15 +146,10 @@ std::string BranchingSyncProtocolBase::build_revision_graph_dot( } std::unique_ptr BranchingSyncProtocolBuilder::build() const { - switch (kind) { - case SyncKind::BranchingEager: - return std::make_unique(verbose); - - case SyncKind::BranchingLazy: - return std::make_unique(verbose); - - default: - throw std::runtime_error("Invalid sync kind for branching protocol"); + if (eager_mode) { + return std::make_unique(verbose_commits); + } else { + return std::make_unique(verbose_commits); } } diff --git a/src/branching/base_sync_protocol.hh b/src/branching/base_sync_protocol.hh index c8417e3..d752cf8 100644 --- a/src/branching/base_sync_protocol.hh +++ b/src/branching/base_sync_protocol.hh @@ -12,13 +12,16 @@ namespace branching { class BranchingSyncProtocolBase : public SyncProtocol { protected: GlobalVersionStore _global_store; - bool verbose; + bool verbose_commits; - explicit BranchingSyncProtocolBase(bool verbose) : verbose(verbose) {} + explicit BranchingSyncProtocolBase(bool verbose_commits) + : verbose_commits(verbose_commits) {} public: ~BranchingSyncProtocolBase() override; + std::unique_ptr clone() const override = 0; + ReadResult read(ThreadContext &ctx, const std::string &var) override; void write(ThreadContext &ctx, const std::string &var, size_t value) override; @@ -53,27 +56,22 @@ public: // Builder for creating branching sync protocols class BranchingSyncProtocolBuilder { private: - SyncKind kind = SyncKind::BranchingLazy; - bool verbose = false; + bool eager_mode = false; + bool verbose_commits = false; public: - BranchingSyncProtocolBuilder& with_kind(SyncKind k) { - kind = k; - return *this; - } - BranchingSyncProtocolBuilder& eager() { - kind = SyncKind::BranchingEager; + eager_mode = true; return *this; } BranchingSyncProtocolBuilder& lazy() { - kind = SyncKind::BranchingLazy; + eager_mode = false; return *this; } BranchingSyncProtocolBuilder& with_verbose_commits(bool v = true) { - verbose = v; + verbose_commits = v; return *this; } diff --git a/src/branching/eager/sync_protocol.hh b/src/branching/eager/sync_protocol.hh index b5fcbff..adcae17 100644 --- a/src/branching/eager/sync_protocol.hh +++ b/src/branching/eager/sync_protocol.hh @@ -14,10 +14,12 @@ public: ~BranchingEagerSyncProtocol() = default; - SyncKind kind() const override { return SyncKind::BranchingEager; }; + std::unique_ptr clone() const override { + return std::make_unique(verbose_commits); + } std::unique_ptr make_thread_state(ThreadID tid) const override { - return std::make_unique(tid, verbose); + return std::make_unique(tid, verbose_commits); } }; diff --git a/src/branching/lazy/sync_protocol.hh b/src/branching/lazy/sync_protocol.hh index 06f56c6..796e437 100644 --- a/src/branching/lazy/sync_protocol.hh +++ b/src/branching/lazy/sync_protocol.hh @@ -14,10 +14,12 @@ public: ~BranchingLazySyncProtocol() = default; - SyncKind kind() const override { return SyncKind::BranchingLazy; }; + std::unique_ptr clone() const override { + return std::make_unique(verbose_commits); + } std::unique_ptr make_thread_state(ThreadID tid) const override { - return std::make_unique(tid, verbose); + return std::make_unique(tid, verbose_commits); } }; diff --git a/src/debug.hh b/src/debug.hh index 78310f8..19dc66e 100644 --- a/src/debug.hh +++ b/src/debug.hh @@ -9,7 +9,6 @@ namespace verbose { /* For debug printing */ inline struct Verbose { bool enabled = false; - bool include_empty_commits = false; template const Verbose &operator<<(const T &msg) const { if (enabled) diff --git a/src/debugger.cc b/src/debugger.cc index 6b1f835..0bc6870 100644 --- a/src/debugger.cc +++ b/src/debugger.cc @@ -175,10 +175,9 @@ StepUIResult do_step(Interpreter &interp, /** Reset the interpreter to a fresh state */ void do_restart(Interpreter &interp, const trieste::Node ast, - SyncKind sync_kind, bool print_graphs, const std::filesystem::path &output_file) { - interp = Interpreter(GlobalContext(ast, make_protocol(sync_kind))); + interp = Interpreter(GlobalContext(ast, interp.context().protocol->clone())); maybe_print_graph(interp, print_graphs, output_file); } @@ -214,8 +213,8 @@ void print_help() { /** Main interactive interpreter loop */ int interpret_interactive(const trieste::Node ast, const std::filesystem::path &output_file, - SyncKind sync_kind) { - Interpreter interp(GlobalContext(ast, make_protocol(sync_kind))); + std::unique_ptr protocol) { + Interpreter interp(GlobalContext(ast, std::move(protocol))); GlobalContext &gctx = interp.context(); size_t prev_no_threads = 1; @@ -253,7 +252,7 @@ int interpret_interactive(const trieste::Node ast, break; case Command::Restart: - do_restart(interp, ast, sync_kind, print_graphs, output_file); + do_restart(interp, ast, print_graphs, output_file); command = {Command::List}; break; diff --git a/src/debugger.hh b/src/debugger.hh index c891fad..90fbe28 100644 --- a/src/debugger.hh +++ b/src/debugger.hh @@ -6,5 +6,5 @@ namespace gitmem { int interpret_interactive(const trieste::Node, const std::filesystem::path &output_file, - SyncKind sync_kind); + std::unique_ptr protocol); } \ No newline at end of file diff --git a/src/execution_state.hh b/src/execution_state.hh index 550f627..bcae76d 100644 --- a/src/execution_state.hh +++ b/src/execution_state.hh @@ -6,7 +6,6 @@ #include #include "lang.hh" -#include "sync_kind.hh" #include "sync_state.hh" #include "graphviz.hh" #include "termination_status.hh" diff --git a/src/gitmem.cc b/src/gitmem.cc index 360c620..37568fe 100644 --- a/src/gitmem.cc +++ b/src/gitmem.cc @@ -5,7 +5,8 @@ #include "model_checker.hh" #include "debugger.hh" #include "lang.hh" -#include "sync_kind.hh" +#include "linear/sync_protocol.hh" +#include "branching/base_sync_protocol.hh" int main(int argc, char **argv) { using namespace trieste; @@ -27,7 +28,6 @@ int main(int argc, char **argv) { app.add_flag("--include-empty-commits", include_empty_commits, "Include empty commits in branching protocol output."); - // TODO: These should probably be subcommands bool interactive = false; app.add_flag("-i,--interactive", interactive, "Enable interactive scheduling mode (use command ? for help)."); @@ -36,16 +36,9 @@ int main(int argc, char **argv) { app.add_flag("-e,--explore", model_check, "Explore all possible execution paths."); - auto string_to_sync = CLI::Transformer(std::map { - {"linear", gitmem::SyncKind::Linear}, - {"branching-eager", gitmem::SyncKind::BranchingEager}, - {"branching-lazy", gitmem::SyncKind::BranchingLazy} - }); - string_to_sync.description("linear,branching-eager,branching-lazy"); - - gitmem::SyncKind sync_kind = gitmem::SyncKind::Linear; - app.add_option("--sync", sync_kind, "Select a sync protocol for execution (default: linear)") - ->transform(string_to_sync) + std::string sync_protocol = "linear"; + app.add_option("--sync", sync_protocol, "Select a sync protocol for execution (default: linear)") + ->check(CLI::IsMember({"linear", "branching-eager", "branching-lazy"})) ->type_name("SYNC_KIND"); try { @@ -56,7 +49,6 @@ int main(int argc, char **argv) { try { gitmem::verbose::out.enabled = verbose; - gitmem::verbose::out.include_empty_commits = include_empty_commits; gitmem::verbose::out << "Reading file " << input_path << std::endl; if (!std::filesystem::exists(input_path)) { @@ -79,14 +71,30 @@ int main(int argc, char **argv) { gitmem::verbose::out << "Output will be written to " << output_path << std::endl; + // Build protocol based on command line options + std::unique_ptr protocol; + if (sync_protocol == "linear") { + protocol = gitmem::linear::LinearSyncProtocolBuilder().build(); + } else if (sync_protocol == "branching-eager") { + protocol = gitmem::branching::BranchingSyncProtocolBuilder() + .eager() + .with_verbose_commits(include_empty_commits) + .build(); + } else if (sync_protocol == "branching-lazy") { + protocol = gitmem::branching::BranchingSyncProtocolBuilder() + .lazy() + .with_verbose_commits(include_empty_commits) + .build(); + } + int exit_status; wf::push_back(gitmem::lang::wf); if (model_check) { - exit_status = gitmem::model_check(result.ast, output_path, sync_kind); + exit_status = gitmem::model_check(result.ast, output_path, std::move(protocol)); } else if (interactive) { - exit_status = gitmem::interpret_interactive(result.ast, output_path, sync_kind); + exit_status = gitmem::interpret_interactive(result.ast, output_path, std::move(protocol)); } else { - exit_status = gitmem::interpret(result.ast, output_path, sync_kind); + exit_status = gitmem::interpret(result.ast, output_path, std::move(protocol)); } wf::pop_front(); diff --git a/src/graph.hh b/src/graph.hh index 37226c1..5495a42 100644 --- a/src/graph.hh +++ b/src/graph.hh @@ -3,6 +3,7 @@ #include #include #include +#include namespace gitmem { @@ -80,13 +81,23 @@ struct Write : Node { struct Read : Node { const std::string var; - const size_t value; const size_t id; - const std::shared_ptr sauce; + struct SuccessfulRead { + size_t value; + std::shared_ptr source; + }; + + const std::variant read_result; + + // Constructor for successful read Read(const std::string var, const size_t value, const size_t id, - const std::shared_ptr sauce) - : var(var), value(value), id(id), sauce(sauce) {} + const std::shared_ptr source) + : var(var), id(id), read_result(SuccessfulRead{value, source}) {} + + // Constructor for conflicting read + Read(const std::string var, const size_t id, Conflict conflict) + : var(var), id(id), read_result(std::move(conflict)) {} void accept(Visitor *v) const override { v->visitRead(this); } }; diff --git a/src/graphviz.cc b/src/graphviz.cc index 217bcb6..c7f8517 100644 --- a/src/graphviz.cc +++ b/src/graphviz.cc @@ -1,4 +1,5 @@ #include "graphviz.hh" +#include "overloaded.hh" #include namespace gitmem { @@ -111,12 +112,29 @@ void GraphvizPrinter::visitWrite(const Write *n) { } void GraphvizPrinter::visitRead(const Read *n) { - emitNode(n, "R" + n->var + " = " + to_string(n->value)); + std::string label = "R" + n->var + " = "; + + std::visit(overloaded{ + [&](const Read::SuccessfulRead& success) { + label += to_string(success.value); + }, + [&](const Conflict& conflict) { + label += "conflict"; + } + }, n->read_result); + + emitNode(n, label); emitProgramOrderEdge(n, n->next.get()); visitProgramOrder(n->next.get()); - assert(n->sauce); - emitReadFromEdge(n, n->sauce.get()); + if (auto* conflict = std::get_if(&n->read_result)) { + emitConflict(n, *conflict); + } else { + auto& success = std::get(n->read_result); + if (success.source) { + emitReadFromEdge(n, success.source.get()); + } + } } void GraphvizPrinter::visitSpawn(const Spawn *n) { diff --git a/src/interpreter.cc b/src/interpreter.cc index 156fcf8..4a2c084 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -73,6 +73,7 @@ Interpreter::evaluate_expression(trieste::Node expr, Thread& thread) { }, [&](std::shared_ptr& conflict) -> std::variant { verbose::out << (*conflict) << std::endl; + thread.trace.on_read(var, conflict); return termination::DataRace(conflict); } }, result); @@ -583,16 +584,19 @@ graph::ExecutionGraph Interpreter::build_execution_graph_from_traces() { auto source = last_write_per_var.contains(arg.var) ? last_write_per_var[arg.var] : nullptr; - auto node = std::make_shared(arg.var, arg.value, tid, source); + + std::shared_ptr node; + std::visit(overloaded{ + [&](size_t val) { + node = std::make_shared(arg.var, val, tid, source); + }, + [&](const std::shared_ptr&) { + node = std::make_shared(arg.var, tid, graph::Conflict(arg.var)); + } + }, arg.value_or_conflict); + link_in_program_order(tid, node); event_to_node[event] = node; - - // Mark conflict if present (full conflict details would require more work) - if (arg.maybe_conflict) { - // For now, just mark the node red - proper conflict edges would need - // extracting source nodes from ConflictBase - // TODO: Extract and visualize conflict sources - } }, [&](const SpawnEvent& arg) { // Link to the child thread's start node @@ -678,9 +682,9 @@ void Interpreter::print_execution_graph(const std::filesystem::path& output_path gv.visit(exec_graph.entry.get()); } - -int interpret(const Node ast, const std::filesystem::path &output_path, SyncKind sync_kind) { - Interpreter interp(GlobalContext(ast, make_protocol(sync_kind))); +int interpret(const Node ast, const std::filesystem::path &output_path, + std::unique_ptr protocol) { + Interpreter interp(GlobalContext(ast, std::move(protocol))); int result = interp.run(); interp.print_revision_graph(output_path); diff --git a/src/interpreter.hh b/src/interpreter.hh index 57cc74f..a0bc5fc 100644 --- a/src/interpreter.hh +++ b/src/interpreter.hh @@ -53,6 +53,7 @@ public: }; // Entry function -int interpret(const trieste::Node, const std::filesystem::path &output_file, SyncKind sync_kind); +int interpret(const trieste::Node, const std::filesystem::path &output_file, + std::unique_ptr protocol); } // namespace gitmem \ No newline at end of file diff --git a/src/linear/sync_protocol.hh b/src/linear/sync_protocol.hh index 32e823c..e8438d5 100644 --- a/src/linear/sync_protocol.hh +++ b/src/linear/sync_protocol.hh @@ -2,7 +2,6 @@ #include "../sync_protocol.hh" #include "conflict.hh" -#include "sync_kind.hh" #include "execution_state.hh" #include "linear/version_store.hh" @@ -20,7 +19,10 @@ class LinearSyncProtocol final : public SyncProtocol { public: ~LinearSyncProtocol() override; - SyncKind kind() const override { return SyncKind::Linear; }; + + std::unique_ptr clone() const override { + return std::make_unique(); + } ReadResult read(ThreadContext &ctx, const std::string &var) override; @@ -57,6 +59,13 @@ public: } }; +class LinearSyncProtocolBuilder { +public: + std::unique_ptr build() const { + return std::make_unique(); + } +}; + } // namespace linear } // namespace gitmem \ No newline at end of file diff --git a/src/model_checker.cc b/src/model_checker.cc index 4cff887..69da1ee 100644 --- a/src/model_checker.cc +++ b/src/model_checker.cc @@ -58,7 +58,8 @@ build_output_path(const std::filesystem::path &output_path, const size_t idx) { * Explore all possible execution paths of the program, printing one trace * for each distinct final state that led to an error. */ -int model_check(const Node ast, const std::filesystem::path &output_path, SyncKind sync_kind) { +int model_check(const Node ast, const std::filesystem::path &output_path, + std::unique_ptr protocol) { auto final_contexts = std::vector>{}; auto failing_contexts = std::vector>{}; auto deadlocked_contexts = std::vector>{}; @@ -72,7 +73,10 @@ int model_check(const Node ast, const std::filesystem::path &output_path, SyncKi auto current_trace = std::vector{0}; // Start with the main thread verbose::out << "==== Thread " << cursor->tid_ << " ====" << std::endl; - Interpreter interp(GlobalContext(ast, make_protocol(sync_kind))); + Interpreter interp(GlobalContext(ast, std::move(protocol))); + + // Keep a pointer to the protocol for cloning later + const SyncProtocol* protocol_template = interp.context().protocol.get(); GlobalContext& gctx = interp.context(); interp.progress_thread(gctx.threads[cursor->tid_]); @@ -161,7 +165,7 @@ int model_check(const Node ast, const std::filesystem::path &output_path, SyncKi if (cursor->complete && !root->complete) { // Reset the cursor to the root and start a new trace verbose::out << std::endl << "Restarting trace..." << std::endl; - interp = Interpreter(GlobalContext(ast, make_protocol(sync_kind))); + interp = Interpreter(GlobalContext(ast, protocol_template->clone())); GlobalContext& gctx = interp.context(); cursor = root; diff --git a/src/model_checker.hh b/src/model_checker.hh index 448f2c0..54d8494 100644 --- a/src/model_checker.hh +++ b/src/model_checker.hh @@ -6,5 +6,5 @@ namespace gitmem { using namespace trieste; - int model_check(const Node ast, const std::filesystem::path &output_path, SyncKind sync_kind); + int model_check(const Node ast, const std::filesystem::path &output_path, std::unique_ptr protocol); } \ No newline at end of file diff --git a/src/sync_kind.hh b/src/sync_kind.hh deleted file mode 100644 index a5d2867..0000000 --- a/src/sync_kind.hh +++ /dev/null @@ -1,11 +0,0 @@ -#pragma once - -namespace gitmem { - -enum class SyncKind { - Linear, - BranchingEager, - BranchingLazy -}; - -} \ No newline at end of file diff --git a/src/sync_protocol.cc b/src/sync_protocol.cc deleted file mode 100644 index a854ce0..0000000 --- a/src/sync_protocol.cc +++ /dev/null @@ -1,32 +0,0 @@ -#include "sync_protocol.hh" -#include "linear/sync_protocol.hh" -#include "branching/base_sync_protocol.hh" -#include "debug.hh" - -namespace gitmem { - -std::unique_ptr make_protocol(SyncKind sync_kind) { - switch (sync_kind) { - case SyncKind::Linear: - return std::make_unique(); - - case SyncKind::BranchingEager: { - auto builder = branching::BranchingSyncProtocolBuilder().eager(); - if (verbose::out.include_empty_commits) { - builder.with_verbose_commits(); - } - return builder.build(); - } - - case SyncKind::BranchingLazy: { - auto builder = branching::BranchingSyncProtocolBuilder().lazy(); - if (verbose::out.include_empty_commits) { - builder.with_verbose_commits(); - } - return builder.build(); - } - } - std::unreachable(); -} - -} // namespace gitmem diff --git a/src/sync_protocol.hh b/src/sync_protocol.hh index f3960f3..f3b2470 100644 --- a/src/sync_protocol.hh +++ b/src/sync_protocol.hh @@ -1,7 +1,6 @@ #pragma once #include "conflict.hh" -#include "sync_kind.hh" #include "sync_state.hh" #include "execution_state.hh" #include "read_result.hh" @@ -10,12 +9,16 @@ namespace gitmem { -std::unique_ptr make_protocol(SyncKind); +// Forward declaration for builder +class SyncProtocolBuilder; class SyncProtocol { public: virtual ~SyncProtocol() = default; - virtual SyncKind kind() const = 0; + + // Create a fresh copy of this protocol with reset state + virtual std::unique_ptr clone() const = 0; + virtual std::unique_ptr make_thread_state(ThreadID tid) const = 0; virtual std::unique_ptr make_lock_state() const = 0; @@ -51,8 +54,6 @@ public: const SyncProtocol &protocol) { return protocol.print(os); } - - }; } // namespace gitmem diff --git a/src/thread_trace.hh b/src/thread_trace.hh index 6321a5c..ed29ed8 100644 --- a/src/thread_trace.hh +++ b/src/thread_trace.hh @@ -2,6 +2,7 @@ #include "thread_id.hh" #include "conflict.hh" +#include "overloaded.hh" namespace gitmem { @@ -9,7 +10,7 @@ struct Event; struct StartEvent {}; struct SpawnEvent { const ThreadID child_tid; }; -struct ReadEvent { const std::string var; const size_t value; std::shared_ptr maybe_conflict; }; +struct ReadEvent { const std::string var; std::variant> value_or_conflict; }; struct WriteEvent { const std::string var; const size_t value; }; struct LockEvent { std::string lock_name; std::shared_ptr maybe_conflict; std::shared_ptr last_unlock_event; }; struct UnlockEvent { const std::string lock_name; std::shared_ptr maybe_conflict; }; @@ -52,11 +53,12 @@ inline std::ostream& operator<<(std::ostream& os, const SpawnEvent& e) { } inline std::ostream& operator<<(std::ostream& os, const ReadEvent& e) { - os << "ReadEvent(var=\"" << e.var << "\", value=" << e.value; - if (e.maybe_conflict) - os << ", conflict)"; - else - os << ")"; + os << "ReadEvent(var=\"" << e.var << "\", "; + std::visit(overloaded{ + [&os](size_t val) { os << "value=" << val; }, + [&os](const std::shared_ptr&) { os << "conflict"; } + }, e.value_or_conflict); + os << ")"; return os; } @@ -139,8 +141,12 @@ private: return append(child_tid); } - std::shared_ptr on_read(const std::string text, const size_t value, std::shared_ptr conflict = nullptr) { - return append(std::move(text), value, conflict); + std::shared_ptr on_read(const std::string text, const size_t value) { + return append(std::move(text), value); + } + + std::shared_ptr on_read(const std::string text, std::shared_ptr conflict) { + return append(std::move(text), conflict); } std::shared_ptr on_write(const std::string text, const size_t value) { From 620ee30be03bcaa949f1fd2010132d3b1402740f Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Mon, 12 Jan 2026 11:06:40 +0100 Subject: [PATCH 49/58] removing builders for protocols, adding infra to later switch between detecting conflicts that came before suppressing writes, and updating command line opts and testsuite --- src/branching/base_sync_protocol.cc | 8 ---- src/branching/base_sync_protocol.hh | 25 ---------- src/branching/lazy/sync_protocol.hh | 11 +++-- src/branching/lazy/version_store.hh | 5 +- src/gitmem.cc | 72 +++++++++++++++++++++-------- test_gitmem.py | 20 +++++--- 6 files changed, 79 insertions(+), 62 deletions(-) diff --git a/src/branching/base_sync_protocol.cc b/src/branching/base_sync_protocol.cc index ed83349..951c31f 100644 --- a/src/branching/base_sync_protocol.cc +++ b/src/branching/base_sync_protocol.cc @@ -145,14 +145,6 @@ std::string BranchingSyncProtocolBase::build_revision_graph_dot( return build_commit_graph_dot(heads); } -std::unique_ptr BranchingSyncProtocolBuilder::build() const { - if (eager_mode) { - return std::make_unique(verbose_commits); - } else { - return std::make_unique(verbose_commits); - } -} - } // end branching } // end gitmem \ No newline at end of file diff --git a/src/branching/base_sync_protocol.hh b/src/branching/base_sync_protocol.hh index d752cf8..20ce253 100644 --- a/src/branching/base_sync_protocol.hh +++ b/src/branching/base_sync_protocol.hh @@ -53,31 +53,6 @@ public: } }; -// Builder for creating branching sync protocols -class BranchingSyncProtocolBuilder { -private: - bool eager_mode = false; - bool verbose_commits = false; - -public: - BranchingSyncProtocolBuilder& eager() { - eager_mode = true; - return *this; - } - - BranchingSyncProtocolBuilder& lazy() { - eager_mode = false; - return *this; - } - - BranchingSyncProtocolBuilder& with_verbose_commits(bool v = true) { - verbose_commits = v; - return *this; - } - - std::unique_ptr build() const; -}; - } // end branching } // end gitmem \ No newline at end of file diff --git a/src/branching/lazy/sync_protocol.hh b/src/branching/lazy/sync_protocol.hh index 796e437..ac8a8bc 100644 --- a/src/branching/lazy/sync_protocol.hh +++ b/src/branching/lazy/sync_protocol.hh @@ -8,18 +8,21 @@ namespace gitmem { namespace branching { class BranchingLazySyncProtocol final : public BranchingSyncProtocolBase { +private: + bool raise_early_conflicts; // currently not used + public: - explicit BranchingLazySyncProtocol(bool verbose = false) - : BranchingSyncProtocolBase(verbose) {} + explicit BranchingLazySyncProtocol(bool verbose = false, bool raise_early_conflicts = false) + : BranchingSyncProtocolBase(verbose), raise_early_conflicts(raise_early_conflicts) {} ~BranchingLazySyncProtocol() = default; std::unique_ptr clone() const override { - return std::make_unique(verbose_commits); + return std::make_unique(verbose_commits, raise_early_conflicts); } std::unique_ptr make_thread_state(ThreadID tid) const override { - return std::make_unique(tid, verbose_commits); + return std::make_unique(tid, verbose_commits, raise_early_conflicts); } }; diff --git a/src/branching/lazy/version_store.hh b/src/branching/lazy/version_store.hh index 13b881d..261c154 100644 --- a/src/branching/lazy/version_store.hh +++ b/src/branching/lazy/version_store.hh @@ -7,10 +7,13 @@ namespace gitmem { namespace branching { class LazyLocalVersionStore : public LocalVersionStore { +private: + bool raise_early_conflicts; + public: ~LazyLocalVersionStore() = default; - LazyLocalVersionStore(ThreadID tid, bool verbose) : LocalVersionStore(tid, verbose) {} + LazyLocalVersionStore(ThreadID tid, bool verbose, bool raise_early_conflicts) : LocalVersionStore(tid, verbose), raise_early_conflicts(raise_early_conflicts) {} std::optional merge_with_commit(const std::shared_ptr&) override; BranchingReadResult get_committed(std::string var) const override; diff --git a/src/gitmem.cc b/src/gitmem.cc index 37568fe..b016c76 100644 --- a/src/gitmem.cc +++ b/src/gitmem.cc @@ -7,6 +7,8 @@ #include "lang.hh" #include "linear/sync_protocol.hh" #include "branching/base_sync_protocol.hh" +#include "branching/eager/sync_protocol.hh" +#include "branching/lazy/sync_protocol.hh" int main(int argc, char **argv) { using namespace trieste; @@ -24,9 +26,28 @@ int main(int argc, char **argv) { app.add_flag("-v,--verbose", verbose, "Enable verbose output from the interpreter."); + std::string sync_protocol = "linear"; + auto sync_opt = app.add_option("--sync", sync_protocol, "Select a sync protocol for execution (default: linear)") + ->check(CLI::IsMember({"linear", "branching"})) + ->type_name("KIND"); + + std::string branching_mode = "eager"; + auto branching_mode_opt = app.add_option("--branching-mode", branching_mode, "Select branching mode: eager or lazy (default: eager)") + ->check(CLI::IsMember({"eager", "lazy"})) + ->type_name("MODE"); + bool include_empty_commits = false; - app.add_flag("--include-empty-commits", include_empty_commits, - "Include empty commits in branching protocol output."); + auto include_empty_opt = app.add_flag("--include-empty-commits", include_empty_commits, + "Include empty commits in branching protocol output (branching mode only)."); + + bool raise_early_conflicts = false; + auto raise_early_opt = app.add_flag("--raise-early-conflicts", raise_early_conflicts, + "Raise conflict errors before suppressing writes (lazy branching mode only)."); + + // Set up option dependencies + branching_mode_opt->needs(sync_opt); + include_empty_opt->needs(sync_opt); + raise_early_opt->needs(branching_mode_opt); bool interactive = false; app.add_flag("-i,--interactive", interactive, @@ -36,13 +57,29 @@ int main(int argc, char **argv) { app.add_flag("-e,--explore", model_check, "Explore all possible execution paths."); - std::string sync_protocol = "linear"; - app.add_option("--sync", sync_protocol, "Select a sync protocol for execution (default: linear)") - ->check(CLI::IsMember({"linear", "branching-eager", "branching-lazy"})) - ->type_name("SYNC_KIND"); - try { app.parse(argc, argv); + + // Additional validation for logical consistency + if (sync_protocol == "linear") { + if (*branching_mode_opt) { + std::cerr << "Error: --branching-mode is only valid with --sync branching" << std::endl; + return 1; + } + if (include_empty_commits) { + std::cerr << "Error: --include-empty-commits is only valid with --sync branching" << std::endl; + return 1; + } + if (raise_early_conflicts) { + std::cerr << "Error: --raise-early-conflicts is only valid with --sync branching" << std::endl; + return 1; + } + } + + if (sync_protocol == "branching" && branching_mode == "eager" && raise_early_conflicts) { + std::cerr << "Error: --raise-early-conflicts is only valid with --branching-mode lazy" << std::endl; + return 1; + } } catch (const CLI::ParseError &e) { return app.exit(e); } @@ -74,17 +111,16 @@ int main(int argc, char **argv) { // Build protocol based on command line options std::unique_ptr protocol; if (sync_protocol == "linear") { - protocol = gitmem::linear::LinearSyncProtocolBuilder().build(); - } else if (sync_protocol == "branching-eager") { - protocol = gitmem::branching::BranchingSyncProtocolBuilder() - .eager() - .with_verbose_commits(include_empty_commits) - .build(); - } else if (sync_protocol == "branching-lazy") { - protocol = gitmem::branching::BranchingSyncProtocolBuilder() - .lazy() - .with_verbose_commits(include_empty_commits) - .build(); + protocol = std::make_unique(); + } else if (sync_protocol == "branching") { + if (branching_mode == "eager") { + protocol = std::make_unique(include_empty_commits); + } else { // lazy + protocol = std::make_unique( + include_empty_commits, + raise_early_conflicts + ); + } } int exit_status; diff --git a/test_gitmem.py b/test_gitmem.py index f686a5e..b03bce0 100644 --- a/test_gitmem.py +++ b/test_gitmem.py @@ -7,9 +7,9 @@ EXAMPLES_DIR = "examples" SYNC_KINDS = { - "linear": "linear", - "branching-eager": "branching-eager", - "branching-lazy": "branching-lazy", + "linear": {"sync": "linear"}, + "branching-eager": {"sync": "branching", "branching_mode": "eager"}, + "branching-lazy": {"sync": "branching", "branching_mode": "lazy"}, } def supports_color(): @@ -27,13 +27,21 @@ def red(text): return color(text, "31") def run_gitmem_test(gitmem_path, file_path, should_accept, sync_kind): + sync_config = SYNC_KINDS[sync_kind] + cmd = [ gitmem_path, file_path, - "--sync", sync_kind, + "--sync", sync_config["sync"], + ] + + if "branching_mode" in sync_config: + cmd.extend(["--branching-mode", sync_config["branching_mode"]]) + + cmd.extend([ "-e", "-o", "/dev/null" - ] + ]) try: result = subprocess.run( @@ -90,7 +98,7 @@ def main(): # If none specified, run all if not selected_syncs: - selected_syncs = list(SYNC_KINDS.values()) + selected_syncs = list(SYNC_KINDS.keys()) results = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: { "total": 0, From f94a680adb65186fdf4d0ee912c141eee27f6a16 Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Mon, 12 Jan 2026 11:48:10 +0100 Subject: [PATCH 50/58] allowing assert nodes in the graph to be passing or failing and only red on failing, adding infrastructure to let prptocols decide what is a scheduling point --- src/branching/base_sync_protocol.cc | 17 +++++++++++ src/branching/base_sync_protocol.hh | 2 ++ src/graph.hh | 11 +++---- src/graphviz.cc | 4 +-- src/graphviz.hh | 2 +- src/interpreter.cc | 45 +++++++++++++++++------------ src/linear/sync_protocol.cc | 13 +++++++++ src/linear/sync_protocol.hh | 2 ++ src/sync_protocol.hh | 14 +++++++++ 9 files changed, 84 insertions(+), 26 deletions(-) diff --git a/src/branching/base_sync_protocol.cc b/src/branching/base_sync_protocol.cc index 951c31f..472595d 100644 --- a/src/branching/base_sync_protocol.cc +++ b/src/branching/base_sync_protocol.cc @@ -145,6 +145,23 @@ std::string BranchingSyncProtocolBase::build_revision_graph_dot( return build_commit_graph_dot(heads); } +bool BranchingSyncProtocolBase::is_scheduling_point(SyncOperation op) const { + // For branching protocol, only operations that actually synchronize state + // (lock/unlock) or require waiting (join) are scheduling points + switch (op) { + case SyncOperation::Lock: + case SyncOperation::Unlock: + case SyncOperation::Join: + return true; + case SyncOperation::Spawn: + case SyncOperation::Start: + case SyncOperation::End: + // These just inherit/commit locally - no scheduling decision needed + return false; + } + return false; +} + } // end branching } // end gitmem \ No newline at end of file diff --git a/src/branching/base_sync_protocol.hh b/src/branching/base_sync_protocol.hh index 20ce253..ba8d6bf 100644 --- a/src/branching/base_sync_protocol.hh +++ b/src/branching/base_sync_protocol.hh @@ -48,6 +48,8 @@ public: std::string build_revision_graph_dot(const std::vector& thread_states) const override; + bool is_scheduling_point(SyncOperation op) const override; + std::unique_ptr make_lock_state() const override { return std::make_unique(); } diff --git a/src/graph.hh b/src/graph.hh index 5495a42..7ef25da 100644 --- a/src/graph.hh +++ b/src/graph.hh @@ -25,7 +25,7 @@ struct Spawn; struct Join; struct Lock; struct Unlock; -struct AssertionFailure; +struct Assertion; struct Pending; struct Conflict { @@ -49,7 +49,7 @@ struct Visitor { virtual void visitJoin(const Join *) = 0; virtual void visitLock(const Lock *) = 0; virtual void visitUnlock(const Unlock *) = 0; - virtual void visitAssertionFailure(const AssertionFailure *) = 0; + virtual void visitAssertion(const Assertion *) = 0; virtual void visitPending(const Pending *) = 0; virtual void visit(const Node *n) { n->accept(this); } }; @@ -144,12 +144,13 @@ struct Unlock : Node { void accept(Visitor *v) const override { v->visitUnlock(this); } }; -struct AssertionFailure : Node { +struct Assertion : Node { const std::string cond; + const bool passed; - AssertionFailure(const std::string &cond) : cond(cond) {} + Assertion(const std::string &cond, const bool passed) : cond(cond), passed(passed) {} - void accept(Visitor *v) const override { v->visitAssertionFailure(this); } + void accept(Visitor *v) const override { v->visitAssertion(this); } }; struct Pending : Node { diff --git a/src/graphviz.cc b/src/graphviz.cc index c7f8517..36370fa 100644 --- a/src/graphviz.cc +++ b/src/graphviz.cc @@ -173,9 +173,9 @@ void GraphvizPrinter::visitUnlock(const Unlock *n) { visitProgramOrder(n->next.get()); } -void GraphvizPrinter::visitAssertionFailure(const AssertionFailure *n) { +void GraphvizPrinter::visitAssertion(const Assertion *n) { emitNode(n, "Assert " + n->cond); - emitFillColor(n, "red"); + if (!n->passed) emitFillColor(n, "red"); // emitShape(n, "doubleoctagon"); emitProgramOrderEdge(n, n->next.get()); visitProgramOrder(n->next.get()); diff --git a/src/graphviz.hh b/src/graphviz.hh index adfef24..c3e294b 100644 --- a/src/graphviz.hh +++ b/src/graphviz.hh @@ -12,7 +12,7 @@ struct GraphvizPrinter : Visitor { void visitJoin(const Join *) override; void visitLock(const Lock *) override; void visitUnlock(const Unlock *) override; - void visitAssertionFailure(const AssertionFailure *) override; + void visitAssertion(const Assertion *) override; void visitPending(const Pending *) override; void visit(const Node *n) override; diff --git a/src/interpreter.cc b/src/interpreter.cc index 4a2c084..12b68dc 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -26,18 +26,37 @@ using namespace trieste; * - t unlocking a lock l, which updates l to have t's versioned memory */ -static bool is_syncing(Node stmt) { +// Map AST node types to sync operations +static std::optional get_sync_operation(Node stmt) { auto s = stmt / lang::Stmt; - return s == lang::Join || s == lang::Lock || s == lang::Unlock; + if (s == lang::Join) return SyncOperation::Join; + if (s == lang::Lock) return SyncOperation::Lock; + if (s == lang::Unlock) return SyncOperation::Unlock; + + // Spawn is an expression, not a statement, but we check for assignment of spawn + if (s == lang::Assign) { + auto rhs = s / lang::Expr; + if (rhs == lang::Spawn) return SyncOperation::Spawn; + } + + return std::nullopt; } -static bool is_syncing(Thread &thread) { +// Check if a statement is a scheduling point according to the protocol +static bool is_syncing(const SyncProtocol& protocol, Node stmt) { + if (auto op = get_sync_operation(stmt)) { + return protocol.is_scheduling_point(*op); + } + return false; +} + +static bool is_syncing(const SyncProtocol& protocol, Thread &thread) { // Can only be true if a thread hasn't terminated - // Either it has executed all statements but not yet terminated (and my sync) + // Either it has executed all statements but not yet terminated (and may sync) // Or it is at a synchronisation node // The lazy eval here is important return !thread.terminated && - ((thread.pc >= thread.block->size()) || is_syncing(thread.block->at(thread.pc))); + ((thread.pc >= thread.block->size()) || is_syncing(protocol, thread.block->at(thread.pc))); } /* Evaluating an expression either returns the result of the expression or @@ -173,16 +192,6 @@ std::variant Interpreter::run_statement(Node stmt, Threa gctx.protocol->write(ctx, var, *val); thread.trace.on_write(var, *val); - - // // Global variable writes need to create a new commit id - // // to track the history of updates - // auto &global = ctx.globals[var]; - // global.val = *val; - // global.commit = gctx.uuid++; - // verbose::out << "Set global '" << lhs->location().view() << "' to " << - // *val << " with id " << *(global.commit) << std::endl; - - // gctx.commit_map[*(global.commit)] = node; } else { throw std::runtime_error("Bad left-hand side: " + std::string(lhs->type().str())); @@ -330,7 +339,7 @@ Interpreter::run_single_thread_to_sync(Thread& thread) { Node stmt = block->at(pc); // Stop *before* executing a sync statement (except first) - if (made_progress && is_syncing(stmt)) + if (made_progress && is_syncing(*gctx.protocol, stmt)) return ProgressStatus::progress; auto result = run_statement(stmt, thread); @@ -385,7 +394,7 @@ Interpreter::progress_thread(Thread& thread) { // If there are new threads, we can run them to sync as well any_progress = true; auto& new_thread = gctx.threads[i]; - if (!is_syncing(new_thread)) { + if (!is_syncing(*gctx.protocol, new_thread)) { verbose::out << "==== Thread " << i << " (spawn) ====" << std::endl; progress_thread(new_thread); } @@ -644,7 +653,7 @@ graph::ExecutionGraph Interpreter::build_execution_graph_from_traces() { } }, [&](const AssertEvent& arg) { - auto node = std::make_shared(arg.condition); + auto node = std::make_shared(arg.condition, arg.pass); link_in_program_order(tid, node); event_to_node[event] = node; } diff --git a/src/linear/sync_protocol.cc b/src/linear/sync_protocol.cc index 7057c2c..2632cad 100644 --- a/src/linear/sync_protocol.cc +++ b/src/linear/sync_protocol.cc @@ -197,6 +197,19 @@ LinearSyncProtocol::on_unlock(ThreadContext &thread, Lock &) { return std::nullopt; } +bool LinearSyncProtocol::is_scheduling_point(SyncOperation op) const { + switch (op) { + case SyncOperation::Lock: + case SyncOperation::Unlock: + case SyncOperation::Join: + case SyncOperation::Spawn: + case SyncOperation::Start: + case SyncOperation::End: + return true; + } + assert(false && "Unknown SyncOperation"); +} + } // namespace linear } // namespace gitmem diff --git a/src/linear/sync_protocol.hh b/src/linear/sync_protocol.hh index e8438d5..51a3bb5 100644 --- a/src/linear/sync_protocol.hh +++ b/src/linear/sync_protocol.hh @@ -50,6 +50,8 @@ public: std::string build_revision_graph_dot(const std::vector& thread_states) const override; + bool is_scheduling_point(SyncOperation op) const override; + std::unique_ptr make_thread_state(ThreadID tid) const override { return std::make_unique(tid); } diff --git a/src/sync_protocol.hh b/src/sync_protocol.hh index f3b2470..b5f3ad9 100644 --- a/src/sync_protocol.hh +++ b/src/sync_protocol.hh @@ -12,6 +12,16 @@ namespace gitmem { // Forward declaration for builder class SyncProtocolBuilder; +// Types of synchronization operations that may be scheduling points +enum class SyncOperation { + Spawn, + Join, + Start, + End, + Lock, + Unlock +}; + class SyncProtocol { public: virtual ~SyncProtocol() = default; @@ -47,6 +57,10 @@ public: virtual std::optional> on_unlock(ThreadContext &thread, Lock &lock) = 0; + // Returns true if the given sync operation is a scheduling point for this protocol + // (i.e., the scheduler should consider switching threads here) + virtual bool is_scheduling_point(SyncOperation op) const = 0; + virtual std::string build_revision_graph_dot(const std::vector& thread_states) const = 0; virtual std::ostream &print(std::ostream &os) const = 0; From e9c9c41b224886db5dadfbc72b6588b0ff216fea Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Mon, 12 Jan 2026 14:39:59 +0100 Subject: [PATCH 51/58] merge nodes are outside thread graphs, new test which is currently showing a bug, more scheduling points on spawn --- .../branching/lock_relock_as_sync.gm | 14 ++++++++++ src/branching/base_version_store.cc | 27 ++++++++++++++++++- src/interpreter.cc | 5 ++-- 3 files changed, 43 insertions(+), 3 deletions(-) create mode 100644 examples/accept/semantics/branching/lock_relock_as_sync.gm diff --git a/examples/accept/semantics/branching/lock_relock_as_sync.gm b/examples/accept/semantics/branching/lock_relock_as_sync.gm new file mode 100644 index 0000000..a6707e5 --- /dev/null +++ b/examples/accept/semantics/branching/lock_relock_as_sync.gm @@ -0,0 +1,14 @@ +x = 0; +lock l1; +$t = spawn { + assert (x == 0); + lock l1; + x = 42; + unlock l1; + assert (x == 42); +}; +x = 2; +assert (x == 2); +unlock l1; +lock l1; +unlock l1; \ No newline at end of file diff --git a/src/branching/base_version_store.cc b/src/branching/base_version_store.cc index 51ff7fd..03afab4 100644 --- a/src/branching/base_version_store.cc +++ b/src/branching/base_version_store.cc @@ -164,6 +164,7 @@ std::string build_commit_graph_dot(const std::vector visited; std::unordered_map>> commits_by_thread; + std::vector> merge_commits; std::stack> stack; // First pass: collect all commits and organize by thread @@ -177,7 +178,12 @@ std::string build_commit_graph_dot(const std::vectorid.thread].push_back(commit); + // Merge commits (2+ parents) go outside clusters + if (commit->parents.size() >= 2) { + merge_commits.push_back(commit); + } else { + commits_by_thread[commit->id.thread].push_back(commit); + } for (const auto& parent : commit->parents) { if (parent) stack.push(parent); @@ -211,6 +217,25 @@ std::string build_commit_graph_dot(const std::vectorid); + + std::ostringstream label; + label << cid << " (merge)"; + if (!commit->changes.empty()) { + label << "\\n"; + bool first = true; + for (const auto& [obj, val] : commit->changes) { + if (!first) label << "\\n"; + first = false; + label << obj << "→" << val; + } + } + + dot << " \"" << cid << "\" [label=\"" << label.str() << "\", style=filled, fillcolor=lightgray];\n"; + } + // Draw edges (outside clusters so they can cross boundaries) visited.clear(); for (const auto& leaf : leaves) diff --git a/src/interpreter.cc b/src/interpreter.cc index 12b68dc..bf3d583 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -35,7 +35,8 @@ static std::optional get_sync_operation(Node stmt) { // Spawn is an expression, not a statement, but we check for assignment of spawn if (s == lang::Assign) { - auto rhs = s / lang::Expr; + // A little gross but okay for now + auto rhs = s / lang::Expr / lang::Expr; if (rhs == lang::Spawn) return SyncOperation::Spawn; } @@ -361,7 +362,7 @@ Interpreter::run_single_thread_to_sync(Thread& thread) { } // If we ran *any* statements, finishing is a sync point for next iteration - if (made_progress) + if (made_progress && gctx.protocol->is_scheduling_point(SyncOperation::End)) return ProgressStatus::progress; // Otherwise, we truly reached the end this iteration From 3a28f269b85f8ae722699057bc0173d1c20b9cac Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Mon, 12 Jan 2026 16:19:00 +0100 Subject: [PATCH 52/58] fixed bug where commits couldn't be the result of find_lowest_lca --- .../branching/lock_relock_as_sync.gm | 5 +- src/branching/base_sync_protocol.cc | 5 +- src/branching/base_version_store.cc | 45 ++++++-------- src/branching/base_version_store.hh | 13 ++++ src/branching/eager/version_store.cc | 60 +++++++++++++++---- src/debugger.cc | 31 +++++++--- src/execution_state.cc | 3 + src/interpreter.cc | 4 ++ src/interpreter.hh | 2 +- src/sync_state.hh | 11 +++- 10 files changed, 123 insertions(+), 56 deletions(-) diff --git a/examples/accept/semantics/branching/lock_relock_as_sync.gm b/examples/accept/semantics/branching/lock_relock_as_sync.gm index a6707e5..28fe69d 100644 --- a/examples/accept/semantics/branching/lock_relock_as_sync.gm +++ b/examples/accept/semantics/branching/lock_relock_as_sync.gm @@ -1,14 +1,11 @@ x = 0; -lock l1; $t = spawn { - assert (x == 0); lock l1; x = 42; unlock l1; - assert (x == 42); }; +lock l1; x = 2; -assert (x == 2); unlock l1; lock l1; unlock l1; \ No newline at end of file diff --git a/src/branching/base_sync_protocol.cc b/src/branching/base_sync_protocol.cc index 472595d..888cecd 100644 --- a/src/branching/base_sync_protocol.cc +++ b/src/branching/base_sync_protocol.cc @@ -119,12 +119,9 @@ BranchingSyncProtocolBase::on_unlock(ThreadContext &thread, Lock &lock) { auto& store = get_store(thread); store.commit_staging(); - LockState& lock_state = get_store(lock); - std::shared_ptr lock_commit = lock_state.commit; - // we know that the last committer was this thread, so no need to merge // this sort of mixes protocol logic and lock state, i am unsure if this is ideal - + LockState& lock_state = get_store(lock); lock_state.commit = store.get_head(); return std::nullopt; diff --git a/src/branching/base_version_store.cc b/src/branching/base_version_store.cc index 03afab4..f44e46d 100644 --- a/src/branching/base_version_store.cc +++ b/src/branching/base_version_store.cc @@ -164,7 +164,6 @@ std::string build_commit_graph_dot(const std::vector visited; std::unordered_map>> commits_by_thread; - std::vector> merge_commits; std::stack> stack; // First pass: collect all commits and organize by thread @@ -178,12 +177,7 @@ std::string build_commit_graph_dot(const std::vectorparents.size() >= 2) { - merge_commits.push_back(commit); - } else { - commits_by_thread[commit->id.thread].push_back(commit); - } + commits_by_thread[commit->id.thread].push_back(commit); for (const auto& parent : commit->parents) { if (parent) stack.push(parent); @@ -201,6 +195,16 @@ std::string build_commit_graph_dot(const std::vectorparents.size() >= 2) { + label << " (merge"; + if (commit->conflicted) { + label << " - CONFLICT"; + } + label << ")"; + } + if (!commit->changes.empty()) { label << "\\n"; bool first = true; @@ -211,29 +215,16 @@ std::string build_commit_graph_dot(const std::vectorid); - - std::ostringstream label; - label << cid << " (merge)"; - if (!commit->changes.empty()) { - label << "\\n"; - bool first = true; - for (const auto& [obj, val] : commit->changes) { - if (!first) label << "\\n"; - first = false; - label << obj << "→" << val; + // Style merge commits differently + if (commit->parents.size() >= 2) { + std::string fillcolor = commit->conflicted ? "pink" : "lightgray"; + dot << " \"" << cid << "\" [label=\"" << label.str() << "\", style=filled, fillcolor=" << fillcolor << "];\n"; + } else { + dot << " \"" << cid << "\" [label=\"" << label.str() << "\"];\n"; } } - dot << " \"" << cid << "\" [label=\"" << label.str() << "\", style=filled, fillcolor=lightgray];\n"; + dot << " }\n"; } // Draw edges (outside clusters so they can cross boundaries) diff --git a/src/branching/base_version_store.hh b/src/branching/base_version_store.hh index 18724c3..08b5f72 100644 --- a/src/branching/base_version_store.hh +++ b/src/branching/base_version_store.hh @@ -51,6 +51,7 @@ struct Commit { Timestamp id; std::unordered_map changes; std::vector> parents; + bool conflicted = false; }; std::string build_commit_graph_dot(const std::vector>& leaves); @@ -126,7 +127,19 @@ public: class LockState : public LockSyncState { public: + ~LockState() = default; + std::shared_ptr commit; + + inline std::ostream &print(std::ostream &os) const override { + os << "LockState{commit="; + if (commit) + os << commit->id; + else + os << "empty"; + os << "}"; + return os; + } }; } // namespace branching diff --git a/src/branching/eager/version_store.cc b/src/branching/eager/version_store.cc index 74cd9d8..9da040a 100644 --- a/src/branching/eager/version_store.cc +++ b/src/branching/eager/version_store.cc @@ -39,6 +39,37 @@ find_lowest_common_ancestor(std::shared_ptr a, std::shared_ptr b) { if (!a || !b) return nullptr; + if (a == b) return a; + + // Check if 'a' is an ancestor of 'b' + { + std::unordered_set> visited; + std::queue> q; + q.push(b); + while (!q.empty()) { + auto c = q.front(); q.pop(); + if (!c) continue; + if (c == a) return a; + if (!visited.insert(c).second) continue; + for (auto& p : c->parents) + q.push(p); + } + } + + // Check if 'b' is an ancestor of 'a' + { + std::unordered_set> visited; + std::queue> q; + q.push(a); + while (!q.empty()) { + auto c = q.front(); q.pop(); + if (!c) continue; + if (c == b) return b; + if (!visited.insert(c).second) continue; + for (auto& p : c->parents) + q.push(p); + } + } // Step 1: collect all ancestors of 'a' std::unordered_set> ancestors_a; @@ -81,15 +112,6 @@ std::optional EagerLocalVersionStore::merge_with_commit(const std::sha if (head == commit) return std::nullopt; - // Create merge commit (no changes itself) - auto merge_commit = std::make_shared( - Commit{ - .id = base_timestamp++, - .parents = {head, commit}, - .changes = {} // merge commit does not write anything - } - ); - // Find lowest common ancestor of the two heads std::shared_ptr lca = find_lowest_common_ancestor(head, commit); verbose::out << "found lca of " << head->id << " and " << commit->id << " to be " << lca->id << std::endl; @@ -104,17 +126,35 @@ std::optional EagerLocalVersionStore::merge_with_commit(const std::sha traverse_until_lca(commit, lca, branch_b, visited, reach_memo); // 1. Eager conflict detection + std::optional conflict; for (const auto& [obj, commit_a] : branch_a) { auto it = branch_b.find(obj); if (it != branch_b.end() && it->second != commit_a) { - return Conflict{ + conflict = Conflict{ .obj = obj, .timestamp_a = commit_a->id, .timestamp_b = it->second->id }; + break; } } + // Create merge commit (even if conflicted, for visualization) + auto merge_commit = std::make_shared( + Commit{ + .id = base_timestamp++, + .changes = {}, // merge commit does not write anything + .parents = {head, commit}, + .conflicted = conflict.has_value() + } + ); + + // If there was a conflict, update head but return the conflict + if (conflict) { + head = merge_commit; + return conflict; + } + // 2. Update thread-local last_writer incrementally // Only overwrite variables that were touched along either branch after LCA for (const auto& [obj, commit] : branch_a) diff --git a/src/debugger.cc b/src/debugger.cc index 0bc6870..5949e1e 100644 --- a/src/debugger.cc +++ b/src/debugger.cc @@ -1,4 +1,5 @@ #include +#include #include "debug.hh" #include "debugger.hh" @@ -23,11 +24,18 @@ struct Command { ThreadID argument = 0; }; -/** Show the global context, including locks and non-completed threads. If - * show_all is true, show all threads, even those that have terminated - * normally. */ -void show_global_context(const GlobalContext &gctx, bool show_all = false) { - std::cout << gctx << std::endl; +/** Clear the terminal screen (platform-specific) */ +void clear_terminal() { +#ifdef _WIN32 + std::system("cls"); +#else + std::system("clear"); +#endif +} + +/** Print a visual separator line */ +void print_separator() { + std::cout << std::string(60, '=') << std::endl; } /** Parse a command. See the help string for the 'Info' command for details. @@ -182,8 +190,13 @@ void do_restart(Interpreter &interp, } /** Print the list of threads and optionally all threads */ -void do_list(GlobalContext &gctx, bool show_all) { - gctx.print(std::cout, show_all); +void do_list(Interpreter &interp, bool show_all) { + // Uncomment the next line if you prefer clearing the screen + // clear_terminal(); + + print_separator(); + interp.print_state(std::cout, show_all); + print_separator(); } void do_finish(Interpreter& interp, bool print_graphs, const std::filesystem::path &output_file) { @@ -227,7 +240,7 @@ int interpret_interactive(const trieste::Node ast, while (command.cmd != Command::Quit) { // Print threads if new threads appeared or command is List if (command.cmd != Command::Skip || prev_no_threads != gctx.threads.size()) { - do_list(gctx, command.cmd == Command::List); + do_list(interp, command.cmd == Command::List); } prev_no_threads = gctx.threads.size(); @@ -242,7 +255,7 @@ int interpret_interactive(const trieste::Node ast, case Command::Step: { ThreadID tid = command.argument; StepUIResult res = do_step(interp, tid, print_graphs, output_file); - if (res.kind != StepKind::Progressed) + if (res.kind != StepKind::Progressed && res.kind != StepKind::Terminated) command = {Command::Skip}; break; } diff --git a/src/execution_state.cc b/src/execution_state.cc index 9780c14..46d76fc 100644 --- a/src/execution_state.cc +++ b/src/execution_state.cc @@ -156,6 +156,9 @@ void show_lock(const std::string &lock_name, const struct Lock &lock) { } else { std::cout << ""; } + if (lock.sync) { + std::cout << ", " << *(lock.sync); + } std::cout << std::endl; } diff --git a/src/interpreter.cc b/src/interpreter.cc index bf3d583..3e6779a 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -495,6 +495,10 @@ int Interpreter::run() { return exception_detected ? 1 : 0; } +void Interpreter::print_state(std::ostream& os, bool show_all) const { + gctx.print(os, show_all); +} + void Interpreter::print_thread_traces() { for (size_t tid = 0; tid < gctx.threads.size(); ++tid) { const auto& thread = gctx.threads[tid]; diff --git a/src/interpreter.hh b/src/interpreter.hh index a0bc5fc..93aea3c 100644 --- a/src/interpreter.hh +++ b/src/interpreter.hh @@ -46,8 +46,8 @@ public: StepResult run_single_thread_to_sync(Thread&); StepResult run_threads_to_sync(); + void print_state(std::ostream& os, bool show_all = false) const; void print_thread_traces(); - void print_revision_graph(const std::filesystem::path& output_path); void print_execution_graph(const std::filesystem::path& output_path); }; diff --git a/src/sync_state.hh b/src/sync_state.hh index 2960120..2d7d680 100644 --- a/src/sync_state.hh +++ b/src/sync_state.hh @@ -17,6 +17,15 @@ public: } }; -class LockSyncState {}; +class LockSyncState { +public: + virtual ~LockSyncState() = default; + + virtual std::ostream &print(std::ostream &os) const = 0; + friend std::ostream &operator<<(std::ostream &os, + const LockSyncState &state) { + return state.print(os); + } +}; } \ No newline at end of file From 5c7e2c90f30ecdc497d10b61151cdb7b440135f7 Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Fri, 16 Jan 2026 16:09:11 +0100 Subject: [PATCH 53/58] Fixing up read from edges in linear protocol --- src/branching/base_sync_protocol.cc | 6 ++--- src/branching/base_sync_protocol.hh | 3 ++- src/graph.hh | 7 ++++++ src/interpreter.cc | 36 +++++++++++++++++------------ src/linear/sync_protocol.cc | 22 ++++++++---------- src/linear/sync_protocol.hh | 2 +- src/linear/version_store.cc | 21 ++++++++++------- src/linear/version_store.hh | 23 ++++++++++-------- src/read_result.hh | 13 ++++++++++- src/sync_protocol.hh | 4 +++- src/thread_trace.hh | 32 ++++++------------------- 11 files changed, 92 insertions(+), 77 deletions(-) diff --git a/src/branching/base_sync_protocol.cc b/src/branching/base_sync_protocol.cc index 888cecd..42dd721 100644 --- a/src/branching/base_sync_protocol.cc +++ b/src/branching/base_sync_protocol.cc @@ -32,7 +32,7 @@ ReadResult BranchingSyncProtocolBase::read(ThreadContext &ctx, return std::visit(overloaded{ [](std::monostate) -> ReadResult { return std::monostate{}; }, - [](const Value& v) -> ReadResult { return v; }, + [](const Value& v) -> ReadResult { return ValueWithSource{v, nullptr}; }, [&](const Conflict& c) -> ReadResult { return std::make_shared( c.obj, std::pair{c.timestamp_a, c.timestamp_b} @@ -43,9 +43,9 @@ ReadResult BranchingSyncProtocolBase::read(ThreadContext &ctx, } void BranchingSyncProtocolBase::write(ThreadContext &ctx, const std::string &var, - size_t value) { + ValueWithSource value) { auto& store = get_store(ctx); - store.stage(var, value); + store.stage(var, value.value); // Branching protocol doesn't track sources yet } std::optional> diff --git a/src/branching/base_sync_protocol.hh b/src/branching/base_sync_protocol.hh index ba8d6bf..2e1091a 100644 --- a/src/branching/base_sync_protocol.hh +++ b/src/branching/base_sync_protocol.hh @@ -24,7 +24,8 @@ public: ReadResult read(ThreadContext &ctx, const std::string &var) override; - void write(ThreadContext &ctx, const std::string &var, size_t value) override; + void write(ThreadContext &ctx, const std::string &var, + ValueWithSource value) override; std::optional> on_spawn(ThreadContext &parent, ThreadContext &child) override; diff --git a/src/graph.hh b/src/graph.hh index 7ef25da..d261533 100644 --- a/src/graph.hh +++ b/src/graph.hh @@ -99,6 +99,13 @@ struct Read : Node { Read(const std::string var, const size_t id, Conflict conflict) : var(var), id(id), read_result(std::move(conflict)) {} + void set_source(const std::shared_ptr source) { + if (std::holds_alternative(read_result)) { + auto &sr = std::get(read_result); + const_cast&>(sr.source) = source; + } + } + void accept(Visitor *v) const override { v->visitRead(this); } }; diff --git a/src/interpreter.cc b/src/interpreter.cc index 3e6779a..30f25eb 100644 --- a/src/interpreter.cc +++ b/src/interpreter.cc @@ -86,10 +86,10 @@ Interpreter::evaluate_expression(trieste::Node expr, Thread& thread) { // invalid: reading a variable that hasn't been written return termination::UnassignedRead(var); }, - [&](Value value) -> std::variant { + [&](ValueWithSource value_with_source) -> std::variant { // normal read - thread.trace.on_read(var, value); - return value; + thread.trace.on_read(var, value_with_source); + return value_with_source.value; }, [&](std::shared_ptr& conflict) -> std::variant { verbose::out << (*conflict) << std::endl; @@ -191,8 +191,9 @@ std::variant Interpreter::run_statement(Node stmt, Threa } else if (lhs == lang::Var) { - gctx.protocol->write(ctx, var, *val); - thread.trace.on_write(var, *val); + auto write_event = thread.trace.on_write(var, *val); + gctx.protocol->write(ctx, var + , ValueWithSource{*val, write_event}); } else { throw std::runtime_error("Bad left-hand side: " + std::string(lhs->type().str())); @@ -536,12 +537,12 @@ graph::ExecutionGraph Interpreter::build_execution_graph_from_traces() { // Track the last unlock event for each lock (for lock->unlock edges) std::unordered_map> last_unlock_per_lock; - // Track write events per variable (for read->write edges) - std::unordered_map> last_write_per_var; - // Track join nodes that need fixing up after all threads are processed std::vector> joins_to_fix; + // Track read nodes that need their source fixed up + std::vector, std::shared_ptr>> reads_to_fix; + // Map from trace events to graph nodes std::unordered_map, std::shared_ptr> event_to_node; @@ -589,20 +590,18 @@ graph::ExecutionGraph Interpreter::build_execution_graph_from_traces() { }, [&](const WriteEvent& arg) { auto node = std::make_shared(arg.var, arg.value, tid); - last_write_per_var[arg.var] = node; link_in_program_order(tid, node); event_to_node[event] = node; }, [&](const ReadEvent& arg) { // Link to the write that produced this value - auto source = last_write_per_var.contains(arg.var) - ? last_write_per_var[arg.var] - : nullptr; - std::shared_ptr node; std::visit(overloaded{ - [&](size_t val) { - node = std::make_shared(arg.var, val, tid, source); + [&](const ReadValue& val) { + // Create the read node, but we might need to fix up the source later + node = std::make_shared(arg.var, val.value, tid, nullptr); + assert(val.source_event && "source missing"); + reads_to_fix.push_back({node, val.source_event}); }, [&](const std::shared_ptr&) { node = std::make_shared(arg.var, tid, graph::Conflict(arg.var)); @@ -687,6 +686,12 @@ graph::ExecutionGraph Interpreter::build_execution_graph_from_traces() { const_cast&>(join_node->joinee) = thread_tails[joinee_tid]; } + // Fix up read nodes to point to their source write events + for (auto& [read_node, source_event] : reads_to_fix) { + assert(event_to_node.contains(source_event) && "source missing in event_to_node map"); + read_node->set_source(event_to_node[source_event]); + } + return g; } @@ -701,6 +706,7 @@ int interpret(const Node ast, const std::filesystem::path &output_path, Interpreter interp(GlobalContext(ast, std::move(protocol))); int result = interp.run(); + interp.print_execution_graph(output_path); interp.print_revision_graph(output_path); return result; diff --git a/src/linear/sync_protocol.cc b/src/linear/sync_protocol.cc index 2632cad..1a65c9c 100644 --- a/src/linear/sync_protocol.cc +++ b/src/linear/sync_protocol.cc @@ -49,7 +49,7 @@ std::string LinearSyncProtocol::build_revision_graph_dot( node_id << obj_name << "_v" << i; std::ostringstream label; - label << version.timestamp() << "\\n" << obj_name << "=" << version.value(); + label << version.timestamp() << "\\n" << obj_name << "=" << version.value().value; dot << " \"" << node_id.str() << "\" [label=\"" << label.str() << "\"];\n"; } @@ -107,23 +107,21 @@ ReadResult LinearSyncProtocol::read(ThreadContext &ctx, const std::string &var) { auto& store = get_store(ctx); - if (auto result = store.get_staged(var)) + if (auto result = store.get_staged(var)) { return *result; + } - std::optional value = _global_store.get_version_for_timestamp( + std::optional value = _global_store.get_version_for_timestamp( var, store.timestamp()); - if (value) - return *value; - - // we do not need to record the staged value for correctness - // TODO: there is something about working out if a value has changed vs been - // written + if (value) { + return value.value(); + } return std::monostate{}; } void LinearSyncProtocol::write(ThreadContext &ctx, const std::string &var, - size_t value) { + ValueWithSource value) { // write into the staging area of the thread auto& store = get_store(ctx); store.stage(var, value); @@ -204,9 +202,7 @@ bool LinearSyncProtocol::is_scheduling_point(SyncOperation op) const { case SyncOperation::Join: case SyncOperation::Spawn: case SyncOperation::Start: - case SyncOperation::End: - return true; - } + case SyncOperation::End: return true; } assert(false && "Unknown SyncOperation"); } diff --git a/src/linear/sync_protocol.hh b/src/linear/sync_protocol.hh index 51a3bb5..37df6fd 100644 --- a/src/linear/sync_protocol.hh +++ b/src/linear/sync_protocol.hh @@ -26,7 +26,7 @@ public: ReadResult read(ThreadContext &ctx, const std::string &var) override; - void write(ThreadContext &ctx, const std::string &var, size_t value) override; + void write(ThreadContext &ctx, const std::string &var, ValueWithSource value) override; std::optional> on_spawn(ThreadContext &parent, ThreadContext &child) override; diff --git a/src/linear/version_store.cc b/src/linear/version_store.cc index e572be5..a6dbf31 100644 --- a/src/linear/version_store.cc +++ b/src/linear/version_store.cc @@ -3,6 +3,8 @@ #include "sync_protocol.hh" #include "version_store.hh" +#include "thread_trace.hh" +#include "read_result.hh" namespace gitmem { @@ -12,15 +14,17 @@ namespace linear { // LocalVersionStore // ----------------------------- -void LocalVersionStore::stage(std::string obj, Value value) { +void LocalVersionStore::stage(std::string obj, ValueWithSource value) { _staging[obj] = value; } -void LocalVersionStore::clear_staging() { _staging.clear(); } +void LocalVersionStore::clear_staging() { + _staging.clear(); +} void LocalVersionStore::advance_base(uint64_t ts) { _timestamp = ts; } -std::optional LocalVersionStore::get_staged(std::string obj) { +std::optional LocalVersionStore::get_staged(std::string obj) { auto it = _staging.find(obj); return it != _staging.end() ? std::make_optional(it->second) : std::nullopt; } @@ -39,7 +43,7 @@ std::ostream& operator<<(std::ostream& os, const LocalVersionStore& store) { for (const auto& [obj, val] : store._staging) { if (!first) os << ", "; first = false; - os << obj << "->" << val; + os << obj << "->" << val.value << " (" << val.source_event << ")"; } os << "}}"; @@ -50,7 +54,7 @@ std::ostream& operator<<(std::ostream& os, const LocalVersionStore& store) { // GlobalVersionStore // ----------------------------- -std::optional +std::optional GlobalVersionStore::get_version_for_timestamp(std::string obj, uint64_t ts) const { const auto it = _history.find(obj); @@ -70,7 +74,7 @@ GlobalVersionStore::get_version_for_timestamp(std::string obj, std::optional GlobalVersionStore::check_conflicts( uint64_t base, - const std::unordered_map &changes) const { + const std::unordered_map &changes) const { for (const auto &[obj, _] : changes) { auto it = _history.find(obj); if (it == _history.end()) { @@ -87,7 +91,8 @@ std::optional GlobalVersionStore::check_conflicts( } uint64_t GlobalVersionStore::apply_changes( - ThreadID tid, uint64_t base, const std::unordered_map &changes) { + ThreadID tid, uint64_t base, + const std::unordered_map &changes) { if (auto conflict = check_conflicts(base, changes)) { throw std::logic_error("apply_changes called with conflicts"); } @@ -109,7 +114,7 @@ std::ostream& operator<<(std::ostream& os, const GlobalVersionStore& store) { os << " Object " << obj_name << ":\n"; for (const auto& version : history) { - os << " [" << version.timestamp() << "] = " << version.value() << "\n"; + os << " [" << version.timestamp() << "] = " << version.value().value << "\n"; } } diff --git a/src/linear/version_store.hh b/src/linear/version_store.hh index d6dbbe9..fd719b8 100644 --- a/src/linear/version_store.hh +++ b/src/linear/version_store.hh @@ -7,9 +7,13 @@ #include #include #include "sync_state.hh" +#include "thread_id.hh" +#include "read_result.hh" namespace gitmem { +struct Event; // Forward declaration + namespace linear { // ----------------------------- @@ -36,13 +40,14 @@ using Value = size_t; class Version { Timestamp _timestamp; - Value _value; + ValueWithSource _value; public: - Version(Timestamp ts, Value value) : _timestamp(ts), _value(value) {} + Version(Timestamp ts, ValueWithSource value) + : _timestamp(ts), _value(value) {} Timestamp timestamp() const { return _timestamp; } - Value value() const { return _value; } + ValueWithSource value() const { return _value; } }; using VersionHistory = std::vector; @@ -64,7 +69,7 @@ struct Conflict { class LocalVersionStore : public ThreadSyncState { ThreadID tid; uint64_t _timestamp; - std::unordered_map _staging; + std::unordered_map _staging; public: ~LocalVersionStore() = default; @@ -76,10 +81,10 @@ public: uint64_t timestamp() const { return _timestamp; } const auto &staged_changes() const { return _staging; } - void stage(std::string obj, Value value); + void stage(std::string obj, ValueWithSource value); void clear_staging(); void advance_base(uint64_t ts); - std::optional get_staged(std::string obj); + std::optional get_staged(std::string obj); bool operator==(const LocalVersionStore& other) const; @@ -109,15 +114,15 @@ class GlobalVersionStore { public: uint64_t current_counter() const { return _counter; } - std::optional get_version_for_timestamp(std::string, uint64_t) const; + std::optional get_version_for_timestamp(std::string, uint64_t) const; std::optional check_conflicts(uint64_t base, - const std::unordered_map &changes) const; + const std::unordered_map &changes) const; uint64_t apply_changes(ThreadID tid, uint64_t base, - const std::unordered_map &changes); + const std::unordered_map &changes); friend std::ostream& operator<<(std::ostream&, const GlobalVersionStore&); diff --git a/src/read_result.hh b/src/read_result.hh index d32cac5..3ea3722 100644 --- a/src/read_result.hh +++ b/src/read_result.hh @@ -1,11 +1,22 @@ #pragma once #include +#include #include "conflict.hh" namespace gitmem { +struct Event; + using Value = size_t; -using ReadResult = std::variant>; + +struct ValueWithSource { + Value value; + std::shared_ptr source_event; + + auto operator<=>(const ValueWithSource&) const = default; +}; + +using ReadResult = std::variant>; } \ No newline at end of file diff --git a/src/sync_protocol.hh b/src/sync_protocol.hh index b5f3ad9..d0cc8a8 100644 --- a/src/sync_protocol.hh +++ b/src/sync_protocol.hh @@ -9,6 +9,8 @@ namespace gitmem { +struct Event; // Forward declaration + // Forward declaration for builder class SyncProtocolBuilder; @@ -37,7 +39,7 @@ public: // Write a shared variable (staged, not committed) virtual void write(ThreadContext &ctx, const std::string &var, - size_t value) = 0; + ValueWithSource value) = 0; virtual std::optional> on_spawn(ThreadContext &parent, ThreadContext &child) = 0; diff --git a/src/thread_trace.hh b/src/thread_trace.hh index ed29ed8..90dba56 100644 --- a/src/thread_trace.hh +++ b/src/thread_trace.hh @@ -3,6 +3,7 @@ #include "thread_id.hh" #include "conflict.hh" #include "overloaded.hh" +#include "read_result.hh" namespace gitmem { @@ -10,7 +11,8 @@ struct Event; struct StartEvent {}; struct SpawnEvent { const ThreadID child_tid; }; -struct ReadEvent { const std::string var; std::variant> value_or_conflict; }; +struct ReadValue { const size_t value; const std::shared_ptr source_event; }; +struct ReadEvent { const std::string var; std::variant> value_or_conflict; }; struct WriteEvent { const std::string var; const size_t value; }; struct LockEvent { std::string lock_name; std::shared_ptr maybe_conflict; std::shared_ptr last_unlock_event; }; struct UnlockEvent { const std::string lock_name; std::shared_ptr maybe_conflict; }; @@ -55,7 +57,7 @@ inline std::ostream& operator<<(std::ostream& os, const SpawnEvent& e) { inline std::ostream& operator<<(std::ostream& os, const ReadEvent& e) { os << "ReadEvent(var=\"" << e.var << "\", "; std::visit(overloaded{ - [&os](size_t val) { os << "value=" << val; }, + [&os](const ReadValue& val) { os << "value=" << val.value << " (from " << val.source_event->eid << ")"; }, [&os](const std::shared_ptr&) { os << "conflict"; } }, e.value_or_conflict); os << ")"; @@ -141,8 +143,8 @@ private: return append(child_tid); } - std::shared_ptr on_read(const std::string text, const size_t value) { - return append(std::move(text), value); + std::shared_ptr on_read(const std::string text, ValueWithSource value) { + return append(std::move(text), ReadValue{value.value, value.source_event}); } std::shared_ptr on_read(const std::string text, std::shared_ptr conflict) { @@ -193,24 +195,4 @@ inline std::ostream& operator<<(std::ostream& os, const ThreadTrace& tt) { return os; } -} // namespace gitmem - -// template -// std::shared_ptr thread_append_node(ThreadContext &ctx, Args &&...args) { -// assert(ctx.tail); -// auto node = std::make_shared(std::forward(args)...); -// ctx.tail->next = node; -// ctx.tail = node; -// return node; -// } - -// template <> -// std::shared_ptr -// thread_append_node(ThreadContext &ctx, std::string &&stmt) { -// // pending nodes don't update the tail position as we will destroy them -// // once we execute the node -// auto s = std::regex_replace(stmt, std::regex("\n"), "\\l "); -// auto node = make_shared(std::move(s)); -// ctx.tail->next = node; -// return node; -// } \ No newline at end of file +} // namespace gitmem \ No newline at end of file From 649a0eebf73293bf4bbbab31803ea7a70106d787 Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Fri, 16 Jan 2026 16:14:19 +0100 Subject: [PATCH 54/58] Fixing up read from edges in branching --- src/branching/base_sync_protocol.cc | 4 ++-- src/branching/base_version_store.cc | 8 ++++---- src/branching/base_version_store.hh | 8 ++++---- src/read_result.hh | 4 +--- 4 files changed, 11 insertions(+), 13 deletions(-) diff --git a/src/branching/base_sync_protocol.cc b/src/branching/base_sync_protocol.cc index 42dd721..c15c82f 100644 --- a/src/branching/base_sync_protocol.cc +++ b/src/branching/base_sync_protocol.cc @@ -32,7 +32,7 @@ ReadResult BranchingSyncProtocolBase::read(ThreadContext &ctx, return std::visit(overloaded{ [](std::monostate) -> ReadResult { return std::monostate{}; }, - [](const Value& v) -> ReadResult { return ValueWithSource{v, nullptr}; }, + [](const ValueWithSource& v) -> ReadResult { return v; }, [&](const Conflict& c) -> ReadResult { return std::make_shared( c.obj, std::pair{c.timestamp_a, c.timestamp_b} @@ -45,7 +45,7 @@ ReadResult BranchingSyncProtocolBase::read(ThreadContext &ctx, void BranchingSyncProtocolBase::write(ThreadContext &ctx, const std::string &var, ValueWithSource value) { auto& store = get_store(ctx); - store.stage(var, value.value); // Branching protocol doesn't track sources yet + store.stage(var, value); } std::optional> diff --git a/src/branching/base_version_store.cc b/src/branching/base_version_store.cc index f44e46d..e3d0360 100644 --- a/src/branching/base_version_store.cc +++ b/src/branching/base_version_store.cc @@ -26,7 +26,7 @@ void print_commit_recursive(std::ostream& os, // Print changes for (const auto& [obj, val] : commit->changes) { - os << std::string((depth + 1) * 2, ' ') << obj << " -> " << val << "\n"; + os << std::string((depth + 1) * 2, ' ') << obj << " -> " << val.value << "\n"; } // Print parents @@ -55,7 +55,7 @@ std::ostream& operator<<(std::ostream& os, const Commit& commit) { return os; } -void LocalVersionStore::stage(std::string obj, Value value) { +void LocalVersionStore::stage(std::string obj, ValueWithSource value) { staging[obj] = value; } @@ -116,7 +116,7 @@ std::ostream& operator<<(std::ostream& os, const LocalVersionStore& store) { for (const auto& [obj, val] : store.staging) { if (!first) os << ", "; first = false; - os << obj << "->" << val; + os << obj << "->" << val.value; } os << "}}"; @@ -211,7 +211,7 @@ std::string build_commit_graph_dot(const std::vectorchanges) { if (!first) label << "\\n"; first = false; - label << obj << "→" << val; + label << obj << "→" << val.value; } } diff --git a/src/branching/base_version_store.hh b/src/branching/base_version_store.hh index 08b5f72..c62ff9f 100644 --- a/src/branching/base_version_store.hh +++ b/src/branching/base_version_store.hh @@ -49,7 +49,7 @@ inline std::string to_string(const Timestamp& ts) { struct Commit { Timestamp id; - std::unordered_map changes; + std::unordered_map changes; std::vector> parents; bool conflicted = false; }; @@ -66,7 +66,7 @@ struct Conflict { Timestamp timestamp_b; }; -using BranchingReadResult = std::variant; +using BranchingReadResult = std::variant; inline std::ostream& operator<<(std::ostream& os, const Conflict& c) { return os << "Conflict{obj=" << c.obj @@ -78,7 +78,7 @@ class LocalVersionStore : public ThreadSyncState { protected: Timestamp base_timestamp; std::shared_ptr head; - std::unordered_map staging; + std::unordered_map staging; std::unordered_map> last_writer; // cached @@ -89,7 +89,7 @@ public: LocalVersionStore(ThreadID tid, bool verbose = false): base_timestamp(tid, 0), verbose(verbose) {} - void stage(std::string obj, Value value); + void stage(std::string obj, ValueWithSource value); void commit_staging(); bool has_commited() { return staging.empty(); } diff --git a/src/read_result.hh b/src/read_result.hh index 3ea3722..cb6a9b6 100644 --- a/src/read_result.hh +++ b/src/read_result.hh @@ -8,10 +8,8 @@ namespace gitmem { struct Event; -using Value = size_t; - struct ValueWithSource { - Value value; + size_t value; std::shared_ptr source_event; auto operator<=>(const ValueWithSource&) const = default; From f6a5f19417e8aeed636560f81c03713ed0a5235d Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Fri, 6 Feb 2026 15:13:49 +0100 Subject: [PATCH 55/58] pushing a few examples and change for more comprehensive error --- .../accept/semantics/branching/singleton.gm | 41 +++++++++++++++++++ .../semantics/linear/conditional_non_race.gm | 8 ++-- examples/accept/semantics/linear/singleton.gm | 41 +++++++++++++++++++ src/model_checker.cc | 9 ++-- 4 files changed, 91 insertions(+), 8 deletions(-) create mode 100644 examples/accept/semantics/branching/singleton.gm create mode 100644 examples/accept/semantics/linear/singleton.gm diff --git a/examples/accept/semantics/branching/singleton.gm b/examples/accept/semantics/branching/singleton.gm new file mode 100644 index 0000000..b39cecc --- /dev/null +++ b/examples/accept/semantics/branching/singleton.gm @@ -0,0 +1,41 @@ +// Construct the singleton pattern: +// A shared variable should be initialised only once by any thread +// Once initialised all threads should read the same value +// Perform the initialisation using double checked locking +// Each thread: +// 1. Checks the variable +// 1a. If it is uninitialised (here we use 0), then take the lock +// 1b. Taking the lock may pull in other thread updates, so check if the +// variable is still uninitialised +// 1bi. If the variable is still uninitialised then we know we need to +// initialise it. So do that and store the initialised value in a local +// var. +// 1bii. Otherwise the variable was initialised and we can read the value +// 2a. Otherwise the variable was initialised and we can read the value +// 3. Check that in all code paths we read the expected initialised value + +instance = 0; + +$t1 = spawn { + if (instance == 0) { + lock l; + if (instance == 0) { + instance = 100; + } + unlock l; + } + assert(instance == 100); +}; + + +if (instance == 0) { + lock l; + if (instance == 0) { + instance = 100; + } + unlock l; +} +assert(instance == 100); + +join $t1; +assert(instance == 100); \ No newline at end of file diff --git a/examples/accept/semantics/linear/conditional_non_race.gm b/examples/accept/semantics/linear/conditional_non_race.gm index b8d4206..baf06f1 100644 --- a/examples/accept/semantics/linear/conditional_non_race.gm +++ b/examples/accept/semantics/linear/conditional_non_race.gm @@ -12,7 +12,7 @@ $t2 = spawn { }; join $t1; join $t2; -// FIXME in branching these assertions are always true, but not in linear -assert (x != 0); -assert (y != 0); -assert (x == y); + +// We don't know the interleaving of the two threads, but we know that they won't conflict, +// and that the reads are reading consistent values. +// We can get all of x = 0, y = 1; x = 1, y = 0; or x = 1, y = 1; but we won't get x = 0, y = 0. \ No newline at end of file diff --git a/examples/accept/semantics/linear/singleton.gm b/examples/accept/semantics/linear/singleton.gm new file mode 100644 index 0000000..b39cecc --- /dev/null +++ b/examples/accept/semantics/linear/singleton.gm @@ -0,0 +1,41 @@ +// Construct the singleton pattern: +// A shared variable should be initialised only once by any thread +// Once initialised all threads should read the same value +// Perform the initialisation using double checked locking +// Each thread: +// 1. Checks the variable +// 1a. If it is uninitialised (here we use 0), then take the lock +// 1b. Taking the lock may pull in other thread updates, so check if the +// variable is still uninitialised +// 1bi. If the variable is still uninitialised then we know we need to +// initialise it. So do that and store the initialised value in a local +// var. +// 1bii. Otherwise the variable was initialised and we can read the value +// 2a. Otherwise the variable was initialised and we can read the value +// 3. Check that in all code paths we read the expected initialised value + +instance = 0; + +$t1 = spawn { + if (instance == 0) { + lock l; + if (instance == 0) { + instance = 100; + } + unlock l; + } + assert(instance == 100); +}; + + +if (instance == 0) { + lock l; + if (instance == 0) { + instance = 100; + } + unlock l; +} +assert(instance == 100); + +join $t1; +assert(instance == 100); \ No newline at end of file diff --git a/src/model_checker.cc b/src/model_checker.cc index 69da1ee..5126c8f 100644 --- a/src/model_checker.cc +++ b/src/model_checker.cc @@ -177,14 +177,15 @@ int model_check(const Node ast, const std::filesystem::path &output_path, } } - verbose::out << "Found a total of " << final_traces.size() - << " trace(s) with distinct final states:" << std::endl; + std::cout << "Found a total of " << final_traces.size() + << " trace(s) with distinct final states" + << " (errors: " << failing_traces.size() + << ", no errors: " << final_traces.size() - failing_traces.size() << ")" + << std::endl; print_traces(verbose::out, final_traces); size_t idx = 0; if (!failing_traces.empty()) { - std::cout << "Found " << failing_traces.size() - << " trace(s) with errors:" << std::endl; print_traces(std::cout, failing_traces); for (const auto &ctx : failing_contexts) { From 3775b3a3696cb8e98e2ccbbacd067ee5244d0a58 Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Fri, 6 Feb 2026 15:14:22 +0100 Subject: [PATCH 56/58] another example --- .../semantics/branching/variable_under_two_locks.gm | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 examples/reject/semantics/branching/variable_under_two_locks.gm diff --git a/examples/reject/semantics/branching/variable_under_two_locks.gm b/examples/reject/semantics/branching/variable_under_two_locks.gm new file mode 100644 index 0000000..79fb341 --- /dev/null +++ b/examples/reject/semantics/branching/variable_under_two_locks.gm @@ -0,0 +1,11 @@ +x = 0; +$t1 = spawn { + lock l1; + x = 1; + unlock l1; +}; +lock l2; +x = 2; +unlock l2; +join $t1; +assert(x == x); \ No newline at end of file From a608a38bb8705bb98654314cb3a5d442b1fdc08f Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Fri, 6 Feb 2026 15:15:28 +0100 Subject: [PATCH 57/58] another example --- .../semantics/linear/variable_under_two_locks.gm | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 examples/reject/semantics/linear/variable_under_two_locks.gm diff --git a/examples/reject/semantics/linear/variable_under_two_locks.gm b/examples/reject/semantics/linear/variable_under_two_locks.gm new file mode 100644 index 0000000..79fb341 --- /dev/null +++ b/examples/reject/semantics/linear/variable_under_two_locks.gm @@ -0,0 +1,11 @@ +x = 0; +$t1 = spawn { + lock l1; + x = 1; + unlock l1; +}; +lock l2; +x = 2; +unlock l2; +join $t1; +assert(x == x); \ No newline at end of file From 471e74a88e190eb4727bdb54213a7b1a3ce4d60d Mon Sep 17 00:00:00 2001 From: Luke Cheeseman Date: Fri, 6 Feb 2026 16:57:23 +0100 Subject: [PATCH 58/58] up-to-date README --- README.md | 235 ++++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 209 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index 651bee4..3022a86 100644 --- a/README.md +++ b/README.md @@ -1,48 +1,231 @@ # gitmem -Experimental interpreter for creating execution diagrams for a new -concurrency model. +An experimental interpreter and model checker for exploring concurrent programs with Git-inspired memory semantics. Gitmem allows you to write multi-threaded programs and automatically explore all possible interleavings to detect data races, deadlocks, and assertion failures. -## Building and Running +## Overview + +Gitmem is a research tool that models concurrent memory operations using version control semantics. It provides: + +- **Multiple sync protocols**: Linear and branching (with eager/lazy variants) semantics for thread synchronization +- **Automatic model checking**: Explores all possible execution paths to find concurrency bugs +- **Interactive debugging**: Step through different thread schedules interactively +- **Execution visualization**: Generates GraphViz diagrams showing execution traces and revision graphs + +## Language Features -To build you need CMake and Ninja. CMake will fetch any other dependencies. +The gitmem language supports: -The following commands should set you up: +- **Shared variables**: `x = value` (global shared state) +- **Thread-local registers**: `$r = value` (prefixed with `$`) +- **Thread operations**: `spawn { ... }`, `join $thread` +- **Synchronization**: `lock var`, `unlock var` +- **Control flow**: `if (condition) { ... } else { ... }` +- **Assertions**: `assert(condition)` +- **Operators**: `==`, `!=`, `+` +### Example Program + +```gitmem +x = 0; +$t1 = spawn { + lock l; + x = x + 1; + unlock l; +}; +$t2 = spawn { + lock l; + x = x + 1; + unlock l; +}; +join $t1; +join $t2; +assert(x == 2); ``` + +## Building and Running + +### Prerequisites + +- CMake (3.14+) +- Ninja build system +- C++23 compatible compiler (Clang recommended) +- Python 3 (for tests) + +### Build Instructions + +```bash mkdir build cd build cmake -G Ninja .. -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_BUILD_TYPE=Debug ninja ``` -If you need, set the C standard with `-DCMAKE_CXX_STANDARD=20`. -If you are running a recent version of CMake, you may need -`-DCMAKE_POLICY_VERSION_MINIMUM=3.5`. +Optional CMake flags: +- `-DCMAKE_CXX_STANDARD=23` - Set C++ standard (if needed) +- `-DCMAKE_BUILD_TYPE=Release` - Build optimized version -You can test if the build was successful by running the following -command in the `build` directory: +### Quick Test -``` +Test your build with: + +```bash ./gitmem -e ../examples/race_condition.gm ``` -The build script creates two executables: +## Usage + +### Basic Execution + +Run a single execution trace: + +```bash +./gitmem examples/race_condition.gm +``` + +This generates a GraphViz `.dot` file showing the execution trace. + +### Model Checking Mode + +Explore all possible execution paths: + +```bash +./gitmem -e examples/race_condition.gm +``` + +Model checking will: +- Try all possible thread interleavings +- Report data races, deadlocks, and assertion failures +- Generate execution diagrams for failing traces +- Exit with status 0 if all paths succeed, non-zero if any fail + +### Interactive Mode + +Step through executions manually: + +```bash +./gitmem -i examples/singleton.gm +``` + +Commands in interactive mode: +- `?` - Show help +- `s` - Show current state +- `n` - Step one thread forward +- `r` - Run to next synchronization point -- `gitmem` parses source code and runs the interpreter in order to - create an execution diagram (work in progress). You can run the - interpreter interactively with the `-i` flag, and automatically - explore all possible traces with the `-e` flag (showing failing - runs). -- `gitmem_trieste` is the default - [Trieste](https://github.com/microsoft/Trieste) driver which can - be used to inspect the parsed source code and test the parser. - Running `gitmem_trieste build foo.gm` will create a file - `foo.trieste` with the parsed source code as an S-expression. +### Sync Protocols + +Gitmem supports different memory models: + +#### Linear (Default) +```bash +./gitmem --sync linear program.gm +``` +Traditional sequential consistency model. + +#### Branching (Eager) +```bash +./gitmem --sync branching --branching-mode eager program.gm +``` +Git-like branching semantics where threads create branches that merge at synchronization points. Conflicts are detected eagerly. + +#### Branching (Lazy) +```bash +./gitmem --sync branching --branching-mode lazy program.gm +``` +Lazy conflict detection variant that defers checking until synchronization. + +Additional flags: +- `--include-empty-commits` - Include empty commits in branching output +- `--raise-early-conflicts` - Raise conflicts before write suppression (lazy mode only) +- `-v, --verbose` - Enable verbose interpreter output +- `-o, --output ` - Specify output file path + +## Testing + +Run the test suite: + +```bash +# From build directory +ninja run_gitmem_tests + +# Or using CTest +ctest +``` + +The test suite includes: +- **Accept tests**: Programs that should execute successfully +- **Reject tests**: Programs with errors (deadlocks, races, assertion failures) +- Tests for both linear and branching semantics + +## Project Structure + +``` +src/ + ├── gitmem.cc - Main entry point + ├── lang.hh - Language token definitions + ├── parser.cc - Parser implementation + ├── interpreter.cc - Interpreter core + ├── model_checker.cc - Model checking engine + ├── debugger.cc - Interactive debugger + ├── execution_state.hh - Thread and memory state + ├── sync_protocol.hh - Sync protocol interface + ├── linear/ - Linear sync protocol + └── branching/ - Branching sync protocols + ├── base_sync_protocol.cc + ├── eager/ - Eager conflict detection + └── lazy/ - Lazy conflict detection + +examples/ + ├── accept/semantics/ - Valid programs + │ ├── linear/ - Linear semantics tests + │ └── branching/ - Branching semantics tests + └── reject/semantics/ - Programs with bugs + ├── linear/ - Deadlocks, races for linear + └── branching/ - Bugs for branching semantics +``` + +## Executables + +The build produces two binaries: + +### `gitmem` +The main interpreter and model checker. Executes programs and generates execution diagrams. + +### `gitmem_trieste` +Parser diagnostic tool built on [Trieste](https://github.com/microsoft/Trieste). Use it to inspect the AST: + +```bash +./gitmem_trieste build program.gm +# Creates program.trieste with S-expression AST +``` ## VSCode Extension -You should be able to use `Developer: Install Extension from -Location` in the VSCode command palette to install a rudimentary -extension in the `gitmem-extension` directory and get syntax -highlighting.. +A syntax highlighting extension is available in `gitmem-extension/`. + +Install via: +1. Open VSCode Command Palette (`Cmd+Shift+P` / `Ctrl+Shift+P`) +2. Run: `Developer: Install Extension from Location` +3. Select the `gitmem-extension` directory + +This provides syntax highlighting for `.gm` files. + +## Output + +Gitmem generates GraphViz `.dot` files visualizing: + +- **Execution traces**: Thread operations and memory states +- **Revision graphs**: For branching semantics, shows branch/merge structure +- **Conflict detection**: Highlights data races and conflicts + +View `.dot` files with GraphViz: + +```bash +dot -Tpng output.dot -o output.png +``` + +## Exit Codes + +- `0` - All execution paths succeeded +- `1` - Assertion failure, deadlock, data race, or error detected +- Other non-zero codes indicate internal errors \ No newline at end of file