Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ TaskGraphExecutionTrace simulate_task_graph_execution(
"simulate_task_graph_execution cannot simulate cyclic directed graphs");
}

TaskGraphExecutionState execution_state =
TaskGraphExecutionState{/*ready_tasks=*/set_of(get_sources(task_graph)),
/*in_progress_tasks=*/{},
/*finished_tasks=*/{},
/*current_time=*/0.0};
TaskGraphExecutionState execution_state = TaskGraphExecutionState{
/*ready_tasks=*/set_of(get_initial_nodes(task_graph)),
/*in_progress_tasks=*/{},
/*finished_tasks=*/{},
/*current_time=*/0.0};

std::unordered_set<TaskProfile> task_profiles;

Expand Down
11 changes: 6 additions & 5 deletions lib/pcg/src/pcg/machine_view.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
#include "utils/containers/scanl.h"
#include "utils/containers/sum.h"
#include "utils/containers/transform.h"
#include "utils/containers/zip.h"
#include "utils/containers/zip3_strict.h"
#include "utils/containers/zip_with_strict.h"
#include "utils/exception.h"
#include "utils/nonnegative_int/nonnegative_range.h"
#include "utils/nonnegative_int/num_elements.h"
Expand Down Expand Up @@ -52,9 +53,9 @@ MachineView machine_view_from_strides_and_machine_spec_dimensions(
start,
strides));
}
std::vector<MachineViewDimension> dimensions =
transform(zip(strides, dims), [&](auto const &p) {
return MachineViewDimension{p.first, p.second};
std::vector<MachineViewDimension> dimensions = zip_with_strict(
strides, dims, [](stride_t s, MachineSpecificationDimension d) {
return MachineViewDimension{s, d};
});
return MachineView{start, dimensions};
}
Expand Down Expand Up @@ -109,7 +110,7 @@ std::optional<MachineSpaceCoordinate> get_machine_space_coordinate(

nonnegative_int index = start_idx;
for (auto [coeff, coord_point, stride] :
zip(coeffs, coord_points, strides)) {
zip3(coeffs, coord_points, strides)) {
index += coeff * coord_point * stride;
}
return index;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ std::unordered_set<ParallelComputationGraphEdge>

std::unordered_set<parallel_layer_guid_t>
get_initial_layers(ParallelComputationGraph const &pcg) {
std::unordered_set<Node> raw_sources = get_sources(pcg.raw_graph);
std::unordered_set<Node> raw_sources = get_initial_nodes(pcg.raw_graph);
return transform(raw_sources,
[](Node const &n) { return parallel_layer_guid_t{n}; });
}
Expand Down
2 changes: 1 addition & 1 deletion lib/utils/include/utils/commutative_pair.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ template <typename T>
struct commutative_pair {
public:
commutative_pair() = delete;
commutative_pair(T const &x, T const &y) : first(x), second(y) {}
explicit commutative_pair(T const &x, T const &y) : first(x), second(y) {}

bool operator==(commutative_pair const &other) const {
return this->tie() == other.tie() || this->rtie() == other.tie();
Expand Down
7 changes: 7 additions & 0 deletions lib/utils/include/utils/containers/find.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FIND_H

#include <algorithm>
#include <unordered_set>

namespace FlexFlow {

Expand All @@ -11,6 +12,12 @@ typename Container::const_iterator
return std::find(c.cbegin(), c.cend(), e);
}

template <typename V>
typename std::unordered_set<V>::const_iterator
find(std::unordered_set<V> const &c, V const &e) {
return c.find(e);
}

} // namespace FlexFlow

#endif
13 changes: 1 addition & 12 deletions lib/utils/include/utils/containers/zip.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_H
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_H

#include <tuple>
#include <set>
#include <utility>
#include <vector>

Expand All @@ -17,17 +17,6 @@ std::vector<std::pair<L, R>> zip(std::vector<L> const &l,
return result;
}

template <typename A, typename B, typename C>
std::vector<std::tuple<A, B, C>> zip(std::vector<A> const &a,
std::vector<B> const &b,
std::vector<C> const &c) {
std::vector<std::tuple<A, B, C>> result;
for (int i = 0; i < std::min({a.size(), b.size(), c.size()}); i++) {
result.push_back(std::make_tuple(a.at(i), b.at(i), c.at(i)));
}
return result;
}

} // namespace FlexFlow

#endif
24 changes: 24 additions & 0 deletions lib/utils/include/utils/containers/zip3.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP3_H
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP3_H

#include <algorithm>
#include <set>
#include <tuple>
#include <vector>

namespace FlexFlow {

template <typename A, typename B, typename C>
std::vector<std::tuple<A, B, C>> zip3(std::vector<A> const &a,
std::vector<B> const &b,
std::vector<C> const &c) {
std::vector<std::tuple<A, B, C>> result;
for (int i = 0; i < std::min({a.size(), b.size(), c.size()}); i++) {
result.push_back(std::make_tuple(a.at(i), b.at(i), c.at(i)));
}
return result;
}

} // namespace FlexFlow

#endif
31 changes: 31 additions & 0 deletions lib/utils/include/utils/containers/zip3_strict.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP3_STRICT_H
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP3_STRICT_H

#include "utils/containers/zip3.h"
#include "utils/exception.h"
#include "utils/fmt/vector.h"

namespace FlexFlow {

template <typename A, typename B, typename C>
std::vector<std::tuple<A, B, C>> zip3_strict(std::vector<A> const &as,
std::vector<B> const &bs,
std::vector<C> const &cs) {
if (!(as.size() == bs.size() && bs.size() == cs.size())) {
throw mk_runtime_error(fmt::format(
"zip3_strict requires as, bs, and cs to have the same length, but "
"received as={} (length {}), bs={} (length {}), and cs={} (length {})",
as,
as.size(),
bs,
bs.size(),
cs,
cs.size()));
}

return zip3(as, bs, cs);
}

} // namespace FlexFlow

#endif
28 changes: 28 additions & 0 deletions lib/utils/include/utils/containers/zip_strict.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_STRICT_H
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_STRICT_H

#include "utils/containers/zip.h"
#include "utils/exception.h"
#include "utils/fmt/vector.h"

namespace FlexFlow {

template <typename L, typename R>
std::vector<std::pair<L, R>> zip_strict(std::vector<L> const &lhs,
std::vector<R> const &rhs) {
if (lhs.size() != rhs.size()) {
throw mk_runtime_error(
fmt::format("zip_strict requires lhs and rhs to have the same length, "
"but received lhs={} (length {}), rhs={} (length {})",
lhs,
lhs.size(),
rhs,
rhs.size()));
}

return zip(lhs, rhs);
}

} // namespace FlexFlow

#endif
24 changes: 24 additions & 0 deletions lib/utils/include/utils/containers/zip_with.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_WITH_H
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_WITH_H

#include <vector>

namespace FlexFlow {

template <typename T1,
typename T2,
typename F,
typename Result = std::invoke_result_t<F, T1, T2>>
std::vector<Result>
zip_with(std::vector<T1> const &l, std::vector<T2> const &r, F &&f) {
std::vector<Result> result;
for (int i = 0; i < l.size() && i < r.size(); i++) {
result.push_back(f(l.at(i), r.at(i)));
}

return result;
}

} // namespace FlexFlow

#endif
33 changes: 33 additions & 0 deletions lib/utils/include/utils/containers/zip_with_strict.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_WITH_STRICT_H
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_WITH_STRICT_H

#include "utils/containers/zip_with.h"
#include "utils/exception.h"
#include "utils/fmt/vector.h"
#include <vector>

namespace FlexFlow {

template <typename T1,
typename T2,
typename F,
typename Result = std::invoke_result_t<F, T1, T2>>
std::vector<Result> zip_with_strict(std::vector<T1> const &lhs,
std::vector<T2> const &rhs,
F &&f) {
if (lhs.size() != rhs.size()) {
throw mk_runtime_error(fmt::format(
"zip_with_strict requires inputs to have the same length, but received "
"lhs = {} (length {}) and rhs = {} (length {})",
lhs,
lhs.size(),
rhs,
rhs.size()));
}

return zip_with(lhs, rhs, f);
}

} // namespace FlexFlow

#endif
42 changes: 42 additions & 0 deletions lib/utils/include/utils/fmt/tuple.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_TUPLE_H
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_TUPLE_H

#include "utils/check_fmtable.h"
#include "utils/join_strings.h"
#include "utils/tuple/visit.h"
#include <cassert>
#include <fmt/format.h>
#include <tuple>
#include <vector>

namespace fmt {

template <typename... Ts, typename Char>
struct formatter<std::tuple<Ts...>, Char> : formatter<std::string> {

template <typename FormatContext>
auto format(std::tuple<Ts...> const &t, FormatContext &ctx) const
-> decltype(ctx.out()) {

std::vector<std::string> stringified_elements;
::FlexFlow::visit_tuple(t, [&](auto const &element) -> void {
stringified_elements.push_back(fmt::to_string(element));
});

return formatter<std::string>::format(
"{" + ::FlexFlow::join_strings(stringified_elements, ", ") + "}", ctx);
}
};

} // namespace fmt

namespace FlexFlow {

template <typename... Ts>
std::ostream &operator<<(std::ostream &s, std::tuple<Ts...> const &t) {
return (s << fmt::to_string(t));
}

} // namespace FlexFlow

#endif
Loading
Loading