From aa1f91a462c59a55fc86ddb9007cb1c7ebf7635a Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sun, 20 Oct 2024 21:13:29 -0700 Subject: [PATCH 01/62] Add utils/orthotope targeting #1528 --- flake.lock | 6 +- .../include/compiler/allowed_machine_views.h | 2 +- .../cost_estimator/network_cost_model.h | 14 ++ .../src/compiler/allowed_machine_views.cc | 2 +- .../cost_estimator/network_cost_model.cc | 13 ++ .../include/op-attrs/dim_ordered/enumerate.h | 4 +- .../include/op-attrs/dim_ordered/get_idxs.h | 4 +- ..._parallel_tensor_space_mapping.struct.toml | 20 +++ .../include/op-attrs}/operator_task_space.h | 8 +- .../op-attrs}/operator_task_space.struct.toml | 0 .../op-attrs/parallel_tensor_dim_degrees.h | 18 +++ .../parallel_tensor_dim_idx_t.variant.toml | 3 + .../parallel_tensor_space_coordinate.h | 14 ++ ...rallel_tensor_space_coordinate.struct.toml | 25 ++++ .../task_space_coordinate.struct.toml | 0 .../dim_ordered/ff_ordered_from_map.cc | 13 ++ .../src/op-attrs/dim_ordered/get_idxs.cc | 10 ++ .../src/op-attrs}/operator_task_space.cc | 2 +- .../op-attrs/parallel_tensor_dim_degrees.cc | 49 +++++++ .../parallel_tensor_space_coordinate.cc | 20 +++ .../test/src/op-attrs}/operator_task_space.cc | 2 +- .../op-attrs/parallel_tensor_dim_degrees.cc | 79 +++++++++++ lib/pcg/include/pcg/machine_view.h | 8 +- .../pcg/start_invariant_machine_view.h | 4 +- lib/pcg/src/pcg/machine_view.cc | 18 +-- .../src/pcg/start_invariant_machine_view.cc | 2 +- .../utils/archetypes/ordered_value_type.h | 52 ++++++++ lib/utils/include/utils/containers/all_of.h | 4 + lib/utils/include/utils/containers/count.h | 2 - .../include/utils/containers/filter_idxs.h | 24 ++++ lib/utils/include/utils/containers/group_by.h | 25 +++- .../include/utils/containers/is_subseteq_of.h | 16 +++ lib/utils/include/utils/containers/scanr.h | 77 +++++++++++ lib/utils/include/utils/containers/uncurry.h | 18 +++ .../utils/containers/vector_from_idx_map.h | 27 ++++ lib/utils/include/utils/containers/zip_with.h | 20 +++ ...ordered_set_labelled_open_dataflow_graph.h | 4 +- lib/utils/include/utils/orthotope/orthotope.h | 17 +++ .../utils/orthotope/orthotope.struct.toml | 20 +++ .../utils/orthotope/orthotope_coordinate.h | 16 +++ .../orthotope_coordinate.struct.toml | 20 +++ .../utils/orthotope/orthotope_dim_idx_t.h | 13 ++ .../orthotope/orthotope_dim_idx_t.struct.toml | 12 ++ .../orthotope_surjective_projection.h | 30 +++++ ...rthotope_surjective_projection.struct.toml | 25 ++++ lib/utils/include/utils/orthotope/orthtope.h | 10 ++ .../utils/archetypes/ordered_value_type.cc | 7 + lib/utils/src/utils/cli/cli_spec.cc | 4 +- lib/utils/src/utils/containers/all_of.cc | 14 ++ lib/utils/src/utils/containers/count.cc | 12 -- lib/utils/src/utils/containers/filter_idxs.cc | 11 ++ lib/utils/src/utils/containers/group_by.cc | 25 ++++ .../src/utils/containers/is_subseteq_of.cc | 14 ++ lib/utils/src/utils/containers/uncurry.cc | 14 ++ .../utils/containers/vector_from_idx_map.cc | 11 ++ lib/utils/src/utils/containers/zip_with.cc | 14 ++ .../instances/unordered_set_dataflow_graph.cc | 4 +- lib/utils/src/utils/orthotope/orthotope.cc | 16 +++ .../utils/orthotope/orthotope_coordinate.cc | 26 ++++ .../utils/orthotope/orthotope_dim_idx_t.cc | 12 ++ .../orthotope_surjective_projection.cc | 126 ++++++++++++++++++ .../test/src/utils/containers/filter_idxs.cc | 17 +++ .../test/src/utils/containers/group_by.cc | 46 +++++++ .../test/src/utils/containers/uncurry.cc | 27 ++++ .../test/src/utils/containers/zip_with.cc | 73 ++++++++++ .../test/src/utils/orthotope/orthotope.cc | 96 +++++++++++++ .../orthotope_surjective_projection.cc | 50 +++++++ 67 files changed, 1339 insertions(+), 52 deletions(-) create mode 100644 lib/compiler/include/compiler/cost_estimator/network_cost_model.h create mode 100644 lib/compiler/src/compiler/cost_estimator/network_cost_model.cc create mode 100644 lib/op-attrs/include/op-attrs/operator_space_to_parallel_tensor_space_mapping.struct.toml rename lib/{pcg/include/pcg => op-attrs/include/op-attrs}/operator_task_space.h (62%) rename lib/{pcg/include/pcg => op-attrs/include/op-attrs}/operator_task_space.struct.toml (100%) create mode 100644 lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.h create mode 100644 lib/op-attrs/include/op-attrs/parallel_tensor_space_coordinate.h create mode 100644 lib/op-attrs/include/op-attrs/parallel_tensor_space_coordinate.struct.toml rename lib/{pcg/include/pcg => op-attrs/include/op-attrs}/task_space_coordinate.struct.toml (100%) rename lib/{pcg/src/pcg => op-attrs/src/op-attrs}/operator_task_space.cc (96%) create mode 100644 lib/op-attrs/src/op-attrs/parallel_tensor_dim_degrees.cc create mode 100644 lib/op-attrs/src/op-attrs/parallel_tensor_space_coordinate.cc rename lib/{pcg/test/src/pcg => op-attrs/test/src/op-attrs}/operator_task_space.cc (98%) create mode 100644 lib/op-attrs/test/src/op-attrs/parallel_tensor_dim_degrees.cc create mode 100644 lib/utils/include/utils/archetypes/ordered_value_type.h create mode 100644 lib/utils/include/utils/containers/filter_idxs.h create mode 100644 lib/utils/include/utils/containers/scanr.h create mode 100644 lib/utils/include/utils/containers/uncurry.h create mode 100644 lib/utils/include/utils/containers/vector_from_idx_map.h create mode 100644 lib/utils/include/utils/containers/zip_with.h create mode 100644 lib/utils/include/utils/orthotope/orthotope.h create mode 100644 lib/utils/include/utils/orthotope/orthotope.struct.toml create mode 100644 lib/utils/include/utils/orthotope/orthotope_coordinate.h create mode 100644 lib/utils/include/utils/orthotope/orthotope_coordinate.struct.toml create mode 100644 lib/utils/include/utils/orthotope/orthotope_dim_idx_t.h create mode 100644 lib/utils/include/utils/orthotope/orthotope_dim_idx_t.struct.toml create mode 100644 lib/utils/include/utils/orthotope/orthotope_surjective_projection.h create mode 100644 lib/utils/include/utils/orthotope/orthotope_surjective_projection.struct.toml create mode 100644 lib/utils/include/utils/orthotope/orthtope.h create mode 100644 lib/utils/src/utils/archetypes/ordered_value_type.cc create mode 100644 lib/utils/src/utils/containers/filter_idxs.cc create mode 100644 lib/utils/src/utils/containers/uncurry.cc create mode 100644 lib/utils/src/utils/containers/vector_from_idx_map.cc create mode 100644 lib/utils/src/utils/containers/zip_with.cc create mode 100644 lib/utils/src/utils/orthotope/orthotope.cc create mode 100644 lib/utils/src/utils/orthotope/orthotope_coordinate.cc create mode 100644 lib/utils/src/utils/orthotope/orthotope_dim_idx_t.cc create mode 100644 lib/utils/src/utils/orthotope/orthotope_surjective_projection.cc create mode 100644 lib/utils/test/src/utils/containers/filter_idxs.cc create mode 100644 lib/utils/test/src/utils/containers/group_by.cc create mode 100644 lib/utils/test/src/utils/containers/uncurry.cc create mode 100644 lib/utils/test/src/utils/containers/zip_with.cc create mode 100644 lib/utils/test/src/utils/orthotope/orthotope.cc create mode 100644 lib/utils/test/src/utils/orthotope/orthotope_surjective_projection.cc diff --git a/flake.lock b/flake.lock index 87fae7f446..f5da93c18c 100644 --- a/flake.lock +++ b/flake.lock @@ -43,11 +43,11 @@ ] }, "locked": { - "lastModified": 1728341842, - "narHash": "sha256-XMS52KBSS6z3k2VaiVcHyZQD6b2QUm1wIvTClel4xwg=", + "lastModified": 1728582984, + "narHash": "sha256-9WHOLLqUdSaTjzeL8SaWpk/C2DHveKlQm5p/lWfH8dg=", "owner": "lockshaw", "repo": "proj", - "rev": "830fb5b1a0c7087752693990e90bbbf021168dfe", + "rev": "317f3253f93ae570e47119ddcbf8fdd3736ee5b5", "type": "github" }, "original": { diff --git a/lib/compiler/include/compiler/allowed_machine_views.h b/lib/compiler/include/compiler/allowed_machine_views.h index 9bb73fd1a9..e53f7329b4 100644 --- a/lib/compiler/include/compiler/allowed_machine_views.h +++ b/lib/compiler/include/compiler/allowed_machine_views.h @@ -3,7 +3,7 @@ #include "pcg/machine_specification.dtg.h" #include "pcg/machine_view.dtg.h" -#include "pcg/operator_task_space.dtg.h" +#include "op-attrs/operator_task_space.dtg.h" namespace FlexFlow { diff --git a/lib/compiler/include/compiler/cost_estimator/network_cost_model.h b/lib/compiler/include/compiler/cost_estimator/network_cost_model.h new file mode 100644 index 0000000000..4bd8be18c1 --- /dev/null +++ b/lib/compiler/include/compiler/cost_estimator/network_cost_model.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_COST_ESTIMATOR_NETWORK_COST_MODEL_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_COST_ESTIMATOR_NETWORK_COST_MODEL_H + +#include "compiler/cost_estimator/tensor_set_movement.dtg.h" +#include "pcg/machine_specification.dtg.h" + +namespace FlexFlow { + +float estimate_communication_cost(MachineSpecification const &, + TensorSetMovement const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/src/compiler/allowed_machine_views.cc b/lib/compiler/src/compiler/allowed_machine_views.cc index 1c226f79b0..75d1d8f6e3 100644 --- a/lib/compiler/src/compiler/allowed_machine_views.cc +++ b/lib/compiler/src/compiler/allowed_machine_views.cc @@ -2,7 +2,7 @@ #include "pcg/machine_specification.h" #include "pcg/machine_view.h" #include "pcg/multi_dimensional_stride.dtg.h" -#include "pcg/operator_task_space.h" +#include "op-attrs/operator_task_space.h" #include "utils/containers/all_of.h" #include "utils/containers/cartesian_product.h" #include "utils/containers/extend.h" diff --git a/lib/compiler/src/compiler/cost_estimator/network_cost_model.cc b/lib/compiler/src/compiler/cost_estimator/network_cost_model.cc new file mode 100644 index 0000000000..c603a7a02b --- /dev/null +++ b/lib/compiler/src/compiler/cost_estimator/network_cost_model.cc @@ -0,0 +1,13 @@ +#include "compiler/cost_estimator/network_cost_model.h" + +namespace FlexFlow { + +float estimate_communication_cost(MachineSpecification const &machine_spec, + TensorSetMovement const &tensor_set_movement) { + + // for (SingleTensorMovement const &single_tensor_movement : tensor_set_movement.single_tensor_movements) { + // for + // } +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/enumerate.h b/lib/op-attrs/include/op-attrs/dim_ordered/enumerate.h index 38e7da4bb2..6f85e7bab3 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/enumerate.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/enumerate.h @@ -3,7 +3,7 @@ #include "op-attrs/dim_ordered/dim_ordered.h" #include "utils/bidict/bidict.h" -#include "utils/containers/count.h" +#include "utils/containers/range.h" namespace FlexFlow { @@ -18,7 +18,7 @@ namespace FlexFlow { template std::map enumerate(FFOrdered const &ff_ordered) { std::map result; - for (int raw_ff_dim : count(ff_ordered.size())) { + for (int raw_ff_dim : range(ff_ordered.size())) { ff_dim_t ff_dim = ff_dim_t{raw_ff_dim}; result.insert({ff_dim, ff_ordered.at(ff_dim)}); } diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/get_idxs.h b/lib/op-attrs/include/op-attrs/dim_ordered/get_idxs.h index 7343dc0e69..6e55e8e22a 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/get_idxs.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/get_idxs.h @@ -2,14 +2,14 @@ #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_GET_IDXS_H #include "op-attrs/dim_ordered/dim_ordered.h" -#include "utils/containers/count.h" +#include "utils/containers/range.h" #include "utils/containers/transform.h" namespace FlexFlow { template std::vector get_idxs(FFOrdered const &d) { - return transform(count(d.size()), [](int i) { return ff_dim_t{i}; }); + return transform(range(d.size()), [](int i) { return ff_dim_t{i}; }); } } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/operator_space_to_parallel_tensor_space_mapping.struct.toml b/lib/op-attrs/include/op-attrs/operator_space_to_parallel_tensor_space_mapping.struct.toml new file mode 100644 index 0000000000..004ba0b7d8 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/operator_space_to_parallel_tensor_space_mapping.struct.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "OperatorSpaceToParallelTensorSpaceMapping" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", +] + +src_includes = [ + "utils/hash/unordered_map.h", + "utils/fmt/unordered_map.h", +] + +[[fields]] +name = "raw_mapping" +type = "std::unordered_map" diff --git a/lib/pcg/include/pcg/operator_task_space.h b/lib/op-attrs/include/op-attrs/operator_task_space.h similarity index 62% rename from lib/pcg/include/pcg/operator_task_space.h rename to lib/op-attrs/include/op-attrs/operator_task_space.h index 61cab4eff1..4250fb9cf7 100644 --- a/lib/pcg/include/pcg/operator_task_space.h +++ b/lib/op-attrs/include/op-attrs/operator_task_space.h @@ -1,8 +1,8 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_OPERATOR_TASK_SPACE_H -#define _FLEXFLOW_PCG_INCLUDE_OPERATOR_TASK_SPACE_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_TASK_SPACE_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_TASK_SPACE_H -#include "pcg/operator_task_space.dtg.h" -#include "pcg/task_space_coordinate.dtg.h" +#include "op-attrs/operator_task_space.dtg.h" +#include "op-attrs/task_space_coordinate.dtg.h" #include #include diff --git a/lib/pcg/include/pcg/operator_task_space.struct.toml b/lib/op-attrs/include/op-attrs/operator_task_space.struct.toml similarity index 100% rename from lib/pcg/include/pcg/operator_task_space.struct.toml rename to lib/op-attrs/include/op-attrs/operator_task_space.struct.toml diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.h b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.h new file mode 100644 index 0000000000..d95a717695 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIM_DEGREES_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIM_DEGREES_H + +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" +#include "op-attrs/parallel_tensor_dim_idx_t.dtg.h" +#include "op-attrs/parallel_tensor_space_coordinate.dtg.h" + +namespace FlexFlow { + +std::unordered_map + get_parallel_tensor_degree_map(ParallelTensorDimDegrees const &); + +std::unordered_set + get_parallel_tensor_space_coordinates(ParallelTensorDimDegrees const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.variant.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.variant.toml index 9396cbcbe8..86833b9935 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.variant.toml +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.variant.toml @@ -15,6 +15,9 @@ includes = [ [[values]] type = "::FlexFlow::ff_dim_t" +key = "shard_dim" + [[values]] type = "::FlexFlow::ReplicaType" +key = "replica_dim" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_space_coordinate.h b/lib/op-attrs/include/op-attrs/parallel_tensor_space_coordinate.h new file mode 100644 index 0000000000..d0da8033c1 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_space_coordinate.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SPACE_COORDINATE_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SPACE_COORDINATE_H + +#include "op-attrs/parallel_tensor_dim_idx_t.dtg.h" +#include "op-attrs/parallel_tensor_space_coordinate.dtg.h" + +namespace FlexFlow { + +ParallelTensorSpaceCoordinate + parallel_tensor_space_coord_from_map(std::unordered_map const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_space_coordinate.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_space_coordinate.struct.toml new file mode 100644 index 0000000000..359c4b96a9 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_space_coordinate.struct.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "ParallelTensorSpaceCoordinate" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "op-attrs/dim_ordered/dim_ordered.h", +] + +[[fields]] +name = "sum_idx" +type = "int" + +[[fields]] +name = "discard_copy_idx" +type = "int" + +[[fields]] +name = "shard_idxs" +type = "::FlexFlow::FFOrdered" diff --git a/lib/pcg/include/pcg/task_space_coordinate.struct.toml b/lib/op-attrs/include/op-attrs/task_space_coordinate.struct.toml similarity index 100% rename from lib/pcg/include/pcg/task_space_coordinate.struct.toml rename to lib/op-attrs/include/op-attrs/task_space_coordinate.struct.toml diff --git a/lib/op-attrs/src/op-attrs/dim_ordered/ff_ordered_from_map.cc b/lib/op-attrs/src/op-attrs/dim_ordered/ff_ordered_from_map.cc index 2de88f38c8..267c64ef07 100644 --- a/lib/op-attrs/src/op-attrs/dim_ordered/ff_ordered_from_map.cc +++ b/lib/op-attrs/src/op-attrs/dim_ordered/ff_ordered_from_map.cc @@ -1 +1,14 @@ #include "op-attrs/dim_ordered/ff_ordered_from_map.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using T = value_type<0>; + +template + FFOrdered ff_ordered_from_map(std::map const &); + +template + FFOrdered ff_ordered_from_map(std::unordered_map const &); + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/dim_ordered/get_idxs.cc b/lib/op-attrs/src/op-attrs/dim_ordered/get_idxs.cc index 175ae8d4bd..baeb130324 100644 --- a/lib/op-attrs/src/op-attrs/dim_ordered/get_idxs.cc +++ b/lib/op-attrs/src/op-attrs/dim_ordered/get_idxs.cc @@ -1 +1,11 @@ #include "op-attrs/dim_ordered/get_idxs.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using T = value_type<0>; + +template + std::vector get_idxs(FFOrdered const &); + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/operator_task_space.cc b/lib/op-attrs/src/op-attrs/operator_task_space.cc similarity index 96% rename from lib/pcg/src/pcg/operator_task_space.cc rename to lib/op-attrs/src/op-attrs/operator_task_space.cc index 02522ae411..755df7ad97 100644 --- a/lib/pcg/src/pcg/operator_task_space.cc +++ b/lib/op-attrs/src/op-attrs/operator_task_space.cc @@ -1,4 +1,4 @@ -#include "pcg/operator_task_space.h" +#include "op-attrs/operator_task_space.h" #include "utils/containers/cartesian_product.h" #include "utils/containers/maximum.h" #include "utils/containers/product.h" diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dim_degrees.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dim_degrees.cc new file mode 100644 index 0000000000..24d9683ab4 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dim_degrees.cc @@ -0,0 +1,49 @@ +#include "op-attrs/parallel_tensor_dim_degrees.h" +#include "op-attrs/dim_ordered/get_idxs.h" +#include "op-attrs/parallel_tensor_dim_idx_t.dtg.h" +#include "op-attrs/parallel_tensor_space_coordinate.h" +#include "utils/containers/generate_map.h" +#include "utils/containers/get_all_assignments.h" +#include "utils/containers/map_keys.h" +#include "utils/containers/map_values.h" +#include "utils/containers/merge_maps.h" +#include "utils/containers/range.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/containers/transform.h" + +namespace FlexFlow { + +std::unordered_map + get_parallel_tensor_degree_map(ParallelTensorDimDegrees const °rees) { + + std::unordered_map replica_dim_degrees = { + {parallel_tensor_dim_idx_t{ReplicaType::SUM}, degrees.sum_degree.value}, + {parallel_tensor_dim_idx_t{ReplicaType::DISCARD_COPY}, degrees.discard_copy_degree.value}, + }; + + std::unordered_map shard_dim_degrees = + generate_map(get_idxs(degrees.shard_degrees), + [&](ff_dim_t const &dim) { return degrees.shard_degrees.at(dim); }); + + return merge_maps( + replica_dim_degrees, + map_keys(shard_dim_degrees, [](ff_dim_t const &dim) { return parallel_tensor_dim_idx_t{dim}; })); +} + +std::unordered_set + get_parallel_tensor_space_coordinates(ParallelTensorDimDegrees const °rees) { + + std::unordered_map degree_map = get_parallel_tensor_degree_map(degrees); + + std::unordered_map< + parallel_tensor_dim_idx_t, + std::unordered_set> possible_per_dim_coords + = map_values(degree_map, [](int degree) { return unordered_set_of(range(degree)); }); + + return transform(get_all_assignments(possible_per_dim_coords), + [](std::unordered_map const &m) { + return parallel_tensor_space_coord_from_map(m); + }); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_space_coordinate.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_space_coordinate.cc new file mode 100644 index 0000000000..ec6b117b4e --- /dev/null +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_space_coordinate.cc @@ -0,0 +1,20 @@ +#include "op-attrs/parallel_tensor_space_coordinate.h" +#include "op-attrs/dim_ordered/ff_ordered_from_map.h" +#include "utils/containers/filtermap_keys.h" + +namespace FlexFlow { + +ParallelTensorSpaceCoordinate + parallel_tensor_space_coord_from_map(std::unordered_map const &m) { + + std::unordered_map shard_map = filtermap_keys + (m, [](parallel_tensor_dim_idx_t const &d) { return d.try_require_shard_dim(); }); + + return ParallelTensorSpaceCoordinate{ + /*sum_idx=*/m.at(parallel_tensor_dim_idx_t{ReplicaType::SUM}), + /*discard_copy_idx=*/m.at(parallel_tensor_dim_idx_t{ReplicaType::DISCARD_COPY}), + /*shard_idxs=*/ff_ordered_from_map(shard_map), + }; +} + +} // namespace FlexFlow diff --git a/lib/pcg/test/src/pcg/operator_task_space.cc b/lib/op-attrs/test/src/op-attrs/operator_task_space.cc similarity index 98% rename from lib/pcg/test/src/pcg/operator_task_space.cc rename to lib/op-attrs/test/src/op-attrs/operator_task_space.cc index 13198d9456..228a3b9d9e 100644 --- a/lib/pcg/test/src/pcg/operator_task_space.cc +++ b/lib/op-attrs/test/src/op-attrs/operator_task_space.cc @@ -1,4 +1,4 @@ -#include "pcg/operator_task_space.h" +#include "op-attrs/operator_task_space.h" #include "utils/fmt/unordered_set.h" #include diff --git a/lib/op-attrs/test/src/op-attrs/parallel_tensor_dim_degrees.cc b/lib/op-attrs/test/src/op-attrs/parallel_tensor_dim_degrees.cc new file mode 100644 index 0000000000..f45de6883a --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/parallel_tensor_dim_degrees.cc @@ -0,0 +1,79 @@ +#include "op-attrs/parallel_tensor_dim_degrees.h" +#include +#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/unordered_set.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_parallel_tensor_degree_map") { + ParallelTensorDimDegrees degrees = ParallelTensorDimDegrees{ + SumDegree{3}, + DiscardCopyDegree{1}, + FFOrdered{ + 1, + 2, + 1, + }, + }; + + std::unordered_map result = get_parallel_tensor_degree_map(degrees); + std::unordered_map correct = { + {parallel_tensor_dim_idx_t{ReplicaType::SUM}, 3}, + {parallel_tensor_dim_idx_t{ReplicaType::DISCARD_COPY}, 1}, + {parallel_tensor_dim_idx_t{ff_dim_t{0}}, 1}, + {parallel_tensor_dim_idx_t{ff_dim_t{1}}, 2}, + {parallel_tensor_dim_idx_t{ff_dim_t{2}}, 1}, + }; + + CHECK(result == correct); + } + + TEST_CASE("get_parallel_tensor_space_coordinates") { + ParallelTensorDimDegrees degrees = ParallelTensorDimDegrees{ + SumDegree{3}, + DiscardCopyDegree{1}, + FFOrdered{ + 1, + 2, + 1, + }, + }; + + std::unordered_set result = get_parallel_tensor_space_coordinates(degrees); + std::unordered_set correct = { + ParallelTensorSpaceCoordinate{ + /*sum_idx=*/0, + /*discard_copy_idx=*/0, + /*shard_idxs=*/FFOrdered{0, 0, 0}, + }, + ParallelTensorSpaceCoordinate{ + /*sum_idx=*/1, + /*discard_copy_idx=*/0, + /*shard_idxs=*/FFOrdered{0, 0, 0}, + }, + ParallelTensorSpaceCoordinate{ + /*sum_idx=*/2, + /*discard_copy_idx=*/0, + /*shard_idxs=*/FFOrdered{0, 0, 0}, + }, + ParallelTensorSpaceCoordinate{ + /*sum_idx=*/0, + /*discard_copy_idx=*/0, + /*shard_idxs=*/FFOrdered{0, 1, 0}, + }, + ParallelTensorSpaceCoordinate{ + /*sum_idx=*/1, + /*discard_copy_idx=*/0, + /*shard_idxs=*/FFOrdered{0, 1, 0}, + }, + ParallelTensorSpaceCoordinate{ + /*sum_idx=*/2, + /*discard_copy_idx=*/0, + /*shard_idxs=*/FFOrdered{0, 1, 0}, + }, + }; + + CHECK(result == correct); + } +} diff --git a/lib/pcg/include/pcg/machine_view.h b/lib/pcg/include/pcg/machine_view.h index 293227b7a1..376cb69c32 100644 --- a/lib/pcg/include/pcg/machine_view.h +++ b/lib/pcg/include/pcg/machine_view.h @@ -1,11 +1,11 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_MACHINE_VIEW_H #define _FLEXFLOW_PCG_INCLUDE_PCG_MACHINE_VIEW_H -#include "machine_specification.dtg.h" -#include "machine_view.dtg.h" +#include "pcg/machine_specification.dtg.h" +#include "pcg/machine_view.dtg.h" #include "pcg/device_id_t.dtg.h" -#include "pcg/operator_task_space.dtg.h" -#include "task_space_coordinate.dtg.h" +#include "op-attrs/operator_task_space.dtg.h" +#include "op-attrs/task_space_coordinate.dtg.h" #include #include #include diff --git a/lib/pcg/include/pcg/start_invariant_machine_view.h b/lib/pcg/include/pcg/start_invariant_machine_view.h index f5091c69d1..0cf515bb6b 100644 --- a/lib/pcg/include/pcg/start_invariant_machine_view.h +++ b/lib/pcg/include/pcg/start_invariant_machine_view.h @@ -4,9 +4,9 @@ #include "pcg/machine_space_offset.h" #include "pcg/machine_specification.dtg.h" #include "pcg/machine_view.dtg.h" -#include "pcg/operator_task_space.dtg.h" +#include "op-attrs/operator_task_space.dtg.h" #include "pcg/start_invariant_machine_view.dtg.h" -#include "pcg/task_space_coordinate.dtg.h" +#include "op-attrs/task_space_coordinate.dtg.h" #include namespace FlexFlow { diff --git a/lib/pcg/src/pcg/machine_view.cc b/lib/pcg/src/pcg/machine_view.cc index 18f6cacb7e..431b3cc4fc 100644 --- a/lib/pcg/src/pcg/machine_view.cc +++ b/lib/pcg/src/pcg/machine_view.cc @@ -1,8 +1,8 @@ #include "pcg/machine_view.h" #include "pcg/machine_specification.h" -#include "pcg/operator_task_space.h" +#include "op-attrs/operator_task_space.h" #include "utils/containers/contains.h" -#include "utils/containers/count.h" +#include "utils/containers/range.h" #include "utils/containers/filter.h" #include "utils/containers/scanl.h" #include "utils/containers/sum.h" @@ -52,21 +52,21 @@ std::optional get_machine_space_coordinate( [&](MachineSpecificationDimension dimension) { std::vector mv_dimensions = get_dimensions(machine_view); - return filter(count(mv_dimensions.size()), [&](size_t idx) { + return filter(range(mv_dimensions.size()), [&](int idx) { return mv_dimensions.at(idx) == dimension; }); }; auto compute_index = [&](int start_idx, - std::vector const &dimension_indices) { + std::vector const &dimension_indices) { std::vector mv_strides = get_strides(machine_view); - std::vector sizes = transform(dimension_indices, [&](size_t i) { + std::vector sizes = transform(dimension_indices, [&](int i) { return task.degrees.at(i) * mv_strides.at(i).unwrapped; }); std::vector coord_points = transform( - dimension_indices, [&](size_t i) { return coord.raw_coord.at(i); }); - std::vector strides = transform(dimension_indices, [&](size_t i) { + dimension_indices, [&](int i) { return coord.raw_coord.at(i); }); + std::vector strides = transform(dimension_indices, [&](int i) { return mv_strides.at(i).unwrapped; }); @@ -80,10 +80,10 @@ std::optional get_machine_space_coordinate( return index; }; - std::vector inter_dimension_indices = + std::vector inter_dimension_indices = get_dimension_indices_for_dimension( MachineSpecificationDimension::INTER_NODE); - std::vector intra_dimension_indices = + std::vector intra_dimension_indices = get_dimension_indices_for_dimension( MachineSpecificationDimension::INTRA_NODE); diff --git a/lib/pcg/src/pcg/start_invariant_machine_view.cc b/lib/pcg/src/pcg/start_invariant_machine_view.cc index 1fcc3ea12f..75abc765c5 100644 --- a/lib/pcg/src/pcg/start_invariant_machine_view.cc +++ b/lib/pcg/src/pcg/start_invariant_machine_view.cc @@ -1,7 +1,7 @@ #include "pcg/start_invariant_machine_view.h" #include "pcg/machine_space_offset.h" #include "pcg/machine_view.h" -#include "pcg/operator_task_space.h" +#include "op-attrs/operator_task_space.h" #include "utils/containers/count.h" #include "utils/containers/filter.h" #include "utils/containers/scanl.h" diff --git a/lib/utils/include/utils/archetypes/ordered_value_type.h b/lib/utils/include/utils/archetypes/ordered_value_type.h new file mode 100644 index 0000000000..270368afc1 --- /dev/null +++ b/lib/utils/include/utils/archetypes/ordered_value_type.h @@ -0,0 +1,52 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ARCHETYPES_ORDERED_VALUE_TYPE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ARCHETYPES_ORDERED_VALUE_TYPE_H + +#include +#include + +namespace FlexFlow { + +template +struct ordered_value_type { + ordered_value_type() = delete; + + ordered_value_type(ordered_value_type const &) { + assert(false); + } + ordered_value_type &operator=(ordered_value_type const &) { + assert(false); + } + + ordered_value_type(ordered_value_type &&) { + assert(false); + } + ordered_value_type &operator=(ordered_value_type &&) { + assert(false); + } + + bool operator==(ordered_value_type const &) const { + assert(false); + } + bool operator!=(ordered_value_type const &) const { + assert(false); + } + + bool operator<(ordered_value_type const &) const { + assert(false); + } +}; + +} // namespace FlexFlow + +namespace std { + +template +struct hash<::FlexFlow::ordered_value_type> { + size_t operator()(::FlexFlow::ordered_value_type const &) const { + assert(false); + }; +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/containers/all_of.h b/lib/utils/include/utils/containers/all_of.h index fb44aeaed8..87c9e067dc 100644 --- a/lib/utils/include/utils/containers/all_of.h +++ b/lib/utils/include/utils/containers/all_of.h @@ -1,6 +1,8 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ALL_OF_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ALL_OF_H +#include + namespace FlexFlow { template @@ -13,6 +15,8 @@ bool all_of(C const &c, F const &f) { return true; } +bool all_of(std::vector const &); + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/count.h b/lib/utils/include/utils/containers/count.h index bae4ba104c..30ad805b06 100644 --- a/lib/utils/include/utils/containers/count.h +++ b/lib/utils/include/utils/containers/count.h @@ -17,8 +17,6 @@ int count(C const &c, F const &f) { return result; } -std::vector count(size_t n); - } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/filter_idxs.h b/lib/utils/include/utils/containers/filter_idxs.h new file mode 100644 index 0000000000..c71ca5e2c5 --- /dev/null +++ b/lib/utils/include/utils/containers/filter_idxs.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FILTER_IDXS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FILTER_IDXS_H + +#include +#include + +namespace FlexFlow { + +template +std::vector filter_idxs(std::vector const &input, std::function const &f) { + std::vector result; + + for (int idx = 0; idx < input.size(); idx++) { + if (f(idx)) { + result.push_back(input.at(idx)); + } + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/group_by.h b/lib/utils/include/utils/containers/group_by.h index 6abffbfed0..720f37c71c 100644 --- a/lib/utils/include/utils/containers/group_by.h +++ b/lib/utils/include/utils/containers/group_by.h @@ -4,12 +4,14 @@ #include #include #include +#include +#include namespace FlexFlow { template > std::unordered_map> - group_by(std::unordered_set const &vs, F f) { + group_by(std::unordered_set const &vs, F &&f) { std::unordered_map> result; for (V const &v : vs) { result[f(v)].insert(v); @@ -17,6 +19,27 @@ std::unordered_map> return result; } +template > +std::unordered_map> + group_by(std::vector const &vs, F &&f) { + std::unordered_map> result; + for (V const &v : vs) { + result[f(v)].push_back(v); + } + return result; +} + +template > +std::unordered_map> + group_by(std::set const &vs, F &&f) { + std::unordered_map> result; + for (V const &v : vs) { + result[f(v)].insert(v); + } + return result; +} + + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/is_subseteq_of.h b/lib/utils/include/utils/containers/is_subseteq_of.h index 26543ca75b..55f7d03814 100644 --- a/lib/utils/include/utils/containers/is_subseteq_of.h +++ b/lib/utils/include/utils/containers/is_subseteq_of.h @@ -3,6 +3,7 @@ #include "utils/containers/contains.h" #include +#include namespace FlexFlow { @@ -21,6 +22,21 @@ bool is_subseteq_of(std::unordered_set const &l, return true; } +template +bool is_subseteq_of(std::set const &l, + std::set const &r) { + if (l.size() > r.size()) { + return false; + } + + for (auto const &ll : l) { + if (!contains(r, ll)) { + return false; + } + } + return true; +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/scanr.h b/lib/utils/include/utils/containers/scanr.h new file mode 100644 index 0000000000..80e07e3f6f --- /dev/null +++ b/lib/utils/include/utils/containers/scanr.h @@ -0,0 +1,77 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SCANR_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SCANR_H + +#include +#include + +namespace FlexFlow { + +/** + * @brief + * Applies `op` to the elements of `c` from right to left, accumulating + * the intermediate results in a vector. `init` is used as the starting point + * for the accumulation. + * + * @example + * std::vector nums = {1, 2, 3, 4}; + * auto result = scanl(nums, 0, [](int a, int b) {return a+b;}); + * result -> {10, 9, 7, 4, 0} + * + * @note + * Essentially a foldl which stores the intermediate results + * For more information, see + * https://hackage.haskell.org/package/base-4.20.0.1/docs/Prelude.html#v:scan4 + */ +template +std::vector scanr(C const &c, T init, F const &op) { + std::vector result; + + result.push_back(init); + for (auto const &elem : c) { + init = op(init, elem); + result.push_back(init); + } + + return result; +} + +/** + * @brief + * Applies `op` to the elements of `c` from right to left, accumulating + * the intermediate results in a vector. The first item of `c` is used as the + * starting point for the accumulation. + * + * @example + * std::vector nums = {1, 2, 3, 4}; + * auto result = scanl1(nums, [](int a, int b) {return a+b;}); + * result -> {10, 9, 7, 4} + * + * @note + * Essentially a foldl1 which stores the intermediate results. + * For more information, see + * https://hackage.haskell.org/package/base-4.20.0.1/docs/Prelude.html#v:scanl1 + */ +template +std::vector scanl1(C const &c, F op) { + + if (c.empty()) { + return std::vector(); + } + + std::optional init = std::nullopt; + std::vector result; + + for (T const &elem : c) { + if (!init.has_value()) { + init = elem; + } else { + init = op(init.value(), elem); + } + result.push_back(init.value()); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/uncurry.h b/lib/utils/include/utils/containers/uncurry.h new file mode 100644 index 0000000000..24019d6d61 --- /dev/null +++ b/lib/utils/include/utils/containers/uncurry.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNCURRY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNCURRY_H + +#include +#include + +namespace FlexFlow { + +template > +std::function const &)> uncurry(F &&f) { + return [f](std::pair const &p) -> Result { + return f(p.first, p.second); + }; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/vector_from_idx_map.h b/lib/utils/include/utils/containers/vector_from_idx_map.h new file mode 100644 index 0000000000..dbd5f26552 --- /dev/null +++ b/lib/utils/include/utils/containers/vector_from_idx_map.h @@ -0,0 +1,27 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VECTOR_FROM_IDX_MAP_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VECTOR_FROM_IDX_MAP_H + +#include +#include +#include "utils/containers/contains_key.h" +#include + +namespace FlexFlow { + +template +std::optional> vector_from_idx_map(std::unordered_map const &m) { + std::vector result; + + for (int i = 0; i < m.size(); i++) { + if (!contains_key(m, i)) { + return std::nullopt; + } + result.push_back(m.at(i)); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/zip_with.h b/lib/utils/include/utils/containers/zip_with.h new file mode 100644 index 0000000000..fb10f2a89e --- /dev/null +++ b/lib/utils/include/utils/containers/zip_with.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_WITH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_WITH_H + +#include + +namespace FlexFlow { + +template > +std::vector zip_with(std::vector const &l, std::vector const &r, F &&f) { + std::vector result; + for (int i = 0; i < l.size() && i < r.size(); i++) { + result.push_back(f(l.at(i), r.at(i))); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h b/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h index f1063c1f21..eb64434fd1 100644 --- a/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h +++ b/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_UNORDERED_SET_LABELLED_DATAFLOW_GRAPH_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_UNORDERED_SET_LABELLED_DATAFLOW_GRAPH_H -#include "utils/containers/count.h" +#include "utils/containers/range.h" #include "utils/containers/enumerate_vector.h" #include "utils/containers/filter.h" #include "utils/containers/generate_map.h" @@ -57,7 +57,7 @@ struct UnorderedSetLabelledOpenDataflowGraph final } std::vector new_outputs = - transform(count(output_labels.size()), [&](int output_idx) { + transform(range(output_labels.size()), [&](int output_idx) { return DataflowOutput{new_node, output_idx}; }); diff --git a/lib/utils/include/utils/orthotope/orthotope.h b/lib/utils/include/utils/orthotope/orthotope.h new file mode 100644 index 0000000000..fa694e7552 --- /dev/null +++ b/lib/utils/include/utils/orthotope/orthotope.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_H + +#include "utils/orthotope/orthotope.dtg.h" +#include "utils/orthotope/orthotope_coordinate.dtg.h" +#include "utils/orthotope/orthotope_dim_idx_t.dtg.h" +#include + +namespace FlexFlow { + +bool orthotope_contains_coord(Orthotope const &, OrthotopeCoordinate const &); + +Orthotope restrict_orthotope_dims(Orthotope const &, std::unordered_set const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/orthotope/orthotope.struct.toml b/lib/utils/include/utils/orthotope/orthotope.struct.toml new file mode 100644 index 0000000000..9fd715df1d --- /dev/null +++ b/lib/utils/include/utils/orthotope/orthotope.struct.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "Orthotope" +features = [ + "eq", + "fmt", + "hash", +] + +includes = [ + "", +] + +src_includes = [ + "utils/hash/vector.h", + "utils/fmt/vector.h", +] + +[[fields]] +name = "dims" +type = "std::vector" diff --git a/lib/utils/include/utils/orthotope/orthotope_coordinate.h b/lib/utils/include/utils/orthotope/orthotope_coordinate.h new file mode 100644 index 0000000000..7fe068a601 --- /dev/null +++ b/lib/utils/include/utils/orthotope/orthotope_coordinate.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_COORDINATE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_COORDINATE_H + +#include "utils/orthotope/orthotope_dim_idx_t.dtg.h" +#include "utils/orthotope/orthotope_coordinate.dtg.h" +#include + +namespace FlexFlow { + +std::set get_orthotope_coord_dims(OrthotopeCoordinate const &); + +OrthotopeCoordinate restrict_orthotope_coord_dims(OrthotopeCoordinate const &, std::set const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/orthotope/orthotope_coordinate.struct.toml b/lib/utils/include/utils/orthotope/orthotope_coordinate.struct.toml new file mode 100644 index 0000000000..cdfbee4d2a --- /dev/null +++ b/lib/utils/include/utils/orthotope/orthotope_coordinate.struct.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "OrthotopeCoordinate" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "" +] + +src_includes = [ + "utils/hash/vector.h", + "utils/fmt/vector.h", +] + +[[fields]] +name = "idxs" +type = "std::vector" diff --git a/lib/utils/include/utils/orthotope/orthotope_dim_idx_t.h b/lib/utils/include/utils/orthotope/orthotope_dim_idx_t.h new file mode 100644 index 0000000000..d14e2633d7 --- /dev/null +++ b/lib/utils/include/utils/orthotope/orthotope_dim_idx_t.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_DIM_IDX_T_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_DIM_IDX_T_H + +#include "utils/orthotope/orthotope_dim_idx_t.dtg.h" +#include + +namespace FlexFlow { + +std::set dim_idxs_for_orthotope_with_num_dims(int num_dims); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/orthotope/orthotope_dim_idx_t.struct.toml b/lib/utils/include/utils/orthotope/orthotope_dim_idx_t.struct.toml new file mode 100644 index 0000000000..68ee54c40f --- /dev/null +++ b/lib/utils/include/utils/orthotope/orthotope_dim_idx_t.struct.toml @@ -0,0 +1,12 @@ +namespace = "FlexFlow" +name = "orthotope_dim_idx_t" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +[[fields]] +name = "raw_idx" +type = "int" diff --git a/lib/utils/include/utils/orthotope/orthotope_surjective_projection.h b/lib/utils/include/utils/orthotope/orthotope_surjective_projection.h new file mode 100644 index 0000000000..8ee7cae131 --- /dev/null +++ b/lib/utils/include/utils/orthotope/orthotope_surjective_projection.h @@ -0,0 +1,30 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_SURJECTIVE_PROJECTION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_SURJECTIVE_PROJECTION_H + +#include "utils/orthotope/orthotope.dtg.h" +#include "utils/orthotope/orthotope_surjective_projection.dtg.h" +#include "utils/orthotope/orthotope_coordinate.dtg.h" +#include + +namespace FlexFlow { + +OrthotopeSurjectiveProjection + make_orthotope_projection_from_map(std::unordered_map const &); + +std::unordered_map get_src_to_dst_dim_map(OrthotopeSurjectiveProjection const &); + +orthotope_dim_idx_t get_dst_dim_for_src_dim(OrthotopeSurjectiveProjection const &, orthotope_dim_idx_t const &); + +int get_src_num_dims(OrthotopeSurjectiveProjection const &); +int get_dst_num_dims(OrthotopeSurjectiveProjection const &); + +OrthotopeSurjectiveProjection reverse_projection(OrthotopeSurjectiveProjection const &); + +std::unordered_set get_all_surjective_projections_between(Orthotope const &src, Orthotope const &dst); + +int deconflict_noninjective_dims(std::vector> const &coords_and_sizes); +OrthotopeCoordinate project_coordinate_through(OrthotopeSurjectiveProjection const &, Orthotope const &, OrthotopeCoordinate const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/orthotope/orthotope_surjective_projection.struct.toml b/lib/utils/include/utils/orthotope/orthotope_surjective_projection.struct.toml new file mode 100644 index 0000000000..4ae4d71dd0 --- /dev/null +++ b/lib/utils/include/utils/orthotope/orthotope_surjective_projection.struct.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "OrthotopeSurjectiveProjection" +features = [ + "eq", + "fmt", + "hash", +] + +includes = [ + "", + "utils/orthotope/orthotope_dim_idx_t.dtg.h", +] + +src_includes = [ + "utils/hash/vector.h", + "utils/fmt/vector.h", +] + +[[fields]] +name = "dim_mapping" +type = "std::vector<::FlexFlow::orthotope_dim_idx_t>" + +[[fields]] +name = "reversed" +type = "bool" diff --git a/lib/utils/include/utils/orthotope/orthtope.h b/lib/utils/include/utils/orthotope/orthtope.h new file mode 100644 index 0000000000..b723e8057f --- /dev/null +++ b/lib/utils/include/utils/orthotope/orthtope.h @@ -0,0 +1,10 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHTOPE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHTOPE_H + +namespace FlexFlow { + +Orthotope orthotope_from_dim_map(std::unordered_map const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/src/utils/archetypes/ordered_value_type.cc b/lib/utils/src/utils/archetypes/ordered_value_type.cc new file mode 100644 index 0000000000..572a03e3cf --- /dev/null +++ b/lib/utils/src/utils/archetypes/ordered_value_type.cc @@ -0,0 +1,7 @@ +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +template struct ordered_value_type<0>; + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/cli/cli_spec.cc b/lib/utils/src/utils/cli/cli_spec.cc index ca51cfe57f..e7ad5e8df4 100644 --- a/lib/utils/src/utils/cli/cli_spec.cc +++ b/lib/utils/src/utils/cli/cli_spec.cc @@ -1,5 +1,5 @@ #include "utils/cli/cli_spec.h" -#include "utils/containers/count.h" +#include "utils/containers/range.h" #include "utils/containers/transform.h" #include "utils/integer_conversions.h" @@ -10,7 +10,7 @@ CLISpec empty_cli_spec() { } std::vector cli_get_flag_keys(CLISpec const &cli) { - return transform(count(cli.flags.size()), + return transform(range(cli.flags.size()), [](int idx) { return CLIFlagKey{idx}; }); } diff --git a/lib/utils/src/utils/containers/all_of.cc b/lib/utils/src/utils/containers/all_of.cc index 5b33efc6e6..9f02c1aaf7 100644 --- a/lib/utils/src/utils/containers/all_of.cc +++ b/lib/utils/src/utils/containers/all_of.cc @@ -1 +1,15 @@ #include "utils/containers/all_of.h" + +namespace FlexFlow { + +bool all_of(std::vector const &v) { + for (bool v : v) { + if (!v) { + return false; + } + } + + return true; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/count.cc b/lib/utils/src/utils/containers/count.cc index 928e777a37..3baff11fb4 100644 --- a/lib/utils/src/utils/containers/count.cc +++ b/lib/utils/src/utils/containers/count.cc @@ -1,13 +1 @@ #include "utils/containers/count.h" - -namespace FlexFlow { - -std::vector count(size_t n) { - std::vector v(n); - for (size_t i = 0; i < n; i++) { - v[i] = i; - } - return v; -} - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/filter_idxs.cc b/lib/utils/src/utils/containers/filter_idxs.cc new file mode 100644 index 0000000000..fd0d61dcf8 --- /dev/null +++ b/lib/utils/src/utils/containers/filter_idxs.cc @@ -0,0 +1,11 @@ +#include "utils/containers/filter_idxs.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using T = value_type<0>; + +template + std::vector filter_idxs(std::vector const &, std::function const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/group_by.cc b/lib/utils/src/utils/containers/group_by.cc index ac05cee861..aa53295902 100644 --- a/lib/utils/src/utils/containers/group_by.cc +++ b/lib/utils/src/utils/containers/group_by.cc @@ -1 +1,26 @@ #include "utils/containers/group_by.h" +#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; +using F = std::function; + +template + std::unordered_map> + group_by(std::unordered_set const &, F &&); + +template + std::unordered_map> + group_by(std::vector const &, F &&); + +using V2 = ordered_value_type<1>; +using F2 = std::function; + +template + std::unordered_map> + group_by(std::set const &, F2 &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/is_subseteq_of.cc b/lib/utils/src/utils/containers/is_subseteq_of.cc index b5efa5dd90..9e59ece7e1 100644 --- a/lib/utils/src/utils/containers/is_subseteq_of.cc +++ b/lib/utils/src/utils/containers/is_subseteq_of.cc @@ -1 +1,15 @@ #include "utils/containers/is_subseteq_of.h" +#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using T1 = value_type<0>; +template + bool is_subseteq_of(std::unordered_set const &, std::unordered_set const &); + +using T2 = ordered_value_type<0>; +template + bool is_subseteq_of(std::unordered_set const &, std::unordered_set const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/uncurry.cc b/lib/utils/src/utils/containers/uncurry.cc new file mode 100644 index 0000000000..fe6d2a17e3 --- /dev/null +++ b/lib/utils/src/utils/containers/uncurry.cc @@ -0,0 +1,14 @@ +#include "utils/containers/uncurry.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using T1 = value_type<0>; +using T2 = value_type<1>; +using Result = value_type<2>; +using F = std::function; + +template + std::function const &)> uncurry(F &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/vector_from_idx_map.cc b/lib/utils/src/utils/containers/vector_from_idx_map.cc new file mode 100644 index 0000000000..71714446e9 --- /dev/null +++ b/lib/utils/src/utils/containers/vector_from_idx_map.cc @@ -0,0 +1,11 @@ +#include "utils/containers/vector_from_idx_map.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using T = value_type<0>; + +template + std::optional> vector_from_idx_map(std::unordered_map const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/zip_with.cc b/lib/utils/src/utils/containers/zip_with.cc new file mode 100644 index 0000000000..499d6ac8b2 --- /dev/null +++ b/lib/utils/src/utils/containers/zip_with.cc @@ -0,0 +1,14 @@ +#include "utils/containers/zip_with.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using T1 = value_type<0>; +using T2 = value_type<1>; +using Result = value_type<2>; +using F = std::function; + +template + std::vector zip_with(std::vector const &, std::vector const &, F &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/instances/unordered_set_dataflow_graph.cc b/lib/utils/src/utils/graph/instances/unordered_set_dataflow_graph.cc index 1ffc5f423f..dfb26cb4e1 100644 --- a/lib/utils/src/utils/graph/instances/unordered_set_dataflow_graph.cc +++ b/lib/utils/src/utils/graph/instances/unordered_set_dataflow_graph.cc @@ -1,6 +1,6 @@ #include "utils/graph/instances/unordered_set_dataflow_graph.h" #include "utils/containers/are_disjoint.h" -#include "utils/containers/count.h" +#include "utils/containers/range.h" #include "utils/containers/enumerate_vector.h" #include "utils/containers/extend.h" #include "utils/containers/transform.h" @@ -36,7 +36,7 @@ NodeAddedResult UnorderedSetDataflowGraph::add_node( Node new_node = this->node_source.new_node(); std::vector new_outputs = - transform(count(num_outputs), [&](int output_idx) { + transform(range(num_outputs), [&](int output_idx) { return DataflowOutput{new_node, output_idx}; }); diff --git a/lib/utils/src/utils/orthotope/orthotope.cc b/lib/utils/src/utils/orthotope/orthotope.cc new file mode 100644 index 0000000000..2cf75acb73 --- /dev/null +++ b/lib/utils/src/utils/orthotope/orthotope.cc @@ -0,0 +1,16 @@ +#include "utils/orthotope/orthotope.h" +#include "utils/containers/zip_with.h" +#include "utils/containers/all_of.h" +#include "utils/exception.h" + +namespace FlexFlow { + +bool orthotope_contains_coord(Orthotope const &o, OrthotopeCoordinate const &c) { + if (o.dims.size() != c.idxs.size()) { + throw mk_runtime_error(fmt::format("orthotope_contains_coord expected orthotope and coord to have the same number of dims, but received o={}, c={}", o, c)); + } + + return all_of(zip_with(o.dims, c.idxs, [](int dim_size, int dim_coord) { return dim_coord >= 0 && dim_coord < dim_size; })); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/orthotope/orthotope_coordinate.cc b/lib/utils/src/utils/orthotope/orthotope_coordinate.cc new file mode 100644 index 0000000000..35fc52e258 --- /dev/null +++ b/lib/utils/src/utils/orthotope/orthotope_coordinate.cc @@ -0,0 +1,26 @@ +#include "utils/orthotope/orthotope_coordinate.h" +#include "utils/containers/filter_idxs.h" +#include "utils/containers/is_subseteq_of.h" +#include "utils/exception.h" +#include "utils/orthotope/orthotope_dim_idx_t.h" +#include "utils/fmt/set.h" + +namespace FlexFlow { + +std::set get_orthotope_coord_dims(OrthotopeCoordinate const &coord) { + return dim_idxs_for_orthotope_with_num_dims(coord.idxs.size()); +} + +OrthotopeCoordinate restrict_orthotope_coord_dims(OrthotopeCoordinate const &coord, std::set const &mask) { + std::set coord_dims = get_orthotope_coord_dims(coord); + + if (!is_subseteq_of(coord_dims, mask)) { + throw mk_runtime_error(fmt::format("restrict_orthotope_coord_dims expected mask to be a subset of coord dims, but got coord={}, mask={}", coord, mask)); + } + + std::vector new_idxs = filter_idxs(coord.idxs, [&](int i) { return contains(mask, orthotope_dim_idx_t{i}); }); + + return OrthotopeCoordinate{new_idxs}; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/orthotope/orthotope_dim_idx_t.cc b/lib/utils/src/utils/orthotope/orthotope_dim_idx_t.cc new file mode 100644 index 0000000000..a26645a521 --- /dev/null +++ b/lib/utils/src/utils/orthotope/orthotope_dim_idx_t.cc @@ -0,0 +1,12 @@ +#include "utils/orthotope/orthotope_dim_idx_t.h" +#include "utils/containers/set_of.h" +#include "utils/containers/transform.h" +#include "utils/containers/range.h" + +namespace FlexFlow { + +std::set dim_idxs_for_orthotope_with_num_dims(int num_dims) { + return set_of(transform(range(num_dims), [](int idx) { return orthotope_dim_idx_t{idx}; })); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/orthotope/orthotope_surjective_projection.cc b/lib/utils/src/utils/orthotope/orthotope_surjective_projection.cc new file mode 100644 index 0000000000..4b3c309bb1 --- /dev/null +++ b/lib/utils/src/utils/orthotope/orthotope_surjective_projection.cc @@ -0,0 +1,126 @@ +#include "utils/orthotope/orthotope_surjective_projection.h" +#include "utils/containers/generate_map.h" +#include "utils/containers/get_all_assignments.h" +#include "utils/containers/group_by.h" +#include "utils/containers/map_keys.h" +#include "utils/containers/range.h" +#include "utils/containers/set_of.h" +#include "utils/containers/sum.h" +#include "utils/containers/transform.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/containers/filter.h" +#include "utils/containers/values.h" +#include "utils/containers/zip_with.h" +#include "utils/exception.h" +#include "utils/orthotope/orthotope.h" +#include "utils/orthotope/orthotope_dim_idx_t.dtg.h" +#include "utils/orthotope/orthotope_dim_idx_t.h" +#include "utils/containers/vector_from_idx_map.h" +#include "utils/containers/scanl.h" +#include "utils/containers/all_of.h" +#include "utils/fmt/vector.h" + +namespace FlexFlow { + +OrthotopeSurjectiveProjection + make_orthotope_projection_from_map(std::unordered_map const &m) { + std::unordered_map raw_idx_map = map_keys(m, [](orthotope_dim_idx_t const &k) { return k.raw_idx; }); + return OrthotopeSurjectiveProjection{ + /*dim_mapping=*/vector_from_idx_map(raw_idx_map).value(), + /*reversed=*/false, + }; +} + +std::unordered_map get_src_to_dst_dim_map(OrthotopeSurjectiveProjection const &p) { + if (p.reversed) { + throw mk_runtime_error(fmt::format("get_src_to_dst_dim_map expected p.reversed=false, but received p={}", p)); + } + + std::unordered_map raw_idx_map = generate_map(range(p.dim_mapping.size()), [&](int x) { return p.dim_mapping.at(x); }); + return map_keys(raw_idx_map, [](int src_dim_idx) { return orthotope_dim_idx_t{src_dim_idx}; }); +} + +orthotope_dim_idx_t get_dst_dim_for_src_dim(OrthotopeSurjectiveProjection const &p, orthotope_dim_idx_t const &src_idx) { + return p.dim_mapping.at(src_idx.raw_idx); +} + +int get_src_num_dims(OrthotopeSurjectiveProjection const &p) { + return p.dim_mapping.size(); +} + +int get_dst_num_dims(OrthotopeSurjectiveProjection const &p) { + return unordered_set_of(p.dim_mapping).size(); +} + +OrthotopeSurjectiveProjection reverse_projection(OrthotopeSurjectiveProjection const &p) { + OrthotopeSurjectiveProjection result = p; + result.reversed = !p.reversed; + return result; +} + +std::unordered_set get_all_surjective_projections_between(int src_num_dims, int dst_num_dims) { + if (src_num_dims < dst_num_dims) { + return transform(get_all_surjective_projections_between(dst_num_dims, src_num_dims), + [](OrthotopeSurjectiveProjection const &p) { return reverse_projection(p); }); + } + + std::set src_dim_idxs = dim_idxs_for_orthotope_with_num_dims(src_num_dims); + std::set dst_dim_idxs = dim_idxs_for_orthotope_with_num_dims(dst_num_dims); + + std::unordered_map> src_to_dst_idxs = + generate_map(src_dim_idxs, [&](orthotope_dim_idx_t) { return unordered_set_of(dst_dim_idxs); }); + + std::unordered_set> valid_mappings = + filter(get_all_assignments(src_to_dst_idxs), + [&](std::unordered_map const &src_to_dst_idx) { + return set_of(values(src_to_dst_idx)) == dst_dim_idxs; + }); + + return transform(valid_mappings, make_orthotope_projection_from_map); +} + +int deconflict_noninjective_dims(std::vector> const &coords_and_sizes) { + if (coords_and_sizes.size() == 0) { + throw mk_runtime_error("deconflict_noninjective_dims expected non-empty vector, but receieved empty vector"); + } + + std::vector coords = transform(coords_and_sizes, [](std::pair const &p) { return p.first; }); + std::vector dim_sizes = transform(coords_and_sizes, [](std::pair const &p) { return p.second; }); + + if (!all_of(zip_with(coords, dim_sizes, [](int coord, int dim_size) { return coord > 0 && coord < dim_size; }))) { + throw mk_runtime_error(fmt::format("coords out of bounds of dim sizes: coords={}, dim_sizes={}", coords, dim_sizes)); + } + + + std::vector strides = scanl(dim_sizes, 1, [](int accum, int next) { return accum * next; }); + return sum(zip_with(coords, strides, [](int coord, int stride) { return coord * stride; })); +} + +OrthotopeCoordinate project_coordinate_through(OrthotopeSurjectiveProjection const &p, Orthotope const &o, OrthotopeCoordinate const &c) { + auto calculate_nonoverlapping_result = [](std::vector const &coords, std::vector const &dim_sizes) -> int { + NOT_IMPLEMENTED(); + }; + + if (p.reversed) { + NOT_IMPLEMENTED(); // TODO @lockshaw + } else { + if (c.idxs.size() != get_src_num_dims(p)) { + throw mk_runtime_error(fmt::format("project_coordinate_through requires projection src and coordinate to have same num dims, but got {} and {} respectively", + get_src_num_dims(p), + c.idxs.size())); + } + + if (!orthotope_contains_coord(o, c)) { + throw mk_runtime_error(fmt::format("project_coordinate_through requires coord to be in the orthotope, but got coord={} and orthotope={} respectively", c, o)); + } + + std::unordered_map> by_dst_dim_idx = + group_by(dim_idxs_for_orthotope_with_num_dims(o.dims.size()), + [&](orthotope_dim_idx_t const &src_dim_idx) { return get_dst_dim_for_src_dim(p, src_dim_idx); }); + + + NOT_IMPLEMENTED(); + } +} + +} // namespace FlexFlow diff --git a/lib/utils/test/src/utils/containers/filter_idxs.cc b/lib/utils/test/src/utils/containers/filter_idxs.cc new file mode 100644 index 0000000000..36a8e2a4f5 --- /dev/null +++ b/lib/utils/test/src/utils/containers/filter_idxs.cc @@ -0,0 +1,17 @@ +#include "utils/containers/filter_idxs.h" +#include +#include +#include "test/utils/doctest/fmt/vector.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("filter_idxs") { + std::vector input = {"hello", "world", "!"}; + + std::vector result = filter_idxs(input, [](int idx) { return idx % 2 == 0; }); + std::vector correct = {"hello", "!"}; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/group_by.cc b/lib/utils/test/src/utils/containers/group_by.cc new file mode 100644 index 0000000000..e76430fa6d --- /dev/null +++ b/lib/utils/test/src/utils/containers/group_by.cc @@ -0,0 +1,46 @@ +#include "utils/containers/group_by.h" +#include +#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/set.h" +#include "test/utils/doctest/fmt/vector.h" +#include "test/utils/doctest/fmt/unordered_set.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("group_by(std::unordered_set, F)") { + std::unordered_set input = {0, 3, 2, 9, 8}; + + std::unordered_map> result = group_by(input, [](int x) { return x % 3; }); + std::unordered_map> correct = { + {0, {0, 3, 9}}, + {2, {2, 8}}, + }; + + CHECK(result == correct); + } + + TEST_CASE("group_by(std::vector, F)") { + std::vector input = {0, 3, 0, 2, 2, 9, 8, 9}; + + std::unordered_map> result = group_by(input, [](int x) { return x % 3; }); + std::unordered_map> correct = { + {0, {0, 3, 0, 9, 9}}, + {2, {2, 2, 8}}, + }; + + CHECK(result == correct); + } + + TEST_CASE("group_by(std::set, F)") { + std::set input = {0, 3, 2, 9, 8}; + + std::unordered_map> result = group_by(input, [](int x) { return x % 3; }); + std::unordered_map> correct = { + {0, {0, 3, 9}}, + {2, {2, 8}}, + }; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/uncurry.cc b/lib/utils/test/src/utils/containers/uncurry.cc new file mode 100644 index 0000000000..8ea3d5bfb6 --- /dev/null +++ b/lib/utils/test/src/utils/containers/uncurry.cc @@ -0,0 +1,27 @@ +#include "utils/containers/uncurry.h" +#include +#include +#include "test/utils/doctest/fmt/pair.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("uncurry") { + auto f = [](int x, std::string const &y) { + return std::make_pair(y, x); + }; + + std::function(std::pair const &)> + result_f = uncurry(f); + + SUBCASE("has same behavior as f") { + int x = 1; + std::string y = "aa"; + std::pair p = {1, "aa"}; + + std::pair result = result_f(p); + std::pair correct = f(x, y); + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/zip_with.cc b/lib/utils/test/src/utils/containers/zip_with.cc new file mode 100644 index 0000000000..df95358803 --- /dev/null +++ b/lib/utils/test/src/utils/containers/zip_with.cc @@ -0,0 +1,73 @@ +#include "utils/containers/zip_with.h" +#include +#include +#include +#include "test/utils/doctest/fmt/vector.h" +#include "test/utils/doctest/fmt/pair.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("zip_with(std::vector, std::vector, F)") { + SUBCASE("result types and input types are all different") { + std::vector v1 = {1, 3, 4, 3}; + std::vector v2 = {"aa", "cc", "bb", "dd"}; + + std::vector> result = zip_with(v1, v2, [](int x1, std::string const &x2) { return std::make_pair(x1, x2); }); + std::vector> correct = { + {1, "aa"}, + {3, "cc"}, + {4, "bb"}, + {3, "dd"}, + }; + + CHECK(result == correct); + } + + + SUBCASE("input lengths don't match") { + auto add = [](int x1, int x2) { return x1 + x2; }; + + std::vector shorter = {1, 2}; + std::vector longer = {1, 3, 5, 7}; + + SUBCASE("first input is shorter") { + std::vector result = zip_with(shorter, longer, add); + std::vector correct = {1+1, 2+3}; + } + + SUBCASE("second input is shorter") { + std::vector result = zip_with(longer, shorter, add); + std::vector correct = {1+1, 2+3}; + } + } + + SUBCASE("properly handles empty inputs") { + std::vector nonempty = {1, 2}; + std::vector empty = {}; + + auto throw_err = [](int x1, int x2) -> int { throw std::runtime_error("error"); }; + + SUBCASE("first input is empty") { + std::vector result = zip_with(empty, nonempty, throw_err); + std::vector correct = empty; + + CHECK(result == correct); + } + + SUBCASE("second input is empty") { + std::vector result = zip_with(nonempty, empty, throw_err); + std::vector correct = empty; + + CHECK(result == correct); + } + + SUBCASE("both inputs are empty") { + std::vector result = zip_with(empty, empty, throw_err); + std::vector correct = empty; + + CHECK(result == correct); + } + } + } +} diff --git a/lib/utils/test/src/utils/orthotope/orthotope.cc b/lib/utils/test/src/utils/orthotope/orthotope.cc new file mode 100644 index 0000000000..28411eed6d --- /dev/null +++ b/lib/utils/test/src/utils/orthotope/orthotope.cc @@ -0,0 +1,96 @@ +#include "utils/orthotope/orthotope.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("orthotope_contains_coord") { + Orthotope orthotope = Orthotope{ + {3, 1}, + }; + + SUBCASE("returns true if coord is in orthotope bounds") { + SUBCASE("smallest allowed coord") { + OrthotopeCoordinate coord = OrthotopeCoordinate{ + {0, 0}, + }; + + bool result = orthotope_contains_coord(orthotope, coord); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("largest allowed coord") { + OrthotopeCoordinate coord = OrthotopeCoordinate{ + {2, 0}, + }; + + bool result = orthotope_contains_coord(orthotope, coord); + bool correct = true; + + CHECK(result == correct); + } + } + + SUBCASE("returns false if coord is out of orthotope bounds") { + SUBCASE("too low") { + // exhaustively check all dims because we can + SUBCASE("dim 0") { + OrthotopeCoordinate coord = OrthotopeCoordinate{ + {-1, 0}, + }; + + bool result = orthotope_contains_coord(orthotope, coord); + bool correct = false; + + CHECK(result == correct); + } + + SUBCASE("dim 1") { + OrthotopeCoordinate coord = OrthotopeCoordinate{ + {1, -1}, + }; + + bool result = orthotope_contains_coord(orthotope, coord); + bool correct = false; + + CHECK(result == correct); + } + } + + SUBCASE("too high") { + // exhaustively check all dims because we can + SUBCASE("dim 0") { + OrthotopeCoordinate coord = OrthotopeCoordinate{ + {3, 0}, + }; + + bool result = orthotope_contains_coord(orthotope, coord); + bool correct = false; + + CHECK(result == correct); + } + + SUBCASE("dim 1") { + OrthotopeCoordinate coord = OrthotopeCoordinate{ + {1, 1}, + }; + + bool result = orthotope_contains_coord(orthotope, coord); + bool correct = false; + + CHECK(result == correct); + } + } + } + + SUBCASE("throws if num dims of coord does not match num dims of the orthotope") { + OrthotopeCoordinate coord = OrthotopeCoordinate{ + {0, 0, 0}, + }; + + CHECK_THROWS(orthotope_contains_coord(orthotope, coord)); + } + } +} diff --git a/lib/utils/test/src/utils/orthotope/orthotope_surjective_projection.cc b/lib/utils/test/src/utils/orthotope/orthotope_surjective_projection.cc new file mode 100644 index 0000000000..b6d7875e91 --- /dev/null +++ b/lib/utils/test/src/utils/orthotope/orthotope_surjective_projection.cc @@ -0,0 +1,50 @@ +#include "utils/orthotope/orthotope_surjective_projection.h" +#include "utils/containers/zip.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("deconflict_noninjective_dims") { + SUBCASE("single input dim is unaffected") { + std::vector coords = {2}; + std::vector dim_sizes = {5}; + + int result = deconflict_noninjective_dims(zip(coords, dim_sizes)); + int correct = 2; + + CHECK(result == correct); + } + + SUBCASE("basic example") { + std::vector coords = {4, 1}; + std::vector dim_sizes = {5, 3}; + + int result = deconflict_noninjective_dims(zip(coords, dim_sizes)); + int correct = 4 * 3 + 1; + + CHECK(result == correct); + } + + SUBCASE("order matters") { + std::vector coords = {1, 4}; + std::vector dim_sizes = {3, 5}; + + int result = deconflict_noninjective_dims(zip(coords, dim_sizes)); + int correct = 1 * 5 + 4; + + CHECK(result == correct); + } + + SUBCASE("throws if coord is outside of corresponding dim_size") { + std::vector coords = {2, 3, 1}; + std::vector dim_sizes = {5, 3, 2}; + + CHECK_THROWS(deconflict_noninjective_dims(zip(coords, dim_sizes))); + } + + SUBCASE("throws if input is empty") { + CHECK_THROWS(deconflict_noninjective_dims({})); + } + } +} From 67dbece0467c7327e70ab2e5b2f03314ff89e000 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sun, 20 Oct 2024 22:49:35 -0700 Subject: [PATCH 02/62] Pass test for deconflict_overlapping_dims --- .../cost_estimator/network_cost_model.cc | 2 +- lib/utils/include/utils/containers/foldr.h | 29 +++++++++++ lib/utils/include/utils/containers/scanl.h | 44 ++--------------- lib/utils/include/utils/containers/scanl1.h | 49 +++++++++++++++++++ lib/utils/include/utils/containers/scanr.h | 49 +++---------------- lib/utils/include/utils/containers/scanr1.h | 49 +++++++++++++++++++ .../orthotope_surjective_projection.h | 2 +- .../utils/containers/reversed_container.cc | 12 +++++ lib/utils/src/utils/containers/scanl.cc | 20 ++++++++ lib/utils/src/utils/containers/scanl1.cc | 21 ++++++++ lib/utils/src/utils/containers/scanr.cc | 22 +++++++++ lib/utils/src/utils/containers/scanr1.cc | 21 ++++++++ .../orthotope_surjective_projection.cc | 12 ++--- lib/utils/test/src/utils/containers/scanl.cc | 38 +++----------- lib/utils/test/src/utils/containers/scanl1.cc | 35 +++++++++++++ lib/utils/test/src/utils/containers/scanr.cc | 44 +++++++++++++++++ lib/utils/test/src/utils/containers/scanr1.cc | 35 +++++++++++++ .../orthotope_surjective_projection.cc | 12 ++--- 18 files changed, 364 insertions(+), 132 deletions(-) create mode 100644 lib/utils/include/utils/containers/foldr.h create mode 100644 lib/utils/include/utils/containers/scanl1.h create mode 100644 lib/utils/include/utils/containers/scanr1.h create mode 100644 lib/utils/src/utils/containers/scanl1.cc create mode 100644 lib/utils/src/utils/containers/scanr.cc create mode 100644 lib/utils/src/utils/containers/scanr1.cc create mode 100644 lib/utils/test/src/utils/containers/scanl1.cc create mode 100644 lib/utils/test/src/utils/containers/scanr.cc create mode 100644 lib/utils/test/src/utils/containers/scanr1.cc diff --git a/lib/compiler/src/compiler/cost_estimator/network_cost_model.cc b/lib/compiler/src/compiler/cost_estimator/network_cost_model.cc index c603a7a02b..76fe66e88c 100644 --- a/lib/compiler/src/compiler/cost_estimator/network_cost_model.cc +++ b/lib/compiler/src/compiler/cost_estimator/network_cost_model.cc @@ -4,7 +4,7 @@ namespace FlexFlow { float estimate_communication_cost(MachineSpecification const &machine_spec, TensorSetMovement const &tensor_set_movement) { - + NOT_IMPLEMENTED(); // TODO @lockshaw // for (SingleTensorMovement const &single_tensor_movement : tensor_set_movement.single_tensor_movements) { // for // } diff --git a/lib/utils/include/utils/containers/foldr.h b/lib/utils/include/utils/containers/foldr.h new file mode 100644 index 0000000000..b9fe30a476 --- /dev/null +++ b/lib/utils/include/utils/containers/foldr.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FOLDR_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FOLDR_H + +#include "utils/exception.h" +#include + +namespace FlexFlow { + +template +T foldr1(std::vector const &vec, F f) { + if (vec.empty()) { + throw mk_runtime_error(fmt::format( + "foldr1 expected non-empty vector, but receieved empty vector")); + } + + auto it = vec.crbegin(); + T result = *it; + it++; + for (; it != vec.crend(); it++) { + result = f(result, *it); + } + + return result; +} + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/scanl.h b/lib/utils/include/utils/containers/scanl.h index a30a9e1576..0d5a4fd7c4 100644 --- a/lib/utils/include/utils/containers/scanl.h +++ b/lib/utils/include/utils/containers/scanl.h @@ -1,14 +1,13 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SCANL_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SCANL_H -#include #include namespace FlexFlow { /** * @brief - * Applies `op` to the elements of `c` from left to right, accumulating + * Applies `f` to the elements of `c` from left to right, accumulating * the intermediate results in a vector. `init` is used as the starting point * for the accumulation. * @@ -23,55 +22,18 @@ namespace FlexFlow { * https://hackage.haskell.org/package/base-4.20.0.1/docs/Prelude.html#v:scanl */ template -std::vector scanl(C const &c, T init, F const &op) { +std::vector scanl(C const &c, T init, F &&f) { std::vector result; result.push_back(init); for (auto const &elem : c) { - init = op(init, elem); + init = f(init, elem); result.push_back(init); } return result; } -/** - * @brief - * Applies `op` to the elements of `c` from left to right, accumulating - * the intermediate results in a vector. The first item of `c` is used as the - * starting point for the accumulation. - * - * @example - * std::vector nums = {1, 2, 3, 4}; - * auto result = scanl1(nums, [](int a, int b) {return a+b;}); - * result -> {1,3,6,10} - * - * @note - * Essentially a foldl1 which stores the intermediate results. - * For more information, see - * https://hackage.haskell.org/package/base-4.20.0.1/docs/Prelude.html#v:scanl1 - */ -template -std::vector scanl1(C const &c, F op) { - - if (c.empty()) { - return std::vector(); - } - - std::optional init = std::nullopt; - std::vector result; - - for (T const &elem : c) { - if (!init.has_value()) { - init = elem; - } else { - init = op(init.value(), elem); - } - result.push_back(init.value()); - } - return result; -} - } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/scanl1.h b/lib/utils/include/utils/containers/scanl1.h new file mode 100644 index 0000000000..4363a2b055 --- /dev/null +++ b/lib/utils/include/utils/containers/scanl1.h @@ -0,0 +1,49 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SCANL1_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SCANL1_H + +#include +#include + +namespace FlexFlow { + +/** + * @brief + * Applies `f` to the elements of `c` from left to right, accumulating + * the intermediate results in a vector. The first item of `c` is used as the + * starting point for the accumulation. + * + * @example + * std::vector nums = {1, 2, 3, 4}; + * auto result = scanl1(nums, [](int a, int b) {return a+b;}); + * result -> {1,3,6,10} + * + * @note + * Essentially a foldl1 which stores the intermediate results. + * For more information, see + * https://hackage.haskell.org/package/base-4.20.0.1/docs/Prelude.html#v:scanl1 + */ +template +std::vector scanl1(C const &c, F &&f) { + + if (c.empty()) { + return std::vector(); + } + + std::optional init = std::nullopt; + std::vector result; + + for (T const &elem : c) { + if (!init.has_value()) { + init = elem; + } else { + init = f(init.value(), elem); + } + result.push_back(init.value()); + } + return result; +} + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/scanr.h b/lib/utils/include/utils/containers/scanr.h index 80e07e3f6f..55a36b8245 100644 --- a/lib/utils/include/utils/containers/scanr.h +++ b/lib/utils/include/utils/containers/scanr.h @@ -1,14 +1,14 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SCANR_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SCANR_H -#include #include +#include "utils/containers/reversed.h" namespace FlexFlow { /** * @brief - * Applies `op` to the elements of `c` from right to left, accumulating + * Applies `f` to the elements of `c` from right to left, accumulating * the intermediate results in a vector. `init` is used as the starting point * for the accumulation. * @@ -23,53 +23,16 @@ namespace FlexFlow { * https://hackage.haskell.org/package/base-4.20.0.1/docs/Prelude.html#v:scan4 */ template -std::vector scanr(C const &c, T init, F const &op) { +std::vector scanr(C const &c, T init, F &&f) { std::vector result; result.push_back(init); - for (auto const &elem : c) { - init = op(init, elem); + for (auto it = c.crbegin(); it != c.crend(); it++) { + init = f(*it, init); result.push_back(init); } - return result; -} - -/** - * @brief - * Applies `op` to the elements of `c` from right to left, accumulating - * the intermediate results in a vector. The first item of `c` is used as the - * starting point for the accumulation. - * - * @example - * std::vector nums = {1, 2, 3, 4}; - * auto result = scanl1(nums, [](int a, int b) {return a+b;}); - * result -> {10, 9, 7, 4} - * - * @note - * Essentially a foldl1 which stores the intermediate results. - * For more information, see - * https://hackage.haskell.org/package/base-4.20.0.1/docs/Prelude.html#v:scanl1 - */ -template -std::vector scanl1(C const &c, F op) { - - if (c.empty()) { - return std::vector(); - } - - std::optional init = std::nullopt; - std::vector result; - - for (T const &elem : c) { - if (!init.has_value()) { - init = elem; - } else { - init = op(init.value(), elem); - } - result.push_back(init.value()); - } - return result; + return reversed(result); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/containers/scanr1.h b/lib/utils/include/utils/containers/scanr1.h new file mode 100644 index 0000000000..5649f305d4 --- /dev/null +++ b/lib/utils/include/utils/containers/scanr1.h @@ -0,0 +1,49 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SCANR1_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SCANR1_H + +#include +#include +#include "utils/containers/reversed.h" + +namespace FlexFlow { + +/** + * @brief + * Applies `op` to the elements of `c` from right to left, accumulating + * the intermediate results in a vector. The first item of `c` is used as the + * starting point for the accumulation. + * + * @example + * std::vector nums = {1, 2, 3, 4}; + * auto result = scanl1(nums, [](int a, int b) {return a+b;}); + * result -> {10, 9, 7, 4} + * + * @note + * Essentially a foldr1 which stores the intermediate results. + * For more information, see + * https://hackage.haskell.org/package/base-4.20.0.1/docs/Prelude.html#v:scanl1 + */ +template +std::vector scanr1(C const &c, F &&f) { + + if (c.empty()) { + return std::vector(); + } + + std::optional init = std::nullopt; + std::vector result; + + for (auto it = c.crbegin(); it != c.crend(); it++) { + if (!init.has_value()) { + init = *it; + } else { + init = f(*it, init.value()); + } + result.push_back(init.value()); + } + return reversed(result); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/orthotope/orthotope_surjective_projection.h b/lib/utils/include/utils/orthotope/orthotope_surjective_projection.h index 8ee7cae131..26c24ba7d3 100644 --- a/lib/utils/include/utils/orthotope/orthotope_surjective_projection.h +++ b/lib/utils/include/utils/orthotope/orthotope_surjective_projection.h @@ -22,7 +22,7 @@ OrthotopeSurjectiveProjection reverse_projection(OrthotopeSurjectiveProjection c std::unordered_set get_all_surjective_projections_between(Orthotope const &src, Orthotope const &dst); -int deconflict_noninjective_dims(std::vector> const &coords_and_sizes); +int deconflict_overlapping_dims(std::vector> const &coords_and_sizes); OrthotopeCoordinate project_coordinate_through(OrthotopeSurjectiveProjection const &, Orthotope const &, OrthotopeCoordinate const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/reversed_container.cc b/lib/utils/src/utils/containers/reversed_container.cc index 1a2fe3cf63..088684c081 100644 --- a/lib/utils/src/utils/containers/reversed_container.cc +++ b/lib/utils/src/utils/containers/reversed_container.cc @@ -1 +1,13 @@ #include "utils/containers/reversed_container.h" +#include "utils/archetypes/value_type.h" +#include + +namespace FlexFlow { + +using T = value_type<0>; +using C = std::vector; + +template + reversed_container_t reversed_container(C const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/scanl.cc b/lib/utils/src/utils/containers/scanl.cc index 4f7ff78b9f..cfc31fd10e 100644 --- a/lib/utils/src/utils/containers/scanl.cc +++ b/lib/utils/src/utils/containers/scanl.cc @@ -1 +1,21 @@ #include "utils/containers/scanl.h" +#include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/value_type.h" +#include + +namespace FlexFlow { + +using T = value_type<0>; +using C = std::vector; +using F = std::function; + +template + std::vector scanl(std::vector const &, T, F &&); + +using T2 = ordered_value_type<0>; +using F2 = std::function; + +template + std::vector scanl(std::set const &, T2, F2 &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/scanl1.cc b/lib/utils/src/utils/containers/scanl1.cc new file mode 100644 index 0000000000..0406b7b9d2 --- /dev/null +++ b/lib/utils/src/utils/containers/scanl1.cc @@ -0,0 +1,21 @@ +#include "utils/containers/scanl1.h" +#include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/value_type.h" +#include + +namespace FlexFlow { + +using T = value_type<0>; +using C = std::vector; +using F = std::function; + +template + std::vector scanl1(std::vector const &, F &&); + +using T2 = ordered_value_type<0>; +using F2 = std::function; + +template + std::vector scanl1(std::set const &, F2 &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/scanr.cc b/lib/utils/src/utils/containers/scanr.cc new file mode 100644 index 0000000000..f0fee9c856 --- /dev/null +++ b/lib/utils/src/utils/containers/scanr.cc @@ -0,0 +1,22 @@ +#include "utils/containers/scanr.h" +#include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/value_type.h" +#include + +namespace FlexFlow { + +using T = value_type<0>; +using C = std::vector; +using F = std::function; + +template + std::vector scanr(std::vector const &, T, F &&); + +using T2 = ordered_value_type<0>; +using F2 = std::function; + +template + std::vector scanr(std::set const &, T2, F2 &&); + + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/scanr1.cc b/lib/utils/src/utils/containers/scanr1.cc new file mode 100644 index 0000000000..8b7f5b9276 --- /dev/null +++ b/lib/utils/src/utils/containers/scanr1.cc @@ -0,0 +1,21 @@ +#include "utils/containers/scanr1.h" +#include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/value_type.h" +#include + +namespace FlexFlow { + +using T = value_type<0>; +using C = std::vector; +using F = std::function; + +template + std::vector scanr1(std::vector const &, F &&); + +using T2 = ordered_value_type<0>; +using F2 = std::function; + +template + std::vector scanr1(std::set const &, F2 &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/orthotope/orthotope_surjective_projection.cc b/lib/utils/src/utils/orthotope/orthotope_surjective_projection.cc index 4b3c309bb1..ae66ee4b12 100644 --- a/lib/utils/src/utils/orthotope/orthotope_surjective_projection.cc +++ b/lib/utils/src/utils/orthotope/orthotope_surjective_projection.cc @@ -5,6 +5,7 @@ #include "utils/containers/map_keys.h" #include "utils/containers/range.h" #include "utils/containers/set_of.h" +#include "utils/containers/subvec.h" #include "utils/containers/sum.h" #include "utils/containers/transform.h" #include "utils/containers/unordered_set_of.h" @@ -16,7 +17,7 @@ #include "utils/orthotope/orthotope_dim_idx_t.dtg.h" #include "utils/orthotope/orthotope_dim_idx_t.h" #include "utils/containers/vector_from_idx_map.h" -#include "utils/containers/scanl.h" +#include "utils/containers/scanr.h" #include "utils/containers/all_of.h" #include "utils/fmt/vector.h" @@ -79,7 +80,7 @@ std::unordered_set get_all_surjective_projections return transform(valid_mappings, make_orthotope_projection_from_map); } -int deconflict_noninjective_dims(std::vector> const &coords_and_sizes) { +int deconflict_overlapping_dims(std::vector> const &coords_and_sizes) { if (coords_and_sizes.size() == 0) { throw mk_runtime_error("deconflict_noninjective_dims expected non-empty vector, but receieved empty vector"); } @@ -91,16 +92,11 @@ int deconflict_noninjective_dims(std::vector> const &coords_ throw mk_runtime_error(fmt::format("coords out of bounds of dim sizes: coords={}, dim_sizes={}", coords, dim_sizes)); } - - std::vector strides = scanl(dim_sizes, 1, [](int accum, int next) { return accum * next; }); + std::vector strides = scanr(subvec(dim_sizes, 1, std::nullopt), 1, [](int next, int accum) { return accum * next; }); return sum(zip_with(coords, strides, [](int coord, int stride) { return coord * stride; })); } OrthotopeCoordinate project_coordinate_through(OrthotopeSurjectiveProjection const &p, Orthotope const &o, OrthotopeCoordinate const &c) { - auto calculate_nonoverlapping_result = [](std::vector const &coords, std::vector const &dim_sizes) -> int { - NOT_IMPLEMENTED(); - }; - if (p.reversed) { NOT_IMPLEMENTED(); // TODO @lockshaw } else { diff --git a/lib/utils/test/src/utils/containers/scanl.cc b/lib/utils/test/src/utils/containers/scanl.cc index d6da0ac0a1..d2e5169312 100644 --- a/lib/utils/test/src/utils/containers/scanl.cc +++ b/lib/utils/test/src/utils/containers/scanl.cc @@ -11,16 +11,16 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("sum") { std::vector input = {1, 2, 3, 4}; std::vector result = - scanl(input, 0, [](int a, int b) { return a + b; }); + scanl(input, 0, [](int accum, int x) { return accum + x; }); std::vector correct = {0, 1, 3, 6, 10}; CHECK(result == correct); } - SUBCASE("custom function") { + SUBCASE("noncommutative function") { std::vector input = {1, 3, 1, 2}; - auto op = [](int a, int b) { return (a + 1) * (b + 1); }; + auto op = [](int accum, int x) { return accum - x; }; std::vector result = scanl(input, 1, op); - std::vector correct = {1, 4, 20, 42, 129}; + std::vector correct = {1, 0, -3, -4, -6}; CHECK(result == correct); } @@ -37,34 +37,8 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("empty input") { std::vector input = {}; std::vector result = - scanl(input, 0, [](int a, int b) { return a + b; }); - std::vector correct = {0}; - CHECK(result == correct); - } - } - - TEST_CASE("scanl1") { - SUBCASE("sum") { - std::vector input = {1, 2, 3, 4}; - std::vector result = - scanl1(input, [](int a, int b) { return a + b; }); - std::vector correct = {1, 3, 6, 10}; - CHECK(result == correct); - } - - SUBCASE("custom function") { - std::vector input = {1, 2, 5, 2}; - auto op = [](int a, int b) { return a * b + 1; }; - std::vector result = scanl1(input, op); - std::vector correct = {1, 3, 16, 33}; - CHECK(result == correct); - } - - SUBCASE("empty input") { - std::vector input = {}; - std::vector result = - scanl1(input, [](int a, int b) { return a + b; }); - std::vector correct = {}; + scanl(input, 2, [](int accum, int x) -> int { throw std::runtime_error("should not be called"); }); + std::vector correct = {2}; CHECK(result == correct); } } diff --git a/lib/utils/test/src/utils/containers/scanl1.cc b/lib/utils/test/src/utils/containers/scanl1.cc new file mode 100644 index 0000000000..426d51a83b --- /dev/null +++ b/lib/utils/test/src/utils/containers/scanl1.cc @@ -0,0 +1,35 @@ +#include "utils/containers/scanl1.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("scanl1") { + SUBCASE("sum") { + std::vector input = {1, 2, 3, 4}; + std::vector result = + scanl1(input, [](int accum, int x) { return accum + x; }); + std::vector correct = {1, 3, 6, 10}; + CHECK(result == correct); + } + + SUBCASE("noncommutative function") { + std::vector input = {1, 2, 5, 2}; + auto f = [](int accum, int x) { return accum - x; }; + std::vector result = scanl1(input, f); + std::vector correct = {1, -1, -6, -8}; + CHECK(result == correct); + } + + SUBCASE("empty input") { + std::vector input = {}; + std::vector result = + scanl1(input, [](int x, int accum) -> int { throw std::runtime_error("should not be called"); }); + std::vector correct = {}; + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/scanr.cc b/lib/utils/test/src/utils/containers/scanr.cc new file mode 100644 index 0000000000..79e6443d19 --- /dev/null +++ b/lib/utils/test/src/utils/containers/scanr.cc @@ -0,0 +1,44 @@ +#include "utils/containers/scanr.h" +#include +#include +#include "test/utils/doctest/fmt/vector.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("scanr") { + SUBCASE("sum") { + std::vector input = {1, 2, 3, 4}; + std::vector result = + scanr(input, 0, [](int x, int accum) { return x + accum; }); + std::vector correct = {10, 9, 7, 4, 0}; + CHECK(result == correct); + } + + SUBCASE("noncommutative function") { + std::vector input = {1, 3, 1, 2}; + auto op = [](int x, int accum) { return accum - x; }; + std::vector result = scanr(input, 1, op); + std::vector correct = {-6, -5, -2, -1, 1}; + CHECK(result == correct); + } + + SUBCASE("heterogeneous types") { + std::vector input = {1, 2, 3, 4}; + auto op = [](int x, std::string const &accum) { + return accum + std::to_string(x); + }; + std::vector result = scanr(input, std::string(""), op); + std::vector correct = {"4321", "432", "43", "4", ""}; + CHECK(result == correct); + } + + SUBCASE("empty input") { + std::vector input = {}; + std::vector result = + scanr(input, 2, [](int x, int accum) -> int { throw std::runtime_error("should not be called"); }); + std::vector correct = {2}; + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/scanr1.cc b/lib/utils/test/src/utils/containers/scanr1.cc new file mode 100644 index 0000000000..e2526b7102 --- /dev/null +++ b/lib/utils/test/src/utils/containers/scanr1.cc @@ -0,0 +1,35 @@ +#include "utils/containers/scanr1.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("scanr1") { + SUBCASE("sum") { + std::vector input = {1, 2, 3, 4}; + std::vector result = + scanr1(input, [](int x, int accum) { return x + accum; }); + std::vector correct = {10, 9, 7, 4}; + CHECK(result == correct); + } + + SUBCASE("noncommutative function") { + std::vector input = {1, 2, 5, 2}; + auto f = [](int x, int accum) { return accum - x; }; + std::vector result = scanr1(input, f); + std::vector correct = {-6, -5, -3, 2}; + CHECK(result == correct); + } + + SUBCASE("empty input") { + std::vector input = {}; + std::vector result = + scanr1(input, [](int x, int accum) -> int { throw std::runtime_error("should not be called"); }); + std::vector correct = {}; + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/orthotope/orthotope_surjective_projection.cc b/lib/utils/test/src/utils/orthotope/orthotope_surjective_projection.cc index b6d7875e91..0418e53cc4 100644 --- a/lib/utils/test/src/utils/orthotope/orthotope_surjective_projection.cc +++ b/lib/utils/test/src/utils/orthotope/orthotope_surjective_projection.cc @@ -5,12 +5,12 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("deconflict_noninjective_dims") { + TEST_CASE("deconflict_overlapping_dims") { SUBCASE("single input dim is unaffected") { std::vector coords = {2}; std::vector dim_sizes = {5}; - int result = deconflict_noninjective_dims(zip(coords, dim_sizes)); + int result = deconflict_overlapping_dims(zip(coords, dim_sizes)); int correct = 2; CHECK(result == correct); @@ -20,7 +20,7 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector coords = {4, 1}; std::vector dim_sizes = {5, 3}; - int result = deconflict_noninjective_dims(zip(coords, dim_sizes)); + int result = deconflict_overlapping_dims(zip(coords, dim_sizes)); int correct = 4 * 3 + 1; CHECK(result == correct); @@ -30,7 +30,7 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector coords = {1, 4}; std::vector dim_sizes = {3, 5}; - int result = deconflict_noninjective_dims(zip(coords, dim_sizes)); + int result = deconflict_overlapping_dims(zip(coords, dim_sizes)); int correct = 1 * 5 + 4; CHECK(result == correct); @@ -40,11 +40,11 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector coords = {2, 3, 1}; std::vector dim_sizes = {5, 3, 2}; - CHECK_THROWS(deconflict_noninjective_dims(zip(coords, dim_sizes))); + CHECK_THROWS(deconflict_overlapping_dims(zip(coords, dim_sizes))); } SUBCASE("throws if input is empty") { - CHECK_THROWS(deconflict_noninjective_dims({})); + CHECK_THROWS(deconflict_overlapping_dims({})); } } } From dcf2dfcdd8146b3c4127ea3a2c2405053968e90a Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Wed, 23 Oct 2024 16:18:13 -0700 Subject: [PATCH 03/62] Pass initial orthotope projection tests --- .../op-attrs/dim_ordered/dim_ordered.h | 2 +- .../src/op-attrs/dim_ordered/dim_ordered.cc | 1 + lib/pcg/src/pcg/machine_view.cc | 3 +- .../include/utils/archetypes/value_type.h | 6 + .../include/utils/containers/intersection.h | 14 ++ .../containers/map_from_keys_and_values.h | 1 + .../include/utils/containers/map_values.h | 6 +- .../include/utils/containers/merge_maps.h | 13 + lib/utils/include/utils/containers/subvec.h | 3 +- lib/utils/include/utils/containers/zip.h | 13 +- lib/utils/include/utils/containers/zip3.h | 24 ++ lib/utils/include/utils/fmt/tuple.h | 40 +++ lib/utils/include/utils/ord/vector.h | 17 ++ lib/utils/include/utils/orthotope/orthotope.h | 5 +- .../utils/orthotope/orthotope.struct.toml | 9 +- .../orthotope_bijective_projection.h | 36 +++ ...rthotope_bijective_projection.struct.toml} | 2 +- .../utils/orthotope/orthotope_coordinate.h | 2 +- .../orthotope_coordinate.struct.toml | 9 +- .../orthotope/orthotope_dim_indexed/all_of.h | 11 + .../orthotope_dim_indexed/drop_idxs_except.h | 31 +++ .../orthotope_dim_indexed.h | 183 ++++++++++++++ .../orthotope_dim_indexed_from_idx_map.h | 29 +++ .../orthotope_dim_indexed_of.h | 16 ++ .../orthotope_dim_indexed/transform.h | 17 ++ .../orthotope_dim_indexed/zip_with.h | 22 ++ .../orthotope_surjective_projection.h | 30 --- lib/utils/include/utils/tuple.h | 15 -- lib/utils/include/utils/tuple/visit.h | 21 ++ .../src/utils/containers/intersection.cc | 20 ++ .../containers/map_from_keys_and_values.cc | 14 ++ lib/utils/src/utils/containers/map_values.cc | 13 + lib/utils/src/utils/containers/merge_maps.cc | 24 ++ lib/utils/src/utils/containers/subvec.cc | 12 + lib/utils/src/utils/containers/zip.cc | 12 + lib/utils/src/utils/containers/zip3.cc | 15 ++ lib/utils/src/utils/fmt/tuple.cc | 9 + lib/utils/src/utils/ord/vector.cc | 11 + lib/utils/src/utils/orthotope/orthotope.cc | 17 +- .../orthotope_bijective_projection.cc | 228 ++++++++++++++++++ .../utils/orthotope/orthotope_coordinate.cc | 13 +- .../orthotope/orthotope_dim_indexed/all_of.cc | 10 + .../orthotope_dim_indexed/drop_idxs_except.cc | 11 + .../orthotope_dim_indexed.cc | 17 ++ .../orthotope_dim_indexed_from_idx_map.cc | 11 + .../orthotope_dim_indexed_of.cc | 11 + .../orthotope_dim_indexed/transform.cc | 13 + .../orthotope_dim_indexed/zip_with.cc | 14 ++ .../orthotope_surjective_projection.cc | 122 ---------- lib/utils/src/utils/tuple/visit.cc | 15 ++ .../include/test/utils/doctest/fmt/tuple.h | 18 ++ .../src/test/utils/doctest/fmt/tuple.cc | 8 + .../test/src/utils/containers/intersection.cc | 31 +-- lib/utils/test/src/utils/containers/zip.cc | 81 +++++++ lib/utils/test/src/utils/containers/zip3.cc | 92 +++++++ lib/utils/test/src/utils/fmt/pair.cc | 12 + lib/utils/test/src/utils/fmt/tuple.cc | 70 ++++++ .../test/src/utils/orthotope/orthotope.cc | 48 ++++ .../orthotope_bijective_projection.cc | 181 ++++++++++++++ .../orthotope_surjective_projection.cc | 50 ---- lib/utils/test/src/utils/tuple/visit.cc | 40 +++ 61 files changed, 1547 insertions(+), 277 deletions(-) create mode 100644 lib/op-attrs/src/op-attrs/dim_ordered/dim_ordered.cc create mode 100644 lib/utils/include/utils/containers/zip3.h create mode 100644 lib/utils/include/utils/fmt/tuple.h create mode 100644 lib/utils/include/utils/ord/vector.h create mode 100644 lib/utils/include/utils/orthotope/orthotope_bijective_projection.h rename lib/utils/include/utils/orthotope/{orthotope_surjective_projection.struct.toml => orthotope_bijective_projection.struct.toml} (89%) create mode 100644 lib/utils/include/utils/orthotope/orthotope_dim_indexed/all_of.h create mode 100644 lib/utils/include/utils/orthotope/orthotope_dim_indexed/drop_idxs_except.h create mode 100644 lib/utils/include/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.h create mode 100644 lib/utils/include/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_from_idx_map.h create mode 100644 lib/utils/include/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_of.h create mode 100644 lib/utils/include/utils/orthotope/orthotope_dim_indexed/transform.h create mode 100644 lib/utils/include/utils/orthotope/orthotope_dim_indexed/zip_with.h delete mode 100644 lib/utils/include/utils/orthotope/orthotope_surjective_projection.h create mode 100644 lib/utils/include/utils/tuple/visit.h create mode 100644 lib/utils/src/utils/containers/map_from_keys_and_values.cc create mode 100644 lib/utils/src/utils/containers/zip3.cc create mode 100644 lib/utils/src/utils/fmt/tuple.cc create mode 100644 lib/utils/src/utils/ord/vector.cc create mode 100644 lib/utils/src/utils/orthotope/orthotope_bijective_projection.cc create mode 100644 lib/utils/src/utils/orthotope/orthotope_dim_indexed/all_of.cc create mode 100644 lib/utils/src/utils/orthotope/orthotope_dim_indexed/drop_idxs_except.cc create mode 100644 lib/utils/src/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.cc create mode 100644 lib/utils/src/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_from_idx_map.cc create mode 100644 lib/utils/src/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_of.cc create mode 100644 lib/utils/src/utils/orthotope/orthotope_dim_indexed/transform.cc create mode 100644 lib/utils/src/utils/orthotope/orthotope_dim_indexed/zip_with.cc delete mode 100644 lib/utils/src/utils/orthotope/orthotope_surjective_projection.cc create mode 100644 lib/utils/src/utils/tuple/visit.cc create mode 100644 lib/utils/test/common/include/test/utils/doctest/fmt/tuple.h create mode 100644 lib/utils/test/common/src/test/utils/doctest/fmt/tuple.cc create mode 100644 lib/utils/test/src/utils/containers/zip.cc create mode 100644 lib/utils/test/src/utils/containers/zip3.cc create mode 100644 lib/utils/test/src/utils/fmt/tuple.cc create mode 100644 lib/utils/test/src/utils/orthotope/orthotope_bijective_projection.cc delete mode 100644 lib/utils/test/src/utils/orthotope/orthotope_surjective_projection.cc create mode 100644 lib/utils/test/src/utils/tuple/visit.cc diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h b/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h index 6aa23d40fc..2d382d86fa 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h @@ -123,7 +123,7 @@ struct DimOrdered { } reverse_iterator rend() { - return this->contents.crend(); + return this->contents.rend(); } const_reverse_iterator rend() const { diff --git a/lib/op-attrs/src/op-attrs/dim_ordered/dim_ordered.cc b/lib/op-attrs/src/op-attrs/dim_ordered/dim_ordered.cc new file mode 100644 index 0000000000..511c69d333 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/dim_ordered/dim_ordered.cc @@ -0,0 +1 @@ +#include "op-attrs/dim_ordered/dim_ordered.h" diff --git a/lib/pcg/src/pcg/machine_view.cc b/lib/pcg/src/pcg/machine_view.cc index 431b3cc4fc..1a3e240f71 100644 --- a/lib/pcg/src/pcg/machine_view.cc +++ b/lib/pcg/src/pcg/machine_view.cc @@ -8,6 +8,7 @@ #include "utils/containers/sum.h" #include "utils/containers/transform.h" #include "utils/containers/zip.h" +#include "utils/containers/zip3.h" namespace FlexFlow { @@ -74,7 +75,7 @@ std::optional get_machine_space_coordinate( int index = start_idx; for (auto [coeff, coord_point, stride] : - zip(coeffs, coord_points, strides)) { + zip3(coeffs, coord_points, strides)) { index += coeff * coord_point * stride; } return index; diff --git a/lib/utils/include/utils/archetypes/value_type.h b/lib/utils/include/utils/archetypes/value_type.h index 1635747612..ade427b392 100644 --- a/lib/utils/include/utils/archetypes/value_type.h +++ b/lib/utils/include/utils/archetypes/value_type.h @@ -3,6 +3,7 @@ #include #include +#include namespace FlexFlow { @@ -32,6 +33,11 @@ struct value_type { } }; +template +std::string format_as(value_type const &) { + assert (false); +} + } // namespace FlexFlow namespace std { diff --git a/lib/utils/include/utils/containers/intersection.h b/lib/utils/include/utils/containers/intersection.h index 938ebd68c9..2379fa5171 100644 --- a/lib/utils/include/utils/containers/intersection.h +++ b/lib/utils/include/utils/containers/intersection.h @@ -4,6 +4,7 @@ #include "utils/containers/contains.h" #include #include +#include namespace FlexFlow { @@ -19,6 +20,19 @@ std::unordered_set intersection(std::unordered_set const &l, return result; } +template +std::set intersection(std::set const &l, + std::set const &r) { + std::set result; + for (T const &ll : l) { + if (contains(r, ll)) { + result.insert(ll); + } + } + return result; +} + + template std::optional intersection(C const &c) { std::optional result; diff --git a/lib/utils/include/utils/containers/map_from_keys_and_values.h b/lib/utils/include/utils/containers/map_from_keys_and_values.h index 499965dc5e..9791bc4248 100644 --- a/lib/utils/include/utils/containers/map_from_keys_and_values.h +++ b/lib/utils/include/utils/containers/map_from_keys_and_values.h @@ -4,6 +4,7 @@ #include "utils/containers/zip.h" #include "utils/exception.h" #include +#include namespace FlexFlow { diff --git a/lib/utils/include/utils/containers/map_values.h b/lib/utils/include/utils/containers/map_values.h index 9f7a4f4add..5b527998f2 100644 --- a/lib/utils/include/utils/containers/map_values.h +++ b/lib/utils/include/utils/containers/map_values.h @@ -11,10 +11,10 @@ template > std::unordered_map map_values(std::unordered_map const &m, - F const &f) { + F &&f) { std::unordered_map result; - for (auto const &kv : m) { - result.insert({kv.first, f(kv.second)}); + for (std::pair const &kv : m) { + result.insert(std::pair{kv.first, f(kv.second)}); } return result; } diff --git a/lib/utils/include/utils/containers/merge_maps.h b/lib/utils/include/utils/containers/merge_maps.h index dd886ab8aa..14a79086d0 100644 --- a/lib/utils/include/utils/containers/merge_maps.h +++ b/lib/utils/include/utils/containers/merge_maps.h @@ -30,6 +30,19 @@ std::unordered_map merge_maps(std::unordered_map const &lhs, return result; } +template +std::unordered_map merge_maps(C const &c) { + std::unordered_map result; + + for (std::unordered_map const &m : c) { + result = merge_maps(result, m); + } + + return result; +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/subvec.h b/lib/utils/include/utils/containers/subvec.h index c89e9227de..edeba6eb58 100644 --- a/lib/utils/include/utils/containers/subvec.h +++ b/lib/utils/include/utils/containers/subvec.h @@ -5,6 +5,7 @@ #include #include #include +#include "utils/fmt/optional.h" namespace FlexFlow { @@ -23,7 +24,7 @@ std::vector subvec(std::vector const &v, new_idx = size + idx; } if (new_idx < 0 || new_idx > size) { - throw mk_runtime_error("Index {} is out of bounds for array {}"); + throw mk_runtime_error(fmt::format("Index {} is out of bounds for array of size {}", new_idx, v.size())); } return new_idx; }; diff --git a/lib/utils/include/utils/containers/zip.h b/lib/utils/include/utils/containers/zip.h index 0f6dbed1d3..7bfca5e8b1 100644 --- a/lib/utils/include/utils/containers/zip.h +++ b/lib/utils/include/utils/containers/zip.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_H -#include #include #include +#include namespace FlexFlow { @@ -17,17 +17,6 @@ std::vector> zip(std::vector const &l, return result; } -template -std::vector> zip(std::vector const &a, - std::vector const &b, - std::vector const &c) { - std::vector> result; - for (int i = 0; i < std::min({a.size(), b.size(), c.size()}); i++) { - result.push_back(std::make_tuple(a.at(i), b.at(i), c.at(i))); - } - return result; -} - } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/zip3.h b/lib/utils/include/utils/containers/zip3.h new file mode 100644 index 0000000000..18fcb28d03 --- /dev/null +++ b/lib/utils/include/utils/containers/zip3.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP3_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP3_H + +#include +#include +#include +#include + +namespace FlexFlow { + +template +std::vector> zip3(std::vector const &a, + std::vector const &b, + std::vector const &c) { + std::vector> result; + for (int i = 0; i < std::min({a.size(), b.size(), c.size()}); i++) { + result.push_back(std::make_tuple(a.at(i), b.at(i), c.at(i))); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/fmt/tuple.h b/lib/utils/include/utils/fmt/tuple.h new file mode 100644 index 0000000000..80054f8e5e --- /dev/null +++ b/lib/utils/include/utils/fmt/tuple.h @@ -0,0 +1,40 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_TUPLE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_TUPLE_H + +#include "utils/check_fmtable.h" +#include +#include +#include +#include +#include "utils/join_strings.h" +#include "utils/tuple/visit.h" + +namespace fmt { + +template +struct formatter, Char> + : formatter { + + template + auto format(std::tuple const &t, FormatContext &ctx) const + -> decltype(ctx.out()) { + + std::vector stringified_elements; + ::FlexFlow::visit_tuple(t, [&](auto const &element) -> void { stringified_elements.push_back(fmt::to_string(element)); }); + + return formatter::format("{" + ::FlexFlow::join_strings(stringified_elements, ", ") + "}", ctx); + } +}; + +} // namespace fmt + +namespace FlexFlow { + +template +std::ostream &operator<<(std::ostream &s, std::tuple const &t) { + return (s << fmt::to_string(t)); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/ord/vector.h b/lib/utils/include/utils/ord/vector.h new file mode 100644 index 0000000000..60989688b8 --- /dev/null +++ b/lib/utils/include/utils/ord/vector.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORD_VECTOR_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORD_VECTOR_H + +#include "utils/type_traits_core.h" +#include +#include + +namespace FlexFlow { + +template +std::enable_if_t, bool> operator<(std::vector const &lhs, std::vector const &rhs) { + return std::lexicographical_compare(lhs.cbegin(), lhs.cend(), rhs.cbegin(), rhs.cend()); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/orthotope/orthotope.h b/lib/utils/include/utils/orthotope/orthotope.h index fa694e7552..b5276c84e0 100644 --- a/lib/utils/include/utils/orthotope/orthotope.h +++ b/lib/utils/include/utils/orthotope/orthotope.h @@ -8,9 +8,12 @@ namespace FlexFlow { +std::set get_orthotope_dims(Orthotope const &); + bool orthotope_contains_coord(Orthotope const &, OrthotopeCoordinate const &); +int orthotope_get_volume(Orthotope const &); -Orthotope restrict_orthotope_dims(Orthotope const &, std::unordered_set const &); +Orthotope orthotope_drop_dims_except(Orthotope const &, std::set const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/orthotope/orthotope.struct.toml b/lib/utils/include/utils/orthotope/orthotope.struct.toml index 9fd715df1d..ccc07373ef 100644 --- a/lib/utils/include/utils/orthotope/orthotope.struct.toml +++ b/lib/utils/include/utils/orthotope/orthotope.struct.toml @@ -7,14 +7,9 @@ features = [ ] includes = [ - "", -] - -src_includes = [ - "utils/hash/vector.h", - "utils/fmt/vector.h", + "utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.h", ] [[fields]] name = "dims" -type = "std::vector" +type = "::FlexFlow::OrthotopeDimIndexed" diff --git a/lib/utils/include/utils/orthotope/orthotope_bijective_projection.h b/lib/utils/include/utils/orthotope/orthotope_bijective_projection.h new file mode 100644 index 0000000000..b2306e616f --- /dev/null +++ b/lib/utils/include/utils/orthotope/orthotope_bijective_projection.h @@ -0,0 +1,36 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_BIJECTIVE_PROJECTION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_BIJECTIVE_PROJECTION_H + +#include "utils/orthotope/orthotope.dtg.h" +#include "utils/orthotope/orthotope_bijective_projection.dtg.h" +#include "utils/orthotope/orthotope_coordinate.dtg.h" +#include + +namespace FlexFlow { + +OrthotopeBijectiveProjection + make_orthotope_projection_from_map(std::unordered_map const &); + +std::unordered_map get_src_to_dst_dim_map(OrthotopeBijectiveProjection const &); + +orthotope_dim_idx_t get_dst_dim_for_src_dim(OrthotopeBijectiveProjection const &, orthotope_dim_idx_t const &); + +int get_src_num_dims(OrthotopeBijectiveProjection const &); +int get_dst_num_dims(OrthotopeBijectiveProjection const &); + +OrthotopeBijectiveProjection reverse_projection(OrthotopeBijectiveProjection const &); + +std::unordered_set get_all_bijective_projections_between(Orthotope const &src, Orthotope const &dst); + +int project_into_1d(Orthotope const &, OrthotopeCoordinate const &); +OrthotopeCoordinate project_out_of_1d(int, Orthotope const &); + + +OrthotopeCoordinate project_coordinate_through(OrthotopeBijectiveProjection const &projection, + Orthotope const &src_orthotope, + OrthotopeCoordinate const &src_coord, + Orthotope const &dst_orthotope); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/orthotope/orthotope_surjective_projection.struct.toml b/lib/utils/include/utils/orthotope/orthotope_bijective_projection.struct.toml similarity index 89% rename from lib/utils/include/utils/orthotope/orthotope_surjective_projection.struct.toml rename to lib/utils/include/utils/orthotope/orthotope_bijective_projection.struct.toml index 4ae4d71dd0..08a29248b9 100644 --- a/lib/utils/include/utils/orthotope/orthotope_surjective_projection.struct.toml +++ b/lib/utils/include/utils/orthotope/orthotope_bijective_projection.struct.toml @@ -1,5 +1,5 @@ namespace = "FlexFlow" -name = "OrthotopeSurjectiveProjection" +name = "OrthotopeBijectiveProjection" features = [ "eq", "fmt", diff --git a/lib/utils/include/utils/orthotope/orthotope_coordinate.h b/lib/utils/include/utils/orthotope/orthotope_coordinate.h index 7fe068a601..c4ac1114bd 100644 --- a/lib/utils/include/utils/orthotope/orthotope_coordinate.h +++ b/lib/utils/include/utils/orthotope/orthotope_coordinate.h @@ -9,7 +9,7 @@ namespace FlexFlow { std::set get_orthotope_coord_dims(OrthotopeCoordinate const &); -OrthotopeCoordinate restrict_orthotope_coord_dims(OrthotopeCoordinate const &, std::set const &); +OrthotopeCoordinate orthotope_coord_drop_dims_except(OrthotopeCoordinate const &, std::set const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/orthotope/orthotope_coordinate.struct.toml b/lib/utils/include/utils/orthotope/orthotope_coordinate.struct.toml index cdfbee4d2a..4e99261cb1 100644 --- a/lib/utils/include/utils/orthotope/orthotope_coordinate.struct.toml +++ b/lib/utils/include/utils/orthotope/orthotope_coordinate.struct.toml @@ -7,14 +7,9 @@ features = [ ] includes = [ - "" -] - -src_includes = [ - "utils/hash/vector.h", - "utils/fmt/vector.h", + "utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.h", ] [[fields]] name = "idxs" -type = "std::vector" +type = "::FlexFlow::OrthotopeDimIndexed" diff --git a/lib/utils/include/utils/orthotope/orthotope_dim_indexed/all_of.h b/lib/utils/include/utils/orthotope/orthotope_dim_indexed/all_of.h new file mode 100644 index 0000000000..b1f4531af3 --- /dev/null +++ b/lib/utils/include/utils/orthotope/orthotope_dim_indexed/all_of.h @@ -0,0 +1,11 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_DIM_INDEXED_ALL_OF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_DIM_INDEXED_ALL_OF_H + +#include "utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.h" +namespace FlexFlow { + +bool all_of(OrthotopeDimIndexed const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/orthotope/orthotope_dim_indexed/drop_idxs_except.h b/lib/utils/include/utils/orthotope/orthotope_dim_indexed/drop_idxs_except.h new file mode 100644 index 0000000000..4ee9ee93ac --- /dev/null +++ b/lib/utils/include/utils/orthotope/orthotope_dim_indexed/drop_idxs_except.h @@ -0,0 +1,31 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_DIM_INDEXED_DROP_IDXS_EXCEPT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_DIM_INDEXED_DROP_IDXS_EXCEPT_H + +#include "utils/containers/contains.h" +#include "utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.h" +#include "utils/exception.h" +#include "utils/containers/is_subseteq_of.h" +#include "utils/fmt/set.h" + +namespace FlexFlow { + +template +OrthotopeDimIndexed drop_idxs_except(OrthotopeDimIndexed const &d, std::set const &keep) { + OrthotopeDimIndexed result; + + if (!is_subseteq_of(d.indices(), keep)) { + throw mk_runtime_error(fmt::format("drop_idxs_except expected keep to be a subset of d's dims, but got d={}, keep={}", d, keep)); + } + + for (orthotope_dim_idx_t const &idx : d.indices()) { + if (contains(keep, idx)) { + result.push_back(d.at(idx)); + } + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.h b/lib/utils/include/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.h new file mode 100644 index 0000000000..1a82e85a01 --- /dev/null +++ b/lib/utils/include/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.h @@ -0,0 +1,183 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_DIM_INDEXED_ORTHOTOPE_DIM_INDEXED_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_DIM_INDEXED_ORTHOTOPE_DIM_INDEXED_H + +#include "utils/orthotope/orthotope_dim_idx_t.dtg.h" +#include +#include "utils/hash-utils.h" +#include +#include "utils/hash/vector.h" +#include "utils/hash/tuple.h" +#include "utils/fmt/vector.h" +#include "utils/type_traits_core.h" +#include "utils/orthotope/orthotope_dim_idx_t.h" +#include +#include "utils/ord/vector.h" + +namespace FlexFlow { + +template +struct OrthotopeDimIndexed { +public: + OrthotopeDimIndexed() + : contents() + { } + + OrthotopeDimIndexed(std::initializer_list const &l) + : contents(l) + { } + + template + OrthotopeDimIndexed(Iter begin, Iter end) + : contents(begin, end) + { } + + T const &at(orthotope_dim_idx_t const &idx) const { + return this->contents.at(idx.raw_idx); + } + + T &at(orthotope_dim_idx_t const &idx) { + return this->contents.at(idx.raw_idx); + } + + T const &back() const { + return this->contents.back(); + } + + T &back() { + return this->contents.back(); + } + + T const &front() const { + return this->contents.front(); + } + + T &front() { + return this->contents.front(); + } + + void push_back(T const &t) { + this->contents.push_back(t); + } + + std::vector const &get_contents() const { + return this->contents; + } + + bool operator==(OrthotopeDimIndexed const &other) const { + return this->tie() == other.tie(); + } + + bool operator!=(OrthotopeDimIndexed const &other) const { + return this->tie() != other.tie(); + } + + std::set indices() const { + return dim_idxs_for_orthotope_with_num_dims(this->size()); + } + + std::tuple const &> tie() const { + return std::tie(contents); + } +private: + std::vector contents; +public: + using iterator = typename decltype(contents)::iterator; + using const_iterator = typename decltype(contents)::const_iterator; + using reverse_iterator = typename decltype(contents)::reverse_iterator; + using const_reverse_iterator = typename decltype(contents)::const_reverse_iterator; + + using value_type = typename decltype(contents)::value_type; + using pointer = typename decltype(contents)::pointer; + using const_pointer = typename decltype(contents)::const_pointer; + using reference = typename decltype(contents)::reference; + using const_reference = typename decltype(contents)::const_reference; + + iterator begin() { + return this->contents.begin(); + } + + const_iterator begin() const { + return this->cbegin(); + } + + const_iterator cbegin() const { + return this->contents.cbegin(); + } + + iterator end() { + return this->contents.end(); + } + + const_iterator end() const { + return this->cend(); + } + + const_iterator cend() const { + return this->contents.cend(); + } + + reverse_iterator rbegin() { + return this->contents.rbegin(); + } + + const_reverse_iterator rbegin() const { + return this->crbegin(); + } + + const_reverse_iterator crbegin() const { + return this->contents.crbegin(); + } + + reverse_iterator rend() { + return this->contents.rend(); + } + + const_reverse_iterator rend() const { + return this->crend(); + } + + const_reverse_iterator crend() const { + return this->contents.crend(); + } + + size_t size() const { + return this->contents.size(); + } + + size_t empty() const { + return this->contents.empty(); + } +}; + +// template +// std::enable_if_t, bool> operator<(OrthotopeDimIndexed const &lhs, OrthotopeDimIndexed const &rhs) { +// return lhs.tie() < rhs.tie(); +// } + +template +std::vector format_as(OrthotopeDimIndexed const &d) { + return d.get_contents(); +} + +template +std::ostream &operator<<(std::ostream &s, OrthotopeDimIndexed const &d) { + return (s << fmt::to_string(d)); +} + +} // namespace FlexFlow + +namespace std { + +template +struct hash<::FlexFlow::OrthotopeDimIndexed> { + size_t operator()(::FlexFlow::OrthotopeDimIndexed const &t) const { + static_assert(::FlexFlow::is_hashable::value, + "Elements must be hashable"); + + return ::FlexFlow::get_std_hash(t.tie()); + } +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_from_idx_map.h b/lib/utils/include/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_from_idx_map.h new file mode 100644 index 0000000000..858224dee5 --- /dev/null +++ b/lib/utils/include/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_from_idx_map.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_DIM_INDEXED_ORTHOTOPE_DIM_INDEXED_FROM_IDX_MAP_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_DIM_INDEXED_ORTHOTOPE_DIM_INDEXED_FROM_IDX_MAP_H + +#include "utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.h" +#include "utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_of.h" +#include "utils/containers/vector_from_idx_map.h" +#include "utils/containers/map_keys.h" + +namespace FlexFlow { + +template +std::optional> orthotope_dim_indexed_from_idx_map(std::unordered_map const &m) { + std::unordered_map raw_idx_map = map_keys(m, [](orthotope_dim_idx_t idx) { return idx.raw_idx; }); + + std::vector raw_vec = ({ + std::optional> returned = vector_from_idx_map(raw_idx_map); + if (!returned.has_value()) { + return std::nullopt; + } + + returned.value(); + }); + + return orthotope_dim_indexed_of(raw_vec); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_of.h b/lib/utils/include/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_of.h new file mode 100644 index 0000000000..bb1794aeb3 --- /dev/null +++ b/lib/utils/include/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_of.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_DIM_INDEXED_ORTHOTOPE_DIM_INDEXED_OF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_DIM_INDEXED_ORTHOTOPE_DIM_INDEXED_OF_H + +#include "utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.h" +#include + +namespace FlexFlow { + +template +OrthotopeDimIndexed orthotope_dim_indexed_of(std::vector const &v) { + return OrthotopeDimIndexed(v.cbegin(), v.cend()); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/orthotope/orthotope_dim_indexed/transform.h b/lib/utils/include/utils/orthotope/orthotope_dim_indexed/transform.h new file mode 100644 index 0000000000..9871284864 --- /dev/null +++ b/lib/utils/include/utils/orthotope/orthotope_dim_indexed/transform.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_DIM_INDEXED_TRANSFORM_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_DIM_INDEXED_TRANSFORM_H + +#include "utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.h" +#include "utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_of.h" +#include "utils/containers/transform.h" + +namespace FlexFlow { + +template > +OrthotopeDimIndexed transform(OrthotopeDimIndexed const &d, F &&f) { + return orthotope_dim_indexed_of(transform(d.get_contents(), f)); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/orthotope/orthotope_dim_indexed/zip_with.h b/lib/utils/include/utils/orthotope/orthotope_dim_indexed/zip_with.h new file mode 100644 index 0000000000..f37c9725bc --- /dev/null +++ b/lib/utils/include/utils/orthotope/orthotope_dim_indexed/zip_with.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_DIM_INDEXED_ZIP_WITH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_DIM_INDEXED_ZIP_WITH_H + +#include "utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.h" +#include "utils/containers/intersection.h" + +namespace FlexFlow { + +template > +OrthotopeDimIndexed zip_with(OrthotopeDimIndexed const &l, OrthotopeDimIndexed const &r, F &&f) { + OrthotopeDimIndexed result; + for (orthotope_dim_idx_t i : intersection(l.indices(), r.indices())) { + result.push_back(f(l.at(i), r.at(i))); + } + + return result; +} + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/orthotope/orthotope_surjective_projection.h b/lib/utils/include/utils/orthotope/orthotope_surjective_projection.h deleted file mode 100644 index 26c24ba7d3..0000000000 --- a/lib/utils/include/utils/orthotope/orthotope_surjective_projection.h +++ /dev/null @@ -1,30 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_SURJECTIVE_PROJECTION_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_SURJECTIVE_PROJECTION_H - -#include "utils/orthotope/orthotope.dtg.h" -#include "utils/orthotope/orthotope_surjective_projection.dtg.h" -#include "utils/orthotope/orthotope_coordinate.dtg.h" -#include - -namespace FlexFlow { - -OrthotopeSurjectiveProjection - make_orthotope_projection_from_map(std::unordered_map const &); - -std::unordered_map get_src_to_dst_dim_map(OrthotopeSurjectiveProjection const &); - -orthotope_dim_idx_t get_dst_dim_for_src_dim(OrthotopeSurjectiveProjection const &, orthotope_dim_idx_t const &); - -int get_src_num_dims(OrthotopeSurjectiveProjection const &); -int get_dst_num_dims(OrthotopeSurjectiveProjection const &); - -OrthotopeSurjectiveProjection reverse_projection(OrthotopeSurjectiveProjection const &); - -std::unordered_set get_all_surjective_projections_between(Orthotope const &src, Orthotope const &dst); - -int deconflict_overlapping_dims(std::vector> const &coords_and_sizes); -OrthotopeCoordinate project_coordinate_through(OrthotopeSurjectiveProjection const &, Orthotope const &, OrthotopeCoordinate const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/tuple.h b/lib/utils/include/utils/tuple.h index 0296e365a3..2afa87e7ac 100644 --- a/lib/utils/include/utils/tuple.h +++ b/lib/utils/include/utils/tuple.h @@ -32,21 +32,6 @@ struct index_of : index_of_impl {}; } // namespace TupleUtils -template -void visit_tuple_impl(Visitor &v, std::tuple const &tup) { - v(Idx, std::get(tup)); - if (Idx >= std::tuple_size::value) { - return; - } else { - visit_tuple_impl<(Idx + 1)>(v, tup); - } -} - -template -void visit_tuple(Visitor &v, std::tuple const &tup) { - visit_tuple_impl<0>(v, tup); -} - struct tuple_get_visitor { tuple_get_visitor() = delete; tuple_get_visitor(int requested_idx, std::any &result) diff --git a/lib/utils/include/utils/tuple/visit.h b/lib/utils/include/utils/tuple/visit.h new file mode 100644 index 0000000000..741eac1e88 --- /dev/null +++ b/lib/utils/include/utils/tuple/visit.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_TUPLE_VISIT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_TUPLE_VISIT_H + +#include +#include + +namespace FlexFlow { + +template +void visit_tuple_impl(Tuple const &tuple, Visitor &&v, std::index_sequence) { + (v(std::get(tuple)), ...); +} + +template +void visit_tuple(std::tuple const &tuple, Visitor &&v) { + visit_tuple_impl(tuple, v, std::index_sequence_for{}); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/src/utils/containers/intersection.cc b/lib/utils/src/utils/containers/intersection.cc index 9c61e99b98..b9e178ebb2 100644 --- a/lib/utils/src/utils/containers/intersection.cc +++ b/lib/utils/src/utils/containers/intersection.cc @@ -1 +1,21 @@ #include "utils/containers/intersection.h" +#include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using T = value_type<0>; + +template + std::unordered_set intersection(std::unordered_set const &, std::unordered_set const &); +template + std::optional> intersection(std::vector> const &); + +using T2 = ordered_value_type<0>; + +template + std::set intersection(std::set const &, std::set const &); +template + std::optional> intersection(std::vector> const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/map_from_keys_and_values.cc b/lib/utils/src/utils/containers/map_from_keys_and_values.cc new file mode 100644 index 0000000000..3c7dccf34f --- /dev/null +++ b/lib/utils/src/utils/containers/map_from_keys_and_values.cc @@ -0,0 +1,14 @@ +#include "utils/containers/map_from_keys_and_values.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K1 = value_type<0>; +using V1 = value_type<1>; + +template + std::unordered_map + map_from_keys_and_values(std::vector const &, + std::vector const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/map_values.cc b/lib/utils/src/utils/containers/map_values.cc index 6166aa92f1..dbe37f8d19 100644 --- a/lib/utils/src/utils/containers/map_values.cc +++ b/lib/utils/src/utils/containers/map_values.cc @@ -1 +1,14 @@ #include "utils/containers/map_values.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; +using V2 = value_type<2>; +using F = std::function; + +template +std::unordered_map map_values(std::unordered_map const &, F &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/merge_maps.cc b/lib/utils/src/utils/containers/merge_maps.cc index a36217fbeb..ea1e25e243 100644 --- a/lib/utils/src/utils/containers/merge_maps.cc +++ b/lib/utils/src/utils/containers/merge_maps.cc @@ -1 +1,25 @@ #include "utils/containers/merge_maps.h" +#include "utils/archetypes/value_type.h" +#include "utils/hash/unordered_map.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; + +template + std::unordered_map merge_maps(std::unordered_map const &, + std::unordered_map const &); + +using C = std::vector>; + +template + std::unordered_map merge_maps(C const &); + +using C2 = std::unordered_set>; + +template + std::unordered_map merge_maps(C2 const &); + + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/subvec.cc b/lib/utils/src/utils/containers/subvec.cc index 93c7de31c5..ada130d8c8 100644 --- a/lib/utils/src/utils/containers/subvec.cc +++ b/lib/utils/src/utils/containers/subvec.cc @@ -1 +1,13 @@ #include "utils/containers/subvec.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using T = value_type<0>; + +template + std::vector subvec(std::vector const &, + std::optional const &, + std::optional const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/zip.cc b/lib/utils/src/utils/containers/zip.cc index a02c9c8a35..7569640802 100644 --- a/lib/utils/src/utils/containers/zip.cc +++ b/lib/utils/src/utils/containers/zip.cc @@ -1 +1,13 @@ #include "utils/containers/zip.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using L1 = value_type<0>; +using R1 = value_type<1>; + +template + std::vector> zip(std::vector const &, + std::vector const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/zip3.cc b/lib/utils/src/utils/containers/zip3.cc new file mode 100644 index 0000000000..6f1d9e46ae --- /dev/null +++ b/lib/utils/src/utils/containers/zip3.cc @@ -0,0 +1,15 @@ +#include "utils/containers/zip3.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using A1 = value_type<0>; +using B1 = value_type<1>; +using C1 = value_type<2>; + +template + std::vector> zip3(std::vector const &, + std::vector const &, + std::vector const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/fmt/tuple.cc b/lib/utils/src/utils/fmt/tuple.cc new file mode 100644 index 0000000000..3ba2c53d89 --- /dev/null +++ b/lib/utils/src/utils/fmt/tuple.cc @@ -0,0 +1,9 @@ +#include "utils/fmt/tuple.h" + +namespace FlexFlow { + + +template + std::ostream &operator<<(std::ostream &s, std::tuple const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/ord/vector.cc b/lib/utils/src/utils/ord/vector.cc new file mode 100644 index 0000000000..c630bc8e46 --- /dev/null +++ b/lib/utils/src/utils/ord/vector.cc @@ -0,0 +1,11 @@ +#include "utils/ord/vector.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using T = ordered_value_type<0>; + +template + bool operator<(std::vector const &, std::vector const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/orthotope/orthotope.cc b/lib/utils/src/utils/orthotope/orthotope.cc index 2cf75acb73..075db4f94a 100644 --- a/lib/utils/src/utils/orthotope/orthotope.cc +++ b/lib/utils/src/utils/orthotope/orthotope.cc @@ -1,10 +1,17 @@ #include "utils/orthotope/orthotope.h" -#include "utils/containers/zip_with.h" +#include "utils/containers/product.h" +#include "utils/orthotope/orthotope_dim_indexed/drop_idxs_except.h" +#include "utils/orthotope/orthotope_dim_indexed/zip_with.h" +#include "utils/orthotope/orthotope_dim_indexed/all_of.h" #include "utils/containers/all_of.h" #include "utils/exception.h" namespace FlexFlow { +std::set get_orthotope_dims(Orthotope const &orthotope) { + return orthotope.dims.indices(); +} + bool orthotope_contains_coord(Orthotope const &o, OrthotopeCoordinate const &c) { if (o.dims.size() != c.idxs.size()) { throw mk_runtime_error(fmt::format("orthotope_contains_coord expected orthotope and coord to have the same number of dims, but received o={}, c={}", o, c)); @@ -13,4 +20,12 @@ bool orthotope_contains_coord(Orthotope const &o, OrthotopeCoordinate const &c) return all_of(zip_with(o.dims, c.idxs, [](int dim_size, int dim_coord) { return dim_coord >= 0 && dim_coord < dim_size; })); } +int orthotope_get_volume(Orthotope const &o) { + return product(o.dims.get_contents()); +} + +Orthotope orthotope_drop_dims_except(Orthotope const &o, std::set const &keep) { + return Orthotope{drop_idxs_except(o.dims, keep)}; +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/orthotope/orthotope_bijective_projection.cc b/lib/utils/src/utils/orthotope/orthotope_bijective_projection.cc new file mode 100644 index 0000000000..d9a62dc1e6 --- /dev/null +++ b/lib/utils/src/utils/orthotope/orthotope_bijective_projection.cc @@ -0,0 +1,228 @@ +#include "utils/orthotope/orthotope_bijective_projection.h" +#include "utils/containers/generate_map.h" +#include "utils/containers/get_all_assignments.h" +#include "utils/containers/group_by.h" +#include "utils/containers/map_from_keys_and_values.h" +#include "utils/containers/map_keys.h" +#include "utils/containers/map_values.h" +#include "utils/containers/merge_maps.h" +#include "utils/containers/range.h" +#include "utils/containers/set_of.h" +#include "utils/containers/subvec.h" +#include "utils/containers/sum.h" +#include "utils/containers/transform.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/containers/filter.h" +#include "utils/containers/values.h" +#include "utils/containers/zip_with.h" +#include "utils/exception.h" +#include "utils/orthotope/orthotope.h" +#include "utils/orthotope/orthotope_coordinate.h" +#include "utils/orthotope/orthotope_dim_idx_t.dtg.h" +#include "utils/orthotope/orthotope_dim_idx_t.h" +#include "utils/containers/vector_from_idx_map.h" +#include "utils/containers/scanr.h" +#include "utils/containers/scanr1.h" +#include "utils/containers/all_of.h" +#include "utils/fmt/vector.h" +#include "utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_from_idx_map.h" + +namespace FlexFlow { + +OrthotopeBijectiveProjection + make_orthotope_projection_from_map(std::unordered_map const &m) { + std::unordered_map raw_idx_map = map_keys(m, [](orthotope_dim_idx_t const &k) { return k.raw_idx; }); + return OrthotopeBijectiveProjection{ + /*dim_mapping=*/vector_from_idx_map(raw_idx_map).value(), + /*reversed=*/false, + }; +} + +std::unordered_map get_src_to_dst_dim_map(OrthotopeBijectiveProjection const &p) { + if (p.reversed) { + throw mk_runtime_error(fmt::format("get_src_to_dst_dim_map expected p.reversed=false, but received p={}", p)); + } + + std::unordered_map raw_idx_map = generate_map(range(p.dim_mapping.size()), [&](int x) { return p.dim_mapping.at(x); }); + return map_keys(raw_idx_map, [](int src_dim_idx) { return orthotope_dim_idx_t{src_dim_idx}; }); +} + +orthotope_dim_idx_t get_dst_dim_for_src_dim(OrthotopeBijectiveProjection const &p, orthotope_dim_idx_t const &src_idx) { + if (p.reversed) { + throw mk_runtime_error(fmt::format("get_dst_dim_for_src_dim expected a non-reversed projection, but received: projection={}", p)); + } + + return p.dim_mapping.at(src_idx.raw_idx); +} + +orthotope_dim_idx_t get_src_dim_for_dst_dim(OrthotopeBijectiveProjection const &p, orthotope_dim_idx_t const &dst_idx) { + if (!p.reversed) { + throw mk_runtime_error(fmt::format("get_src_dim_for_dst_dim expected a reversed projection, but received: projection={}", p)); + } + + return get_dst_dim_for_src_dim(reverse_projection(p), dst_idx); +} + +int get_src_num_dims(OrthotopeBijectiveProjection const &p) { + if (p.reversed) { + return get_dst_num_dims(reverse_projection(p)); + } + + return p.dim_mapping.size(); +} + +int get_dst_num_dims(OrthotopeBijectiveProjection const &p) { + if (p.reversed) { + return get_src_num_dims(reverse_projection(p)); + } + + return unordered_set_of(p.dim_mapping).size(); +} + +OrthotopeBijectiveProjection reverse_projection(OrthotopeBijectiveProjection const &p) { + OrthotopeBijectiveProjection result = p; + result.reversed = !p.reversed; + return result; +} + +std::unordered_set get_all_bijective_projections_between(int src_num_dims, int dst_num_dims) { + if (src_num_dims < dst_num_dims) { + return transform(get_all_bijective_projections_between(dst_num_dims, src_num_dims), + [](OrthotopeBijectiveProjection const &p) { return reverse_projection(p); }); + } + + std::set src_dim_idxs = dim_idxs_for_orthotope_with_num_dims(src_num_dims); + std::set dst_dim_idxs = dim_idxs_for_orthotope_with_num_dims(dst_num_dims); + + std::unordered_map> src_to_dst_idxs = + generate_map(src_dim_idxs, [&](orthotope_dim_idx_t) { return unordered_set_of(dst_dim_idxs); }); + + std::unordered_set> valid_mappings = + filter(get_all_assignments(src_to_dst_idxs), + [&](std::unordered_map const &src_to_dst_idx) { + return set_of(values(src_to_dst_idx)) == dst_dim_idxs; + }); + + return transform(valid_mappings, make_orthotope_projection_from_map); +} + +int project_into_1d(Orthotope const &orthotope, OrthotopeCoordinate const &coord) { + if (!orthotope_contains_coord(orthotope, coord)) { + throw mk_runtime_error(fmt::format("coord out of bounds of orthotope: orthotope={}, coord={}", orthotope, coord)); + } + + if (orthotope.dims.size() == 0) { + return 0; + } + + std::vector> coords_and_sizes = zip(coord.idxs.get_contents(), + orthotope.dims.get_contents()); + + std::vector coords = transform(coords_and_sizes, [](std::pair const &p) { return p.first; }); + std::vector dim_sizes = transform(coords_and_sizes, [](std::pair const &p) { return p.second; }); + + std::vector strides = scanr(subvec(dim_sizes, 1, std::nullopt), 1, [](int next, int accum) { return accum * next; }); + return sum(zip_with(coords, strides, [](int coord, int stride) { return coord * stride; })); +} + +OrthotopeCoordinate project_out_of_1d(int one_dimensional_coord, Orthotope const &dst_orthotope) { + if (dst_orthotope.dims.size() == 0) { + if (one_dimensional_coord == 0) { + return OrthotopeCoordinate{{}}; + } else { + throw mk_runtime_error(fmt::format("Only valid one_dimensional_coord for zero-dimensional orthotope is 0, but receieved one_dimensional_coord={}", one_dimensional_coord)); + } + } + + if (one_dimensional_coord >= orthotope_get_volume(dst_orthotope)) { + throw mk_runtime_error(fmt::format("project_out_of_1d received coordinate that would be out of bounds of dst orthotope: dst_orthotope={}, coordinate={}", dst_orthotope, one_dimensional_coord)); + } + + std::vector dim_sizes = dst_orthotope.dims.get_contents(); + std::vector strides = scanr(subvec(dim_sizes, 1, std::nullopt), 1, [](int next, int accum) { return accum * next; }); + + OrthotopeCoordinate result = OrthotopeCoordinate{ + orthotope_dim_indexed_of(zip_with(dim_sizes, strides, [&](int dim_size, int stride) { return (one_dimensional_coord / stride) % dim_size; })), + }; + return result; +} + +OrthotopeCoordinate project_coordinate_through(OrthotopeBijectiveProjection const &p, Orthotope const &src_orthotope, OrthotopeCoordinate const &src_coord, Orthotope const &dst_orthotope) { + std::set dst_dim_idxs = transform(get_orthotope_dims(dst_orthotope), [](orthotope_dim_idx_t const &idx) { return idx; }); + std::set src_dim_idxs = transform(get_orthotope_dims(src_orthotope), [](orthotope_dim_idx_t const &idx) { return idx; }); + + if (src_coord.idxs.size() != get_src_num_dims(p)) { + throw mk_runtime_error(fmt::format("project_coordinate_through requires projection src and coordinate to have same num dims, but got {} and {} respectively", + get_src_num_dims(p), + src_coord.idxs.size())); + } + + if (!orthotope_contains_coord(src_orthotope, src_coord)) { + throw mk_runtime_error(fmt::format("project_coordinate_through requires coord to be in the orthotope, but got coord={} and orthotope={} respectively", src_coord, src_orthotope)); + } + + if (p.reversed) { + std::unordered_map> + dst_dim_idxs_by_src_dim_idx = + group_by(dst_dim_idxs, + [&](orthotope_dim_idx_t const &dst_dim_idx) { return get_src_dim_for_dst_dim(p, dst_dim_idx); }); + + + std::unordered_map dst_sub_orthotopes_by_src_dim_idx = + map_values(dst_dim_idxs_by_src_dim_idx, + [&](std::set const &dst_dim_idxs) { + return orthotope_drop_dims_except(dst_orthotope, dst_dim_idxs); + }); + + std::unordered_map dst_coords_by_src_dim_idx = + generate_map(src_dim_idxs, + [&](orthotope_dim_idx_t const &src_idx) -> OrthotopeCoordinate { + return project_out_of_1d(src_coord.idxs.at(src_idx), + dst_sub_orthotopes_by_src_dim_idx.at(src_idx)); + }); + + std::unordered_map dst_coords = merge_maps( + transform(vector_of(src_dim_idxs), [&](orthotope_dim_idx_t const &src_idx) -> std::unordered_map { + return map_from_keys_and_values( + vector_of(dst_dim_idxs_by_src_dim_idx.at(src_idx)), + dst_coords_by_src_dim_idx.at(src_idx).idxs.get_contents()); + })); + + return OrthotopeCoordinate{ + orthotope_dim_indexed_from_idx_map(dst_coords).value(), + }; + } else { + std::unordered_map> src_dim_idxs_by_dst_dim_idx = + group_by(src_dim_idxs, + [&](orthotope_dim_idx_t const &src_dim_idx) { return get_dst_dim_for_src_dim(p, src_dim_idx); }); + + + std::unordered_map src_sub_orthotopes_by_dst_dim_idx = + map_values(src_dim_idxs_by_dst_dim_idx, + [&](std::set const &src_dim_idxs) { + return orthotope_drop_dims_except(src_orthotope, src_dim_idxs); + }); + + std::unordered_map src_sub_coords_by_dst_dim_idx = + map_values(src_dim_idxs_by_dst_dim_idx, + [&](std::set const &src_dim_idxs) { + return orthotope_coord_drop_dims_except(src_coord, src_dim_idxs); + }); + + std::unordered_map dst_coords = + generate_map(dst_dim_idxs, + [&](orthotope_dim_idx_t const &dst_idx) { + return project_into_1d( + src_sub_orthotopes_by_dst_dim_idx.at(dst_idx), + src_sub_coords_by_dst_dim_idx.at(dst_idx)); + }); + + + return OrthotopeCoordinate{ + orthotope_dim_indexed_from_idx_map(dst_coords).value(), + }; + } +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/orthotope/orthotope_coordinate.cc b/lib/utils/src/utils/orthotope/orthotope_coordinate.cc index 35fc52e258..bcf310c043 100644 --- a/lib/utils/src/utils/orthotope/orthotope_coordinate.cc +++ b/lib/utils/src/utils/orthotope/orthotope_coordinate.cc @@ -4,21 +4,16 @@ #include "utils/exception.h" #include "utils/orthotope/orthotope_dim_idx_t.h" #include "utils/fmt/set.h" +#include "utils/orthotope/orthotope_dim_indexed/drop_idxs_except.h" namespace FlexFlow { std::set get_orthotope_coord_dims(OrthotopeCoordinate const &coord) { - return dim_idxs_for_orthotope_with_num_dims(coord.idxs.size()); + return coord.idxs.indices(); } -OrthotopeCoordinate restrict_orthotope_coord_dims(OrthotopeCoordinate const &coord, std::set const &mask) { - std::set coord_dims = get_orthotope_coord_dims(coord); - - if (!is_subseteq_of(coord_dims, mask)) { - throw mk_runtime_error(fmt::format("restrict_orthotope_coord_dims expected mask to be a subset of coord dims, but got coord={}, mask={}", coord, mask)); - } - - std::vector new_idxs = filter_idxs(coord.idxs, [&](int i) { return contains(mask, orthotope_dim_idx_t{i}); }); +OrthotopeCoordinate orthotope_coord_drop_dims_except(OrthotopeCoordinate const &coord, std::set const &mask) { + OrthotopeDimIndexed new_idxs = drop_idxs_except(coord.idxs, mask); return OrthotopeCoordinate{new_idxs}; } diff --git a/lib/utils/src/utils/orthotope/orthotope_dim_indexed/all_of.cc b/lib/utils/src/utils/orthotope/orthotope_dim_indexed/all_of.cc new file mode 100644 index 0000000000..3a4d680392 --- /dev/null +++ b/lib/utils/src/utils/orthotope/orthotope_dim_indexed/all_of.cc @@ -0,0 +1,10 @@ +#include "utils/orthotope/orthotope_dim_indexed/all_of.h" +#include "utils/containers/all_of.h" + +namespace FlexFlow { + +bool all_of(OrthotopeDimIndexed const &d) { + return all_of(d.get_contents()); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/orthotope/orthotope_dim_indexed/drop_idxs_except.cc b/lib/utils/src/utils/orthotope/orthotope_dim_indexed/drop_idxs_except.cc new file mode 100644 index 0000000000..13d9fa2779 --- /dev/null +++ b/lib/utils/src/utils/orthotope/orthotope_dim_indexed/drop_idxs_except.cc @@ -0,0 +1,11 @@ +#include "utils/orthotope/orthotope_dim_indexed/drop_idxs_except.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using T = value_type<0>; + +template + OrthotopeDimIndexed drop_idxs_except(OrthotopeDimIndexed const &, std::set const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.cc b/lib/utils/src/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.cc new file mode 100644 index 0000000000..56662bd7de --- /dev/null +++ b/lib/utils/src/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.cc @@ -0,0 +1,17 @@ +#include "utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.h" +#include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using T = value_type<0>; + +template + struct OrthotopeDimIndexed; + +using T2 = ordered_value_type<0>; + +// template +// bool operator<(OrthotopeDimIndexed const &, OrthotopeDimIndexed const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_from_idx_map.cc b/lib/utils/src/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_from_idx_map.cc new file mode 100644 index 0000000000..f7d4012688 --- /dev/null +++ b/lib/utils/src/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_from_idx_map.cc @@ -0,0 +1,11 @@ +#include "utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_from_idx_map.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using T = value_type<0>; + +template + std::optional> orthotope_dim_indexed_from_idx_map(std::unordered_map const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_of.cc b/lib/utils/src/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_of.cc new file mode 100644 index 0000000000..7d4a515205 --- /dev/null +++ b/lib/utils/src/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_of.cc @@ -0,0 +1,11 @@ +#include "utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_of.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using T = value_type<0>; + +template + OrthotopeDimIndexed orthotope_dim_indexed_of(std::vector const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/orthotope/orthotope_dim_indexed/transform.cc b/lib/utils/src/utils/orthotope/orthotope_dim_indexed/transform.cc new file mode 100644 index 0000000000..593e673ee9 --- /dev/null +++ b/lib/utils/src/utils/orthotope/orthotope_dim_indexed/transform.cc @@ -0,0 +1,13 @@ +#include "utils/orthotope/orthotope_dim_indexed/transform.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using T = value_type<0>; +using Result = value_type<1>; +using F = std::function; + +template + OrthotopeDimIndexed transform(OrthotopeDimIndexed const &, F &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/orthotope/orthotope_dim_indexed/zip_with.cc b/lib/utils/src/utils/orthotope/orthotope_dim_indexed/zip_with.cc new file mode 100644 index 0000000000..d9f9e4bb1f --- /dev/null +++ b/lib/utils/src/utils/orthotope/orthotope_dim_indexed/zip_with.cc @@ -0,0 +1,14 @@ +#include "utils/orthotope/orthotope_dim_indexed/zip_with.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using T1 = value_type<0>; +using T2 = value_type<1>; +using Result = value_type<2>; +using F = std::function; + +template + OrthotopeDimIndexed zip_with(OrthotopeDimIndexed const &, OrthotopeDimIndexed const &, F &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/orthotope/orthotope_surjective_projection.cc b/lib/utils/src/utils/orthotope/orthotope_surjective_projection.cc deleted file mode 100644 index ae66ee4b12..0000000000 --- a/lib/utils/src/utils/orthotope/orthotope_surjective_projection.cc +++ /dev/null @@ -1,122 +0,0 @@ -#include "utils/orthotope/orthotope_surjective_projection.h" -#include "utils/containers/generate_map.h" -#include "utils/containers/get_all_assignments.h" -#include "utils/containers/group_by.h" -#include "utils/containers/map_keys.h" -#include "utils/containers/range.h" -#include "utils/containers/set_of.h" -#include "utils/containers/subvec.h" -#include "utils/containers/sum.h" -#include "utils/containers/transform.h" -#include "utils/containers/unordered_set_of.h" -#include "utils/containers/filter.h" -#include "utils/containers/values.h" -#include "utils/containers/zip_with.h" -#include "utils/exception.h" -#include "utils/orthotope/orthotope.h" -#include "utils/orthotope/orthotope_dim_idx_t.dtg.h" -#include "utils/orthotope/orthotope_dim_idx_t.h" -#include "utils/containers/vector_from_idx_map.h" -#include "utils/containers/scanr.h" -#include "utils/containers/all_of.h" -#include "utils/fmt/vector.h" - -namespace FlexFlow { - -OrthotopeSurjectiveProjection - make_orthotope_projection_from_map(std::unordered_map const &m) { - std::unordered_map raw_idx_map = map_keys(m, [](orthotope_dim_idx_t const &k) { return k.raw_idx; }); - return OrthotopeSurjectiveProjection{ - /*dim_mapping=*/vector_from_idx_map(raw_idx_map).value(), - /*reversed=*/false, - }; -} - -std::unordered_map get_src_to_dst_dim_map(OrthotopeSurjectiveProjection const &p) { - if (p.reversed) { - throw mk_runtime_error(fmt::format("get_src_to_dst_dim_map expected p.reversed=false, but received p={}", p)); - } - - std::unordered_map raw_idx_map = generate_map(range(p.dim_mapping.size()), [&](int x) { return p.dim_mapping.at(x); }); - return map_keys(raw_idx_map, [](int src_dim_idx) { return orthotope_dim_idx_t{src_dim_idx}; }); -} - -orthotope_dim_idx_t get_dst_dim_for_src_dim(OrthotopeSurjectiveProjection const &p, orthotope_dim_idx_t const &src_idx) { - return p.dim_mapping.at(src_idx.raw_idx); -} - -int get_src_num_dims(OrthotopeSurjectiveProjection const &p) { - return p.dim_mapping.size(); -} - -int get_dst_num_dims(OrthotopeSurjectiveProjection const &p) { - return unordered_set_of(p.dim_mapping).size(); -} - -OrthotopeSurjectiveProjection reverse_projection(OrthotopeSurjectiveProjection const &p) { - OrthotopeSurjectiveProjection result = p; - result.reversed = !p.reversed; - return result; -} - -std::unordered_set get_all_surjective_projections_between(int src_num_dims, int dst_num_dims) { - if (src_num_dims < dst_num_dims) { - return transform(get_all_surjective_projections_between(dst_num_dims, src_num_dims), - [](OrthotopeSurjectiveProjection const &p) { return reverse_projection(p); }); - } - - std::set src_dim_idxs = dim_idxs_for_orthotope_with_num_dims(src_num_dims); - std::set dst_dim_idxs = dim_idxs_for_orthotope_with_num_dims(dst_num_dims); - - std::unordered_map> src_to_dst_idxs = - generate_map(src_dim_idxs, [&](orthotope_dim_idx_t) { return unordered_set_of(dst_dim_idxs); }); - - std::unordered_set> valid_mappings = - filter(get_all_assignments(src_to_dst_idxs), - [&](std::unordered_map const &src_to_dst_idx) { - return set_of(values(src_to_dst_idx)) == dst_dim_idxs; - }); - - return transform(valid_mappings, make_orthotope_projection_from_map); -} - -int deconflict_overlapping_dims(std::vector> const &coords_and_sizes) { - if (coords_and_sizes.size() == 0) { - throw mk_runtime_error("deconflict_noninjective_dims expected non-empty vector, but receieved empty vector"); - } - - std::vector coords = transform(coords_and_sizes, [](std::pair const &p) { return p.first; }); - std::vector dim_sizes = transform(coords_and_sizes, [](std::pair const &p) { return p.second; }); - - if (!all_of(zip_with(coords, dim_sizes, [](int coord, int dim_size) { return coord > 0 && coord < dim_size; }))) { - throw mk_runtime_error(fmt::format("coords out of bounds of dim sizes: coords={}, dim_sizes={}", coords, dim_sizes)); - } - - std::vector strides = scanr(subvec(dim_sizes, 1, std::nullopt), 1, [](int next, int accum) { return accum * next; }); - return sum(zip_with(coords, strides, [](int coord, int stride) { return coord * stride; })); -} - -OrthotopeCoordinate project_coordinate_through(OrthotopeSurjectiveProjection const &p, Orthotope const &o, OrthotopeCoordinate const &c) { - if (p.reversed) { - NOT_IMPLEMENTED(); // TODO @lockshaw - } else { - if (c.idxs.size() != get_src_num_dims(p)) { - throw mk_runtime_error(fmt::format("project_coordinate_through requires projection src and coordinate to have same num dims, but got {} and {} respectively", - get_src_num_dims(p), - c.idxs.size())); - } - - if (!orthotope_contains_coord(o, c)) { - throw mk_runtime_error(fmt::format("project_coordinate_through requires coord to be in the orthotope, but got coord={} and orthotope={} respectively", c, o)); - } - - std::unordered_map> by_dst_dim_idx = - group_by(dim_idxs_for_orthotope_with_num_dims(o.dims.size()), - [&](orthotope_dim_idx_t const &src_dim_idx) { return get_dst_dim_for_src_dim(p, src_dim_idx); }); - - - NOT_IMPLEMENTED(); - } -} - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/tuple/visit.cc b/lib/utils/src/utils/tuple/visit.cc new file mode 100644 index 0000000000..f0d218b207 --- /dev/null +++ b/lib/utils/src/utils/tuple/visit.cc @@ -0,0 +1,15 @@ +#include "utils/tuple/visit.h" +#include "utils/archetypes/value_type.h" +#include + +namespace FlexFlow { + +using T1 = value_type<0>; +using T2 = value_type<1>; +using T3 = value_type<2>; +using Visitor = std::function const &)>; + +template + void visit_tuple(std::tuple const &, Visitor &&); + +} // namespace FlexFlow diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/tuple.h b/lib/utils/test/common/include/test/utils/doctest/fmt/tuple.h new file mode 100644 index 0000000000..c4a4b6ac63 --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/tuple.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_TUPLE_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_TUPLE_H + +#include "utils/fmt/tuple.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::tuple const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/tuple.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/tuple.cc new file mode 100644 index 0000000000..717bef5ed8 --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/tuple.cc @@ -0,0 +1,8 @@ +#include "test/utils/doctest/fmt/tuple.h" + +namespace doctest { + +template + struct StringMaker>; + +} // namespace doctest diff --git a/lib/utils/test/src/utils/containers/intersection.cc b/lib/utils/test/src/utils/containers/intersection.cc index 52de6ee6d3..c9beaa8d85 100644 --- a/lib/utils/test/src/utils/containers/intersection.cc +++ b/lib/utils/test/src/utils/containers/intersection.cc @@ -2,45 +2,46 @@ #include "test/utils/doctest/fmt/optional.h" #include "test/utils/doctest/fmt/unordered_set.h" #include +#include "test/utils/doctest/fmt/set.h" using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("intersection(std::unordered_set, std::unordered_set)") { - std::unordered_set input_l = {1, 2, 3}; - std::unordered_set input_r = {2, 3, 5}; + TEST_CASE_TEMPLATE("intersection(S, S)", S, std::unordered_set, std::set) { + S input_l = {1, 2, 3}; + S input_r = {2, 3, 5}; - std::unordered_set result = intersection(input_l, input_r); - std::unordered_set correct = {2, 3}; + S result = intersection(input_l, input_r); + S correct = {2, 3}; CHECK(result == correct); } - TEST_CASE("intersection(C)") { + TEST_CASE_TEMPLATE("intersection(C)", S, std::unordered_set, std::set) { SUBCASE("input is empty container") { - std::vector> input = {}; + std::vector input = {}; - std::optional> result = intersection(input); - std::optional> correct = std::nullopt; + std::optional result = intersection(input); + std::optional correct = std::nullopt; CHECK(result == correct); } SUBCASE("input is has only one set") { - std::vector> input = {{1, 2, 3}}; + std::vector input = {{1, 2, 3}}; - std::optional> result = intersection(input); - std::optional> correct = {{1, 2, 3}}; + std::optional result = intersection(input); + std::optional correct = {{1, 2, 3}}; CHECK(result == correct); } SUBCASE("input has multiple sets") { - std::vector> input = { + std::vector input = { {1, 2, 3}, {2, 3, 4}, {3, 4, 5}}; - std::optional> result = intersection(input); - std::optional> correct = {{3}}; + std::optional result = intersection(input); + std::optional correct = {{3}}; CHECK(result == correct); } diff --git a/lib/utils/test/src/utils/containers/zip.cc b/lib/utils/test/src/utils/containers/zip.cc new file mode 100644 index 0000000000..c305e53f69 --- /dev/null +++ b/lib/utils/test/src/utils/containers/zip.cc @@ -0,0 +1,81 @@ +#include +#include "utils/containers/zip.h" +#include +#include "test/utils/doctest/fmt/vector.h" +#include "test/utils/doctest/fmt/pair.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("zip(std::vector, std::vector)") { + SUBCASE("L and R types are the same") { + std::vector lhs = {2, 1, 2}; + std::vector rhs = {5, 4, 8}; + + std::vector> result = zip(lhs, rhs); + std::vector> correct = {{2, 5}, {1, 4}, {2, 8}}; + + CHECK(result == correct); + } + + SUBCASE("L and R types are different") { + std::vector lhs = {"a", "b", "b"}; + std::vector rhs = {5, 4, 8}; + + std::vector> result = zip(lhs, rhs); + std::vector> correct = {{"a", 5}, {"b", 4}, {"b", 8}}; + + CHECK(result == correct); + } + + SUBCASE("left is longer than right") { + std::vector lhs = {2, 1, 2}; + std::vector rhs = {5, 4}; + + std::vector> result = zip(lhs, rhs); + std::vector> correct = {{2, 5}, {1, 4}}; + + CHECK(result == correct); + } + + SUBCASE("right is longer than left") { + std::vector lhs = {2}; + std::vector rhs = {5, 4, 8}; + + std::vector> result = zip(lhs, rhs); + std::vector> correct = {{2, 5}}; + + CHECK(result == correct); + } + + SUBCASE("left is empty") { + std::vector lhs = {}; + std::vector rhs = {5, 4, 8}; + + std::vector> result = zip(lhs, rhs); + std::vector> correct = {}; + + CHECK(result == correct); + } + + SUBCASE("right is empty") { + std::vector lhs = {2, 1, 2}; + std::vector rhs = {}; + + std::vector> result = zip(lhs, rhs); + std::vector> correct = {}; + + CHECK(result == correct); + } + + SUBCASE("both are empty") { + std::vector lhs = {}; + std::vector rhs = {}; + + std::vector> result = zip(lhs, rhs); + std::vector> correct = {}; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/zip3.cc b/lib/utils/test/src/utils/containers/zip3.cc new file mode 100644 index 0000000000..f1613105ee --- /dev/null +++ b/lib/utils/test/src/utils/containers/zip3.cc @@ -0,0 +1,92 @@ +#include +#include "utils/containers/zip3.h" +#include +#include "test/utils/doctest/fmt/vector.h" +#include "test/utils/doctest/fmt/tuple.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("zip3(std::vector, std::vector, std::vector)") { + SUBCASE("types are same") { + std::vector input_a = {2, 1, 2}; + std::vector input_b = {5, 4, 5}; + std::vector input_c = {3, 4, 3}; + + std::vector> result = zip3(input_a, input_b, input_c); + std::vector> correct = {{2, 5, 3}, {1, 4, 4}, {2, 5, 3}}; + + CHECK(result == correct); + } + + SUBCASE("types are different") { + std::vector input_a = {2, 1, 2}; + std::vector input_b = {"a", "d", "d"}; + std::vector> input_c = {{1, 2}, {}, {3, 1}}; + + std::vector>> result = zip3(input_a, input_b, input_c); + std::vector>> correct = { + {2, "a", {1, 2}}, + {1, "d", {}}, + {2, "d", {3, 1}}, + }; + + CHECK(result == correct); + } + + SUBCASE("A list is shortest") { + std::vector input_a = {2}; + std::vector input_b = {5, 4, 5}; + std::vector input_c = {3, 4}; + + std::vector> result = zip3(input_a, input_b, input_c); + std::vector> correct = {{2, 5, 3}}; + + CHECK(result == correct); + } + + SUBCASE("B list is shortest") { + std::vector input_a = {2, 1, 2, 4}; + std::vector input_b = {5, 4}; + std::vector input_c = {3, 4, 3}; + + std::vector> result = zip3(input_a, input_b, input_c); + std::vector> correct = {{2, 5, 3}, {1, 4, 4}}; + + CHECK(result == correct); + } + + SUBCASE("C list is shortest") { + std::vector input_a = {2, 1, 2}; + std::vector input_b = {5, 4, 5}; + std::vector input_c = {3, 3}; + + std::vector> result = zip3(input_a, input_b, input_c); + std::vector> correct = {{2, 5, 3}, {1, 4, 3}}; + + CHECK(result == correct); + } + + SUBCASE("one list is empty") { + std::vector input_a = {2, 1, 2}; + std::vector input_b = {5, 4, 5}; + std::vector input_c = {}; + + std::vector> result = zip3(input_a, input_b, input_c); + std::vector> correct = {}; + + CHECK(result == correct); + } + + SUBCASE("all lists are empty") { + std::vector input_a = {}; + std::vector input_b = {}; + std::vector input_c = {}; + + std::vector> result = zip3(input_a, input_b, input_c); + std::vector> correct = {}; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/fmt/pair.cc b/lib/utils/test/src/utils/fmt/pair.cc index e848eb08c7..14b62730fa 100644 --- a/lib/utils/test/src/utils/fmt/pair.cc +++ b/lib/utils/test/src/utils/fmt/pair.cc @@ -1,5 +1,6 @@ #include "utils/fmt/pair.h" #include +#include using namespace ::FlexFlow; @@ -10,4 +11,15 @@ TEST_SUITE(FF_TEST_SUITE) { std::string correct = "{3, 5}"; CHECK(result == correct); } + + TEST_CASE("operator<<(ostream &, std::pair)") { + std::pair input = {3, 5}; + + std::ostringstream oss; + oss << input; + std::string result = oss.str(); + + std::string correct = "{3, 5}"; + CHECK(result == correct); + } } diff --git a/lib/utils/test/src/utils/fmt/tuple.cc b/lib/utils/test/src/utils/fmt/tuple.cc new file mode 100644 index 0000000000..1ee7d63a1f --- /dev/null +++ b/lib/utils/test/src/utils/fmt/tuple.cc @@ -0,0 +1,70 @@ +#include "utils/fmt/tuple.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("fmt::to_string(std::tuple)") { + SUBCASE("types are different") { + std::tuple input = {3, false, "hello"}; + + std::string result = fmt::to_string(input); + std::string correct = "{3, false, hello}"; + + CHECK(result == correct); + } + + SUBCASE("types are the same") { + std::tuple input = {3, 5}; + + std::string result = fmt::to_string(input); + std::string correct = "{3, 5}"; + + CHECK(result == correct); + } + + SUBCASE("empty tuple") { + std::tuple<> input = {}; + + std::string result = fmt::to_string(input); + std::string correct = "{}"; + + CHECK(result == correct); + } + } + + TEST_CASE("operator<<(ostream &, std::tuple)") { + auto through_ostringstream = [](auto const &t) { + std::ostringstream oss; + oss << t; + return oss.str(); + }; + + SUBCASE("types are different") { + std::tuple input = {3, false, "hello"}; + + std::string result = through_ostringstream(input); + std::string correct = "{3, false, hello}"; + + CHECK(result == correct); + } + + SUBCASE("types are the same") { + std::tuple input = {3, 5}; + + std::string result = through_ostringstream(input); + std::string correct = "{3, 5}"; + + CHECK(result == correct); + } + + SUBCASE("empty tuple") { + std::tuple<> input = {}; + + std::string result = through_ostringstream(input); + std::string correct = "{}"; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/orthotope/orthotope.cc b/lib/utils/test/src/utils/orthotope/orthotope.cc index 28411eed6d..c4cdb91f56 100644 --- a/lib/utils/test/src/utils/orthotope/orthotope.cc +++ b/lib/utils/test/src/utils/orthotope/orthotope.cc @@ -92,5 +92,53 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK_THROWS(orthotope_contains_coord(orthotope, coord)); } + + SUBCASE("works if the orthotope is zero-dimensional") { + Orthotope orthotope = Orthotope{{}}; + OrthotopeCoordinate coord = OrthotopeCoordinate{{}}; + + bool result = orthotope_contains_coord(orthotope, coord); + bool correct = true; + + CHECK(result == correct); + } + } + + TEST_CASE("orthotope_get_volume") { + SUBCASE("1d orthotope volume is just dim size") { + Orthotope input = Orthotope{{8}}; + + int result = orthotope_get_volume(input); + int correct = 8; + + CHECK(result == correct); + } + + SUBCASE("multi-dimensional orthotope") { + Orthotope input = Orthotope{{3, 5, 1, 2}}; + + int result = orthotope_get_volume(input); + int correct = 30; + + CHECK(result == correct); + } + + SUBCASE("any dim size being zero makes the volume zero") { + Orthotope input = Orthotope{{3, 5, 0, 2}}; + + int result = orthotope_get_volume(input); + int correct = 0; + + CHECK(result == correct); + } + + SUBCASE("zero-dimensional orthotope has volume 1") { + Orthotope input = Orthotope{{}}; + + int result = orthotope_get_volume(input); + int correct = 1; + + CHECK(result == correct); + } } } diff --git a/lib/utils/test/src/utils/orthotope/orthotope_bijective_projection.cc b/lib/utils/test/src/utils/orthotope/orthotope_bijective_projection.cc new file mode 100644 index 0000000000..b8cefd5005 --- /dev/null +++ b/lib/utils/test/src/utils/orthotope/orthotope_bijective_projection.cc @@ -0,0 +1,181 @@ +#include "utils/orthotope/orthotope_bijective_projection.h" +#include "utils/containers/zip.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("project_into_1d") { + SUBCASE("to 1d from 1d is identity") { + OrthotopeCoordinate coord = OrthotopeCoordinate{{2}}; + Orthotope orthotope = Orthotope{{5}}; + + int result = project_into_1d(orthotope, coord); + int correct = 2; + + CHECK(result == correct); + } + + SUBCASE("basic example") { + OrthotopeCoordinate coord = OrthotopeCoordinate{{4, 1}}; + Orthotope orthotope = Orthotope{{5, 3}}; + + int result = project_into_1d(orthotope, coord); + int correct = 4 * 3 + 1; + + CHECK(result == correct); + } + + SUBCASE("order matters") { + OrthotopeCoordinate coord = OrthotopeCoordinate{{1, 4}}; + Orthotope orthotope = Orthotope{{3, 5}}; + + int result = project_into_1d(orthotope, coord); + int correct = 1 * 5 + 4; + + CHECK(result == correct); + } + + SUBCASE("throws if coord is outside of orthotope") { + OrthotopeCoordinate coord = OrthotopeCoordinate{ + {2, 3, 1}, + }; + + Orthotope orthotope = Orthotope{ + {5, 3, 2}, + }; + + CHECK_THROWS(project_into_1d(orthotope, coord)); + } + + SUBCASE("throws if coord does not have same dimension as orthotope") { + OrthotopeCoordinate coord = OrthotopeCoordinate{ + {2, 3, 1}, + }; + + Orthotope orthotope = Orthotope{ + {5, 3}, + }; + + CHECK_THROWS(project_into_1d(orthotope, coord)); + } + + SUBCASE("returns 0 if orthotope is 0-dimensional") { + OrthotopeCoordinate coord = OrthotopeCoordinate{{}}; + Orthotope orthotope = Orthotope{{}}; + + int result = project_into_1d(orthotope, coord); + int correct = 0; + + CHECK(result == correct); + } + } + + TEST_CASE("project_out_of_1d") { + SUBCASE("from 1d to 1d is identity") { + Orthotope orthotope = Orthotope{{5}}; + int coord = 2; + + OrthotopeCoordinate result = project_out_of_1d(coord, orthotope); + OrthotopeCoordinate correct = OrthotopeCoordinate{{2}}; + + CHECK(result == correct); + } + + SUBCASE("basic example") { + Orthotope orthotope = Orthotope{{5, 3}}; + int coord = 4 * 3 + 1; + + OrthotopeCoordinate result = project_out_of_1d(coord, orthotope); + OrthotopeCoordinate correct = OrthotopeCoordinate{{4, 1}}; + + CHECK(result == correct); + } + + SUBCASE("orthotope dimension order matters") { + Orthotope orthotope = Orthotope{{3, 5}}; + int coord = 1 * 5 + 4; + + OrthotopeCoordinate result = project_out_of_1d(coord, orthotope); + OrthotopeCoordinate correct = OrthotopeCoordinate{{1, 4}}; + + CHECK(result == correct); + } + + SUBCASE("throws if coord would be projected outside of orthotope") { + Orthotope orthotope = Orthotope{{5, 3}}; + + SUBCASE("smallest coord outside of orthotope") { + int coord = 15; + + CHECK_THROWS(project_out_of_1d(coord, orthotope)); + } + + SUBCASE("largest coord inside of orthotope") { + int coord = 14; + + OrthotopeCoordinate result = project_out_of_1d(coord, orthotope); + OrthotopeCoordinate correct = OrthotopeCoordinate{{4, 2}}; + + CHECK(result == correct); + } + } + + SUBCASE("if dst orthotope is 0-dimensional") { + Orthotope orthotope = Orthotope{{}}; + + SUBCASE("returns 0-d coord if input coord is 0") { + int input_coord = 0; + + OrthotopeCoordinate result = project_out_of_1d(input_coord, orthotope); + OrthotopeCoordinate correct = OrthotopeCoordinate{{}}; + + CHECK(result == correct); + } + + SUBCASE("throws if input coord is anything other than zero") { + int input_coord = 1; + + CHECK_THROWS(project_out_of_1d(input_coord, orthotope)); + } + } + } + + TEST_CASE("project_coordinate_through") { + Orthotope src = Orthotope{ + {2, 3}, + }; + + Orthotope dst = Orthotope{ + {6}, + }; + + OrthotopeBijectiveProjection proj = OrthotopeBijectiveProjection{ + {orthotope_dim_idx_t{0}, orthotope_dim_idx_t{0}}, + /*reversed=*/false, + }; + + OrthotopeCoordinate src_coord = OrthotopeCoordinate{ + {1, 2}, + }; + OrthotopeCoordinate dst_coord = OrthotopeCoordinate{ + {1*3+2}, + }; + + SUBCASE("forward") { + OrthotopeCoordinate result = project_coordinate_through(proj, src, src_coord, dst); + OrthotopeCoordinate correct = dst_coord; + + CHECK(result == correct); + } + + SUBCASE("backward") { + OrthotopeBijectiveProjection reversed = reverse_projection(proj); + + OrthotopeCoordinate result = project_coordinate_through(reversed, dst, dst_coord, src); + OrthotopeCoordinate correct = src_coord; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/orthotope/orthotope_surjective_projection.cc b/lib/utils/test/src/utils/orthotope/orthotope_surjective_projection.cc deleted file mode 100644 index 0418e53cc4..0000000000 --- a/lib/utils/test/src/utils/orthotope/orthotope_surjective_projection.cc +++ /dev/null @@ -1,50 +0,0 @@ -#include "utils/orthotope/orthotope_surjective_projection.h" -#include "utils/containers/zip.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("deconflict_overlapping_dims") { - SUBCASE("single input dim is unaffected") { - std::vector coords = {2}; - std::vector dim_sizes = {5}; - - int result = deconflict_overlapping_dims(zip(coords, dim_sizes)); - int correct = 2; - - CHECK(result == correct); - } - - SUBCASE("basic example") { - std::vector coords = {4, 1}; - std::vector dim_sizes = {5, 3}; - - int result = deconflict_overlapping_dims(zip(coords, dim_sizes)); - int correct = 4 * 3 + 1; - - CHECK(result == correct); - } - - SUBCASE("order matters") { - std::vector coords = {1, 4}; - std::vector dim_sizes = {3, 5}; - - int result = deconflict_overlapping_dims(zip(coords, dim_sizes)); - int correct = 1 * 5 + 4; - - CHECK(result == correct); - } - - SUBCASE("throws if coord is outside of corresponding dim_size") { - std::vector coords = {2, 3, 1}; - std::vector dim_sizes = {5, 3, 2}; - - CHECK_THROWS(deconflict_overlapping_dims(zip(coords, dim_sizes))); - } - - SUBCASE("throws if input is empty") { - CHECK_THROWS(deconflict_overlapping_dims({})); - } - } -} diff --git a/lib/utils/test/src/utils/tuple/visit.cc b/lib/utils/test/src/utils/tuple/visit.cc new file mode 100644 index 0000000000..7024f12e65 --- /dev/null +++ b/lib/utils/test/src/utils/tuple/visit.cc @@ -0,0 +1,40 @@ +#include "utils/tuple/visit.h" +#include +#include "utils/overload.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("visit(std::tuple, Visitor)") { + std::ostringstream oss; + auto visitor = overload { + [&](int const &i) -> void { oss << "int(" << i << "), "; }, + [&](bool const &b) -> void { oss << "bool(" << b << "), "; }, + [&](std::string const &s) -> void { oss << "string(" << s << "), "; }, + }; + + SUBCASE("repeated types") { + std::tuple input = {3, "hello", false, "world"}; + + visit_tuple(input, visitor); + + std::string result = oss.str(); + std::string correct = "int(3), string(hello), bool(0), string(world), "; + + CHECK(result == correct); + } + + SUBCASE("empty tuple") { + std::tuple<> input = {}; + + visit_tuple(input, visitor); + + std::string result = oss.str(); + std::string correct = ""; + + CHECK(result == correct); + } + } +} From de871f7fc797e2792f642043a30949e5786a98aa Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Wed, 23 Oct 2024 17:52:54 -0700 Subject: [PATCH 04/62] Start moving some machine view types over to use orthotope --- .proj.toml | 30 +++---- ..._parallel_tensor_space_mapping.struct.toml | 15 ++++ ..._parallel_tensor_space_mapping.struct.toml | 20 ----- .../op-attrs/operator_task_space.struct.toml | 11 +-- lib/op-attrs/include/op-attrs/ops/linear.h | 6 ++ .../parallel_tensor_space.struct.toml | 16 ++++ .../task_space_coordinate.struct.toml | 11 +-- .../src/op-attrs/operator_task_space.cc | 18 ++--- .../test/src/op-attrs/operator_task_space.cc | 32 ++++---- .../parallel_computation_graph.cc | 2 +- lib/utils/include/utils/containers/all_of.h | 26 +++++- lib/utils/include/utils/orthotope/orthotope.h | 3 + .../utils/orthotope/orthotope.struct.toml | 6 ++ .../orthotope_bijective_projection.h | 9 ++- .../orthotope_coordinate.struct.toml | 6 ++ .../orthotope/orthotope_dim_indexed/json.h | 23 ++++++ .../orthotope_dim_indexed.h | 8 +- lib/utils/src/utils/containers/all_of.cc | 37 +++++++++ lib/utils/src/utils/orthotope/orthotope.cc | 25 ++++++ .../orthotope_bijective_projection.cc | 69 ++++++++++++++-- .../orthotope/orthotope_dim_indexed/json.cc | 8 ++ .../utils/graph/dataflow_graph/algorithms.cc | 2 +- .../orthotope_bijective_projection.cc | 81 +++++++++++++++++++ 23 files changed, 370 insertions(+), 94 deletions(-) create mode 100644 lib/op-attrs/include/op-attrs/operator_space_parallel_tensor_space_mapping.struct.toml delete mode 100644 lib/op-attrs/include/op-attrs/operator_space_to_parallel_tensor_space_mapping.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/parallel_tensor_space.struct.toml create mode 100644 lib/utils/include/utils/orthotope/orthotope_dim_indexed/json.h create mode 100644 lib/utils/src/utils/orthotope/orthotope_dim_indexed/json.cc diff --git a/.proj.toml b/.proj.toml index 5592f184ad..04fcc0a573 100644 --- a/.proj.toml +++ b/.proj.toml @@ -6,27 +6,27 @@ header_extension = ".h" build_targets = [ "utils", "op-attrs", - "kernels", - "pcg", - "substitutions", - "compiler", - "substitution-generator", - "local-execution", - "models", - "export-model-arch", - "substitution-to-dot", + # "kernels", + # "pcg", + # "substitutions", + # "compiler", + # "substitution-generator", + # "local-execution", + # "models", + # "export-model-arch", + # "substitution-to-dot", ] test_targets = [ # "kernels-tests", "utils-tests", "op-attrs-tests", - "pcg-tests", - "substitutions-tests", - "compiler-tests", - "substitution-generator-tests", - "local-execution-tests", - "models-tests", + # "pcg-tests", + # "substitutions-tests", + # "compiler-tests", + # "substitution-generator-tests", + # "local-execution-tests", + # "models-tests", ] [cmake_flags_extra] diff --git a/lib/op-attrs/include/op-attrs/operator_space_parallel_tensor_space_mapping.struct.toml b/lib/op-attrs/include/op-attrs/operator_space_parallel_tensor_space_mapping.struct.toml new file mode 100644 index 0000000000..0655e205cc --- /dev/null +++ b/lib/op-attrs/include/op-attrs/operator_space_parallel_tensor_space_mapping.struct.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "OperatorSpaceParallelTensorSpaceMapping" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/orthotope/orthotope_bijective_projection.dtg.h", +] + +[[fields]] +name = "raw_projection" +type = "::FlexFlow::OrthotopeBijectiveProjection" diff --git a/lib/op-attrs/include/op-attrs/operator_space_to_parallel_tensor_space_mapping.struct.toml b/lib/op-attrs/include/op-attrs/operator_space_to_parallel_tensor_space_mapping.struct.toml deleted file mode 100644 index 004ba0b7d8..0000000000 --- a/lib/op-attrs/include/op-attrs/operator_space_to_parallel_tensor_space_mapping.struct.toml +++ /dev/null @@ -1,20 +0,0 @@ -namespace = "FlexFlow" -name = "OperatorSpaceToParallelTensorSpaceMapping" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "", -] - -src_includes = [ - "utils/hash/unordered_map.h", - "utils/fmt/unordered_map.h", -] - -[[fields]] -name = "raw_mapping" -type = "std::unordered_map" diff --git a/lib/op-attrs/include/op-attrs/operator_task_space.struct.toml b/lib/op-attrs/include/op-attrs/operator_task_space.struct.toml index 3ab8b83173..d02422c7e0 100644 --- a/lib/op-attrs/include/op-attrs/operator_task_space.struct.toml +++ b/lib/op-attrs/include/op-attrs/operator_task_space.struct.toml @@ -10,14 +10,9 @@ features = [ ] includes = [ - "", -] - -src_includes = [ - "utils/fmt/vector.h", - "utils/hash/vector.h" + "utils/orthotope/orthotope.dtg.h", ] [[fields]] -name = "degrees" -type = "std::vector" +name = "raw_orthotope" +type = "::FlexFlow::Orthotope" diff --git a/lib/op-attrs/include/op-attrs/ops/linear.h b/lib/op-attrs/include/op-attrs/ops/linear.h index 065cc7e38e..a4bf49651c 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear.h +++ b/lib/op-attrs/include/op-attrs/ops/linear.h @@ -8,6 +8,8 @@ #include "op-attrs/tensor_shape.dtg.h" #include "utils/record_formatter.h" #include +#include "op-attrs/parallel_tensor_space.dtg.h" +#include "op-attrs/operator_space_parallel_tensor_space_mapping.dtg.h" namespace FlexFlow { @@ -34,6 +36,10 @@ tl::expected get_output_shape(LinearAttrs const &attrs, ParallelTensorShape const &input); +tl::expected + get_output_space_mapping(LinearAttrs const &attrs, + ParallelTensorSpace const &input); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_space.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_space.struct.toml new file mode 100644 index 0000000000..b46c98b1d6 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_space.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "ParallelTensorSpace" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/orthotope/orthotope.dtg.h", +] + +[[fields]] +name = "raw_orthotope" +type = "::FlexFlow::Orthotope" diff --git a/lib/op-attrs/include/op-attrs/task_space_coordinate.struct.toml b/lib/op-attrs/include/op-attrs/task_space_coordinate.struct.toml index 65aea167cb..508d0c21b6 100644 --- a/lib/op-attrs/include/op-attrs/task_space_coordinate.struct.toml +++ b/lib/op-attrs/include/op-attrs/task_space_coordinate.struct.toml @@ -10,14 +10,9 @@ features = [ ] includes = [ - "", -] - -src_includes = [ - "utils/hash/vector.h", - "utils/fmt/vector.h", + "utils/orthotope/orthotope_coordinate.dtg.h", ] [[fields]] -name = "raw_coord" -type = "std::vector" +name = "orthotope_coord" +type = "::FlexFlow::OrthotopeCoordinate" diff --git a/lib/op-attrs/src/op-attrs/operator_task_space.cc b/lib/op-attrs/src/op-attrs/operator_task_space.cc index 163c47c9ce..62cd9e1347 100644 --- a/lib/op-attrs/src/op-attrs/operator_task_space.cc +++ b/lib/op-attrs/src/op-attrs/operator_task_space.cc @@ -6,22 +6,15 @@ #include "utils/containers/transform.h" #include "utils/containers/unordered_set_of.h" #include "utils/fmt/unordered_set.h" +#include "utils/orthotope/orthotope.h" namespace FlexFlow { std::unordered_set get_task_space_coordinates(OperatorTaskSpace const &task) { - std::vector> coordinate_ranges = transform( - task.degrees, [&](int const &num_points) { return range(num_points); }); - - std::unordered_set> raw_coordinates = - unordered_set_of(cartesian_product(coordinate_ranges)); - std::unordered_set task_space_coordinates = - transform(raw_coordinates, [](std::vector const &point) { - return TaskSpaceCoordinate{point}; - }); - return task_space_coordinates; + return transform(orthotope_get_contained_coordinates(task.raw_orthotope), + [](OrthotopeCoordinate const &c) { return TaskSpaceCoordinate{c}; }); } TaskSpaceCoordinate @@ -30,10 +23,11 @@ TaskSpaceCoordinate } size_t num_dims(OperatorTaskSpace const &task) { - return task.degrees.size(); + return orthotope_num_dims(task.raw_orthotope); } + size_t num_tasks(OperatorTaskSpace const &task) { - return product(task.degrees); + return orthotope_get_volume(task.raw_orthotope); } } // namespace FlexFlow diff --git a/lib/op-attrs/test/src/op-attrs/operator_task_space.cc b/lib/op-attrs/test/src/op-attrs/operator_task_space.cc index 228a3b9d9e..3d9a58ab11 100644 --- a/lib/op-attrs/test/src/op-attrs/operator_task_space.cc +++ b/lib/op-attrs/test/src/op-attrs/operator_task_space.cc @@ -8,23 +8,23 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_task_space_coordinates") { SUBCASE("OperatorTaskSpace has 0 dimensions") { - OperatorTaskSpace task = OperatorTaskSpace{{}}; + OperatorTaskSpace task = OperatorTaskSpace{Orthotope{{}}}; std::unordered_set correct = { - TaskSpaceCoordinate{{}}}; + TaskSpaceCoordinate{OrthotopeCoordinate{{}}}}; std::unordered_set result = get_task_space_coordinates(task); CHECK(correct == result); } SUBCASE("OperatorTaskSpace has 2 dimensions") { - OperatorTaskSpace task = OperatorTaskSpace{{2, 2}}; + OperatorTaskSpace task = OperatorTaskSpace{Orthotope{{2, 2}}}; std::unordered_set correct = {{ - TaskSpaceCoordinate{{0, 0}}, - TaskSpaceCoordinate{{0, 1}}, - TaskSpaceCoordinate{{1, 0}}, - TaskSpaceCoordinate{{1, 1}}, + TaskSpaceCoordinate{OrthotopeCoordinate{{0, 0}}}, + TaskSpaceCoordinate{OrthotopeCoordinate{{0, 1}}}, + TaskSpaceCoordinate{OrthotopeCoordinate{{1, 0}}}, + TaskSpaceCoordinate{OrthotopeCoordinate{{1, 1}}}, }}; std::unordered_set result = get_task_space_coordinates(task); @@ -32,13 +32,13 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("OperatorTaskSpace has 3 dimensions") { - OperatorTaskSpace task = OperatorTaskSpace{{1, 2, 2}}; + OperatorTaskSpace task = OperatorTaskSpace{Orthotope{{1, 2, 2}}}; std::unordered_set correct = {{ - TaskSpaceCoordinate{{0, 0, 0}}, - TaskSpaceCoordinate{{0, 0, 1}}, - TaskSpaceCoordinate{{0, 1, 0}}, - TaskSpaceCoordinate{{0, 1, 1}}, + TaskSpaceCoordinate{OrthotopeCoordinate{{0, 0, 0}}}, + TaskSpaceCoordinate{OrthotopeCoordinate{{0, 0, 1}}}, + TaskSpaceCoordinate{OrthotopeCoordinate{{0, 1, 0}}}, + TaskSpaceCoordinate{OrthotopeCoordinate{{0, 1, 1}}}, }}; std::unordered_set result = get_task_space_coordinates(task); @@ -48,17 +48,17 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_task_space_maximum_coordinate") { SUBCASE("OperatorTaskSpace has 2 dimensions") { - OperatorTaskSpace task = OperatorTaskSpace{{3, 2}}; + OperatorTaskSpace task = OperatorTaskSpace{Orthotope{{3, 2}}}; - TaskSpaceCoordinate correct = TaskSpaceCoordinate{{2, 1}}; + TaskSpaceCoordinate correct = TaskSpaceCoordinate{OrthotopeCoordinate{{2, 1}}}; TaskSpaceCoordinate result = get_task_space_maximum_coordinate(task); CHECK(correct == result); } SUBCASE("OperatorTaskSpace has 3 dimensions") { - OperatorTaskSpace task = OperatorTaskSpace{{3, 2, 4}}; + OperatorTaskSpace task = OperatorTaskSpace{Orthotope{{3, 2, 4}}}; - TaskSpaceCoordinate correct = TaskSpaceCoordinate{{2, 1, 3}}; + TaskSpaceCoordinate correct = TaskSpaceCoordinate{OrthotopeCoordinate{{2, 1, 3}}}; TaskSpaceCoordinate result = get_task_space_maximum_coordinate(task); CHECK(correct == result); } diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index fc07edf5b3..56bd7a4335 100644 --- a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -9,7 +9,7 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("topological_ordering") { + TEST_CASE("topological_ordering(ParallelComputationGraph)") { // TODO(@lockshaw) should probably be replaced with a rapidcheck test that // compares ParallelComputationGraph to DataflowGraph, but since we // currently don't have rapidcheck generation for DataflowGraph this will diff --git a/lib/utils/include/utils/containers/all_of.h b/lib/utils/include/utils/containers/all_of.h index 87c9e067dc..cb167a65e8 100644 --- a/lib/utils/include/utils/containers/all_of.h +++ b/lib/utils/include/utils/containers/all_of.h @@ -2,11 +2,13 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ALL_OF_H #include +#include +#include namespace FlexFlow { template -bool all_of(C const &c, F const &f) { +bool all_of(C const &c, F &&f) { for (auto const &v : c) { if (!f(v)) { return false; @@ -15,6 +17,28 @@ bool all_of(C const &c, F const &f) { return true; } +template +bool all_of(std::unordered_map const &m, F &&f) { + for (auto const &[k, v] : m) { + if (!f(k, v)) { + return false; + } + } + + return true; +} + +template +bool all_of(std::map const &m, F &&f) { + for (auto const &[k, v] : m) { + if (!f(k, v)) { + return false; + } + } + + return true; +} + bool all_of(std::vector const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/orthotope/orthotope.h b/lib/utils/include/utils/orthotope/orthotope.h index b5276c84e0..8660173e52 100644 --- a/lib/utils/include/utils/orthotope/orthotope.h +++ b/lib/utils/include/utils/orthotope/orthotope.h @@ -9,8 +9,11 @@ namespace FlexFlow { std::set get_orthotope_dims(Orthotope const &); +int orthotope_num_dims(Orthotope const &); bool orthotope_contains_coord(Orthotope const &, OrthotopeCoordinate const &); +std::unordered_set orthotope_get_contained_coordinates(Orthotope const &); + int orthotope_get_volume(Orthotope const &); Orthotope orthotope_drop_dims_except(Orthotope const &, std::set const &); diff --git a/lib/utils/include/utils/orthotope/orthotope.struct.toml b/lib/utils/include/utils/orthotope/orthotope.struct.toml index ccc07373ef..0117b44893 100644 --- a/lib/utils/include/utils/orthotope/orthotope.struct.toml +++ b/lib/utils/include/utils/orthotope/orthotope.struct.toml @@ -2,14 +2,20 @@ namespace = "FlexFlow" name = "Orthotope" features = [ "eq", + "ord", "fmt", "hash", + "json", ] includes = [ "utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.h", ] +src_includes = [ + "utils/orthotope/orthotope_dim_indexed/json.h", +] + [[fields]] name = "dims" type = "::FlexFlow::OrthotopeDimIndexed" diff --git a/lib/utils/include/utils/orthotope/orthotope_bijective_projection.h b/lib/utils/include/utils/orthotope/orthotope_bijective_projection.h index b2306e616f..cb5a64cf23 100644 --- a/lib/utils/include/utils/orthotope/orthotope_bijective_projection.h +++ b/lib/utils/include/utils/orthotope/orthotope_bijective_projection.h @@ -5,15 +5,21 @@ #include "utils/orthotope/orthotope_bijective_projection.dtg.h" #include "utils/orthotope/orthotope_coordinate.dtg.h" #include +#include namespace FlexFlow { OrthotopeBijectiveProjection - make_orthotope_projection_from_map(std::unordered_map const &); + make_orthotope_projection_from_map(std::unordered_map const &, bool reversed); + +bool is_valid_projection_between(OrthotopeBijectiveProjection const &proj, Orthotope const &src, Orthotope const &dst); std::unordered_map get_src_to_dst_dim_map(OrthotopeBijectiveProjection const &); +std::unordered_map> get_dst_dims_by_src_dim_map(OrthotopeBijectiveProjection const &); +std::unordered_map> get_src_dims_by_dst_dim_map(OrthotopeBijectiveProjection const &); orthotope_dim_idx_t get_dst_dim_for_src_dim(OrthotopeBijectiveProjection const &, orthotope_dim_idx_t const &); +orthotope_dim_idx_t get_src_dim_for_dst_dim(OrthotopeBijectiveProjection const &, orthotope_dim_idx_t const &); int get_src_num_dims(OrthotopeBijectiveProjection const &); int get_dst_num_dims(OrthotopeBijectiveProjection const &); @@ -21,6 +27,7 @@ int get_dst_num_dims(OrthotopeBijectiveProjection const &); OrthotopeBijectiveProjection reverse_projection(OrthotopeBijectiveProjection const &); std::unordered_set get_all_bijective_projections_between(Orthotope const &src, Orthotope const &dst); +std::unordered_set get_all_bijective_projections_between_dim_numbers(int src_num_dims, int dst_num_dims); int project_into_1d(Orthotope const &, OrthotopeCoordinate const &); OrthotopeCoordinate project_out_of_1d(int, Orthotope const &); diff --git a/lib/utils/include/utils/orthotope/orthotope_coordinate.struct.toml b/lib/utils/include/utils/orthotope/orthotope_coordinate.struct.toml index 4e99261cb1..fdaef519fa 100644 --- a/lib/utils/include/utils/orthotope/orthotope_coordinate.struct.toml +++ b/lib/utils/include/utils/orthotope/orthotope_coordinate.struct.toml @@ -2,14 +2,20 @@ namespace = "FlexFlow" name = "OrthotopeCoordinate" features = [ "eq", + "ord", "hash", "fmt", + "json", ] includes = [ "utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.h", ] +src_includes = [ + "utils/orthotope/orthotope_dim_indexed/json.h", +] + [[fields]] name = "idxs" type = "::FlexFlow::OrthotopeDimIndexed" diff --git a/lib/utils/include/utils/orthotope/orthotope_dim_indexed/json.h b/lib/utils/include/utils/orthotope/orthotope_dim_indexed/json.h new file mode 100644 index 0000000000..7f1f8ed9e9 --- /dev/null +++ b/lib/utils/include/utils/orthotope/orthotope_dim_indexed/json.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_DIM_INDEXED_JSON_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_DIM_INDEXED_JSON_H + +#include "utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.h" +#include "utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_of.h" +#include + +namespace nlohmann { + +template +struct adl_serializer<::FlexFlow::OrthotopeDimIndexed> { + static ::FlexFlow::OrthotopeDimIndexed from_json(json const &j) { + return ::FlexFlow::orthotope_dim_indexed_of(j.get>()); + } + + static void to_json(json &j, ::FlexFlow::OrthotopeDimIndexed const &d) { + j = d.get_contents(); + } +}; + +} // namespace nlohmann + +#endif diff --git a/lib/utils/include/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.h b/lib/utils/include/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.h index 1a82e85a01..23f76037ac 100644 --- a/lib/utils/include/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.h +++ b/lib/utils/include/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.h @@ -149,10 +149,10 @@ struct OrthotopeDimIndexed { } }; -// template -// std::enable_if_t, bool> operator<(OrthotopeDimIndexed const &lhs, OrthotopeDimIndexed const &rhs) { -// return lhs.tie() < rhs.tie(); -// } +template +std::enable_if_t, bool> operator<(OrthotopeDimIndexed const &lhs, OrthotopeDimIndexed const &rhs) { + return lhs.tie() < rhs.tie(); +} template std::vector format_as(OrthotopeDimIndexed const &d) { diff --git a/lib/utils/src/utils/containers/all_of.cc b/lib/utils/src/utils/containers/all_of.cc index 9f02c1aaf7..1d4b24067c 100644 --- a/lib/utils/src/utils/containers/all_of.cc +++ b/lib/utils/src/utils/containers/all_of.cc @@ -1,7 +1,44 @@ #include "utils/containers/all_of.h" +#include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/value_type.h" +#include +#include namespace FlexFlow { +using T1 = value_type<0>; +using F1 = std::function; + +template + bool all_of(std::vector const &, F1 &&); +template + bool all_of(std::unordered_set const &, F1 &&); +template + bool all_of(std::unordered_multiset const &, F1 &&); + +using T2 = ordered_value_type<0>; +using F2 = std::function; + +template + bool all_of(std::set const &, F2 &&); +template + bool all_of(std::multiset const &, F2 &&); + +using K3 = value_type<0>; +using V3 = value_type<1>; +using F3 = std::function; + +template + bool all_of(std::unordered_map const &, F3 &&); + +using K4 = ordered_value_type<0>; +using V4 = ordered_value_type<1>; +using F4 = std::function; + +template + bool all_of(std::map const &, F4 &&); + + bool all_of(std::vector const &v) { for (bool v : v) { if (!v) { diff --git a/lib/utils/src/utils/orthotope/orthotope.cc b/lib/utils/src/utils/orthotope/orthotope.cc index 075db4f94a..fe570bc978 100644 --- a/lib/utils/src/utils/orthotope/orthotope.cc +++ b/lib/utils/src/utils/orthotope/orthotope.cc @@ -1,9 +1,15 @@ #include "utils/orthotope/orthotope.h" +#include "utils/containers/generate_map.h" +#include "utils/containers/get_all_assignments.h" #include "utils/containers/product.h" +#include "utils/containers/range.h" +#include "utils/containers/unordered_set_of.h" #include "utils/orthotope/orthotope_dim_indexed/drop_idxs_except.h" +#include "utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_from_idx_map.h" #include "utils/orthotope/orthotope_dim_indexed/zip_with.h" #include "utils/orthotope/orthotope_dim_indexed/all_of.h" #include "utils/containers/all_of.h" +#include "utils/containers/transform.h" #include "utils/exception.h" namespace FlexFlow { @@ -12,6 +18,10 @@ std::set get_orthotope_dims(Orthotope const &orthotope) { return orthotope.dims.indices(); } +int orthotope_num_dims(Orthotope const &orthotope) { + return orthotope.dims.size(); +} + bool orthotope_contains_coord(Orthotope const &o, OrthotopeCoordinate const &c) { if (o.dims.size() != c.idxs.size()) { throw mk_runtime_error(fmt::format("orthotope_contains_coord expected orthotope and coord to have the same number of dims, but received o={}, c={}", o, c)); @@ -20,6 +30,21 @@ bool orthotope_contains_coord(Orthotope const &o, OrthotopeCoordinate const &c) return all_of(zip_with(o.dims, c.idxs, [](int dim_size, int dim_coord) { return dim_coord >= 0 && dim_coord < dim_size; })); } +std::unordered_set orthotope_get_contained_coordinates(Orthotope const &orthotope) { + std::unordered_map> possible_coord_assignments = + generate_map(get_orthotope_dims(orthotope), + [&](orthotope_dim_idx_t const &dim_idx) { + return unordered_set_of(range(orthotope.dims.at(dim_idx))); + }); + + return transform(get_all_assignments(possible_coord_assignments), + [](std::unordered_map const &assignment) { + return OrthotopeCoordinate{ + orthotope_dim_indexed_from_idx_map(assignment).value(), + }; + }); +} + int orthotope_get_volume(Orthotope const &o) { return product(o.dims.get_contents()); } diff --git a/lib/utils/src/utils/orthotope/orthotope_bijective_projection.cc b/lib/utils/src/utils/orthotope/orthotope_bijective_projection.cc index d9a62dc1e6..001722e006 100644 --- a/lib/utils/src/utils/orthotope/orthotope_bijective_projection.cc +++ b/lib/utils/src/utils/orthotope/orthotope_bijective_projection.cc @@ -6,6 +6,7 @@ #include "utils/containers/map_keys.h" #include "utils/containers/map_values.h" #include "utils/containers/merge_maps.h" +#include "utils/containers/product.h" #include "utils/containers/range.h" #include "utils/containers/set_of.h" #include "utils/containers/subvec.h" @@ -30,14 +31,61 @@ namespace FlexFlow { OrthotopeBijectiveProjection - make_orthotope_projection_from_map(std::unordered_map const &m) { + make_orthotope_projection_from_map(std::unordered_map const &m, bool reversed) { std::unordered_map raw_idx_map = map_keys(m, [](orthotope_dim_idx_t const &k) { return k.raw_idx; }); return OrthotopeBijectiveProjection{ /*dim_mapping=*/vector_from_idx_map(raw_idx_map).value(), - /*reversed=*/false, + /*reversed=*/reversed, }; } +bool is_valid_projection_between(OrthotopeBijectiveProjection const &proj, Orthotope const &src, Orthotope const &dst) { + if (proj.reversed) { + return is_valid_projection_between(reverse_projection(proj), dst, src); + } + + auto get_src_dim_size = [&](orthotope_dim_idx_t const &src_idx) -> int { + return src.dims.at(src_idx); + }; + + auto get_dst_dim_size = [&](orthotope_dim_idx_t const &dst_idx) -> int { + return dst.dims.at(dst_idx); + }; + + std::unordered_map> src_dims_by_dst_dim + = get_src_dims_by_dst_dim_map(proj); + + return all_of(src_dims_by_dst_dim, + [&](orthotope_dim_idx_t const &dst_idx, std::set const &src_idxs) -> bool { + std::vector src_dim_sizes = transform(vector_of(src_idxs), get_src_dim_size); + + return get_dst_dim_size(dst_idx) == product(src_dim_sizes); + }); +} +std::unordered_map> + get_src_dims_by_dst_dim_map(OrthotopeBijectiveProjection const &p) { + if (p.reversed) { + throw mk_runtime_error(fmt::format("get_src_dims_by_dst_dim_map expected p.reversed=false, but received p={}", p)); + } + + std::set src_dim_idxs = dim_idxs_for_orthotope_with_num_dims(get_src_num_dims(p)); + + return group_by(src_dim_idxs, + [&](orthotope_dim_idx_t const &src_dim_idx) { return get_dst_dim_for_src_dim(p, src_dim_idx); }); +} + +std::unordered_map> + get_dst_dims_by_src_dim_map(OrthotopeBijectiveProjection const &p) { + if (!p.reversed) { + throw mk_runtime_error(fmt::format("get_dst_dims_by_src_dim_map expected p.reversed=true, but received p={}", p)); + } + + std::set dst_dim_idxs = dim_idxs_for_orthotope_with_num_dims(get_dst_num_dims(p)); + + return group_by(dst_dim_idxs, + [&](orthotope_dim_idx_t const &dst_dim_idx) { return get_src_dim_for_dst_dim(p, dst_dim_idx); }); +} + std::unordered_map get_src_to_dst_dim_map(OrthotopeBijectiveProjection const &p) { if (p.reversed) { throw mk_runtime_error(fmt::format("get_src_to_dst_dim_map expected p.reversed=false, but received p={}", p)); @@ -85,9 +133,9 @@ OrthotopeBijectiveProjection reverse_projection(OrthotopeBijectiveProjection con return result; } -std::unordered_set get_all_bijective_projections_between(int src_num_dims, int dst_num_dims) { +std::unordered_set get_all_bijective_projections_between_dim_numbers(int src_num_dims, int dst_num_dims) { if (src_num_dims < dst_num_dims) { - return transform(get_all_bijective_projections_between(dst_num_dims, src_num_dims), + return transform(get_all_bijective_projections_between_dim_numbers(dst_num_dims, src_num_dims), [](OrthotopeBijectiveProjection const &p) { return reverse_projection(p); }); } @@ -103,7 +151,14 @@ std::unordered_set get_all_bijective_projections_b return set_of(values(src_to_dst_idx)) == dst_dim_idxs; }); - return transform(valid_mappings, make_orthotope_projection_from_map); + return transform(valid_mappings, [](std::unordered_map const &m) { return make_orthotope_projection_from_map(m, /*reversed=*/false); }); +} + +std::unordered_set get_all_bijective_projections_between(Orthotope const &src, Orthotope const &dst) { + return filter(get_all_bijective_projections_between_dim_numbers(/*src_num_dims=*/orthotope_num_dims(src), /*dst_num_dims=*/orthotope_num_dims(dst)), + [&](OrthotopeBijectiveProjection const &p) { + return is_valid_projection_between(p, /*src=*/src, /*dst=*/dst); + }); } int project_into_1d(Orthotope const &orthotope, OrthotopeCoordinate const &coord) { @@ -148,8 +203,8 @@ OrthotopeCoordinate project_out_of_1d(int one_dimensional_coord, Orthotope const } OrthotopeCoordinate project_coordinate_through(OrthotopeBijectiveProjection const &p, Orthotope const &src_orthotope, OrthotopeCoordinate const &src_coord, Orthotope const &dst_orthotope) { - std::set dst_dim_idxs = transform(get_orthotope_dims(dst_orthotope), [](orthotope_dim_idx_t const &idx) { return idx; }); - std::set src_dim_idxs = transform(get_orthotope_dims(src_orthotope), [](orthotope_dim_idx_t const &idx) { return idx; }); + std::set dst_dim_idxs = get_orthotope_dims(dst_orthotope); + std::set src_dim_idxs = get_orthotope_dims(src_orthotope); if (src_coord.idxs.size() != get_src_num_dims(p)) { throw mk_runtime_error(fmt::format("project_coordinate_through requires projection src and coordinate to have same num dims, but got {} and {} respectively", diff --git a/lib/utils/src/utils/orthotope/orthotope_dim_indexed/json.cc b/lib/utils/src/utils/orthotope/orthotope_dim_indexed/json.cc new file mode 100644 index 0000000000..539e7e7ba0 --- /dev/null +++ b/lib/utils/src/utils/orthotope/orthotope_dim_indexed/json.cc @@ -0,0 +1,8 @@ +#include "utils/orthotope/orthotope_dim_indexed/json.h" + +namespace nlohmann { + +template + struct adl_serializer<::FlexFlow::OrthotopeDimIndexed>; + +} // namespace nlohmann diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms.cc index 25f990f80e..a2945bb537 100644 --- a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms.cc +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms.cc @@ -41,7 +41,7 @@ TEST_SUITE(FF_TEST_SUITE) { } } - TEST_CASE("topological_ordering") { + TEST_CASE("topological_ordering(DataflowGraphView)") { DataflowGraph g = DataflowGraph::create(); NodeAddedResult n1_added = g.add_node({}, 1); diff --git a/lib/utils/test/src/utils/orthotope/orthotope_bijective_projection.cc b/lib/utils/test/src/utils/orthotope/orthotope_bijective_projection.cc index b8cefd5005..bb07b92c39 100644 --- a/lib/utils/test/src/utils/orthotope/orthotope_bijective_projection.cc +++ b/lib/utils/test/src/utils/orthotope/orthotope_bijective_projection.cc @@ -1,10 +1,91 @@ #include "utils/orthotope/orthotope_bijective_projection.h" #include "utils/containers/zip.h" #include +#include "test/utils/doctest/fmt/unordered_set.h" using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_all_bijective_projections_between") { + SUBCASE("dst num dims greater than src num dims") { + Orthotope src = Orthotope{{6, 4}}; + orthotope_dim_idx_t src0 = orthotope_dim_idx_t{0}; + orthotope_dim_idx_t src1 = orthotope_dim_idx_t{1}; + + Orthotope dst = Orthotope{{3, 4, 2}}; + orthotope_dim_idx_t dst0 = orthotope_dim_idx_t{0}; + orthotope_dim_idx_t dst1 = orthotope_dim_idx_t{1}; + orthotope_dim_idx_t dst2 = orthotope_dim_idx_t{2}; + + std::unordered_set result = get_all_bijective_projections_between(src, dst); + std::unordered_set correct = { + make_orthotope_projection_from_map({ + {dst0, src0}, + {dst1, src1}, + {dst2, src0}, + }, /*reversed=*/true), + }; + + CHECK(result == correct); + } + + SUBCASE("src num dims greater than dst num dims") { + Orthotope src = Orthotope{{3, 4, 2}}; + orthotope_dim_idx_t src0 = orthotope_dim_idx_t{0}; + orthotope_dim_idx_t src1 = orthotope_dim_idx_t{1}; + orthotope_dim_idx_t src2 = orthotope_dim_idx_t{2}; + + Orthotope dst = Orthotope{{6, 4}}; + orthotope_dim_idx_t dst0 = orthotope_dim_idx_t{0}; + orthotope_dim_idx_t dst1 = orthotope_dim_idx_t{1}; + + std::unordered_set result = get_all_bijective_projections_between(src, dst); + std::unordered_set correct = { + make_orthotope_projection_from_map({ + {src0, dst0}, + {src1, dst1}, + {src2, dst0}, + }, /*reversed=*/false), + }; + + CHECK(result == correct); + } + + SUBCASE("multiple possible mappings") { + Orthotope src = Orthotope{{3, 3}}; + orthotope_dim_idx_t src0 = orthotope_dim_idx_t{0}; + orthotope_dim_idx_t src1 = orthotope_dim_idx_t{1}; + + Orthotope dst = Orthotope{{3, 3}}; + orthotope_dim_idx_t dst0 = orthotope_dim_idx_t{0}; + orthotope_dim_idx_t dst1 = orthotope_dim_idx_t{1}; + + std::unordered_set result = get_all_bijective_projections_between(src, dst); + std::unordered_set correct = { + make_orthotope_projection_from_map({ + {src0, dst0}, + {src1, dst1}, + }, /*reversed=*/false), + make_orthotope_projection_from_map({ + {src0, dst1}, + {src1, dst0}, + }, /*reversed=*/false), + }; + + CHECK(result == correct); + } + + SUBCASE("no possible mappings") { + Orthotope src = Orthotope{{4, 3}}; + Orthotope dst = Orthotope{{6, 2}}; + + std::unordered_set result = get_all_bijective_projections_between(src, dst); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + } + TEST_CASE("project_into_1d") { SUBCASE("to 1d from 1d is identity") { OrthotopeCoordinate coord = OrthotopeCoordinate{{2}}; From e258d3d66e8aaa7f767eb76c5af5e570729c4060 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Wed, 23 Oct 2024 21:27:00 -0700 Subject: [PATCH 05/62] Add OneToMany, ManyToOne, and DimProjection to fix projection equality issue --- lib/op-attrs/include/op-attrs/ops/linear.h | 14 +++ lib/op-attrs/src/op-attrs/ops/linear.cc | 6 ++ .../include/utils/many_to_one/many_to_one.h | 86 +++++++++++++++++++ .../include/utils/one_to_many/one_to_many.h | 86 +++++++++++++++++++ .../orthotope/dim_projection.variant.toml | 29 +++++++ .../orthotope/down_projection.struct.toml | 20 +++++ .../utils/orthotope/eq_projection.struct.toml | 19 ++++ .../utils/orthotope/up_projection.struct.toml | 20 +++++ .../src/utils/many_to_one/many_to_one.cc | 11 +++ .../src/utils/one_to_many/one_to_many.cc | 12 +++ .../orthotope_bijective_projection.cc | 19 ++++ 11 files changed, 322 insertions(+) create mode 100644 lib/utils/include/utils/many_to_one/many_to_one.h create mode 100644 lib/utils/include/utils/one_to_many/one_to_many.h create mode 100644 lib/utils/include/utils/orthotope/dim_projection.variant.toml create mode 100644 lib/utils/include/utils/orthotope/down_projection.struct.toml create mode 100644 lib/utils/include/utils/orthotope/eq_projection.struct.toml create mode 100644 lib/utils/include/utils/orthotope/up_projection.struct.toml create mode 100644 lib/utils/src/utils/many_to_one/many_to_one.cc create mode 100644 lib/utils/src/utils/one_to_many/one_to_many.cc diff --git a/lib/op-attrs/include/op-attrs/ops/linear.h b/lib/op-attrs/include/op-attrs/ops/linear.h index a4bf49651c..def302ad9c 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear.h +++ b/lib/op-attrs/include/op-attrs/ops/linear.h @@ -4,6 +4,7 @@ #include "op-attrs/incoming_tensor_role.dtg.h" #include "op-attrs/ops/core.h" #include "op-attrs/ops/linear_attrs.dtg.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" #include "utils/record_formatter.h" @@ -27,6 +28,13 @@ tl::expected get_bias_shape(LinearAttrs const &attrs, tl::expected get_output_shape(LinearAttrs const &attrs, TensorShape const &input); +tl::expected + get_projection_parallel_dim_degrees(LinearAttrs const &attrs, ParallelTensorDimDegrees const &input); +tl::expected + get_bias_parallel_dim_degrees(LinearAttrs const &attrs, ParallelTensorDimDegrees const &input); +tl::expected + get_output_parallel_dim_degrees(LinearAttrs const &attrs, ParallelTensorDimDegrees const &input); + tl::expected get_projection_shape(LinearAttrs const &attrs, ParallelTensorShape const &input); @@ -36,6 +44,12 @@ tl::expected get_output_shape(LinearAttrs const &attrs, ParallelTensorShape const &input); +tl::expected + get_projection_space_mapping(LinearAttrs const &attrs, + ParallelTensorSpace const &input); +tl::expected + get_bias_space_mapping(LinearAttrs const &attrs, + ParallelTensorSpace const &input); tl::expected get_output_space_mapping(LinearAttrs const &attrs, ParallelTensorSpace const &input); diff --git a/lib/op-attrs/src/op-attrs/ops/linear.cc b/lib/op-attrs/src/op-attrs/ops/linear.cc index feac647216..a81155b799 100644 --- a/lib/op-attrs/src/op-attrs/ops/linear.cc +++ b/lib/op-attrs/src/op-attrs/ops/linear.cc @@ -139,4 +139,10 @@ tl::expected unpar, sum_degree, discard_copy_degree, shard_degrees); } +tl::expected + get_output_space_mapping(LinearAttrs const &attrs, + ParallelTensorSpace const &input) { + NOT_IMPLEMENTED(); +} + } // namespace FlexFlow diff --git a/lib/utils/include/utils/many_to_one/many_to_one.h b/lib/utils/include/utils/many_to_one/many_to_one.h new file mode 100644 index 0000000000..985fe43076 --- /dev/null +++ b/lib/utils/include/utils/many_to_one/many_to_one.h @@ -0,0 +1,86 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_MANY_TO_ONE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_MANY_TO_ONE_H + +#include +#include +#include "utils/containers/try_at.h" +#include +#include "utils/hash-utils.h" +#include "utils/hash/unordered_map.h" +#include "utils/hash/unordered_set.h" +#include "utils/containers/keys.h" +#include "utils/exception.h" + +namespace FlexFlow { + +template +struct ManyToOne { +public: + ManyToOne() + : l_to_r(), r_to_l() + { } + + bool operator==(ManyToOne const &other) const { + return this->tie() == other.tie(); + } + + bool operator!=(ManyToOne const &other) const { + return this->tie() != other.tie(); + } + + void insert(std::pair const &p) { + L l = p.first; + R r = p.second; + + std::optional found_r = try_at(this->l_to_r, l); + + if (!found_r.has_value()) { + this->l_to_r.insert({l, r}); + this->r_to_l[r].insert(l); + } else if (found_r.value() == r) { + return; + } else { + throw mk_runtime_error(fmt::format("Existing mapping found for left value {}: tried to map to right value {}, but is already bound to right value {}", l, r, found_r.value())); + } + } + + R const &at_l(L const &l) const { + return this->l_to_r.at(l); + } + + std::unordered_set const &at_r(R const &r) const { + return this->r_to_l.at(r); + } + + std::unordered_set left_values() const { + return keys(this->l_to_r); + } + + std::unordered_set right_values() const { + return keys(this->r_to_l); + } +private: + std::unordered_map l_to_r; + std::unordered_map> r_to_l; +private: + std::tuple tie() const { + return std::tie(this->l_to_r, this->r_to_l); + } + + friend struct std::hash>; +}; + +} // namespace FlexFlow + +namespace std { + +template +struct hash<::FlexFlow::ManyToOne> { + size_t operator()(::FlexFlow::ManyToOne const &m) { + return ::FlexFlow::get_std_hash(m.tie()); + } +}; + +} + +#endif diff --git a/lib/utils/include/utils/one_to_many/one_to_many.h b/lib/utils/include/utils/one_to_many/one_to_many.h new file mode 100644 index 0000000000..58cbd5c40e --- /dev/null +++ b/lib/utils/include/utils/one_to_many/one_to_many.h @@ -0,0 +1,86 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_ONE_TO_MANY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_ONE_TO_MANY_H + +#include +#include +#include "utils/containers/try_at.h" +#include +#include "utils/hash-utils.h" +#include "utils/hash/unordered_map.h" +#include "utils/hash/unordered_set.h" +#include "utils/containers/keys.h" +#include "utils/exception.h" + +namespace FlexFlow { + +template +struct OneToMany { +public: + OneToMany() + : l_to_r(), r_to_l() + { } + + bool operator==(OneToMany const &other) const { + return this->tie() == other.tie(); + } + + bool operator!=(OneToMany const &other) const { + return this->tie() != other.tie(); + } + + void insert(std::pair const &p) { + L l = p.first; + R r = p.second; + + std::optional found_l = try_at(this->r_to_l, r); + + if (!found_l.has_value()) { + this->r_to_l.insert({r, l}); + this->l_to_r[l].insert(r); + } else if (found_l.value() == l) { + return; + } else { + throw mk_runtime_error(fmt::format("Existing mapping found for right value {}: tried to map to left value {}, but is already bound to left value {}", r, l, found_l.value())); + } + } + + std::unordered_set const &at_l(L const &l) const { + return this->l_to_r.at(l); + } + + L const &at_r(R const &r) const { + return this->r_to_l.at(r); + } + + std::unordered_set left_values() const { + return keys(this->l_to_r); + } + + std::unordered_set right_values() const { + return keys(this->r_to_l); + } +private: + std::unordered_map> l_to_r; + std::unordered_map r_to_l; +private: + std::tuple tie() const { + return std::tie(this->l_to_r, this->r_to_l); + } + + friend struct std::hash>; +}; + +} // namespace FlexFlow + +namespace std { + +template +struct hash<::FlexFlow::OneToMany> { + size_t operator()(::FlexFlow::OneToMany const &m) { + return ::FlexFlow::get_std_hash(m.tie()); + } +}; + +} + +#endif diff --git a/lib/utils/include/utils/orthotope/dim_projection.variant.toml b/lib/utils/include/utils/orthotope/dim_projection.variant.toml new file mode 100644 index 0000000000..b22e327259 --- /dev/null +++ b/lib/utils/include/utils/orthotope/dim_projection.variant.toml @@ -0,0 +1,29 @@ +namespace = "FlexFlow" +name = "DimProjection" +features = [ + "eq", + "hash", + "fmt", +] + +template_params = [ + "L", "R" +] + +includes = [ + "utils/orthotope/up_projection.dtg.h", + "utils/orthotope/eq_projection.dtg.h", + "utils/orthotope/down_projection.dtg.h", +] + +[[values]] +type = "::FlexFlow::UpProjection" +key = "up_proj" + +[[values]] +type = "::FlexFlow::EqProjection" +key = "eq_proj" + +[[values]] +type = "::FlexFlow::DownProjection" +key = "down_proj" diff --git a/lib/utils/include/utils/orthotope/down_projection.struct.toml b/lib/utils/include/utils/orthotope/down_projection.struct.toml new file mode 100644 index 0000000000..e9d9747bec --- /dev/null +++ b/lib/utils/include/utils/orthotope/down_projection.struct.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "DownProjection" +features = [ + "eq", + "hash", + "fmt", +] + +template_params = [ + "L", "R" +] + +includes = [ + "utils/many_to_one/many_to_one.h", + "utils/orthotope/orthotope_dim_idx_t.dtg.h", +] + +[[fields]] +name = "dim_mapping" +type = "::FlexFlow::ManyToOne" diff --git a/lib/utils/include/utils/orthotope/eq_projection.struct.toml b/lib/utils/include/utils/orthotope/eq_projection.struct.toml new file mode 100644 index 0000000000..f70e25d7cd --- /dev/null +++ b/lib/utils/include/utils/orthotope/eq_projection.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "EqProjection" +features = [ + "eq", + "hash", + "fmt", +] + +template_params = [ + "L", "R" +] + +includes = [ + "utils/bidict/bidict.h", +] + +[[fields]] +name = "dim_mapping" +type = "::FlexFlow::bidict" diff --git a/lib/utils/include/utils/orthotope/up_projection.struct.toml b/lib/utils/include/utils/orthotope/up_projection.struct.toml new file mode 100644 index 0000000000..b37aba037a --- /dev/null +++ b/lib/utils/include/utils/orthotope/up_projection.struct.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "UpProjection" +features = [ + "eq", + "hash", + "fmt", +] + +template_params = [ + "L", "R" +] + +includes = [ + "utils/one_to_many/one_to_many.h", + "utils/orthotope/orthotope_dim_idx_t.dtg.h", +] + +[[fields]] +name = "dim_mapping" +type = "::FlexFlow::OneToMany" diff --git a/lib/utils/src/utils/many_to_one/many_to_one.cc b/lib/utils/src/utils/many_to_one/many_to_one.cc new file mode 100644 index 0000000000..f4b2a59756 --- /dev/null +++ b/lib/utils/src/utils/many_to_one/many_to_one.cc @@ -0,0 +1,11 @@ +#include "utils/many_to_one/many_to_one.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using L = value_type<0>; +using R = value_type<1>; + +template struct ManyToOne; + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/one_to_many/one_to_many.cc b/lib/utils/src/utils/one_to_many/one_to_many.cc new file mode 100644 index 0000000000..a962de7f52 --- /dev/null +++ b/lib/utils/src/utils/one_to_many/one_to_many.cc @@ -0,0 +1,12 @@ +#include "utils/one_to_many/one_to_many.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using L = value_type<0>; +using R = value_type<1>; + +template struct OneToMany; + +} // namespace FlexFlow + diff --git a/lib/utils/test/src/utils/orthotope/orthotope_bijective_projection.cc b/lib/utils/test/src/utils/orthotope/orthotope_bijective_projection.cc index bb07b92c39..7fff0a709f 100644 --- a/lib/utils/test/src/utils/orthotope/orthotope_bijective_projection.cc +++ b/lib/utils/test/src/utils/orthotope/orthotope_bijective_projection.cc @@ -6,6 +6,25 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("operator==(OrthotopeBijectiveProjection, OrthotopeBijectiveProjection)") { + SUBCASE("if src num dims and dst num dims are the same, projections are equivalent") { + orthotope_dim_idx_t src0 = orthotope_dim_idx_t{0}; + orthotope_dim_idx_t src1 = orthotope_dim_idx_t{1}; + + orthotope_dim_idx_t dst0 = orthotope_dim_idx_t{0}; + orthotope_dim_idx_t dst1 = orthotope_dim_idx_t{1}; + + OrthotopeBijectiveProjection p = make_orthotope_projection_from_map( + { + {src0, dst0}, + {src1, dst1}, + }, + /*reversed=*/false); + + CHECK(p == reverse_projection(p)); + } + } + TEST_CASE("get_all_bijective_projections_between") { SUBCASE("dst num dims greater than src num dims") { Orthotope src = Orthotope{{6, 4}}; From e635b1fe6a1c8a2e8ecb14fe301a4e3843be65fb Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sun, 27 Oct 2024 00:12:00 -0700 Subject: [PATCH 06/62] Improved OneToMany and ManyToOne, and some bidict refactoring --- .proj.toml | 4 +- .../include/op-attrs/dim_ordered/get_idxs.h | 4 +- ...ator_space_parallel_tensor_space_mapping.h | 14 +++ ..._parallel_tensor_space_mapping.struct.toml | 6 +- .../operator_task_space_dim_idx_t.struct.toml | 13 +++ lib/op-attrs/include/op-attrs/ops/linear.h | 7 +- .../op-attrs/parallel_tensor_dim_degrees.h | 2 + .../parallel_tensor_space.struct.toml | 16 --- ...tor_space_parallel_tensor_space_mapping.cc | 26 +++++ lib/op-attrs/src/op-attrs/ops/linear.cc | 22 +++- .../op-attrs/parallel_tensor_dim_degrees.cc | 25 +++++ .../algorithms/exhaustive_relational_join.h | 26 +++++ .../utils/bidict/algorithms/filter_keys.h | 21 ++++ .../utils/bidict/algorithms/filter_values.h | 22 ++++ .../utils/bidict/algorithms/filtrans_keys.h | 25 +++++ .../utils/bidict/algorithms/filtrans_values.h | 25 +++++ .../utils/bidict/algorithms/transform.h | 24 +++++ .../utils/bidict/algorithms/transform_keys.h | 22 ++++ .../bidict/algorithms/transform_values.h | 22 ++++ lib/utils/include/utils/bidict/bidict.h | 98 ++--------------- .../include/utils/containers/multiset_of.h | 15 +++ .../include/utils/containers/set_union.h | 8 ++ .../algorithms/is_isomorphic_under.h | 5 +- .../many_to_one/exhaustive_relational_join.h | 29 +++++ .../include/utils/many_to_one/many_to_one.h | 34 ++++++ .../one_to_many/exhaustive_relational_join.h | 29 +++++ .../include/utils/one_to_many/one_to_many.h | 29 +++++ .../one_to_many_from_l_to_r_mapping.h | 23 ++++ .../include/utils/orthotope/dim_projection.h | 49 +++++++++ .../algorithms/exhaustive_relational_join.cc | 13 +++ .../utils/bidict/algorithms/filter_keys.cc | 13 +++ .../utils/bidict/algorithms/filter_values.cc | 13 +++ .../utils/bidict/algorithms/filtrans_keys.cc | 14 +++ .../bidict/algorithms/filtrans_values.cc | 15 +++ .../src/utils/bidict/algorithms/transform.cc | 15 +++ .../utils/bidict/algorithms/transform_keys.cc | 14 +++ .../bidict/algorithms/transform_values.cc | 14 +++ lib/utils/src/utils/bidict/bidict.cc | 16 +++ lib/utils/src/utils/containers/multiset_of.cc | 11 ++ lib/utils/src/utils/containers/set_union.cc | 20 ++++ .../algorithms/is_isomorphic_under.cc | 5 +- .../many_to_one/exhaustive_relational_join.cc | 13 +++ .../src/utils/many_to_one/many_to_one.cc | 14 ++- .../one_to_many/exhaustive_relational_join.cc | 13 +++ .../src/utils/one_to_many/one_to_many.cc | 14 ++- .../one_to_many_from_l_to_r_mapping.cc | 12 +++ .../algorithms/exhaustive_relational_join.cc | 102 ++++++++++++++++++ .../utils/bidict/algorithms/filter_keys.cc | 20 ++++ .../utils/bidict/algorithms/filter_values.cc | 21 ++++ .../utils/bidict/algorithms/filtrans_keys.cc | 30 ++++++ .../bidict/algorithms/filtrans_values.cc | 26 +++++ .../src/utils/bidict/algorithms/transform.cc | 24 +++++ .../utils/bidict/algorithms/transform_keys.cc | 25 +++++ .../bidict/algorithms/transform_values.cc | 21 ++++ lib/utils/test/src/utils/bidict/bidict.cc | 85 --------------- .../test/src/utils/containers/multiset_of.cc | 15 +++ .../many_to_one/exhaustive_relational_join.cc | 92 ++++++++++++++++ .../test/src/utils/many_to_one/many_to_one.cc | 10 ++ .../one_to_many/exhaustive_relational_join.cc | 93 ++++++++++++++++ .../test/src/utils/one_to_many/one_to_many.cc | 27 +++++ 60 files changed, 1257 insertions(+), 208 deletions(-) create mode 100644 lib/op-attrs/include/op-attrs/operator_space_parallel_tensor_space_mapping.h create mode 100644 lib/op-attrs/include/op-attrs/operator_task_space_dim_idx_t.struct.toml delete mode 100644 lib/op-attrs/include/op-attrs/parallel_tensor_space.struct.toml create mode 100644 lib/op-attrs/src/op-attrs/operator_space_parallel_tensor_space_mapping.cc create mode 100644 lib/utils/include/utils/bidict/algorithms/exhaustive_relational_join.h create mode 100644 lib/utils/include/utils/bidict/algorithms/filter_keys.h create mode 100644 lib/utils/include/utils/bidict/algorithms/filter_values.h create mode 100644 lib/utils/include/utils/bidict/algorithms/filtrans_keys.h create mode 100644 lib/utils/include/utils/bidict/algorithms/filtrans_values.h create mode 100644 lib/utils/include/utils/bidict/algorithms/transform.h create mode 100644 lib/utils/include/utils/bidict/algorithms/transform_keys.h create mode 100644 lib/utils/include/utils/bidict/algorithms/transform_values.h create mode 100644 lib/utils/include/utils/containers/multiset_of.h create mode 100644 lib/utils/include/utils/many_to_one/exhaustive_relational_join.h create mode 100644 lib/utils/include/utils/one_to_many/exhaustive_relational_join.h create mode 100644 lib/utils/include/utils/one_to_many/one_to_many_from_l_to_r_mapping.h create mode 100644 lib/utils/include/utils/orthotope/dim_projection.h create mode 100644 lib/utils/src/utils/bidict/algorithms/exhaustive_relational_join.cc create mode 100644 lib/utils/src/utils/bidict/algorithms/filter_keys.cc create mode 100644 lib/utils/src/utils/bidict/algorithms/filter_values.cc create mode 100644 lib/utils/src/utils/bidict/algorithms/filtrans_keys.cc create mode 100644 lib/utils/src/utils/bidict/algorithms/filtrans_values.cc create mode 100644 lib/utils/src/utils/bidict/algorithms/transform.cc create mode 100644 lib/utils/src/utils/bidict/algorithms/transform_keys.cc create mode 100644 lib/utils/src/utils/bidict/algorithms/transform_values.cc create mode 100644 lib/utils/src/utils/containers/multiset_of.cc create mode 100644 lib/utils/src/utils/many_to_one/exhaustive_relational_join.cc create mode 100644 lib/utils/src/utils/one_to_many/exhaustive_relational_join.cc create mode 100644 lib/utils/src/utils/one_to_many/one_to_many_from_l_to_r_mapping.cc create mode 100644 lib/utils/test/src/utils/bidict/algorithms/exhaustive_relational_join.cc create mode 100644 lib/utils/test/src/utils/bidict/algorithms/filter_keys.cc create mode 100644 lib/utils/test/src/utils/bidict/algorithms/filter_values.cc create mode 100644 lib/utils/test/src/utils/bidict/algorithms/filtrans_keys.cc create mode 100644 lib/utils/test/src/utils/bidict/algorithms/filtrans_values.cc create mode 100644 lib/utils/test/src/utils/bidict/algorithms/transform.cc create mode 100644 lib/utils/test/src/utils/bidict/algorithms/transform_keys.cc create mode 100644 lib/utils/test/src/utils/bidict/algorithms/transform_values.cc create mode 100644 lib/utils/test/src/utils/containers/multiset_of.cc create mode 100644 lib/utils/test/src/utils/many_to_one/exhaustive_relational_join.cc create mode 100644 lib/utils/test/src/utils/many_to_one/many_to_one.cc create mode 100644 lib/utils/test/src/utils/one_to_many/exhaustive_relational_join.cc create mode 100644 lib/utils/test/src/utils/one_to_many/one_to_many.cc diff --git a/.proj.toml b/.proj.toml index 04fcc0a573..5e503677ab 100644 --- a/.proj.toml +++ b/.proj.toml @@ -5,7 +5,7 @@ header_extension = ".h" build_targets = [ "utils", - "op-attrs", + # "op-attrs", # "kernels", # "pcg", # "substitutions", @@ -20,7 +20,7 @@ build_targets = [ test_targets = [ # "kernels-tests", "utils-tests", - "op-attrs-tests", + # "op-attrs-tests", # "pcg-tests", # "substitutions-tests", # "compiler-tests", diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/get_idxs.h b/lib/op-attrs/include/op-attrs/dim_ordered/get_idxs.h index 6e55e8e22a..4d9e014c04 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/get_idxs.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/get_idxs.h @@ -8,8 +8,8 @@ namespace FlexFlow { template -std::vector get_idxs(FFOrdered const &d) { - return transform(range(d.size()), [](int i) { return ff_dim_t{i}; }); +std::set get_idxs(FFOrdered const &d) { + return transform(set_of(range(d.size())), [](int i) { return ff_dim_t{i}; }); } } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/operator_space_parallel_tensor_space_mapping.h b/lib/op-attrs/include/op-attrs/operator_space_parallel_tensor_space_mapping.h new file mode 100644 index 0000000000..50c8db7d32 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/operator_space_parallel_tensor_space_mapping.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_SPACE_PARALLEL_TENSOR_SPACE_MAPPING_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_SPACE_PARALLEL_TENSOR_SPACE_MAPPING_H + +#include "op-attrs/operator_space_parallel_tensor_space_mapping.dtg.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" + +namespace FlexFlow { + +OperatorSpaceParallelTensorSpaceMapping + get_identity_mapping(ParallelTensorDimDegrees const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/operator_space_parallel_tensor_space_mapping.struct.toml b/lib/op-attrs/include/op-attrs/operator_space_parallel_tensor_space_mapping.struct.toml index 0655e205cc..24c7676527 100644 --- a/lib/op-attrs/include/op-attrs/operator_space_parallel_tensor_space_mapping.struct.toml +++ b/lib/op-attrs/include/op-attrs/operator_space_parallel_tensor_space_mapping.struct.toml @@ -7,9 +7,11 @@ features = [ ] includes = [ - "utils/orthotope/orthotope_bijective_projection.dtg.h", + "utils/orthotope/dim_projection.dtg.h", + "op-attrs/operator_task_space_dim_idx_t.dtg.h", + "op-attrs/parallel_tensor_dim_idx_t.dtg.h", ] [[fields]] name = "raw_projection" -type = "::FlexFlow::OrthotopeBijectiveProjection" +type = "::FlexFlow::DimProjection<::FlexFlow::operator_task_space_dim_idx_t, ::FlexFlow::parallel_tensor_dim_idx_t>" diff --git a/lib/op-attrs/include/op-attrs/operator_task_space_dim_idx_t.struct.toml b/lib/op-attrs/include/op-attrs/operator_task_space_dim_idx_t.struct.toml new file mode 100644 index 0000000000..124b46013a --- /dev/null +++ b/lib/op-attrs/include/op-attrs/operator_task_space_dim_idx_t.struct.toml @@ -0,0 +1,13 @@ +namespace = "FlexFlow" +name = "operator_task_space_dim_idx_t" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +[[fields]] +name = "raw_idx" +type = "int" diff --git a/lib/op-attrs/include/op-attrs/ops/linear.h b/lib/op-attrs/include/op-attrs/ops/linear.h index def302ad9c..50d4f0fd36 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear.h +++ b/lib/op-attrs/include/op-attrs/ops/linear.h @@ -9,7 +9,6 @@ #include "op-attrs/tensor_shape.dtg.h" #include "utils/record_formatter.h" #include -#include "op-attrs/parallel_tensor_space.dtg.h" #include "op-attrs/operator_space_parallel_tensor_space_mapping.dtg.h" namespace FlexFlow { @@ -46,13 +45,13 @@ tl::expected tl::expected get_projection_space_mapping(LinearAttrs const &attrs, - ParallelTensorSpace const &input); + ParallelTensorDimDegrees const &input); tl::expected get_bias_space_mapping(LinearAttrs const &attrs, - ParallelTensorSpace const &input); + ParallelTensorDimDegrees const &input); tl::expected get_output_space_mapping(LinearAttrs const &attrs, - ParallelTensorSpace const &input); + ParallelTensorDimDegrees const &input); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.h b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.h index d95a717695..22128ca74e 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.h @@ -7,6 +7,8 @@ namespace FlexFlow { +std::set get_nontrivial_parallel_tensor_dim_indices(ParallelTensorDimDegrees const &); + std::unordered_map get_parallel_tensor_degree_map(ParallelTensorDimDegrees const &); diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_space.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_space.struct.toml deleted file mode 100644 index b46c98b1d6..0000000000 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_space.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "ParallelTensorSpace" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/orthotope/orthotope.dtg.h", -] - -[[fields]] -name = "raw_orthotope" -type = "::FlexFlow::Orthotope" diff --git a/lib/op-attrs/src/op-attrs/operator_space_parallel_tensor_space_mapping.cc b/lib/op-attrs/src/op-attrs/operator_space_parallel_tensor_space_mapping.cc new file mode 100644 index 0000000000..651966840a --- /dev/null +++ b/lib/op-attrs/src/op-attrs/operator_space_parallel_tensor_space_mapping.cc @@ -0,0 +1,26 @@ +#include "op-attrs/operator_space_parallel_tensor_space_mapping.h" +#include "op-attrs/parallel_tensor_dim_degrees.h" +#include "utils/containers/range.h" +#include "utils/containers/set_of.h" +#include "utils/containers/transform.h" +#include "utils/bidict/algorithms/bidict_from_keys_and_values.h" + +namespace FlexFlow { + +OperatorSpaceParallelTensorSpaceMapping + get_identity_mapping(ParallelTensorDimDegrees const °rees) { + + std::set parallel_tensor_dim_indices + = get_nontrivial_parallel_tensor_dim_indices(degrees); + + std::set operator_space_dim_indices + = transform(set_of(range(parallel_tensor_dim_indices.size())), + [](int raw_idx) { return operator_task_space_dim_idx_t{raw_idx}; }); + + bidict raw_bidict + = bidict_from_keys_and_values(vector_of(operator_space_dim_indices), vector_of(parallel_tensor_dim_indices)); + + return OperatorSpaceParallelTensorSpaceMapping{DimProjection{EqProjection{raw_bidict}}}; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/linear.cc b/lib/op-attrs/src/op-attrs/ops/linear.cc index a81155b799..438bec9708 100644 --- a/lib/op-attrs/src/op-attrs/ops/linear.cc +++ b/lib/op-attrs/src/op-attrs/ops/linear.cc @@ -1,6 +1,7 @@ #include "op-attrs/ops/linear.h" #include "op-attrs/dim_ordered/slice.h" #include "op-attrs/dim_ordered/transform.h" +#include "op-attrs/operator_space_parallel_tensor_space_mapping.h" #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_shape.h" #include "utils/containers/product.h" @@ -139,10 +140,27 @@ tl::expected unpar, sum_degree, discard_copy_degree, shard_degrees); } +tl::expected + get_projection_space_mapping(LinearAttrs const &attrs, + ParallelTensorDimDegrees const &input) { + + SumDegree sum_degree = SumDegree{1}; + DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{ + get_sum_degree(input) * + product( + slice(ff_ordered_shard_degrees(input), std::nullopt, ff_dim_t{-1}))}; + FFOrdered shard_degrees = FFOrdered{ + shard_dim_at_idx(input, ff_dim_t{-1}).degree, + get_discard_copy_degree(input), + }; + + return +} + tl::expected get_output_space_mapping(LinearAttrs const &attrs, - ParallelTensorSpace const &input) { - NOT_IMPLEMENTED(); + ParallelTensorDimDegrees const &input) { + return get_identity_mapping(input); } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dim_degrees.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dim_degrees.cc index 24d9683ab4..aef224a31e 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dim_degrees.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dim_degrees.cc @@ -2,6 +2,7 @@ #include "op-attrs/dim_ordered/get_idxs.h" #include "op-attrs/parallel_tensor_dim_idx_t.dtg.h" #include "op-attrs/parallel_tensor_space_coordinate.h" +#include "utils/containers/filtrans.h" #include "utils/containers/generate_map.h" #include "utils/containers/get_all_assignments.h" #include "utils/containers/map_keys.h" @@ -10,9 +11,33 @@ #include "utils/containers/range.h" #include "utils/containers/unordered_set_of.h" #include "utils/containers/transform.h" +#include "utils/containers/set_union.h" namespace FlexFlow { +std::set get_nontrivial_parallel_tensor_dim_indices(ParallelTensorDimDegrees const °rees) { + std::set nontrivial_replica_dims; + + if (degrees.sum_degree.value > 1) { + nontrivial_replica_dims.insert(parallel_tensor_dim_idx_t{ReplicaType::SUM}); + } + + if (degrees.discard_copy_degree.value > 1) { + nontrivial_replica_dims.insert(parallel_tensor_dim_idx_t{ReplicaType::DISCARD_COPY}); + } + + std::set nontrivial_shard_dims = + filtrans(get_idxs(degrees.shard_degrees), [&](ff_dim_t const &dim) -> std::optional { + if (degrees.shard_degrees.at(dim) > 1) { + return parallel_tensor_dim_idx_t{dim}; + } else { + return std::nullopt; + } + }); + + return set_union(nontrivial_replica_dims, nontrivial_shard_dims); +} + std::unordered_map get_parallel_tensor_degree_map(ParallelTensorDimDegrees const °rees) { diff --git a/lib/utils/include/utils/bidict/algorithms/exhaustive_relational_join.h b/lib/utils/include/utils/bidict/algorithms/exhaustive_relational_join.h new file mode 100644 index 0000000000..781844af8b --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/exhaustive_relational_join.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_EXHAUSTIVE_RELATIONAL_JOIN_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_EXHAUSTIVE_RELATIONAL_JOIN_H + +#include "utils/bidict/bidict.h" +#include "utils/exception.h" + +namespace FlexFlow { + +template +bidict exhaustive_relational_join(bidict const &fst, bidict const &snd) { + if (fst.size() != snd.size()) { + throw mk_runtime_error(fmt::format("exhaustive_relational_join received bidicts of different sizes: fst has size {} while snd has size {}", fst.size(), snd.size())); + } + + bidict result; + + for (auto const &[v1, v2] : fst) { + result.equate({v1, snd.at_l(v2)}); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bidict/algorithms/filter_keys.h b/lib/utils/include/utils/bidict/algorithms/filter_keys.h new file mode 100644 index 0000000000..2734dfaeb5 --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/filter_keys.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_FILTER_KEYS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_FILTER_KEYS_H + +#include "utils/bidict/bidict.h" + +namespace FlexFlow { + +template +bidict filter_keys(bidict const &m, F &&f) { + bidict result; + for (auto const &kv : m) { + if (f(kv.first)) { + result.equate(kv); + } + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bidict/algorithms/filter_values.h b/lib/utils/include/utils/bidict/algorithms/filter_values.h new file mode 100644 index 0000000000..be95de51f6 --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/filter_values.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_FILTER_VALUES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_FILTER_VALUES_H + +#include "utils/bidict/bidict.h" + +namespace FlexFlow { + +template +bidict filter_values(bidict const &m, F &&f) { + bidict result; + for (auto const &kv : m) { + if (f(kv.second)) { + result.equate(kv); + } + } + return result; +} + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bidict/algorithms/filtrans_keys.h b/lib/utils/include/utils/bidict/algorithms/filtrans_keys.h new file mode 100644 index 0000000000..df6495b400 --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/filtrans_keys.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_FILTRANS_KEYS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_FILTRANS_KEYS_H + +#include "utils/bidict/bidict.h" + +namespace FlexFlow { + +template ::value_type> +bidict filtrans_keys(bidict const &m, F &&f) { + bidict result; + for (auto const &[k, v] : m) { + std::optional new_k = f(k); + if (new_k.has_value()) { + result.equate(new_k.value(), v); + } + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bidict/algorithms/filtrans_values.h b/lib/utils/include/utils/bidict/algorithms/filtrans_values.h new file mode 100644 index 0000000000..11180938b8 --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/filtrans_values.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_FILTRANS_VALUES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_FILTRANS_VALUES_H + +#include "utils/bidict/bidict.h" + +namespace FlexFlow { + +template ::value_type> +bidict filtrans_values(bidict const &m, F &&f) { + bidict result; + for (auto const &[k, v] : m) { + std::optional new_v = f(v); + if (new_v.has_value()) { + result.equate(k, new_v.value()); + } + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bidict/algorithms/transform.h b/lib/utils/include/utils/bidict/algorithms/transform.h new file mode 100644 index 0000000000..2a56d54dab --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/transform.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_TRANSFORM_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_TRANSFORM_H + +#include "utils/bidict/bidict.h" + +namespace FlexFlow { + +template ::first_type, + typename V2 = typename std::invoke_result_t::second_type> +bidict transform(bidict const &m, F &&f) { + bidict result; + for (auto const &[k, v] : m) { + result.equate(f(k, v)); + } + return result; +} + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bidict/algorithms/transform_keys.h b/lib/utils/include/utils/bidict/algorithms/transform_keys.h new file mode 100644 index 0000000000..8ecb10c401 --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/transform_keys.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_TRANSFORM_KEYS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_TRANSFORM_KEYS_H + +#include "utils/bidict/bidict.h" + +namespace FlexFlow { + +template > +bidict transform_keys(bidict const &m, F &&f) { + bidict result; + for (auto const &kv : m) { + result.equate(f(kv.first), kv.second); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bidict/algorithms/transform_values.h b/lib/utils/include/utils/bidict/algorithms/transform_values.h new file mode 100644 index 0000000000..ef5b34ebe9 --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/transform_values.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_TRANSFORM_VALUES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_TRANSFORM_VALUES_H + +#include "utils/bidict/bidict.h" + +namespace FlexFlow { + +template > +bidict transform_values(bidict const &m, F &&f) { + bidict result; + for (auto const &kv : m) { + result.equate({kv.first, f(kv.second)}); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bidict/bidict.h b/lib/utils/include/utils/bidict/bidict.h index 8b19313002..d02be7c19b 100644 --- a/lib/utils/include/utils/bidict/bidict.h +++ b/lib/utils/include/utils/bidict/bidict.h @@ -6,6 +6,7 @@ #include #include #include +#include "utils/containers/keys.h" namespace FlexFlow { @@ -85,6 +86,14 @@ struct bidict { return bwd_map.at(r); } + std::unordered_set left_values() const { + return keys(this->fwd_map); + } + + std::unordered_set right_values() const { + return keys(this->bwd_map); + } + std::size_t size() const { assert(fwd_map.size() == bwd_map.size()); return fwd_map.size(); @@ -208,95 +217,6 @@ std::ostream &operator<<(std::ostream &s, bidict const &b) { return s << fmt::to_string(b); } -template ()(std::declval()))> -bidict map_keys(bidict const &m, F const &f) { - bidict result; - for (auto const &kv : m) { - result.equate(f(kv.first), kv.second); - } - return result; -} - -template ()(std::declval()))> -bidict map_values(bidict const &m, F const &f) { - bidict result; - for (auto const &kv : m) { - result.equate({kv.first, f(kv.second)}); - } - return result; -} - -template -bidict filter_keys(bidict const &m, F const &f) { - bidict result; - for (auto const &kv : m) { - if (f(kv.first)) { - result.equate(kv); - } - } - return result; -} - -template -bidict filter_values(bidict const &m, F const &f) { - bidict result; - for (auto const &kv : m) { - if (f(kv.second)) { - result.equate(kv); - } - } - return result; -} - -template ::value_type> -bidict filtermap_keys(bidict const &m, F const &f) { - bidict result; - for (auto const &[k, v] : m) { - std::optional new_k = f(k); - if (new_k.has_value()) { - result.equate(new_k.value(), v); - } - } - return result; -} - -template ::value_type> -bidict filtermap_values(bidict const &m, F const &f) { - bidict result; - for (auto const &[k, v] : m) { - std::optional new_v = f(v); - if (new_v.has_value()) { - result.equate(k, new_v.value()); - } - } - return result; -} - -template ::first_type, - typename V2 = typename std::invoke_result_t::second_type> -bidict transform(bidict const &m, F const &f) { - bidict result; - for (auto const &[k, v] : m) { - result.equate(f(k, v)); - } - return result; -} - } // namespace FlexFlow namespace std { diff --git a/lib/utils/include/utils/containers/multiset_of.h b/lib/utils/include/utils/containers/multiset_of.h new file mode 100644 index 0000000000..79bfbc40a3 --- /dev/null +++ b/lib/utils/include/utils/containers/multiset_of.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MULTISET_OF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MULTISET_OF_H + +#include + +namespace FlexFlow { + +template +std::multiset multiset_of(C const &c) { + return std::multiset{std::cbegin(c), std::cend(c)}; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/set_union.h b/lib/utils/include/utils/containers/set_union.h index 0f7b895f7a..2d073b177a 100644 --- a/lib/utils/include/utils/containers/set_union.h +++ b/lib/utils/include/utils/containers/set_union.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SET_UNION_H #include +#include namespace FlexFlow { @@ -13,6 +14,13 @@ std::unordered_set set_union(std::unordered_set const &l, return result; } +template +std::set set_union(std::set const &l, std::set const &r) { + std::set result = l; + result.insert(r.cbegin(), r.cend()); + return result; +} + template std::unordered_set set_union(C const &sets) { std::unordered_set result; diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.h index ecf9c22143..fea564f2c1 100644 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.h +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.h @@ -7,6 +7,7 @@ #include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" #include "utils/graph/open_dataflow_graph/algorithms/new_dataflow_graph_input.dtg.h" #include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_isomorphism.dtg.h" +#include "utils/bidict/algorithms/transform_values.h" namespace FlexFlow { @@ -17,11 +18,11 @@ bool is_isomorphic_under( OpenDataflowGraphIsomorphism const &candidate_isomorphism) { bidict node_permutation = - map_values(candidate_isomorphism.node_mapping, [](Node const &dst_node) { + transform_values(candidate_isomorphism.node_mapping, [](Node const &dst_node) { return NewNode{dst_node}; }).reversed(); bidict input_permutation = - map_values(candidate_isomorphism.input_mapping, + transform_values(candidate_isomorphism.input_mapping, [](DataflowGraphInput const &dst_input) { return NewDataflowGraphInput{dst_input}; }) diff --git a/lib/utils/include/utils/many_to_one/exhaustive_relational_join.h b/lib/utils/include/utils/many_to_one/exhaustive_relational_join.h new file mode 100644 index 0000000000..32c2239f5a --- /dev/null +++ b/lib/utils/include/utils/many_to_one/exhaustive_relational_join.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_EXHAUSTIVE_RELATIONAL_JOIN_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_EXHAUSTIVE_RELATIONAL_JOIN_H + +#include "utils/many_to_one/many_to_one.h" + +namespace FlexFlow { + +template +ManyToOne exhaustive_relational_join(ManyToOne const &fst, ManyToOne const &snd) { + ManyToOne result; + + if (fst.right_values() != snd.left_values()) { + throw mk_runtime_error(fmt::format("exhaustive_relational_join for ManyToOne received inputs with non-matching inner dimensions: right dimension of fst is {} while left dimension of snd is {}", fst.right_values(), snd.left_values())); + } + + for (T3 const &t3 : snd.right_values()) { + for (T2 const &t2 : snd.at_r(t3)) { + for (T1 const &t1 : fst.at_r(t2)) { + result.insert({t1, t3}); + } + } + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/many_to_one/many_to_one.h b/lib/utils/include/utils/many_to_one/many_to_one.h index 985fe43076..95c070a72d 100644 --- a/lib/utils/include/utils/many_to_one/many_to_one.h +++ b/lib/utils/include/utils/many_to_one/many_to_one.h @@ -8,8 +8,11 @@ #include "utils/hash-utils.h" #include "utils/hash/unordered_map.h" #include "utils/hash/unordered_set.h" +#include "utils/hash/tuple.h" #include "utils/containers/keys.h" #include "utils/exception.h" +#include "utils/fmt/unordered_map.h" +#include "utils/fmt/unordered_set.h" namespace FlexFlow { @@ -20,6 +23,21 @@ struct ManyToOne { : l_to_r(), r_to_l() { } + template + ManyToOne(It start, It end) + : ManyToOne() + { + for (; start < end; start++) { + for (L const &l : start->first) { + this->insert(std::pair{l, start->second}); + } + } + } + + ManyToOne(std::initializer_list, R>> const &l_to_r) + : ManyToOne(l_to_r.begin(), l_to_r.end()) + { } + bool operator==(ManyToOne const &other) const { return this->tie() == other.tie(); } @@ -70,6 +88,22 @@ struct ManyToOne { friend struct std::hash>; }; +template +std::unordered_map, R> format_as(ManyToOne const &m) { + std::unordered_map, R> result; + + for (R const &r : m.right_values()) { + result.insert({m.at_r(r), r}); + } + + return result; +} + +template +std::ostream &operator<<(std::ostream &s, ManyToOne const &m) { + return (s << fmt::to_string(m)); +} + } // namespace FlexFlow namespace std { diff --git a/lib/utils/include/utils/one_to_many/exhaustive_relational_join.h b/lib/utils/include/utils/one_to_many/exhaustive_relational_join.h new file mode 100644 index 0000000000..b70320e130 --- /dev/null +++ b/lib/utils/include/utils/one_to_many/exhaustive_relational_join.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_EXHAUSTIVE_RELATIONAL_JOIN_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_EXHAUSTIVE_RELATIONAL_JOIN_H + +#include "utils/one_to_many/one_to_many.h" + +namespace FlexFlow { + +template +OneToMany exhaustive_relational_join(OneToMany const &fst, OneToMany const &snd) { + OneToMany result; + + if (fst.right_values() != snd.left_values()) { + throw mk_runtime_error(fmt::format("exhaustive_relational_join for OneToMany received inputs with non-matching inner dimensions: right dimension of fst is {} while left dimension of snd is {}", fst.right_values(), snd.left_values())); + } + + for (T1 const &t1 : fst.left_values()) { + for (T2 const &t2 : fst.at_l(t1)) { + for (T3 const &t3 : snd.at_l(t2)) { + result.insert({t1, t3}); + } + } + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/one_to_many/one_to_many.h b/lib/utils/include/utils/one_to_many/one_to_many.h index 58cbd5c40e..8ce8808207 100644 --- a/lib/utils/include/utils/one_to_many/one_to_many.h +++ b/lib/utils/include/utils/one_to_many/one_to_many.h @@ -8,8 +8,12 @@ #include "utils/hash-utils.h" #include "utils/hash/unordered_map.h" #include "utils/hash/unordered_set.h" +#include "utils/hash/tuple.h" #include "utils/containers/keys.h" #include "utils/exception.h" +#include "utils/containers/generate_map.h" +#include "utils/fmt/unordered_map.h" +#include "utils/fmt/unordered_set.h" namespace FlexFlow { @@ -20,6 +24,21 @@ struct OneToMany { : l_to_r(), r_to_l() { } + template + OneToMany(It start, It end) + : OneToMany() + { + for (; start < end; start++) { + for (R const &r : start->second) { + this->insert(std::pair{start->first, r}); + } + } + } + + OneToMany(std::initializer_list>> const &l_to_r) + : OneToMany(l_to_r.begin(), l_to_r.end()) + { } + bool operator==(OneToMany const &other) const { return this->tie() == other.tie(); } @@ -70,6 +89,16 @@ struct OneToMany { friend struct std::hash>; }; +template +std::unordered_map> format_as(OneToMany const &m) { + return generate_map(m.left_values(), [&](L const &l) { return m.at_l(l); }); +} + +template +std::ostream &operator<<(std::ostream &s, OneToMany const &m) { + return (s << fmt::to_string(m)); +} + } // namespace FlexFlow namespace std { diff --git a/lib/utils/include/utils/one_to_many/one_to_many_from_l_to_r_mapping.h b/lib/utils/include/utils/one_to_many/one_to_many_from_l_to_r_mapping.h new file mode 100644 index 0000000000..fccd878282 --- /dev/null +++ b/lib/utils/include/utils/one_to_many/one_to_many_from_l_to_r_mapping.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_FROM_L_TO_R_MAPPING_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_FROM_L_TO_R_MAPPING_H + +#include "utils/one_to_many/one_to_many.h" + +namespace FlexFlow { + +template +OneToMany one_to_many_from_l_to_r_mapping(std::unordered_map> const &m) { + OneToMany result; + + for (auto const &[l, rs] : m) { + for (auto const &r : rs) { + result.insert({l, r}); + } + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/orthotope/dim_projection.h b/lib/utils/include/utils/orthotope/dim_projection.h new file mode 100644 index 0000000000..47f470ba7a --- /dev/null +++ b/lib/utils/include/utils/orthotope/dim_projection.h @@ -0,0 +1,49 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_DIM_PROJECTION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_DIM_PROJECTION_H + +#include "utils/orthotope/down_projection.dtg.h" +#include "utils/orthotope/eq_projection.dtg.h" +#include "utils/orthotope/up_projection.dtg.h" + +namespace FlexFlow { + +template +EqProjection compose_dim_projections(EqProjection const &fst, EqProjection const &snd) { + return EqProjection{ + exhaustive_relational_join(fst.dim_mapping, snd.dim_mapping) + }; +} + +template +UpProjection compose_dim_projections(UpProjection const &fst, UpProjection const &snd) { + NOT_IMPLEMENTED(); +} + +template +DownProjection compose_dim_projections(DownProjection const &fst, DownProjection const &snd) { + NOT_IMPLEMENTED(); +} + +template +UpProjection compose_dim_projections(EqProjection const &fst, UpProjection const &snd) { + NOT_IMPLEMENTED(); +} + +template +UpProjection compose_dim_projections(UpProjection const &fst, EqProjection const &snd) { + NOT_IMPLEMENTED(); +} + +template +DownProjection compose_dim_projections(EqProjection const &fst, DownProjection const &snd) { + NOT_IMPLEMENTED(); +} + +template +DownProjection compose_dim_projections(DownProjection const &fst, EqProjection const &snd) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/src/utils/bidict/algorithms/exhaustive_relational_join.cc b/lib/utils/src/utils/bidict/algorithms/exhaustive_relational_join.cc new file mode 100644 index 0000000000..a1471c16f3 --- /dev/null +++ b/lib/utils/src/utils/bidict/algorithms/exhaustive_relational_join.cc @@ -0,0 +1,13 @@ +#include "utils/bidict/algorithms/exhaustive_relational_join.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using T1 = value_type<0>; +using T2 = value_type<1>; +using T3 = value_type<2>; + +template + bidict exhaustive_relational_join(bidict const &, bidict const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/bidict/algorithms/filter_keys.cc b/lib/utils/src/utils/bidict/algorithms/filter_keys.cc new file mode 100644 index 0000000000..37c89ee459 --- /dev/null +++ b/lib/utils/src/utils/bidict/algorithms/filter_keys.cc @@ -0,0 +1,13 @@ +#include "utils/bidict/algorithms/filter_keys.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; +using F = std::function; + +template + bidict filter_keys(bidict const &, F &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/bidict/algorithms/filter_values.cc b/lib/utils/src/utils/bidict/algorithms/filter_values.cc new file mode 100644 index 0000000000..e24870e382 --- /dev/null +++ b/lib/utils/src/utils/bidict/algorithms/filter_values.cc @@ -0,0 +1,13 @@ +#include "utils/bidict/algorithms/filter_values.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; +using F = std::function; + +template + bidict filter_values(bidict const &, F &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/bidict/algorithms/filtrans_keys.cc b/lib/utils/src/utils/bidict/algorithms/filtrans_keys.cc new file mode 100644 index 0000000000..3644df0bb7 --- /dev/null +++ b/lib/utils/src/utils/bidict/algorithms/filtrans_keys.cc @@ -0,0 +1,14 @@ +#include "utils/bidict/algorithms/filtrans_keys.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; +using K2 = value_type<2>; +using F = std::function(K)>; + +template + bidict filtrans_keys(bidict const &, F &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/bidict/algorithms/filtrans_values.cc b/lib/utils/src/utils/bidict/algorithms/filtrans_values.cc new file mode 100644 index 0000000000..fe4767e196 --- /dev/null +++ b/lib/utils/src/utils/bidict/algorithms/filtrans_values.cc @@ -0,0 +1,15 @@ +#include "utils/bidict/algorithms/filtrans_values.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; +using V2 = value_type<2>; +using F = std::function(V)>; + + +template + bidict filtrans_values(bidict const &, F &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/bidict/algorithms/transform.cc b/lib/utils/src/utils/bidict/algorithms/transform.cc new file mode 100644 index 0000000000..7ab98e50a3 --- /dev/null +++ b/lib/utils/src/utils/bidict/algorithms/transform.cc @@ -0,0 +1,15 @@ +#include "utils/bidict/algorithms/transform.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; +using K2 = value_type<2>; +using V2 = value_type<3>; +using F = std::function(K, V)>; + +template + bidict transform(bidict const &, F &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/bidict/algorithms/transform_keys.cc b/lib/utils/src/utils/bidict/algorithms/transform_keys.cc new file mode 100644 index 0000000000..039f7b30d5 --- /dev/null +++ b/lib/utils/src/utils/bidict/algorithms/transform_keys.cc @@ -0,0 +1,14 @@ +#include "utils/bidict/algorithms/transform_keys.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; +using K2 = value_type<2>; +using F = std::function; + +template + bidict transform_keys(bidict const &, F &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/bidict/algorithms/transform_values.cc b/lib/utils/src/utils/bidict/algorithms/transform_values.cc new file mode 100644 index 0000000000..8e3c0e9594 --- /dev/null +++ b/lib/utils/src/utils/bidict/algorithms/transform_values.cc @@ -0,0 +1,14 @@ +#include "utils/bidict/algorithms/transform_values.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; +using V2 = value_type<2>; +using F = std::function; + +template + bidict transform_values(bidict const &, F &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/bidict/bidict.cc b/lib/utils/src/utils/bidict/bidict.cc index 57d75dfdfd..227b547163 100644 --- a/lib/utils/src/utils/bidict/bidict.cc +++ b/lib/utils/src/utils/bidict/bidict.cc @@ -1 +1,17 @@ #include "utils/bidict/bidict.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using L = value_type<0>; +using R = value_type<1>; + +template struct bidict; + +template + std::unordered_map format_as(bidict const &); + +template + std::ostream &operator<<(std::ostream &, bidict const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/multiset_of.cc b/lib/utils/src/utils/containers/multiset_of.cc new file mode 100644 index 0000000000..ea1cb18906 --- /dev/null +++ b/lib/utils/src/utils/containers/multiset_of.cc @@ -0,0 +1,11 @@ +#include "utils/containers/multiset_of.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using T = ordered_value_type<0>; + +template std::multiset multiset_of(std::vector const &); +template std::multiset multiset_of(std::set const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/set_union.cc b/lib/utils/src/utils/containers/set_union.cc index 646e862808..b3f581b42d 100644 --- a/lib/utils/src/utils/containers/set_union.cc +++ b/lib/utils/src/utils/containers/set_union.cc @@ -1 +1,21 @@ #include "utils/containers/set_union.h" +#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using T = value_type<0>; + +template + std::unordered_set set_union(std::unordered_set const &, + std::unordered_set const &); + +using T2 = ordered_value_type<0>; + +template + std::set set_union(std::set const &, std::set const &); + +template + std::unordered_set set_union(std::vector> const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/is_isomorphic_under.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/is_isomorphic_under.cc index 77e23d9c87..4f150c96cd 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/is_isomorphic_under.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/is_isomorphic_under.cc @@ -4,6 +4,7 @@ #include "utils/graph/open_dataflow_graph/algorithms/new_dataflow_graph_input.dtg.h" #include "utils/graph/open_dataflow_graph/algorithms/permute_input_ids.h" #include "utils/graph/open_dataflow_graph/algorithms/permute_node_ids.h" +#include "utils/bidict/algorithms/transform_values.h" namespace FlexFlow { @@ -13,11 +14,11 @@ bool is_isomorphic_under( OpenDataflowGraphIsomorphism const &candidate_isomorphism) { bidict node_permutation = - map_values(candidate_isomorphism.node_mapping, [](Node const &dst_node) { + transform_values(candidate_isomorphism.node_mapping, [](Node const &dst_node) { return NewNode{dst_node}; }).reversed(); bidict input_permutation = - map_values(candidate_isomorphism.input_mapping, + transform_values(candidate_isomorphism.input_mapping, [](DataflowGraphInput const &dst_input) { return NewDataflowGraphInput{dst_input}; }) diff --git a/lib/utils/src/utils/many_to_one/exhaustive_relational_join.cc b/lib/utils/src/utils/many_to_one/exhaustive_relational_join.cc new file mode 100644 index 0000000000..13177281c6 --- /dev/null +++ b/lib/utils/src/utils/many_to_one/exhaustive_relational_join.cc @@ -0,0 +1,13 @@ +#include "utils/many_to_one/exhaustive_relational_join.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using T1 = value_type<0>; +using T2 = value_type<1>; +using T3 = value_type<2>; + +template + ManyToOne exhaustive_relational_join(ManyToOne const &, ManyToOne const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/many_to_one/many_to_one.cc b/lib/utils/src/utils/many_to_one/many_to_one.cc index f4b2a59756..633ba24fa2 100644 --- a/lib/utils/src/utils/many_to_one/many_to_one.cc +++ b/lib/utils/src/utils/many_to_one/many_to_one.cc @@ -1,11 +1,23 @@ #include "utils/many_to_one/many_to_one.h" #include "utils/archetypes/value_type.h" -namespace FlexFlow { +using namespace ::FlexFlow; using L = value_type<0>; using R = value_type<1>; +namespace FlexFlow { + template struct ManyToOne; +template std::unordered_map, R> format_as(ManyToOne const &); + +template std::ostream &operator<<(std::ostream &, ManyToOne const &); + } // namespace FlexFlow + +namespace std { + +template struct hash>; + +} diff --git a/lib/utils/src/utils/one_to_many/exhaustive_relational_join.cc b/lib/utils/src/utils/one_to_many/exhaustive_relational_join.cc new file mode 100644 index 0000000000..8dad537b9b --- /dev/null +++ b/lib/utils/src/utils/one_to_many/exhaustive_relational_join.cc @@ -0,0 +1,13 @@ +#include "utils/one_to_many/exhaustive_relational_join.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using T1 = value_type<0>; +using T2 = value_type<1>; +using T3 = value_type<2>; + +template + OneToMany exhaustive_relational_join(OneToMany const &, OneToMany const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/one_to_many/one_to_many.cc b/lib/utils/src/utils/one_to_many/one_to_many.cc index a962de7f52..b50c64d5b9 100644 --- a/lib/utils/src/utils/one_to_many/one_to_many.cc +++ b/lib/utils/src/utils/one_to_many/one_to_many.cc @@ -1,12 +1,24 @@ #include "utils/one_to_many/one_to_many.h" #include "utils/archetypes/value_type.h" -namespace FlexFlow { +using namespace ::FlexFlow; using L = value_type<0>; using R = value_type<1>; +namespace FlexFlow { + template struct OneToMany; +template std::unordered_map> format_as(OneToMany const &); + +template std::ostream &operator<<(std::ostream &, OneToMany const &); + } // namespace FlexFlow +namespace std { + +template struct hash>; + +} + diff --git a/lib/utils/src/utils/one_to_many/one_to_many_from_l_to_r_mapping.cc b/lib/utils/src/utils/one_to_many/one_to_many_from_l_to_r_mapping.cc new file mode 100644 index 0000000000..2e5bf0176c --- /dev/null +++ b/lib/utils/src/utils/one_to_many/one_to_many_from_l_to_r_mapping.cc @@ -0,0 +1,12 @@ +#include "utils/one_to_many/one_to_many_from_l_to_r_mapping.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using L = value_type<0>; +using R = value_type<1>; + +template + OneToMany one_to_many_from_l_to_r_mapping(std::unordered_map> const &); + +} // namespace FlexFlow diff --git a/lib/utils/test/src/utils/bidict/algorithms/exhaustive_relational_join.cc b/lib/utils/test/src/utils/bidict/algorithms/exhaustive_relational_join.cc new file mode 100644 index 0000000000..60ef6cb8c5 --- /dev/null +++ b/lib/utils/test/src/utils/bidict/algorithms/exhaustive_relational_join.cc @@ -0,0 +1,102 @@ +#include "utils/bidict/algorithms/exhaustive_relational_join.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("exhaustive_relational_join(bidict, bidict)") { + SUBCASE("inputs are empty") { + bidict fst = {}; + bidict> snd = {}; + + bidict> result = exhaustive_relational_join(fst, snd); + bidict> correct = {}; + + CHECK(result == correct); + } + + SUBCASE("join is exhaustive") { + bidict fst = { + {1, "one"}, + {2, "two"}, + {3, "three"}, + }; + bidict> snd = { + {"one", {2, 0}}, + {"two", {3, 1}}, + {"three", {4, 2}} + }; + + bidict> result = exhaustive_relational_join(fst, snd); + bidict> correct = { + {1, {2, 0}}, + {2, {3, 1}}, + {3, {4, 2}}, + }; + + CHECK(result == correct); + } + + SUBCASE("extra relation in fst") { + bidict fst = { + {1, "one"}, + {2, "two"}, + {3, "three"}, + }; + bidict> snd = { + {"one", {2, 0}}, + {"two", {3, 1}}, + }; + + CHECK_THROWS(exhaustive_relational_join(fst, snd)); + } + + SUBCASE("extra relation in snd") { + bidict fst = { + {1, "one"}, + {3, "three"}, + }; + bidict> snd = { + {"one", {2, 0}}, + {"two", {3, 1}}, + {"three", {4, 2}}, + }; + + CHECK_THROWS(exhaustive_relational_join(fst, snd)); + } + + SUBCASE("same number of relations in fst and snd, but not matching") { + bidict fst = { + {1, "one"}, + {2, "two"}, + {3, "three"}, + }; + bidict> snd = { + {"one", {2, 0}}, + {"three", {4, 2}}, + {"four", {5, 3}}, + }; + + CHECK_THROWS(exhaustive_relational_join(fst, snd)); + } + + SUBCASE("works even if all the types are the same") { + bidict fst = { + {1, 2}, + {2, 3}, + }; + bidict snd = { + {2, 3}, + {3, 4}, + }; + + bidict result = exhaustive_relational_join(fst, snd); + bidict correct = { + {1, 3}, + {2, 4}, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/bidict/algorithms/filter_keys.cc b/lib/utils/test/src/utils/bidict/algorithms/filter_keys.cc new file mode 100644 index 0000000000..ae1e3310db --- /dev/null +++ b/lib/utils/test/src/utils/bidict/algorithms/filter_keys.cc @@ -0,0 +1,20 @@ +#include "utils/bidict/algorithms/filter_keys.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("filter_keys(bidict, F)") { + bidict dict = { + {1, "one"}, + {2, "two"}, + }; + + bidict result = + filter_keys(dict, [](int k) { return k == 1; }); + bidict correct = { + {1, "one"}, + }; + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/bidict/algorithms/filter_values.cc b/lib/utils/test/src/utils/bidict/algorithms/filter_values.cc new file mode 100644 index 0000000000..44110f297b --- /dev/null +++ b/lib/utils/test/src/utils/bidict/algorithms/filter_values.cc @@ -0,0 +1,21 @@ +#include "utils/bidict/algorithms/filter_values.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("filter_values(bidict, F") { + bidict dict = { + {1, "one"}, + {2, "two"}, + }; + + bidict result = + filter_values(dict, [](std::string const &v) { return v == "two"; }); + bidict correct = { + {2, "two"}, + }; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/bidict/algorithms/filtrans_keys.cc b/lib/utils/test/src/utils/bidict/algorithms/filtrans_keys.cc new file mode 100644 index 0000000000..5a6a4cac73 --- /dev/null +++ b/lib/utils/test/src/utils/bidict/algorithms/filtrans_keys.cc @@ -0,0 +1,30 @@ +#include "utils/bidict/algorithms/filtrans_keys.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("filtrans_keys(bidict, F)") { + bidict dict = { + {1, "one"}, + {2, "two"}, + }; + + bidict result = + filtrans_keys(dict, [](int k) -> std::optional { + if (k == 1) { + return std::nullopt; + } else { + std::ostringstream oss; + oss << (k + 1); + return oss.str(); + } + }); + + bidict correct = { + {"3", "two"}, + }; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/bidict/algorithms/filtrans_values.cc b/lib/utils/test/src/utils/bidict/algorithms/filtrans_values.cc new file mode 100644 index 0000000000..200b03dc88 --- /dev/null +++ b/lib/utils/test/src/utils/bidict/algorithms/filtrans_values.cc @@ -0,0 +1,26 @@ +#include "utils/bidict/algorithms/filtrans_values.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("filtrans_values(bidict, F)") { + bidict dict = { + {1, "one"}, + {2, "two"}, + }; + + bidict result = filtrans_values( + dict, [](std::string const &v) -> std::optional { + if (v == "two") { + return std::nullopt; + } else { + return v.size() + 1; + } + }); + bidict correct = { + {1, 4}, + }; + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/bidict/algorithms/transform.cc b/lib/utils/test/src/utils/bidict/algorithms/transform.cc new file mode 100644 index 0000000000..835f942f6b --- /dev/null +++ b/lib/utils/test/src/utils/bidict/algorithms/transform.cc @@ -0,0 +1,24 @@ +#include "utils/bidict/algorithms/transform.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("transform(bidict, F)") { + bidict dict = { + {1, "one"}, + {2, "two"}, + }; + + bidict result = + transform(dict, [](int k, std::string const &v) { + return std::make_pair(v, k); + }); + bidict correct = { + {"one", 1}, + {"two", 2}, + }; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/bidict/algorithms/transform_keys.cc b/lib/utils/test/src/utils/bidict/algorithms/transform_keys.cc new file mode 100644 index 0000000000..0f2114213f --- /dev/null +++ b/lib/utils/test/src/utils/bidict/algorithms/transform_keys.cc @@ -0,0 +1,25 @@ +#include "utils/bidict/algorithms/transform_keys.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("transform_keys(bidict, F)") { + bidict dict = { + {1, "one"}, + {2, "two"}, + }; + + bidict result = transform_keys(dict, [](int k) { + std::ostringstream oss; + oss << k; + return oss.str(); + }); + bidict correct = { + {"1", "one"}, + {"2", "two"}, + }; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/bidict/algorithms/transform_values.cc b/lib/utils/test/src/utils/bidict/algorithms/transform_values.cc new file mode 100644 index 0000000000..79ea5f838a --- /dev/null +++ b/lib/utils/test/src/utils/bidict/algorithms/transform_values.cc @@ -0,0 +1,21 @@ +#include "utils/bidict/algorithms/transform_values.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("transform_values(bidict, F)") { + bidict dict = { + {1, "one"}, + {2, "two"}, + }; + + bidict result = + transform_values(dict, [](std::string const &v) { return v + "a"; }); + bidict correct = { + {1, "onea"}, + {2, "twoa"}, + }; + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/bidict/bidict.cc b/lib/utils/test/src/utils/bidict/bidict.cc index d158af129f..f6cad73d5c 100644 --- a/lib/utils/test/src/utils/bidict/bidict.cc +++ b/lib/utils/test/src/utils/bidict/bidict.cc @@ -92,91 +92,6 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK_WITHOUT_STRINGIFY(it == dict.end()); } - SUBCASE("map_keys(bidict, F)") { - bidict result = map_keys(dict, [](int k) { - std::ostringstream oss; - oss << k; - return oss.str(); - }); - bidict correct = { - {"1", "one"}, - {"2", "two"}, - }; - CHECK(result == correct); - } - - SUBCASE("map_values(bidict, F)") { - bidict result = - map_values(dict, [](std::string const &v) { return v + "a"; }); - bidict correct = { - {1, "onea"}, - {2, "twoa"}, - }; - CHECK(result == correct); - } - - SUBCASE("filter_keys(bidict, F") { - bidict result = - filter_keys(dict, [](int k) { return k == 1; }); - bidict correct = { - {1, "one"}, - }; - CHECK(result == correct); - } - - SUBCASE("filter_values(bidict, F") { - bidict result = - filter_values(dict, [](std::string const &v) { return v == "two"; }); - bidict correct = { - {2, "two"}, - }; - CHECK(result == correct); - } - - SUBCASE("filtermap_keys(bidict, F)") { - bidict result = - filtermap_keys(dict, [](int k) -> std::optional { - if (k == 1) { - return std::nullopt; - } else { - std::ostringstream oss; - oss << (k + 1); - return oss.str(); - } - }); - bidict correct = { - {"3", "two"}, - }; - CHECK(result == correct); - } - - SUBCASE("filtermap_values(bidict, F)") { - bidict result = filtermap_values( - dict, [](std::string const &v) -> std::optional { - if (v == "two") { - return std::nullopt; - } else { - return v.size() + 1; - } - }); - bidict correct = { - {1, 4}, - }; - CHECK(result == correct); - } - - SUBCASE("transform(bidict, F)") { - bidict result = - transform(dict, [](int k, std::string const &v) { - return std::make_pair(v, k); - }); - bidict correct = { - {"one", 1}, - {"two", 2}, - }; - CHECK(result == correct); - } - SUBCASE("fmt::to_string(bidict)") { std::string result = fmt::to_string(dict); std::string correct = fmt::to_string(dict.as_unordered_map()); diff --git a/lib/utils/test/src/utils/containers/multiset_of.cc b/lib/utils/test/src/utils/containers/multiset_of.cc new file mode 100644 index 0000000000..d44979f655 --- /dev/null +++ b/lib/utils/test/src/utils/containers/multiset_of.cc @@ -0,0 +1,15 @@ +#include "utils/containers/multiset_of.h" +#include "test/utils/doctest/fmt/multiset.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("multiset_of") { + std::vector input = {1, 2, 3, 3, 2, 3}; + std::multiset result = multiset_of(input); + std::multiset correct = {1, 2, 3, 3, 2, 3}; + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/many_to_one/exhaustive_relational_join.cc b/lib/utils/test/src/utils/many_to_one/exhaustive_relational_join.cc new file mode 100644 index 0000000000..5488b65f51 --- /dev/null +++ b/lib/utils/test/src/utils/many_to_one/exhaustive_relational_join.cc @@ -0,0 +1,92 @@ +#include "utils/many_to_one/exhaustive_relational_join.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("exhaustive_relational_join(ManyToOne, ManyToOne)") { + SUBCASE("inputs are empty") { + ManyToOne fst = {}; + ManyToOne> snd = {}; + + ManyToOne> result = exhaustive_relational_join(fst, snd); + ManyToOne> correct = {}; + + CHECK(result == correct); + } + + SUBCASE("succeeds if join is exhaustive") { + ManyToOne fst = { + {{2, 4}, "2"}, + {{3, 9, 27}, "3"}, + {{5, 25, 125}, "5"}, + }; + + ManyToOne> snd = { + {{"2"}, {"even", true}}, + {{"3", "5"}, {"odd", false}}, + }; + + ManyToOne> result = exhaustive_relational_join(fst, snd); + ManyToOne> correct = { + {{2, 4}, {"even", true}}, + {{3, 9, 27, 5, 25, 125}, {"odd", false}}, + }; + + CHECK(result == correct); + } + + SUBCASE("throws if extra R in fst") { + ManyToOne fst = { + {{2, 4}, "2"}, + {{3, 9, 27}, "3"}, + {{5, 25, 125}, "5"}, + {{6, 36}, "6"}, + }; + + ManyToOne> snd = { + {{"2"}, {"even", true}}, + {{"3", "5"}, {"odd", false}}, + }; + + CHECK_THROWS(exhaustive_relational_join(fst, snd)); + } + + SUBCASE("throws if extra L in snd") { + ManyToOne fst = { + {{2, 4}, "2"}, + {{3, 9, 27}, "3"}, + {{5, 25, 125}, "5"}, + }; + + ManyToOne> snd = { + {{"2", "6"}, {"even", true}}, + {{"3", "5"}, {"odd", false}}, + }; + + CHECK_THROWS(exhaustive_relational_join(fst, snd)); + } + + SUBCASE("works even if all types are the same") { + ManyToOne fst = { + {{2, 4}, 2}, + {{3, 9, 27}, 3}, + {{5, 25, 125}, 5}, + }; + + ManyToOne snd = { + {{2}, 1}, + {{3, 5}, 0}, + }; + + + ManyToOne result = exhaustive_relational_join(fst, snd); + ManyToOne correct = { + {{2, 4}, 1}, + {{3, 9, 27, 5, 25, 125}, 0}, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/many_to_one/many_to_one.cc b/lib/utils/test/src/utils/many_to_one/many_to_one.cc new file mode 100644 index 0000000000..a416862e37 --- /dev/null +++ b/lib/utils/test/src/utils/many_to_one/many_to_one.cc @@ -0,0 +1,10 @@ +#include "utils/many_to_one/many_to_one.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("ManyToOne") { + FAIL("TODO"); + } +} diff --git a/lib/utils/test/src/utils/one_to_many/exhaustive_relational_join.cc b/lib/utils/test/src/utils/one_to_many/exhaustive_relational_join.cc new file mode 100644 index 0000000000..a6395494f7 --- /dev/null +++ b/lib/utils/test/src/utils/one_to_many/exhaustive_relational_join.cc @@ -0,0 +1,93 @@ +#include "utils/one_to_many/exhaustive_relational_join.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("exhaustive_relational_join(OneToMany, OneToMany)") { + SUBCASE("inputs are empty") { + OneToMany fst = {}; + OneToMany> snd = {}; + + OneToMany> result = exhaustive_relational_join(fst, snd); + OneToMany> correct = {}; + + CHECK(result == correct); + } + + SUBCASE("succeeds if join is exhaustive") { + OneToMany fst = { + {1, {"one", "ONE"}}, + {2, {"two"}}, + {3, {"three"}}, + }; + OneToMany> snd = { + {"one", {{"one", 0}}}, + {"ONE", {{"ONE", 0}, {"ONE", 1}}}, + {"two", {{"two", 0}, {"two", 1}}}, + {"three", {{"three", 2}}}, + }; + + OneToMany> result = exhaustive_relational_join(fst, snd); + OneToMany> correct = { + {1, {{"one", 0}, {"ONE", 0}, {"ONE", 1}}}, + {2, {{"two", 0}, {"two", 1}}}, + {3, {{"three", 2}}}, + }; + + CHECK(result == correct); + } + + SUBCASE("throws if extra R in fst") { + OneToMany fst = { + {1, {"one", "One", "ONE"}}, + {2, {"two"}}, + {3, {"three"}}, + }; + OneToMany> snd = { + {"one", {{"one", 0}}}, + {"ONE", {{"ONE", 0}, {"ONE", 1}}}, + {"two", {{"two", 0}, {"two", 1}, {"two", 2}}}, + {"three", {{"three", 2}}}, + }; + + CHECK_THROWS(exhaustive_relational_join(fst, snd)); + } + + SUBCASE("throws if extra L in snd") { + OneToMany fst = { + {1, {"one"}}, + {2, {"two"}}, + {3, {"three"}}, + }; + OneToMany> snd = { + {"one", {{"one", 0}}}, + {"ONE", {{"ONE", 0}, {"ONE", 1}}}, + {"two", {{"two", 0}, {"two", 1}, {"two", 2}}}, + {"three", {{"three", 2}}}, + }; + + CHECK_THROWS(exhaustive_relational_join(fst, snd)); + } + + SUBCASE("works even if all types are the same") { + OneToMany fst = { + {4, {2}}, + {18*20, {18, 20}}, + }; + OneToMany snd = { + {2, {2}}, + {18, {3, 6}}, + {20, {4, 5}}, + }; + + OneToMany result = exhaustive_relational_join(fst, snd); + OneToMany correct = { + {4, {2}}, + {18*20, {3, 4, 5, 6}}, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/one_to_many/one_to_many.cc b/lib/utils/test/src/utils/one_to_many/one_to_many.cc new file mode 100644 index 0000000000..38fe74ecf9 --- /dev/null +++ b/lib/utils/test/src/utils/one_to_many/one_to_many.cc @@ -0,0 +1,27 @@ +#include "utils/one_to_many/one_to_many.h" +#include "utils/one_to_many/one_to_many_from_l_to_r_mapping.h" +#include +#include "utils/containers/multiset_of.h" +#include "test/utils/doctest/fmt/multiset.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("OneToMany") { + FAIL("TODO"); + } + + TEST_CASE("fmt::to_string(OneToMany)") { + OneToMany input + = one_to_many_from_l_to_r_mapping({ + {1, {"hello", "world"}}, + {2, {}}, + {3, {"HELLO"}} + }); + + std::string result = fmt::to_string(input); + std::string correct = "{{1, {hello, world}}, {3, {HELLO}}}"; + + CHECK(multiset_of(result) == multiset_of(correct)); + } +} From 5e774c65501b97d376c5383e6098b9cd53bfccd2 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Fri, 27 Dec 2024 15:20:32 -0800 Subject: [PATCH 07/62] Fri Dec 27 03:20:32 PM PST 2024 --- lib/kernels/include/kernels/legion_dim.h | 2 + lib/kernels/src/{ => kernels}/legion_dim.cc | 7 +++ lib/op-attrs/include/op-attrs/ff_dim.h | 18 ------ lib/op-attrs/include/op-attrs/ff_dim_t.h | 20 ++++++ ...f_dim.struct.toml => ff_dim_t.struct.toml} | 0 ...ator_space_parallel_tensor_space_mapping.h | 4 +- lib/op-attrs/include/op-attrs/ops/linear.h | 8 ++- .../op-attrs/parallel_tensor_dim_idx_t.h | 14 +++++ .../op-attrs/parallel_tensor_space_mapping.h | 8 +++ .../parallel_tensor_space_mapping.struct.toml | 16 +++++ .../include/op-attrs/tensor_num_dims.h | 13 ++++ .../op-attrs/tensor_num_dims.struct.toml | 14 +++++ lib/op-attrs/src/op-attrs/ff_dim_t.cc | 24 ++++++++ ...tor_space_parallel_tensor_space_mapping.cc | 2 +- lib/op-attrs/src/op-attrs/ops/linear.cc | 61 ++++++++++++++++++- .../src/op-attrs/parallel_tensor_dim_idx_t.cc | 17 ++++++ lib/op-attrs/src/op-attrs/tensor_num_dims.cc | 10 +++ .../utils/many_to_one/invert_many_to_one.h | 22 +++++++ .../many_to_one/many_to_one_from_bidict.h | 22 +++++++ .../utils/one_to_many/invert_one_to_many.h | 22 +++++++ .../one_to_many/one_to_many_from_bidict.h | 22 +++++++ .../include/utils/orthotope/dim_projection.h | 37 ----------- .../include/utils/orthotope/down_projection.h | 48 +++++++++++++++ .../include/utils/orthotope/eq_projection.h | 26 ++++++++ .../include/utils/orthotope/up_projection.h | 47 ++++++++++++++ .../utils/many_to_one/invert_many_to_one.cc | 12 ++++ .../many_to_one/many_to_one_from_bidict.cc | 12 ++++ .../utils/one_to_many/invert_one_to_many.cc | 12 ++++ .../one_to_many/one_to_many_from_bidict.cc | 12 ++++ .../src/utils/orthotope/down_projection.cc | 28 +++++++++ .../src/utils/orthotope/eq_projection.cc | 19 ++++++ .../src/utils/orthotope/up_projection.cc | 28 +++++++++ .../one_to_many/one_to_many_from_bidict.cc | 47 ++++++++++++++ 33 files changed, 593 insertions(+), 61 deletions(-) rename lib/kernels/src/{ => kernels}/legion_dim.cc (58%) delete mode 100644 lib/op-attrs/include/op-attrs/ff_dim.h create mode 100644 lib/op-attrs/include/op-attrs/ff_dim_t.h rename lib/op-attrs/include/op-attrs/{ff_dim.struct.toml => ff_dim_t.struct.toml} (100%) create mode 100644 lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.h create mode 100644 lib/op-attrs/include/op-attrs/parallel_tensor_space_mapping.h create mode 100644 lib/op-attrs/include/op-attrs/parallel_tensor_space_mapping.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/tensor_num_dims.h create mode 100644 lib/op-attrs/include/op-attrs/tensor_num_dims.struct.toml create mode 100644 lib/op-attrs/src/op-attrs/ff_dim_t.cc create mode 100644 lib/op-attrs/src/op-attrs/parallel_tensor_dim_idx_t.cc create mode 100644 lib/op-attrs/src/op-attrs/tensor_num_dims.cc create mode 100644 lib/utils/include/utils/many_to_one/invert_many_to_one.h create mode 100644 lib/utils/include/utils/many_to_one/many_to_one_from_bidict.h create mode 100644 lib/utils/include/utils/one_to_many/invert_one_to_many.h create mode 100644 lib/utils/include/utils/one_to_many/one_to_many_from_bidict.h create mode 100644 lib/utils/include/utils/orthotope/down_projection.h create mode 100644 lib/utils/include/utils/orthotope/eq_projection.h create mode 100644 lib/utils/include/utils/orthotope/up_projection.h create mode 100644 lib/utils/src/utils/many_to_one/invert_many_to_one.cc create mode 100644 lib/utils/src/utils/many_to_one/many_to_one_from_bidict.cc create mode 100644 lib/utils/src/utils/one_to_many/invert_one_to_many.cc create mode 100644 lib/utils/src/utils/one_to_many/one_to_many_from_bidict.cc create mode 100644 lib/utils/src/utils/orthotope/down_projection.cc create mode 100644 lib/utils/src/utils/orthotope/eq_projection.cc create mode 100644 lib/utils/src/utils/orthotope/up_projection.cc create mode 100644 lib/utils/test/src/utils/one_to_many/one_to_many_from_bidict.cc diff --git a/lib/kernels/include/kernels/legion_dim.h b/lib/kernels/include/kernels/legion_dim.h index e4dd9723b8..d2476d7dac 100644 --- a/lib/kernels/include/kernels/legion_dim.h +++ b/lib/kernels/include/kernels/legion_dim.h @@ -6,6 +6,8 @@ namespace FlexFlow { +std::set legion_dim_range(int end); + legion_dim_t add_to_legion_dim(legion_dim_t legion_dim, int value); legion_dim_t legion_dim_from_ff_dim(ff_dim_t, int num_dimensions); diff --git a/lib/kernels/src/legion_dim.cc b/lib/kernels/src/kernels/legion_dim.cc similarity index 58% rename from lib/kernels/src/legion_dim.cc rename to lib/kernels/src/kernels/legion_dim.cc index 9ef47d40ae..d7a632367d 100644 --- a/lib/kernels/src/legion_dim.cc +++ b/lib/kernels/src/kernels/legion_dim.cc @@ -1,7 +1,14 @@ #include "kernels/legion_dim.h" +#include "utils/containers/range.h" +#include "utils/containers/transform.h" +#include "utils/containers/set_of.h" namespace FlexFlow { +std::set legion_dim_range(int end) { + return set_of(transform(range(end), [](int i) { return ff_dim_t{i}; })); +} + legion_dim_t add_to_legion_dim(legion_dim_t legion_dim, int value) { return legion_dim_t(legion_dim.value + value); } diff --git a/lib/op-attrs/include/op-attrs/ff_dim.h b/lib/op-attrs/include/op-attrs/ff_dim.h deleted file mode 100644 index e78ce4b51e..0000000000 --- a/lib/op-attrs/include/op-attrs/ff_dim.h +++ /dev/null @@ -1,18 +0,0 @@ - -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_H - -#include "op-attrs/ff_dim.dtg.h" -#include "rapidcheck.h" - -namespace rc { -template <> -struct Arbitrary { - static Gen arbitrary() { - return gen::construct( - gen::inRange(0, MAX_TENSOR_DIM)); - } -}; -} // namespace rc - -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_H diff --git a/lib/op-attrs/include/op-attrs/ff_dim_t.h b/lib/op-attrs/include/op-attrs/ff_dim_t.h new file mode 100644 index 0000000000..354c7ec4aa --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ff_dim_t.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_T_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_T_H + +#include "op-attrs/ff_dim.dtg.h" +#include "rapidcheck.h" + +namespace FlexFlow { + +std::set ff_dim_range(int end); + +} // namespace FlexFlow + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif diff --git a/lib/op-attrs/include/op-attrs/ff_dim.struct.toml b/lib/op-attrs/include/op-attrs/ff_dim_t.struct.toml similarity index 100% rename from lib/op-attrs/include/op-attrs/ff_dim.struct.toml rename to lib/op-attrs/include/op-attrs/ff_dim_t.struct.toml diff --git a/lib/op-attrs/include/op-attrs/operator_space_parallel_tensor_space_mapping.h b/lib/op-attrs/include/op-attrs/operator_space_parallel_tensor_space_mapping.h index 50c8db7d32..b659e2ffe5 100644 --- a/lib/op-attrs/include/op-attrs/operator_space_parallel_tensor_space_mapping.h +++ b/lib/op-attrs/include/op-attrs/operator_space_parallel_tensor_space_mapping.h @@ -2,12 +2,12 @@ #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_SPACE_PARALLEL_TENSOR_SPACE_MAPPING_H #include "op-attrs/operator_space_parallel_tensor_space_mapping.dtg.h" -#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" +#include "op-attrs/tensor_num_dims.dtg.h" namespace FlexFlow { OperatorSpaceParallelTensorSpaceMapping - get_identity_mapping(ParallelTensorDimDegrees const &); + get_identity_mapping(TensorNumDims const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/linear.h b/lib/op-attrs/include/op-attrs/ops/linear.h index 50d4f0fd36..2cde2bd22a 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear.h +++ b/lib/op-attrs/include/op-attrs/ops/linear.h @@ -6,10 +6,12 @@ #include "op-attrs/ops/linear_attrs.dtg.h" #include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_num_dims.dtg.h" #include "op-attrs/tensor_shape.dtg.h" #include "utils/record_formatter.h" #include #include "op-attrs/operator_space_parallel_tensor_space_mapping.dtg.h" +#include "op-attrs/parallel_tensor_space_mapping.dtg.h" namespace FlexFlow { @@ -27,6 +29,10 @@ tl::expected get_bias_shape(LinearAttrs const &attrs, tl::expected get_output_shape(LinearAttrs const &attrs, TensorShape const &input); +tl::expected + get_projection_to_output_parallel_dim_mapping(LinearAttrs const &attrs, + ParallelTensorDimDegrees const &input); + tl::expected get_projection_parallel_dim_degrees(LinearAttrs const &attrs, ParallelTensorDimDegrees const &input); tl::expected @@ -51,7 +57,7 @@ tl::expected ParallelTensorDimDegrees const &input); tl::expected get_output_space_mapping(LinearAttrs const &attrs, - ParallelTensorDimDegrees const &input); + TensorNumDims const &input_num_dims); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.h b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.h new file mode 100644 index 0000000000..61dfdd2b17 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIM_IDX_T_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIM_IDX_T_H + +#include "op-attrs/parallel_tensor_dim_idx_t.dtg.h" + +namespace FlexFlow { + +parallel_tensor_dim_idx_t sum_dim_idx(); +parallel_tensor_dim_idx_t discard_copy_dim_idx(); +parallel_tensor_dim_idx_t shard_dim_idx(ff_dim_t); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_space_mapping.h b/lib/op-attrs/include/op-attrs/parallel_tensor_space_mapping.h new file mode 100644 index 0000000000..ee982f2c4e --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_space_mapping.h @@ -0,0 +1,8 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SPACE_MAPPING_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SPACE_MAPPING_H + +namespace FlexFlow { + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_space_mapping.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_space_mapping.struct.toml new file mode 100644 index 0000000000..d611d77bfc --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_space_mapping.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "ParallelTensorSpaceMapping" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/orthotope/dim_projection.dtg.h", + "op-attrs/parallel_tensor_dim_idx_t.dtg.h", +] + +[[fields]] +name = "raw_projection" +type = "::FlexFlow::DimProjection<::FlexFlow::parallel_tensor_dim_idx_t, ::FlexFlow::parallel_tensor_dim_idx_t>" diff --git a/lib/op-attrs/include/op-attrs/tensor_num_dims.h b/lib/op-attrs/include/op-attrs/tensor_num_dims.h new file mode 100644 index 0000000000..372c1ccd5c --- /dev/null +++ b/lib/op-attrs/include/op-attrs/tensor_num_dims.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_NUM_DIMS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_NUM_DIMS_H + +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/tensor_num_dims.dtg.h" + +namespace FlexFlow { + +std::set ff_dim_idxs_for_num_dims(TensorNumDims const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/tensor_num_dims.struct.toml b/lib/op-attrs/include/op-attrs/tensor_num_dims.struct.toml new file mode 100644 index 0000000000..f421425290 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/tensor_num_dims.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "TensorNumDims" +features = [ + "eq", + "hash", + "ord", + "fmt", + "rapidcheck", + "json", +] + +[[fields]] +name = "value" +type = "int" diff --git a/lib/op-attrs/src/op-attrs/ff_dim_t.cc b/lib/op-attrs/src/op-attrs/ff_dim_t.cc new file mode 100644 index 0000000000..735a0f5ce0 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ff_dim_t.cc @@ -0,0 +1,24 @@ +#include "op-attrs/ff_dim_t.h" +#include "utils/containers/range.h" +#include "utils/containers/set_of.h" +#include "utils/containers/transform.h" + +using ::FlexFlow::ff_dim_t; + +namespace FlexFlow { + +std::set ff_dim_range(int end) { + return set_of(transform(range(end), [](int i) { return ff_dim_t{i}; })); +} + +} // namespace FlexFlow + +namespace rc { + +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::inRange(0, MAX_TENSOR_DIM)); +} + + +} // namespace rc diff --git a/lib/op-attrs/src/op-attrs/operator_space_parallel_tensor_space_mapping.cc b/lib/op-attrs/src/op-attrs/operator_space_parallel_tensor_space_mapping.cc index 651966840a..f3fcf84f6e 100644 --- a/lib/op-attrs/src/op-attrs/operator_space_parallel_tensor_space_mapping.cc +++ b/lib/op-attrs/src/op-attrs/operator_space_parallel_tensor_space_mapping.cc @@ -8,7 +8,7 @@ namespace FlexFlow { OperatorSpaceParallelTensorSpaceMapping - get_identity_mapping(ParallelTensorDimDegrees const °rees) { + get_identity_mapping(TensorNumDims const &tensor_num_dims) { std::set parallel_tensor_dim_indices = get_nontrivial_parallel_tensor_dim_indices(degrees); diff --git a/lib/op-attrs/src/op-attrs/ops/linear.cc b/lib/op-attrs/src/op-attrs/ops/linear.cc index 438bec9708..4d259813a4 100644 --- a/lib/op-attrs/src/op-attrs/ops/linear.cc +++ b/lib/op-attrs/src/op-attrs/ops/linear.cc @@ -1,11 +1,17 @@ #include "op-attrs/ops/linear.h" #include "op-attrs/dim_ordered/slice.h" #include "op-attrs/dim_ordered/transform.h" +#include "op-attrs/ff_dim_t.h" #include "op-attrs/operator_space_parallel_tensor_space_mapping.h" +#include "op-attrs/parallel_tensor_dim_idx_t.h" #include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_num_dims.dtg.h" #include "op-attrs/tensor_shape.h" #include "utils/containers/product.h" #include "utils/integer_conversions.h" +#include "utils/orthotope/down_projection.h" +#include "utils/orthotope/eq_projection.h" +#include "utils/orthotope/up_projection.h" namespace FlexFlow { @@ -140,9 +146,60 @@ tl::expected unpar, sum_degree, discard_copy_degree, shard_degrees); } +// tl::expected +// get_input_to_projection_parallel_mapping(LinearAttrs const &attrs, +// ParallelTensorDimDegrees const &input) { +// return ParallelTensor{ +// +// }; +// } + +tl::expected + get_input_to_output_projection(LinearAttrs const &attrs, TensorNumDims const &input_num_dims) { + + auto inp_to_out = make_empty_down_projection(); + + project_dims(inp_to_out, + /*from=*/{sum_dim_idx(), shard_dim_idx(ff_dim_t{-1})}, + /*onto=*/sum_dim_idx()); + project_dims(inp_to_out, + /*from=*/{discard_copy_dim_idx()}, + /*onto=*/shard_dim_idx(ff_dim_t{-1})); + + for (ff_dim_t const &idx : ff_dim_range(input_num_dims.value - 1)) { + project_dims(inp_to_out, + /*from=*/{shard_dim_idx(idx)}, + /*onto=*/shard_dim_idx(idx)); + } + + return ParallelTensorSpaceMapping{DimProjection{inp_to_out}}; +} + tl::expected - get_projection_space_mapping(LinearAttrs const &attrs, - ParallelTensorDimDegrees const &input) { + get_operator_to_input_projection(LinearAttrs const &attrs, + TensorNumDims const &input_num_dims) { + + TensorNumDims output_num_dims = input_num_dims; + + UpProjection + out_to_inp = invert_down_projection(throw_if_unexpected(get_input_to_output_projection(attrs, input_num_dims)).raw_projection.require_down_proj()); + + EqProjection + op_to_out = throw_if_unexpected(get_output_space_mapping(attrs, output_num_dims)).raw_projection.require_eq_proj(); + + return OperatorSpaceParallelTensorSpaceMapping{ + DimProjection{ + compose_up_projections(up_from_eq_proj(op_to_out), out_to_inp), + }, + }; +} + +tl::expected + get_operator_to_output_projection(LinearAttrs const &attrs, + TensorNumDims const &output_num_dims) { + + return OperatorSpaceParallelTensorSpaceMapping{ + get SumDegree sum_degree = SumDegree{1}; DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{ diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dim_idx_t.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dim_idx_t.cc new file mode 100644 index 0000000000..ef0fd4616c --- /dev/null +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dim_idx_t.cc @@ -0,0 +1,17 @@ +#include "op-attrs/parallel_tensor_dim_idx_t.h" + +namespace FlexFlow { + +parallel_tensor_dim_idx_t sum_dim_idx() { + return parallel_tensor_dim_idx_t{ReplicaType::SUM}; +} + +parallel_tensor_dim_idx_t discard_copy_dim_idx() { + return parallel_tensor_dim_idx_t{ReplicaType::DISCARD_COPY}; +} + +parallel_tensor_dim_idx_t shard_dim_idx(ff_dim_t idx) { + return parallel_tensor_dim_idx_t{idx}; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/tensor_num_dims.cc b/lib/op-attrs/src/op-attrs/tensor_num_dims.cc new file mode 100644 index 0000000000..94689c06de --- /dev/null +++ b/lib/op-attrs/src/op-attrs/tensor_num_dims.cc @@ -0,0 +1,10 @@ +#include "op-attrs/tensor_num_dims.h" +#include "op-attrs/ff_dim_t.h" + +namespace FlexFlow { + +std::set ff_dim_idxs_for_num_dims(TensorNumDims const &num_dims) { + return ff_dim_range(num_dims.value); +} + +} // namespace FlexFlow diff --git a/lib/utils/include/utils/many_to_one/invert_many_to_one.h b/lib/utils/include/utils/many_to_one/invert_many_to_one.h new file mode 100644 index 0000000000..83423541bb --- /dev/null +++ b/lib/utils/include/utils/many_to_one/invert_many_to_one.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_INVERT_MANY_TO_ONE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_INVERT_MANY_TO_ONE_H + +#include "utils/many_to_one/many_to_one.h" +#include "utils/one_to_many/one_to_many.h" + +namespace FlexFlow { + +template +OneToMany invert_many_to_one(ManyToOne const &many_to_one) { + OneToMany result; + + for (L const &l : many_to_one.left_values()) { + result.insert({many_to_one.at_l(l), l}); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/many_to_one/many_to_one_from_bidict.h b/lib/utils/include/utils/many_to_one/many_to_one_from_bidict.h new file mode 100644 index 0000000000..ba50a960c2 --- /dev/null +++ b/lib/utils/include/utils/many_to_one/many_to_one_from_bidict.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_MANY_TO_ONE_FROM_BIDICT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_MANY_TO_ONE_FROM_BIDICT_H + +#include "utils/bidict/bidict.h" +#include "utils/many_to_one/many_to_one.h" + +namespace FlexFlow { + +template +ManyToOne many_to_one_from_bidict(bidict const &b) { + ManyToOne result; + + for (auto const &[l, r] : b) { + result.insert({l, r}); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/one_to_many/invert_one_to_many.h b/lib/utils/include/utils/one_to_many/invert_one_to_many.h new file mode 100644 index 0000000000..bde623d387 --- /dev/null +++ b/lib/utils/include/utils/one_to_many/invert_one_to_many.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_INVERT_ONE_TO_MANY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_INVERT_ONE_TO_MANY_H + +#include "utils/many_to_one/many_to_one.h" +#include "utils/one_to_many/one_to_many.h" + +namespace FlexFlow { + +template +ManyToOne invert_one_to_many(OneToMany const &one_to_many) { + ManyToOne result; + + for (R const &r : one_to_many.right_values()) { + result.insert({r, one_to_many.at_r(r)}); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/one_to_many/one_to_many_from_bidict.h b/lib/utils/include/utils/one_to_many/one_to_many_from_bidict.h new file mode 100644 index 0000000000..3783f1f663 --- /dev/null +++ b/lib/utils/include/utils/one_to_many/one_to_many_from_bidict.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_ONE_TO_MANY_FROM_BIDICT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_ONE_TO_MANY_FROM_BIDICT_H + +#include "utils/bidict/bidict.h" +#include "utils/one_to_many/one_to_many.h" + +namespace FlexFlow { + +template +OneToMany one_to_many_from_bidict(bidict const &b) { + OneToMany result; + + for (auto const &[l, r] : b) { + result.insert({l, r}); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/orthotope/dim_projection.h b/lib/utils/include/utils/orthotope/dim_projection.h index 47f470ba7a..07e5885d49 100644 --- a/lib/utils/include/utils/orthotope/dim_projection.h +++ b/lib/utils/include/utils/orthotope/dim_projection.h @@ -7,43 +7,6 @@ namespace FlexFlow { -template -EqProjection compose_dim_projections(EqProjection const &fst, EqProjection const &snd) { - return EqProjection{ - exhaustive_relational_join(fst.dim_mapping, snd.dim_mapping) - }; -} - -template -UpProjection compose_dim_projections(UpProjection const &fst, UpProjection const &snd) { - NOT_IMPLEMENTED(); -} - -template -DownProjection compose_dim_projections(DownProjection const &fst, DownProjection const &snd) { - NOT_IMPLEMENTED(); -} - -template -UpProjection compose_dim_projections(EqProjection const &fst, UpProjection const &snd) { - NOT_IMPLEMENTED(); -} - -template -UpProjection compose_dim_projections(UpProjection const &fst, EqProjection const &snd) { - NOT_IMPLEMENTED(); -} - -template -DownProjection compose_dim_projections(EqProjection const &fst, DownProjection const &snd) { - NOT_IMPLEMENTED(); -} - -template -DownProjection compose_dim_projections(DownProjection const &fst, EqProjection const &snd) { - NOT_IMPLEMENTED(); -} - } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/orthotope/down_projection.h b/lib/utils/include/utils/orthotope/down_projection.h new file mode 100644 index 0000000000..472ffcc4d4 --- /dev/null +++ b/lib/utils/include/utils/orthotope/down_projection.h @@ -0,0 +1,48 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_DOWN_PROJECTION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_DOWN_PROJECTION_H + +#include "utils/orthotope/down_projection.dtg.h" +#include "utils/orthotope/eq_projection.dtg.h" +#include "utils/orthotope/up_projection.dtg.h" +#include "utils/many_to_one/many_to_one_from_bidict.h" +#include "utils/many_to_one/exhaustive_relational_join.h" + +namespace FlexFlow { + +template +DownProjection make_empty_down_projection() { + return DownProjection{ManyToOne{}}; +} + +template +void project_dims(DownProjection &proj, std::unordered_set const &from, R const &onto) { + for (L const &l : from) { + proj.dim_mapping.insert({l, onto}); + } +} + +template +UpProjection invert_down_projection(DownProjection const &down_proj) { + return UpProjection{ + invert_many_to_one(down_proj.dim_mapping), + }; +} + +template +DownProjection compose_down_projections(DownProjection const &fst, DownProjection const &snd) { + return DownProjection{ + exhaustive_relational_join(fst.dim_mapping, snd.dim_mapping), + }; +} + +template +DownProjection down_from_eq_proj(EqProjection const &eq) { + return DownProjection{ + many_to_one_from_bidict(eq.dim_mapping), + }; +} + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/orthotope/eq_projection.h b/lib/utils/include/utils/orthotope/eq_projection.h new file mode 100644 index 0000000000..f9b0766f5d --- /dev/null +++ b/lib/utils/include/utils/orthotope/eq_projection.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_EQ_PROJECTION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_EQ_PROJECTION_H + +#include "utils/orthotope/eq_projection.dtg.h" +#include "utils/bidict/algorithms/exhaustive_relational_join.h" + +namespace FlexFlow { + +template +EqProjection invert_eq_projection(EqProjection const &input) { + return EqProjection{ + input.dim_mapping.reversed(), + }; +} + +template +EqProjection compose_eq_projections(EqProjection const &fst, EqProjection const &snd) { + return EqProjection{ + exhaustive_relational_join(fst.dim_mapping, snd.dim_mapping) + }; +} + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/orthotope/up_projection.h b/lib/utils/include/utils/orthotope/up_projection.h new file mode 100644 index 0000000000..db2291b3db --- /dev/null +++ b/lib/utils/include/utils/orthotope/up_projection.h @@ -0,0 +1,47 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_UP_PROJECTION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_UP_PROJECTION_H + +#include "utils/orthotope/up_projection.dtg.h" +#include "utils/orthotope/eq_projection.dtg.h" +#include "utils/orthotope/down_projection.dtg.h" +#include "utils/one_to_many/one_to_many_from_bidict.h" +#include "utils/one_to_many/exhaustive_relational_join.h" + +namespace FlexFlow { + +template +UpProjection make_empty_up_projection() { + return UpProjection{OneToMany{}}; +} + +template +void project_dims(UpProjection &proj, L const &onto, std::unordered_set const &from) { + for (R const &r : from) { + proj.dim_mapping.insert({onto, r}); + } +} + +template +DownProjection invert_up_projection(UpProjection const &up_proj) { + return DownProjection{ + invert_one_to_many(up_proj.dim_mapping), + }; +} + +template +UpProjection compose_up_projections(UpProjection const &fst, UpProjection const &snd) { + return UpProjection{ + exhaustive_relational_join(fst.dim_mapping, snd.dim_mapping), + }; +} + +template +UpProjection up_from_eq_proj(EqProjection const &eq) { + return UpProjection{ + one_to_many_from_bidict(eq.dim_mapping), + }; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/src/utils/many_to_one/invert_many_to_one.cc b/lib/utils/src/utils/many_to_one/invert_many_to_one.cc new file mode 100644 index 0000000000..d25b9d6311 --- /dev/null +++ b/lib/utils/src/utils/many_to_one/invert_many_to_one.cc @@ -0,0 +1,12 @@ +#include "utils/many_to_one/invert_many_to_one.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using L = value_type<0>; +using R = value_type<1>; + +template + OneToMany invert_many_to_one(ManyToOne const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/many_to_one/many_to_one_from_bidict.cc b/lib/utils/src/utils/many_to_one/many_to_one_from_bidict.cc new file mode 100644 index 0000000000..3fcd37e301 --- /dev/null +++ b/lib/utils/src/utils/many_to_one/many_to_one_from_bidict.cc @@ -0,0 +1,12 @@ +#include "utils/many_to_one/many_to_one_from_bidict.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using L = value_type<0>; +using R = value_type<1>; + +template + ManyToOne many_to_one_from_bidict(bidict const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/one_to_many/invert_one_to_many.cc b/lib/utils/src/utils/one_to_many/invert_one_to_many.cc new file mode 100644 index 0000000000..2783fde9b4 --- /dev/null +++ b/lib/utils/src/utils/one_to_many/invert_one_to_many.cc @@ -0,0 +1,12 @@ +#include "utils/one_to_many/invert_one_to_many.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using L = value_type<0>; +using R = value_type<1>; + +template + ManyToOne invert_one_to_many(OneToMany const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/one_to_many/one_to_many_from_bidict.cc b/lib/utils/src/utils/one_to_many/one_to_many_from_bidict.cc new file mode 100644 index 0000000000..06f3029aeb --- /dev/null +++ b/lib/utils/src/utils/one_to_many/one_to_many_from_bidict.cc @@ -0,0 +1,12 @@ +#include "utils/one_to_many/one_to_many_from_bidict.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using L = value_type<0>; +using R = value_type<1>; + +template + OneToMany one_to_many_from_bidict(bidict const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/orthotope/down_projection.cc b/lib/utils/src/utils/orthotope/down_projection.cc new file mode 100644 index 0000000000..e4f6477c0e --- /dev/null +++ b/lib/utils/src/utils/orthotope/down_projection.cc @@ -0,0 +1,28 @@ +#include "utils/orthotope/down_projection.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using T1 = value_type<0>; +using T2 = value_type<1>; +using T3 = value_type<2>; + +template + DownProjection compose_down_projections(DownProjection const &, DownProjection const &); + +using L = value_type<0>; +using R = value_type<1>; + +template + DownProjection make_empty_down_projection(); + +template + void project_dims(DownProjection &, std::unordered_set const &, R const &); + +template + UpProjection invert_down_projection(DownProjection const &); + +template + DownProjection down_from_eq_proj(EqProjection const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/orthotope/eq_projection.cc b/lib/utils/src/utils/orthotope/eq_projection.cc new file mode 100644 index 0000000000..cb4563c78b --- /dev/null +++ b/lib/utils/src/utils/orthotope/eq_projection.cc @@ -0,0 +1,19 @@ +#include "utils/orthotope/eq_projection.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using L = value_type<0>; +using R = value_type<1>; + +template + EqProjection invert_eq_projection(EqProjection const &); + +using T1 = value_type<0>; +using T2 = value_type<1>; +using T3 = value_type<2>; + +template + EqProjection compose_eq_projections(EqProjection const &, EqProjection const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/orthotope/up_projection.cc b/lib/utils/src/utils/orthotope/up_projection.cc new file mode 100644 index 0000000000..e8b5fb4db4 --- /dev/null +++ b/lib/utils/src/utils/orthotope/up_projection.cc @@ -0,0 +1,28 @@ +#include "utils/orthotope/up_projection.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using T1 = value_type<0>; +using T2 = value_type<1>; +using T3 = value_type<2>; + +template + UpProjection compose_up_projections(UpProjection const &, UpProjection const &); + +using L = value_type<0>; +using R = value_type<1>; + +template + UpProjection make_empty_up_projection(); + +template + void project_dims(UpProjection &, L const &, std::unordered_set const &); + +template + DownProjection invert_up_projection(UpProjection const &); + +template + UpProjection up_from_eq_proj(EqProjection const &); + +} // namespace FlexFlow diff --git a/lib/utils/test/src/utils/one_to_many/one_to_many_from_bidict.cc b/lib/utils/test/src/utils/one_to_many/one_to_many_from_bidict.cc new file mode 100644 index 0000000000..ef083a96a9 --- /dev/null +++ b/lib/utils/test/src/utils/one_to_many/one_to_many_from_bidict.cc @@ -0,0 +1,47 @@ +#include "utils/one_to_many/one_to_many_from_bidict.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("one_to_many_from_bidict(bidict)") { + SUBCASE("input is empty") { + bidict input = {}; + + OneToMany result = one_to_many_from_bidict(input); + OneToMany correct = {}; + + CHECK(result == correct); + } + + SUBCASE("input is nonempty") { + bidict input = { + {1, "one"}, + {2, "two"}, + }; + + OneToMany result = one_to_many_from_bidict(input); + OneToMany correct = { + {1, {"one"}}, + {2, {"two"}}, + }; + + CHECK(result == correct); + } + + SUBCASE("key and value types are the same") { + bidict input = { + {1, -1}, + {2, -2}, + }; + + OneToMany result = one_to_many_from_bidict(input); + OneToMany correct = { + {1, {-1}}, + {2, {-2}}, + }; + + CHECK(result == correct); + } + } +} From b6aed076482385bc4771bf389082d27e887b7b87 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Thu, 16 Jan 2025 18:50:59 -0800 Subject: [PATCH 08/62] Additional changes. --- lib/utils/include/utils/many_to_one/invert_many_to_one.h | 4 ++-- lib/utils/include/utils/orthotope/down_projection.h | 1 + lib/utils/include/utils/orthotope/up_projection.h | 1 + lib/utils/src/utils/many_to_one/invert_many_to_one.cc | 2 +- 4 files changed, 5 insertions(+), 3 deletions(-) diff --git a/lib/utils/include/utils/many_to_one/invert_many_to_one.h b/lib/utils/include/utils/many_to_one/invert_many_to_one.h index 83423541bb..7fdf36859f 100644 --- a/lib/utils/include/utils/many_to_one/invert_many_to_one.h +++ b/lib/utils/include/utils/many_to_one/invert_many_to_one.h @@ -7,8 +7,8 @@ namespace FlexFlow { template -OneToMany invert_many_to_one(ManyToOne const &many_to_one) { - OneToMany result; +OneToMany invert_many_to_one(ManyToOne const &many_to_one) { + OneToMany result; for (L const &l : many_to_one.left_values()) { result.insert({many_to_one.at_l(l), l}); diff --git a/lib/utils/include/utils/orthotope/down_projection.h b/lib/utils/include/utils/orthotope/down_projection.h index 472ffcc4d4..9ff24381b3 100644 --- a/lib/utils/include/utils/orthotope/down_projection.h +++ b/lib/utils/include/utils/orthotope/down_projection.h @@ -6,6 +6,7 @@ #include "utils/orthotope/up_projection.dtg.h" #include "utils/many_to_one/many_to_one_from_bidict.h" #include "utils/many_to_one/exhaustive_relational_join.h" +#include "utils/many_to_one/invert_many_to_one.h" namespace FlexFlow { diff --git a/lib/utils/include/utils/orthotope/up_projection.h b/lib/utils/include/utils/orthotope/up_projection.h index db2291b3db..cd81e16b7b 100644 --- a/lib/utils/include/utils/orthotope/up_projection.h +++ b/lib/utils/include/utils/orthotope/up_projection.h @@ -6,6 +6,7 @@ #include "utils/orthotope/down_projection.dtg.h" #include "utils/one_to_many/one_to_many_from_bidict.h" #include "utils/one_to_many/exhaustive_relational_join.h" +#include "utils/one_to_many/invert_one_to_many.h" namespace FlexFlow { diff --git a/lib/utils/src/utils/many_to_one/invert_many_to_one.cc b/lib/utils/src/utils/many_to_one/invert_many_to_one.cc index d25b9d6311..a422a607d5 100644 --- a/lib/utils/src/utils/many_to_one/invert_many_to_one.cc +++ b/lib/utils/src/utils/many_to_one/invert_many_to_one.cc @@ -7,6 +7,6 @@ using L = value_type<0>; using R = value_type<1>; template - OneToMany invert_many_to_one(ManyToOne const &); + OneToMany invert_many_to_one(ManyToOne const &); } // namespace FlexFlow From 79f7098eecf7d67f9c57aa1b7bcccafafb3b2fa1 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Tue, 21 Jan 2025 22:23:56 -0800 Subject: [PATCH 09/62] Move over to new projection infra --- .proj.toml | 4 +- .../include/op-attrs/dim_ordered/get_idxs.h | 1 + lib/op-attrs/include/op-attrs/ff_dim_t.h | 2 +- ...ator_space_parallel_tensor_space_mapping.h | 6 +- .../op-attrs/operator_task_space_dim_idx_t.h | 14 + .../operator_task_space_dim_idx_t.struct.toml | 6 +- lib/op-attrs/include/op-attrs/ops/linear.h | 12 +- .../op-attrs/parallel_tensor_dim_idx_t.h | 1 + .../include/op-attrs/relative_ff_dim_t.h | 2 +- .../include/op-attrs/tensor_num_dims.h | 13 - .../op-attrs/tensor_num_dims.struct.toml | 14 - .../src/op-attrs/dim_ordered/get_idxs.cc | 2 +- lib/op-attrs/src/op-attrs/ff_dim_t.cc | 6 +- ...tor_space_parallel_tensor_space_mapping.cc | 12 +- .../op-attrs/operator_task_space_dim_idx_t.cc | 13 + lib/op-attrs/src/op-attrs/ops/layer_norm.cc | 5 +- lib/op-attrs/src/op-attrs/ops/linear.cc | 57 +- .../src/op-attrs/parallel_tensor_dim_idx_t.cc | 12 + .../src/op-attrs/relative_ff_dim_t.cc | 4 +- lib/op-attrs/src/op-attrs/tensor_num_dims.cc | 10 - ...tor_space_parallel_tensor_space_mapping.cc | 23 + .../op-attrs/operator_task_space_dim_idx_t.cc | 13 + .../op-attrs/parallel_tensor_dim_degrees.cc | 81 ++- .../test/src/op-attrs/relative_ff_dim_t.cc | 2 +- .../utils/archetypes/ordered_value_type.h | 6 + lib/utils/include/utils/bijection/bijection.h | 18 + .../utils/bijection/bijection.struct.toml | 17 + .../include/utils/bijection/to.struct.toml | 13 + .../utils/cli/cli_flag_key.struct.toml | 1 + .../include/utils/containers/filter_idxs.h | 9 +- .../utils/containers/generate_vector.h | 21 + lib/utils/include/utils/containers/product.h | 2 +- .../include/utils/containers/zip3_with.h | 23 + .../utils/containers/zip3_with_strict.h | 25 + .../include/utils/containers/zip_strict.h | 22 + .../utils/containers/zip_with_strict.h | 22 + .../utils/nonnegative_int/nonnegative_int.h | 9 + .../utils/nonnegative_int/num_elements.h | 32 + .../include/utils/nonnegative_int/range.h | 14 + lib/utils/include/utils/orthotope/dim_coord.h | 70 +++ .../utils/orthotope/dim_coord.struct.toml | 26 + .../include/utils/orthotope/dim_domain.h | 28 + .../utils/orthotope/dim_domain.struct.toml | 27 + .../include/utils/orthotope/dim_projection.h | 32 +- .../include/utils/orthotope/down_projection.h | 40 ++ .../orthotope/down_projection.struct.toml | 1 - .../include/utils/orthotope/eq_projection.h | 10 + lib/utils/include/utils/orthotope/orthotope.h | 20 +- .../utils/orthotope/orthotope.struct.toml | 8 +- .../orthotope_bijective_projection.h | 43 -- ...orthotope_bijective_projection.struct.toml | 25 - .../include/utils/orthotope/orthotope_coord.h | 12 + .../orthotope/orthotope_coord.struct.toml | 23 + .../utils/orthotope/orthotope_coordinate.h | 16 - .../orthotope_coordinate.struct.toml | 21 - .../utils/orthotope/orthotope_dim_idx_t.h | 13 - .../orthotope/orthotope_dim_idx_t.struct.toml | 12 - .../orthotope/orthotope_dim_indexed/all_of.h | 11 - .../orthotope_dim_indexed/drop_idxs_except.h | 31 - .../orthotope/orthotope_dim_indexed/json.h | 23 - .../orthotope_dim_indexed.h | 183 ------ .../orthotope_dim_indexed_from_idx_map.h | 29 - .../orthotope_dim_indexed_of.h | 16 - .../orthotope_dim_indexed/transform.h | 17 - .../orthotope_dim_indexed/zip_with.h | 22 - lib/utils/include/utils/orthotope/orthtope.h | 10 - .../include/utils/orthotope/up_projection.h | 39 ++ .../utils/orthotope/up_projection.struct.toml | 1 - lib/utils/src/utils/bijection/bijection.cc | 12 + lib/utils/src/utils/cli/cli_spec.cc | 1 + lib/utils/src/utils/containers/filter_idxs.cc | 2 +- .../src/utils/containers/generate_vector.cc | 9 + lib/utils/src/utils/containers/range.cc | 6 +- lib/utils/src/utils/containers/zip3_with.cc | 18 + .../src/utils/containers/zip3_with_strict.cc | 19 + lib/utils/src/utils/containers/zip_strict.cc | 12 + .../src/utils/containers/zip_with_strict.cc | 14 + .../instances/unordered_set_dataflow_graph.cc | 1 + .../utils/nonnegative_int/nonnegative_int.cc | 28 +- .../src/utils/nonnegative_int/num_elements.cc | 20 + lib/utils/src/utils/nonnegative_int/range.cc | 15 + lib/utils/src/utils/orthotope/dim_coord.cc | 21 + lib/utils/src/utils/orthotope/dim_domain.cc | 18 + .../src/utils/orthotope/dim_projection.cc | 15 + .../src/utils/orthotope/down_projection.cc | 31 + .../src/utils/orthotope/eq_projection.cc | 6 + lib/utils/src/utils/orthotope/orthotope.cc | 86 +-- .../orthotope_bijective_projection.cc | 283 --------- .../src/utils/orthotope/orthotope_coord.cc | 13 + .../utils/orthotope/orthotope_coordinate.cc | 21 - .../utils/orthotope/orthotope_dim_idx_t.cc | 12 - .../orthotope/orthotope_dim_indexed/all_of.cc | 10 - .../orthotope_dim_indexed/drop_idxs_except.cc | 11 - .../orthotope/orthotope_dim_indexed/json.cc | 8 - .../orthotope_dim_indexed.cc | 17 - .../orthotope_dim_indexed_from_idx_map.cc | 11 - .../orthotope_dim_indexed_of.cc | 11 - .../orthotope_dim_indexed/transform.cc | 13 - .../orthotope_dim_indexed/zip_with.cc | 14 - .../src/utils/orthotope/up_projection.cc | 6 + lib/utils/test/src/utils/containers/range.cc | 10 + .../test/src/utils/containers/zip_strict.cc | 28 + .../utils/nonnegative_int/nonnegative_int.cc | 7 + .../test/src/utils/nonnegative_int/range.cc | 70 +++ .../test/src/utils/orthotope/dim_coord.cc | 25 + .../test/src/utils/orthotope/orthotope.cc | 288 ++++----- .../orthotope_bijective_projection.cc | 562 +++++++++--------- 107 files changed, 1640 insertions(+), 1472 deletions(-) create mode 100644 lib/op-attrs/include/op-attrs/operator_task_space_dim_idx_t.h delete mode 100644 lib/op-attrs/include/op-attrs/tensor_num_dims.h delete mode 100644 lib/op-attrs/include/op-attrs/tensor_num_dims.struct.toml create mode 100644 lib/op-attrs/src/op-attrs/operator_task_space_dim_idx_t.cc delete mode 100644 lib/op-attrs/src/op-attrs/tensor_num_dims.cc create mode 100644 lib/op-attrs/test/src/op-attrs/operator_space_parallel_tensor_space_mapping.cc create mode 100644 lib/op-attrs/test/src/op-attrs/operator_task_space_dim_idx_t.cc create mode 100644 lib/utils/include/utils/bijection/bijection.h create mode 100644 lib/utils/include/utils/bijection/bijection.struct.toml create mode 100644 lib/utils/include/utils/bijection/to.struct.toml create mode 100644 lib/utils/include/utils/containers/generate_vector.h create mode 100644 lib/utils/include/utils/containers/zip3_with.h create mode 100644 lib/utils/include/utils/containers/zip3_with_strict.h create mode 100644 lib/utils/include/utils/containers/zip_strict.h create mode 100644 lib/utils/include/utils/containers/zip_with_strict.h create mode 100644 lib/utils/include/utils/nonnegative_int/num_elements.h create mode 100644 lib/utils/include/utils/nonnegative_int/range.h create mode 100644 lib/utils/include/utils/orthotope/dim_coord.h create mode 100644 lib/utils/include/utils/orthotope/dim_coord.struct.toml create mode 100644 lib/utils/include/utils/orthotope/dim_domain.h create mode 100644 lib/utils/include/utils/orthotope/dim_domain.struct.toml delete mode 100644 lib/utils/include/utils/orthotope/orthotope_bijective_projection.h delete mode 100644 lib/utils/include/utils/orthotope/orthotope_bijective_projection.struct.toml create mode 100644 lib/utils/include/utils/orthotope/orthotope_coord.h create mode 100644 lib/utils/include/utils/orthotope/orthotope_coord.struct.toml delete mode 100644 lib/utils/include/utils/orthotope/orthotope_coordinate.h delete mode 100644 lib/utils/include/utils/orthotope/orthotope_coordinate.struct.toml delete mode 100644 lib/utils/include/utils/orthotope/orthotope_dim_idx_t.h delete mode 100644 lib/utils/include/utils/orthotope/orthotope_dim_idx_t.struct.toml delete mode 100644 lib/utils/include/utils/orthotope/orthotope_dim_indexed/all_of.h delete mode 100644 lib/utils/include/utils/orthotope/orthotope_dim_indexed/drop_idxs_except.h delete mode 100644 lib/utils/include/utils/orthotope/orthotope_dim_indexed/json.h delete mode 100644 lib/utils/include/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.h delete mode 100644 lib/utils/include/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_from_idx_map.h delete mode 100644 lib/utils/include/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_of.h delete mode 100644 lib/utils/include/utils/orthotope/orthotope_dim_indexed/transform.h delete mode 100644 lib/utils/include/utils/orthotope/orthotope_dim_indexed/zip_with.h delete mode 100644 lib/utils/include/utils/orthotope/orthtope.h create mode 100644 lib/utils/src/utils/bijection/bijection.cc create mode 100644 lib/utils/src/utils/containers/generate_vector.cc create mode 100644 lib/utils/src/utils/containers/zip3_with.cc create mode 100644 lib/utils/src/utils/containers/zip3_with_strict.cc create mode 100644 lib/utils/src/utils/containers/zip_strict.cc create mode 100644 lib/utils/src/utils/containers/zip_with_strict.cc create mode 100644 lib/utils/src/utils/nonnegative_int/num_elements.cc create mode 100644 lib/utils/src/utils/nonnegative_int/range.cc create mode 100644 lib/utils/src/utils/orthotope/dim_coord.cc create mode 100644 lib/utils/src/utils/orthotope/dim_domain.cc create mode 100644 lib/utils/src/utils/orthotope/dim_projection.cc delete mode 100644 lib/utils/src/utils/orthotope/orthotope_bijective_projection.cc create mode 100644 lib/utils/src/utils/orthotope/orthotope_coord.cc delete mode 100644 lib/utils/src/utils/orthotope/orthotope_coordinate.cc delete mode 100644 lib/utils/src/utils/orthotope/orthotope_dim_idx_t.cc delete mode 100644 lib/utils/src/utils/orthotope/orthotope_dim_indexed/all_of.cc delete mode 100644 lib/utils/src/utils/orthotope/orthotope_dim_indexed/drop_idxs_except.cc delete mode 100644 lib/utils/src/utils/orthotope/orthotope_dim_indexed/json.cc delete mode 100644 lib/utils/src/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.cc delete mode 100644 lib/utils/src/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_from_idx_map.cc delete mode 100644 lib/utils/src/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_of.cc delete mode 100644 lib/utils/src/utils/orthotope/orthotope_dim_indexed/transform.cc delete mode 100644 lib/utils/src/utils/orthotope/orthotope_dim_indexed/zip_with.cc create mode 100644 lib/utils/test/src/utils/containers/zip_strict.cc create mode 100644 lib/utils/test/src/utils/nonnegative_int/range.cc create mode 100644 lib/utils/test/src/utils/orthotope/dim_coord.cc diff --git a/.proj.toml b/.proj.toml index 5e503677ab..04fcc0a573 100644 --- a/.proj.toml +++ b/.proj.toml @@ -5,7 +5,7 @@ header_extension = ".h" build_targets = [ "utils", - # "op-attrs", + "op-attrs", # "kernels", # "pcg", # "substitutions", @@ -20,7 +20,7 @@ build_targets = [ test_targets = [ # "kernels-tests", "utils-tests", - # "op-attrs-tests", + "op-attrs-tests", # "pcg-tests", # "substitutions-tests", # "compiler-tests", diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/get_idxs.h b/lib/op-attrs/include/op-attrs/dim_ordered/get_idxs.h index bdd8c1a8f0..60792d572e 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/get_idxs.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/get_idxs.h @@ -5,6 +5,7 @@ #include "op-attrs/ff_dim_t.h" #include "utils/containers/range.h" #include "utils/containers/transform.h" +#include "utils/containers/set_of.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ff_dim_t.h b/lib/op-attrs/include/op-attrs/ff_dim_t.h index 2727482b52..f4b85a929b 100644 --- a/lib/op-attrs/include/op-attrs/ff_dim_t.h +++ b/lib/op-attrs/include/op-attrs/ff_dim_t.h @@ -9,7 +9,7 @@ namespace FlexFlow { relative_ff_dim_t relative_ff_dim_t_from_ff_dim_t(ff_dim_t ff_dim); -std::set ff_dim_range(int end); +std::set ff_dim_range(nonnegative_int end); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/operator_space_parallel_tensor_space_mapping.h b/lib/op-attrs/include/op-attrs/operator_space_parallel_tensor_space_mapping.h index b659e2ffe5..908f25aaa6 100644 --- a/lib/op-attrs/include/op-attrs/operator_space_parallel_tensor_space_mapping.h +++ b/lib/op-attrs/include/op-attrs/operator_space_parallel_tensor_space_mapping.h @@ -2,12 +2,14 @@ #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_SPACE_PARALLEL_TENSOR_SPACE_MAPPING_H #include "op-attrs/operator_space_parallel_tensor_space_mapping.dtg.h" -#include "op-attrs/tensor_num_dims.dtg.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" namespace FlexFlow { OperatorSpaceParallelTensorSpaceMapping - get_identity_mapping(TensorNumDims const &); + get_identity_mapping(nonnegative_int num_dims); + +compute_ } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/operator_task_space_dim_idx_t.h b/lib/op-attrs/include/op-attrs/operator_task_space_dim_idx_t.h new file mode 100644 index 0000000000..af77fcfaa1 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/operator_task_space_dim_idx_t.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_TASK_SPACE_DIM_IDX_T_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_TASK_SPACE_DIM_IDX_T_H + +#include "op-attrs/operator_task_space_dim_idx_t.dtg.h" +#include "utils/nonnegative_int/nonnegative_int.h" +#include + +namespace FlexFlow { + +std::set operator_task_space_dim_idx_range(nonnegative_int end); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/operator_task_space_dim_idx_t.struct.toml b/lib/op-attrs/include/op-attrs/operator_task_space_dim_idx_t.struct.toml index 124b46013a..95e4b72977 100644 --- a/lib/op-attrs/include/op-attrs/operator_task_space_dim_idx_t.struct.toml +++ b/lib/op-attrs/include/op-attrs/operator_task_space_dim_idx_t.struct.toml @@ -8,6 +8,10 @@ features = [ "fmt", ] +includes = [ + "utils/nonnegative_int/nonnegative_int.h", +] + [[fields]] name = "raw_idx" -type = "int" +type = "::FlexFlow::nonnegative_int" diff --git a/lib/op-attrs/include/op-attrs/ops/linear.h b/lib/op-attrs/include/op-attrs/ops/linear.h index 2cde2bd22a..2d9dc1cee7 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear.h +++ b/lib/op-attrs/include/op-attrs/ops/linear.h @@ -6,7 +6,6 @@ #include "op-attrs/ops/linear_attrs.dtg.h" #include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" -#include "op-attrs/tensor_num_dims.dtg.h" #include "op-attrs/tensor_shape.dtg.h" #include "utils/record_formatter.h" #include @@ -49,15 +48,12 @@ tl::expected get_output_shape(LinearAttrs const &attrs, ParallelTensorShape const &input); +tl::expected + get_input_to_output_mapping(LinearAttrs const &attrs, nonnegative_int input_num_dims); tl::expected - get_projection_space_mapping(LinearAttrs const &attrs, - ParallelTensorDimDegrees const &input); -tl::expected - get_bias_space_mapping(LinearAttrs const &attrs, - ParallelTensorDimDegrees const &input); + get_operator_to_input_mapping(LinearAttrs const &attrs, nonnegative_int input_num_dims); tl::expected - get_output_space_mapping(LinearAttrs const &attrs, - TensorNumDims const &input_num_dims); + get_operator_to_output_mapping(LinearAttrs const &attrs, nonnegative_int input_num_dims); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.h b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.h index 61dfdd2b17..b39362b0c2 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.h @@ -8,6 +8,7 @@ namespace FlexFlow { parallel_tensor_dim_idx_t sum_dim_idx(); parallel_tensor_dim_idx_t discard_copy_dim_idx(); parallel_tensor_dim_idx_t shard_dim_idx(ff_dim_t); +std::set dim_idxs_for_num_shard_dims(nonnegative_int num_shard_dims); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/relative_ff_dim_t.h b/lib/op-attrs/include/op-attrs/relative_ff_dim_t.h index af51cc69be..5205b1ead8 100644 --- a/lib/op-attrs/include/op-attrs/relative_ff_dim_t.h +++ b/lib/op-attrs/include/op-attrs/relative_ff_dim_t.h @@ -7,7 +7,7 @@ namespace FlexFlow { ff_dim_t ff_dim_t_from_relative_ff_dim_t(relative_ff_dim_t ff_dim, - int input_dim); + nonnegative_int input_dim); } // namespace FlexFlow namespace rc { diff --git a/lib/op-attrs/include/op-attrs/tensor_num_dims.h b/lib/op-attrs/include/op-attrs/tensor_num_dims.h deleted file mode 100644 index 372c1ccd5c..0000000000 --- a/lib/op-attrs/include/op-attrs/tensor_num_dims.h +++ /dev/null @@ -1,13 +0,0 @@ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_NUM_DIMS_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_NUM_DIMS_H - -#include "op-attrs/ff_dim.dtg.h" -#include "op-attrs/tensor_num_dims.dtg.h" - -namespace FlexFlow { - -std::set ff_dim_idxs_for_num_dims(TensorNumDims const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/op-attrs/include/op-attrs/tensor_num_dims.struct.toml b/lib/op-attrs/include/op-attrs/tensor_num_dims.struct.toml deleted file mode 100644 index f421425290..0000000000 --- a/lib/op-attrs/include/op-attrs/tensor_num_dims.struct.toml +++ /dev/null @@ -1,14 +0,0 @@ -namespace = "FlexFlow" -name = "TensorNumDims" -features = [ - "eq", - "hash", - "ord", - "fmt", - "rapidcheck", - "json", -] - -[[fields]] -name = "value" -type = "int" diff --git a/lib/op-attrs/src/op-attrs/dim_ordered/get_idxs.cc b/lib/op-attrs/src/op-attrs/dim_ordered/get_idxs.cc index baeb130324..6bf5f97895 100644 --- a/lib/op-attrs/src/op-attrs/dim_ordered/get_idxs.cc +++ b/lib/op-attrs/src/op-attrs/dim_ordered/get_idxs.cc @@ -6,6 +6,6 @@ namespace FlexFlow { using T = value_type<0>; template - std::vector get_idxs(FFOrdered const &); + std::set get_idxs(FFOrdered const &); } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ff_dim_t.cc b/lib/op-attrs/src/op-attrs/ff_dim_t.cc index c79de97e8d..2c8b8c455d 100644 --- a/lib/op-attrs/src/op-attrs/ff_dim_t.cc +++ b/lib/op-attrs/src/op-attrs/ff_dim_t.cc @@ -1,5 +1,5 @@ #include "op-attrs/ff_dim_t.h" -#include "utils/containers/range.h" +#include "utils/nonnegative_int/range.h" #include "utils/containers/set_of.h" #include "utils/containers/transform.h" @@ -8,8 +8,8 @@ relative_ff_dim_t relative_ff_dim_t_from_ff_dim_t(ff_dim_t ff_dim) { return relative_ff_dim_t{ff_dim.value.get_value()}; } -std::set ff_dim_range(int end) { - return set_of(transform(range(end), [](int i) { return ff_dim_t{i}; })); +std::set ff_dim_range(nonnegative_int end) { + return set_of(transform(range(end), [](nonnegative_int i) { return ff_dim_t{i}; })); } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/operator_space_parallel_tensor_space_mapping.cc b/lib/op-attrs/src/op-attrs/operator_space_parallel_tensor_space_mapping.cc index f3fcf84f6e..6db994afb1 100644 --- a/lib/op-attrs/src/op-attrs/operator_space_parallel_tensor_space_mapping.cc +++ b/lib/op-attrs/src/op-attrs/operator_space_parallel_tensor_space_mapping.cc @@ -1,21 +1,23 @@ #include "op-attrs/operator_space_parallel_tensor_space_mapping.h" #include "op-attrs/parallel_tensor_dim_degrees.h" -#include "utils/containers/range.h" +#include "op-attrs/parallel_tensor_dim_idx_t.h" +#include "utils/nonnegative_int/range.h" #include "utils/containers/set_of.h" #include "utils/containers/transform.h" #include "utils/bidict/algorithms/bidict_from_keys_and_values.h" +#include "utils/nonnegative_int/num_elements.h" namespace FlexFlow { OperatorSpaceParallelTensorSpaceMapping - get_identity_mapping(TensorNumDims const &tensor_num_dims) { + get_identity_mapping(nonnegative_int num_shard_dims) { std::set parallel_tensor_dim_indices - = get_nontrivial_parallel_tensor_dim_indices(degrees); + = dim_idxs_for_num_shard_dims(num_shard_dims); std::set operator_space_dim_indices - = transform(set_of(range(parallel_tensor_dim_indices.size())), - [](int raw_idx) { return operator_task_space_dim_idx_t{raw_idx}; }); + = transform(set_of(range(num_elements(parallel_tensor_dim_indices))), + [](nonnegative_int raw_idx) { return operator_task_space_dim_idx_t{raw_idx}; }); bidict raw_bidict = bidict_from_keys_and_values(vector_of(operator_space_dim_indices), vector_of(parallel_tensor_dim_indices)); diff --git a/lib/op-attrs/src/op-attrs/operator_task_space_dim_idx_t.cc b/lib/op-attrs/src/op-attrs/operator_task_space_dim_idx_t.cc new file mode 100644 index 0000000000..ed94ae4f90 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/operator_task_space_dim_idx_t.cc @@ -0,0 +1,13 @@ +#include "op-attrs/operator_task_space_dim_idx_t.h" +#include "utils/nonnegative_int/range.h" +#include "utils/containers/set_of.h" +#include "utils/containers/transform.h" + +namespace FlexFlow { + +std::set operator_task_space_dim_idx_range(nonnegative_int end) { + return transform(set_of(range(end)), + [](nonnegative_int raw_idx) { return operator_task_space_dim_idx_t{raw_idx}; }); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/layer_norm.cc b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc index 86426dd18f..687ccdb208 100644 --- a/lib/op-attrs/src/op-attrs/ops/layer_norm.cc +++ b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc @@ -8,6 +8,7 @@ #include "utils/containers/contains.h" #include "utils/containers/extend.h" #include "utils/containers/filter.h" +#include "utils/containers/vector_of.h" namespace FlexFlow { @@ -69,7 +70,7 @@ tl::expected } std::vector non_layer_norm_dim_idxs = filter( - get_idxs(input_shape.dims.ff_ordered), + vector_of(get_idxs(input_shape.dims.ff_ordered)), [&](ff_dim_t const &dim_idx) { return !contains(attrs.axes, dim_idx); }); std::vector raw_weight_dims = transform(non_layer_norm_dim_idxs, [&](ff_dim_t const &dim_idx) { @@ -162,7 +163,7 @@ tl::expected } std::vector non_layer_norm_dim_idxs = filter( - get_idxs(input_shape.dims.shard_dims), + vector_of(get_idxs(input_shape.dims.shard_dims)), [&](ff_dim_t const &dim_idx) { return !contains(attrs.axes, dim_idx); }); std::vector raw_weight_shard_dims = transform(non_layer_norm_dim_idxs, [&](ff_dim_t const &dim_idx) { diff --git a/lib/op-attrs/src/op-attrs/ops/linear.cc b/lib/op-attrs/src/op-attrs/ops/linear.cc index a846b69acc..ac575884a7 100644 --- a/lib/op-attrs/src/op-attrs/ops/linear.cc +++ b/lib/op-attrs/src/op-attrs/ops/linear.cc @@ -5,7 +5,7 @@ #include "op-attrs/operator_space_parallel_tensor_space_mapping.h" #include "op-attrs/parallel_tensor_dim_idx_t.h" #include "op-attrs/parallel_tensor_shape.h" -#include "op-attrs/tensor_num_dims.dtg.h" +#include "op-attrs/relative_ff_dim_t.h" #include "op-attrs/tensor_shape.h" #include "utils/containers/product.h" #include "utils/integer_conversions.h" @@ -100,7 +100,7 @@ tl::expected }; return lift_to_parallel_with_degrees( - unpar, sum_degree, discard_copy_degree, shard_degrees); + unpar, ParallelTensorDimDegrees{sum_degree, discard_copy_degree, shard_degrees}); } tl::expected @@ -122,7 +122,7 @@ tl::expected FFOrdered shard_degrees = FFOrdered{get_discard_copy_degree(input)}; return lift_to_parallel_with_degrees( - unpar, sum_degree, discard_copy_degree, shard_degrees); + unpar, ParallelTensorDimDegrees{sum_degree, discard_copy_degree, shard_degrees}); } tl::expected @@ -145,7 +145,7 @@ tl::expected shard_degrees.at(relative_ff_dim_t{-1}) = get_discard_copy_degree(input); return lift_to_parallel_with_degrees( - unpar, sum_degree, discard_copy_degree, shard_degrees); + unpar, ParallelTensorDimDegrees{sum_degree, discard_copy_degree, shard_degrees}); } // tl::expected @@ -157,18 +157,26 @@ tl::expected // } tl::expected - get_input_to_output_projection(LinearAttrs const &attrs, TensorNumDims const &input_num_dims) { + get_input_to_output_projection(LinearAttrs const &attrs, nonnegative_int input_num_dims) { - auto inp_to_out = make_empty_down_projection(); + DownProjection< + parallel_tensor_dim_idx_t, + parallel_tensor_dim_idx_t + > inp_to_out = make_empty_down_projection(); + + ff_dim_t input_channel_dim = ff_dim_t_from_relative_ff_dim_t(relative_ff_dim_t{-1}, input_num_dims); + + nonnegative_int output_num_dims = input_num_dims; + ff_dim_t output_channel_dim = ff_dim_t_from_relative_ff_dim_t(relative_ff_dim_t{-1}, output_num_dims); project_dims(inp_to_out, - /*from=*/{sum_dim_idx(), shard_dim_idx(ff_dim_t{-1})}, + /*from=*/{sum_dim_idx(), shard_dim_idx(input_channel_dim)}, /*onto=*/sum_dim_idx()); project_dims(inp_to_out, /*from=*/{discard_copy_dim_idx()}, - /*onto=*/shard_dim_idx(ff_dim_t{-1})); + /*onto=*/shard_dim_idx(output_channel_dim)); - for (ff_dim_t const &idx : ff_dim_range(input_num_dims.value - 1)) { + for (ff_dim_t const &idx : ff_dim_range(nonnegative_int{input_num_dims.get_value() - 1})) { project_dims(inp_to_out, /*from=*/{shard_dim_idx(idx)}, /*onto=*/shard_dim_idx(idx)); @@ -179,15 +187,15 @@ tl::expected tl::expected get_operator_to_input_projection(LinearAttrs const &attrs, - TensorNumDims const &input_num_dims) { + nonnegative_int input_num_dims) { - TensorNumDims output_num_dims = input_num_dims; + nonnegative_int output_num_dims = input_num_dims; UpProjection out_to_inp = invert_down_projection(throw_if_unexpected(get_input_to_output_projection(attrs, input_num_dims)).raw_projection.require_down_proj()); EqProjection - op_to_out = throw_if_unexpected(get_output_space_mapping(attrs, output_num_dims)).raw_projection.require_eq_proj(); + op_to_out = throw_if_unexpected(get_operator_to_output_mapping(attrs, input_num_dims)).raw_projection.require_eq_proj(); return OperatorSpaceParallelTensorSpaceMapping{ DimProjection{ @@ -197,29 +205,12 @@ tl::expected } tl::expected - get_operator_to_output_projection(LinearAttrs const &attrs, - TensorNumDims const &output_num_dims) { - - return OperatorSpaceParallelTensorSpaceMapping{ - get + get_operator_to_output_mapping(LinearAttrs const &attrs, + nonnegative_int input_num_shard_dims) { + nonnegative_int output_num_shard_dims = input_num_shard_dims; - SumDegree sum_degree = SumDegree{1}; - DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{ - get_sum_degree(input) * - product( - slice(ff_ordered_shard_degrees(input), std::nullopt, ff_dim_t{-1}))}; - FFOrdered shard_degrees = FFOrdered{ - shard_dim_at_idx(input, ff_dim_t{-1}).degree, - get_discard_copy_degree(input), - }; - - return + return get_identity_mapping(output_num_shard_dims); } -tl::expected - get_output_space_mapping(LinearAttrs const &attrs, - ParallelTensorDimDegrees const &input) { - return get_identity_mapping(input); -} } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dim_idx_t.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dim_idx_t.cc index ef0fd4616c..3ac9ce2015 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dim_idx_t.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dim_idx_t.cc @@ -1,4 +1,7 @@ #include "op-attrs/parallel_tensor_dim_idx_t.h" +#include "op-attrs/ff_dim_t.h" +#include "utils/containers/set_of.h" +#include "utils/containers/transform.h" namespace FlexFlow { @@ -14,4 +17,13 @@ parallel_tensor_dim_idx_t shard_dim_idx(ff_dim_t idx) { return parallel_tensor_dim_idx_t{idx}; } +std::set dim_idxs_for_num_shard_dims(nonnegative_int num_shard_dims) { + std::set result = + transform(set_of(ff_dim_range(num_shard_dims)), shard_dim_idx); + result.insert(sum_dim_idx()); + result.insert(discard_copy_dim_idx()); + + return result; +} + } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/relative_ff_dim_t.cc b/lib/op-attrs/src/op-attrs/relative_ff_dim_t.cc index 0671bb05f2..3ef2b436d3 100644 --- a/lib/op-attrs/src/op-attrs/relative_ff_dim_t.cc +++ b/lib/op-attrs/src/op-attrs/relative_ff_dim_t.cc @@ -3,10 +3,10 @@ namespace FlexFlow { ff_dim_t ff_dim_t_from_relative_ff_dim_t(relative_ff_dim_t ff_dim, - int input_dim) { + nonnegative_int input_dim) { int raw = ff_dim.value; if (raw < 0) { - raw = input_dim + raw; + raw = input_dim.get_value() + raw; } return ff_dim_t{nonnegative_int{raw}}; } diff --git a/lib/op-attrs/src/op-attrs/tensor_num_dims.cc b/lib/op-attrs/src/op-attrs/tensor_num_dims.cc deleted file mode 100644 index 94689c06de..0000000000 --- a/lib/op-attrs/src/op-attrs/tensor_num_dims.cc +++ /dev/null @@ -1,10 +0,0 @@ -#include "op-attrs/tensor_num_dims.h" -#include "op-attrs/ff_dim_t.h" - -namespace FlexFlow { - -std::set ff_dim_idxs_for_num_dims(TensorNumDims const &num_dims) { - return ff_dim_range(num_dims.value); -} - -} // namespace FlexFlow diff --git a/lib/op-attrs/test/src/op-attrs/operator_space_parallel_tensor_space_mapping.cc b/lib/op-attrs/test/src/op-attrs/operator_space_parallel_tensor_space_mapping.cc new file mode 100644 index 0000000000..bab662d9f0 --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/operator_space_parallel_tensor_space_mapping.cc @@ -0,0 +1,23 @@ +#include "op-attrs/operator_space_parallel_tensor_space_mapping.h" +#include "op-attrs/parallel_tensor_dim_idx_t.h" +#include + +using namespace ::FlexFlow; + +static parallel_tensor_dim_idx_t shard_dim_idx_from_raw(int idx) { + return parallel_tensor_dim_idx_t{ff_dim_t{nonnegative_int{idx}}}; +} + +static operator_task_space_dim_idx_t op_task_space_dim_from_raw(int idx) { + return operator_task_space_dim_idx_t{nonnegative_int{idx}}; +} + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_identity_mapping(ParallelTensorDimDegrees)") { + nonnegative_int num_shard_dims = nonnegative_int{2}; + + OperatorSpaceParallelTensorSpaceMapping result = get_identity_mapping(num_shard_dims); + + CHECK(result == correct); + } +} diff --git a/lib/op-attrs/test/src/op-attrs/operator_task_space_dim_idx_t.cc b/lib/op-attrs/test/src/op-attrs/operator_task_space_dim_idx_t.cc new file mode 100644 index 0000000000..5e19241f5e --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/operator_task_space_dim_idx_t.cc @@ -0,0 +1,13 @@ +#include "op-attrs/operator_task_space_dim_idx_t.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("operator_task_space_dim_idx_range(nonnegative_int)") { + SUBCASE("end is zero") { + std::set result = operator_task_space_dim_idx_range(nonnegative_int{0}); + std::set correct = {operator_task_space_dim_idx_t{nonnegative_int{0}}}; + } + } +} diff --git a/lib/op-attrs/test/src/op-attrs/parallel_tensor_dim_degrees.cc b/lib/op-attrs/test/src/op-attrs/parallel_tensor_dim_degrees.cc index f45de6883a..314a7f2ae5 100644 --- a/lib/op-attrs/test/src/op-attrs/parallel_tensor_dim_degrees.cc +++ b/lib/op-attrs/test/src/op-attrs/parallel_tensor_dim_degrees.cc @@ -1,10 +1,16 @@ #include "op-attrs/parallel_tensor_dim_degrees.h" #include +#include "op-attrs/parallel_tensor_dim_idx_t.h" #include "test/utils/doctest/fmt/unordered_map.h" #include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/set.h" using namespace ::FlexFlow; +static parallel_tensor_dim_idx_t shard_dim_idx_from_raw(int idx) { + return parallel_tensor_dim_idx_t{ff_dim_t{nonnegative_int{idx}}}; +} + TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_parallel_tensor_degree_map") { ParallelTensorDimDegrees degrees = ParallelTensorDimDegrees{ @@ -21,9 +27,9 @@ TEST_SUITE(FF_TEST_SUITE) { std::unordered_map correct = { {parallel_tensor_dim_idx_t{ReplicaType::SUM}, 3}, {parallel_tensor_dim_idx_t{ReplicaType::DISCARD_COPY}, 1}, - {parallel_tensor_dim_idx_t{ff_dim_t{0}}, 1}, - {parallel_tensor_dim_idx_t{ff_dim_t{1}}, 2}, - {parallel_tensor_dim_idx_t{ff_dim_t{2}}, 1}, + {shard_dim_idx_from_raw(0), 1}, + {shard_dim_idx_from_raw(1), 2}, + {shard_dim_idx_from_raw(2), 1}, }; CHECK(result == correct); @@ -76,4 +82,73 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } + + TEST_CASE("get_nontrivial_parallel_tensor_dim_indices(ParallelTensorDimDegrees)") { + SUBCASE("a replica dim has degree 1") { + ParallelTensorDimDegrees degrees = ParallelTensorDimDegrees{ + SumDegree{3}, + DiscardCopyDegree{1}, + FFOrdered{4, 2, 4}, + }; + + std::set result = get_nontrivial_parallel_tensor_dim_indices(degrees); + std::set correct = { + parallel_tensor_dim_idx_t{ReplicaType::SUM}, + shard_dim_idx_from_raw(0), + shard_dim_idx_from_raw(1), + shard_dim_idx_from_raw(2), + }; + + CHECK(result == correct); + } + + SUBCASE("a shard dim has degree 1") { + ParallelTensorDimDegrees degrees = ParallelTensorDimDegrees{ + SumDegree{3}, + DiscardCopyDegree{2}, + FFOrdered{1, 4, 1}, + }; + + std::set result = get_nontrivial_parallel_tensor_dim_indices(degrees); + std::set correct = { + parallel_tensor_dim_idx_t{ReplicaType::SUM}, + parallel_tensor_dim_idx_t{ReplicaType::DISCARD_COPY}, + shard_dim_idx_from_raw(1), + }; + + CHECK(result == correct); + } + + SUBCASE("no dims have degree 1") { + ParallelTensorDimDegrees degrees = ParallelTensorDimDegrees{ + SumDegree{3}, + DiscardCopyDegree{2}, + FFOrdered{4, 2, 5}, + }; + + std::set result = get_nontrivial_parallel_tensor_dim_indices(degrees); + std::set correct = { + parallel_tensor_dim_idx_t{ReplicaType::SUM}, + parallel_tensor_dim_idx_t{ReplicaType::DISCARD_COPY}, + shard_dim_idx_from_raw(0), + shard_dim_idx_from_raw(1), + shard_dim_idx_from_raw(2), + }; + + CHECK(result == correct); + } + + SUBCASE("all dims have degree 1") { + ParallelTensorDimDegrees degrees = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{1, 1, 1}, + }; + + std::set result = get_nontrivial_parallel_tensor_dim_indices(degrees); + std::set correct = {}; + + CHECK(result == correct); + } + } } diff --git a/lib/op-attrs/test/src/op-attrs/relative_ff_dim_t.cc b/lib/op-attrs/test/src/op-attrs/relative_ff_dim_t.cc index c09c1ec3df..1c317b6ad7 100644 --- a/lib/op-attrs/test/src/op-attrs/relative_ff_dim_t.cc +++ b/lib/op-attrs/test/src/op-attrs/relative_ff_dim_t.cc @@ -5,7 +5,7 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("ff_dim_t_from_relative_ff_dim_t") { - int input_dim = 5; + nonnegative_int input_dim = nonnegative_int{5}; SUBCASE("relative index is zero") { relative_ff_dim_t relative_ff_dim = relative_ff_dim_t{0}; diff --git a/lib/utils/include/utils/archetypes/ordered_value_type.h b/lib/utils/include/utils/archetypes/ordered_value_type.h index 154a47c565..e721a90e34 100644 --- a/lib/utils/include/utils/archetypes/ordered_value_type.h +++ b/lib/utils/include/utils/archetypes/ordered_value_type.h @@ -3,6 +3,7 @@ #include #include +#include namespace FlexFlow { @@ -48,6 +49,11 @@ struct ordered_value_type { } }; +template +std::string format_as(ordered_value_type const &) { + assert (false); +} + } // namespace FlexFlow namespace std { diff --git a/lib/utils/include/utils/bijection/bijection.h b/lib/utils/include/utils/bijection/bijection.h new file mode 100644 index 0000000000..5d302dd4c5 --- /dev/null +++ b/lib/utils/include/utils/bijection/bijection.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIJECTION_BIJECTION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIJECTION_BIJECTION_H + +#include "utils/bijection/bijection.dtg.h" + +namespace FlexFlow { + +template +Bijection flip_bijection(Bijection const &b) { + return Bijection{ + /*l_to_r=*/b.r_to_l, + /*r_to_l=*/b.l_to_r, + }; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bijection/bijection.struct.toml b/lib/utils/include/utils/bijection/bijection.struct.toml new file mode 100644 index 0000000000..fb7702318c --- /dev/null +++ b/lib/utils/include/utils/bijection/bijection.struct.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "Bijection" +features = [] + +template_params = [ "L", "R" ] + +includes = [ + "" +] + +[[fields]] +name = "l_to_r" +type = "std::function" + +[[fields]] +name = "r_to_l" +type = "std::function" diff --git a/lib/utils/include/utils/bijection/to.struct.toml b/lib/utils/include/utils/bijection/to.struct.toml new file mode 100644 index 0000000000..76d47ae488 --- /dev/null +++ b/lib/utils/include/utils/bijection/to.struct.toml @@ -0,0 +1,13 @@ +namespace = "FlexFlow" +name = "To" +features = [] + +template_params = [ "L", "R" ] + +includes = [ + "" +] + +[[fields]] +name = "func" +type = "std::function" diff --git a/lib/utils/include/utils/cli/cli_flag_key.struct.toml b/lib/utils/include/utils/cli/cli_flag_key.struct.toml index 790a752911..ee04f1aeb9 100644 --- a/lib/utils/include/utils/cli/cli_flag_key.struct.toml +++ b/lib/utils/include/utils/cli/cli_flag_key.struct.toml @@ -2,6 +2,7 @@ namespace = "FlexFlow" name = "CLIFlagKey" features = [ "eq", + "ord", "hash", "fmt", ] diff --git a/lib/utils/include/utils/containers/filter_idxs.h b/lib/utils/include/utils/containers/filter_idxs.h index c71ca5e2c5..e9d9777e74 100644 --- a/lib/utils/include/utils/containers/filter_idxs.h +++ b/lib/utils/include/utils/containers/filter_idxs.h @@ -1,18 +1,21 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FILTER_IDXS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FILTER_IDXS_H +#include "utils/nonnegative_int/nonnegative_int.h" +#include "utils/nonnegative_int/range.h" +#include "utils/nonnegative_int/num_elements.h" #include #include namespace FlexFlow { template -std::vector filter_idxs(std::vector const &input, std::function const &f) { +std::vector filter_idxs(std::vector const &input, std::function const &f) { std::vector result; - for (int idx = 0; idx < input.size(); idx++) { + for (nonnegative_int idx : range(num_elements(input))) { if (f(idx)) { - result.push_back(input.at(idx)); + result.push_back(input.at(idx.get_value())); } } diff --git a/lib/utils/include/utils/containers/generate_vector.h b/lib/utils/include/utils/containers/generate_vector.h new file mode 100644 index 0000000000..40025ca65e --- /dev/null +++ b/lib/utils/include/utils/containers/generate_vector.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GENERATE_VECTOR_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GENERATE_VECTOR_H + +#include "utils/nonnegative_int/nonnegative_int.h" +#include "utils/nonnegative_int/range.h" +#include + +namespace FlexFlow { + +template > +std::vector generate_vector(nonnegative_int length, F &&f) { + std::vector result; + for (nonnegative_int idx : range(length)) { + result.push_back(f(idx)); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/product.h b/lib/utils/include/utils/containers/product.h index af04edcb81..30aac2681a 100644 --- a/lib/utils/include/utils/containers/product.h +++ b/lib/utils/include/utils/containers/product.h @@ -10,7 +10,7 @@ namespace FlexFlow { **/ template Element product(Container const &container) { - Element result = 1; + Element result = Element{1}; for (Element const &element : container) { result *= element; } diff --git a/lib/utils/include/utils/containers/zip3_with.h b/lib/utils/include/utils/containers/zip3_with.h new file mode 100644 index 0000000000..70ed2a73ba --- /dev/null +++ b/lib/utils/include/utils/containers/zip3_with.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP3_WITH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP3_WITH_H + +#include + +namespace FlexFlow { + +template > +std::vector zip3_with(std::vector const &v_a, + std::vector const &v_b, + std::vector const &v_c, + F &&f) { + std::vector result; + for (int i = 0; i < std::min(v_a.size(), v_b.size(), v_c.size()); i++) { + result.push_back(v_a.at(i), v_b.at(i), v_c.at(i)); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/zip3_with_strict.h b/lib/utils/include/utils/containers/zip3_with_strict.h new file mode 100644 index 0000000000..ae7239f5d8 --- /dev/null +++ b/lib/utils/include/utils/containers/zip3_with_strict.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP3_WITH_STRICT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP3_WITH_STRICT_H + +#include +#include "utils/exception.h" +#include "utils/fmt/vector.h" +#include "utils/containers/zip3_with.h" + +namespace FlexFlow { + +template > +std::vector zip3_with_strict(std::vector const &v_a, + std::vector const &v_b, + std::vector const &v_c, + F &&f) { + if (!(v_a.size() == v_b.size() && v_b.size() == v_c.size())) { + throw mk_runtime_error(fmt::format("zip3_with_strict requires inputs to have the same length, but received v_a = {} (length {}), v_b = {} (length {}), and v_c = {} (length {})", v_a, v_a.size(), v_b, v_b.size(), v_c, v_c.size())); + } + + return zip3_with(v_a, v_b, v_c, f); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/zip_strict.h b/lib/utils/include/utils/containers/zip_strict.h new file mode 100644 index 0000000000..42f32e64d2 --- /dev/null +++ b/lib/utils/include/utils/containers/zip_strict.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_STRICT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_STRICT_H + +#include "utils/exception.h" +#include "utils/fmt/vector.h" +#include "utils/containers/zip.h" + +namespace FlexFlow { + +template +std::vector> zip_strict(std::vector const &lhs, + std::vector const &rhs) { + if (lhs.size() != rhs.size()) { + throw mk_runtime_error(fmt::format("zip_strict requires lhs and rhs to have the same length, but received lhs={} (length {}), rhs={} (length {})", lhs, lhs.size(), rhs, rhs.size())); + } + + return zip(lhs, rhs); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/zip_with_strict.h b/lib/utils/include/utils/containers/zip_with_strict.h new file mode 100644 index 0000000000..357d0e94c6 --- /dev/null +++ b/lib/utils/include/utils/containers/zip_with_strict.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_WITH_STRICT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_WITH_STRICT_H + +#include "utils/exception.h" +#include +#include "utils/containers/zip_with.h" +#include "utils/fmt/vector.h" + +namespace FlexFlow { + +template > +std::vector zip_with_strict(std::vector const &lhs, std::vector const &rhs, F &&f) { + if (lhs.size() != rhs.size()) { + throw mk_runtime_error(fmt::format("zip_with_strict requires inputs to have the same length, but received lhs = {} (length {}) and rhs = {} (length {})", lhs, lhs.size(), rhs, rhs.size())); + } + + return zip_with(lhs, rhs, f); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/nonnegative_int/nonnegative_int.h b/lib/utils/include/utils/nonnegative_int/nonnegative_int.h index 01bee29f63..4abe8f6090 100644 --- a/lib/utils/include/utils/nonnegative_int/nonnegative_int.h +++ b/lib/utils/include/utils/nonnegative_int/nonnegative_int.h @@ -17,6 +17,9 @@ class nonnegative_int { explicit operator int() const noexcept; + nonnegative_int operator*(nonnegative_int other) const; + nonnegative_int &operator*=(nonnegative_int other); + bool operator<(nonnegative_int const &other) const; bool operator==(nonnegative_int const &other) const; bool operator>(nonnegative_int const &other) const; @@ -44,9 +47,15 @@ class nonnegative_int { int get_value() const; +private: + nonnegative_int &set_value(int); + private: int value_; }; + +nonnegative_int operator ""_n(unsigned long long int); + } // namespace FlexFlow namespace nlohmann { diff --git a/lib/utils/include/utils/nonnegative_int/num_elements.h b/lib/utils/include/utils/nonnegative_int/num_elements.h new file mode 100644 index 0000000000..9211783dd8 --- /dev/null +++ b/lib/utils/include/utils/nonnegative_int/num_elements.h @@ -0,0 +1,32 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_NONNEGATIVE_INT_NUM_ELEMENTS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_NONNEGATIVE_INT_NUM_ELEMENTS_H + +#include "utils/nonnegative_int/nonnegative_int.h" +#include +#include "utils/integer_conversions.h" + +namespace FlexFlow { + +template +nonnegative_int num_elements(std::vector const &v) { + return nonnegative_int{int_from_size_t(v.size())}; +} + +template +nonnegative_int num_elements(std::list const &v) { + return nonnegative_int{int_from_size_t(v.size())}; +} + +template +nonnegative_int num_elements(std::set const &v) { + return nonnegative_int{int_from_size_t(v.size())}; +} + +template +nonnegative_int num_elements(std::unordered_set const &v) { + return nonnegative_int{int_from_size_t(v.size())}; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/nonnegative_int/range.h b/lib/utils/include/utils/nonnegative_int/range.h new file mode 100644 index 0000000000..ab8fe0ca7f --- /dev/null +++ b/lib/utils/include/utils/nonnegative_int/range.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_NONNEGATIVE_INT_RANGE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_NONNEGATIVE_INT_RANGE_H + +#include "utils/nonnegative_int/nonnegative_int.h" +#include + +namespace FlexFlow { + +std::vector range(nonnegative_int start, nonnegative_int end, int step = 1); +std::vector range(nonnegative_int end); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/orthotope/dim_coord.h b/lib/utils/include/utils/orthotope/dim_coord.h new file mode 100644 index 0000000000..b57fd823d1 --- /dev/null +++ b/lib/utils/include/utils/orthotope/dim_coord.h @@ -0,0 +1,70 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_DIM_COORD_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_DIM_COORD_H + +#include "utils/containers/subvec.h" +#include "utils/containers/zip_with_strict.h" +#include "utils/orthotope/dim_coord.dtg.h" +#include "utils/orthotope/dim_domain.dtg.h" +#include "utils/orthotope/orthotope.h" +#include "utils/exception.h" +#include "utils/containers/keys.h" +#include "utils/containers/restrict_keys.h" +#include "utils/exception.h" +#include "utils/containers/sorted.h" +#include "utils/containers/transform.h" +#include "utils/containers/scanr.h" +#include "utils/containers/product.h" +#include "utils/containers/map_from_keys_and_values.h" + +namespace FlexFlow { + +template +std::unordered_set get_coord_dims(DimCoord const &coord) { + return keys(coord.raw); +} + +template +DimCoord restrict_coord_to_dims(DimCoord const &coord, std::unordered_set const &dims) { + return DimCoord{ + restrict_keys(coord.raw, dims), + }; +} + +template +OrthotopeCoord orthotope_coord_from_dim_coord(DimCoord const &coord) { + return OrthotopeCoord{ + transform(sorted(get_coord_dims(coord)), [&](T const &t) { return coord.raw.at(t); }), + }; +} + +template +DimCoord dim_coord_from_orthotope_coord(OrthotopeCoord const &coord, DimDomain const &domain) { + return DimCoord{ + map_from_keys_and_values(coord.raw, get_domain_dims(domain)), + }; +} + +template +nonnegative_int flatten_coord(DimCoord const &coord, + DimDomain const &domain) { + if (get_coord_dims(coord) != get_dims_for_domain(domain)) { + throw mk_runtime_error(fmt::format("flatten_dims expected coord dimensions to match domain dimensions, but received coord={} and domain={}", coord, domain)); + } + + OrthotopeCoord orthotope_coord = orthotope_coord_from_dim_coord(coord); + Orthotope orthotope_domain = orthotope_from_dim_domain(domain); + + return flatten_orthotope_coord(orthotope_coord, orthotope_domain); +} + +template +DimCoord unflatten_coord(nonnegative_int flattened, DimDomain const &domain) { + Orthotope orthotope_domain = orthotope_from_dim_domain(domain); + OrthotopeCoord orthotope_coord = unflatten_orthotope_coord(flattened, orthotope_domain); + + return dim_coord_from_orthotope_coord(orthotope_coord, domain); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/orthotope/dim_coord.struct.toml b/lib/utils/include/utils/orthotope/dim_coord.struct.toml new file mode 100644 index 0000000000..f6f565922a --- /dev/null +++ b/lib/utils/include/utils/orthotope/dim_coord.struct.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "DimCoord" +features = [ + "eq", + "ord", + "fmt", + "hash", +] + +template_params = [ + "T", +] + +includes = [ + "", + "utils/nonnegative_int/nonnegative_int.h", +] + +src_includes = [ + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", +] + +[[fields]] +name = "raw" +type = "std::unordered_map" diff --git a/lib/utils/include/utils/orthotope/dim_domain.h b/lib/utils/include/utils/orthotope/dim_domain.h new file mode 100644 index 0000000000..2c24e11943 --- /dev/null +++ b/lib/utils/include/utils/orthotope/dim_domain.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_DIM_DOMAIN_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_DIM_DOMAIN_H + +#include "utils/orthotope/dim_domain.dtg.h" +#include "utils/orthotope/orthotope.dtg.h" + +namespace FlexFlow { + +template +std::set get_domain_dims(DimDomain const &domain) { + return keys(domain.dims); +} + +template +DimDomain restrict_domain_to_dims(DimDomain const &domain, std::unordered_set const &allowed) { + return DimDomain{restrict_keys(domain.dims, allowed)}; +} + +template +Orthotope orthotope_from_dim_domain(DimDomain const &domain) { + return Orthotope{ + transform(sorted(get_domain_dims(domain)), [&](T const &t) { return domain.dims.at(t); }), + }; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/orthotope/dim_domain.struct.toml b/lib/utils/include/utils/orthotope/dim_domain.struct.toml new file mode 100644 index 0000000000..86aa2fb7ff --- /dev/null +++ b/lib/utils/include/utils/orthotope/dim_domain.struct.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "DimDomain" +features = [ + "eq", + "ord", + "fmt", + "hash", + "json", +] + +template_params = [ + "T", +] + +includes = [ + "", + "utils/nonnegative_int/nonnegative_int.h", +] + +src_includes = [ + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", +] + +[[fields]] +name = "dims" +type = "std::unordered_map" diff --git a/lib/utils/include/utils/orthotope/dim_projection.h b/lib/utils/include/utils/orthotope/dim_projection.h index 07e5885d49..aac858cbfe 100644 --- a/lib/utils/include/utils/orthotope/dim_projection.h +++ b/lib/utils/include/utils/orthotope/dim_projection.h @@ -1,12 +1,38 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_DIM_PROJECTION_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_DIM_PROJECTION_H -#include "utils/orthotope/down_projection.dtg.h" -#include "utils/orthotope/eq_projection.dtg.h" -#include "utils/orthotope/up_projection.dtg.h" +#include "utils/orthotope/dim_projection.dtg.h" +#include "utils/orthotope/down_projection.h" +#include "utils/orthotope/eq_projection.h" +#include "utils/orthotope/up_projection.h" +#include "utils/orthotope/dim_coord.dtg.h" +#include "utils/overload.h" namespace FlexFlow { +template +std::unordered_set input_dims_of_projection(DimProjection const &projection) { + return projection.template visit>(overload { + [](UpProjection const &p) { return input_dims_of_up_projection(p); }, + [](EqProjection const &p) { return input_dims_of_eq_projection(p); }, + [](DownProjection const &p) { return input_dims_of_down_projection(p); }, + }); +} + +template +std::unordered_set output_dims_of_projection(DimProjection const &projection) { + return projection.template visit>(overload { + [](UpProjection const &p) { return output_dims_of_up_projection(p); }, + [](EqProjection const &p) { return output_dims_of_eq_projection(p); }, + [](DownProjection const &p) { return output_dims_of_down_projection(p); }, + }); +}; + +// template +// DimCoord compute_projection(DimProjection const &projection, DimCoord const &coord) { +// if (coord_dims(coord) != +// } + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/orthotope/down_projection.h b/lib/utils/include/utils/orthotope/down_projection.h index 9ff24381b3..cec83e9d8c 100644 --- a/lib/utils/include/utils/orthotope/down_projection.h +++ b/lib/utils/include/utils/orthotope/down_projection.h @@ -1,8 +1,13 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_DOWN_PROJECTION_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_DOWN_PROJECTION_H +#include "utils/orthotope/dim_coord.dtg.h" +#include "utils/orthotope/dim_domain.dtg.h" #include "utils/orthotope/down_projection.dtg.h" #include "utils/orthotope/eq_projection.dtg.h" +#include "utils/orthotope/orthotope.dtg.h" +#include "utils/orthotope/orthotope.h" +#include "utils/orthotope/orthotope_coord.dtg.h" #include "utils/orthotope/up_projection.dtg.h" #include "utils/many_to_one/many_to_one_from_bidict.h" #include "utils/many_to_one/exhaustive_relational_join.h" @@ -15,6 +20,41 @@ DownProjection make_empty_down_projection() { return DownProjection{ManyToOne{}}; } +template +std::unordered_set input_dims_of_down_projection(DownProjection const &projection) { + return projection.dim_mapping.left_values(); +} + +template +std::unordered_set output_dims_of_down_projection(DownProjection const &projection) { + return projection.dim_mapping.right_values(); +} + +template +DimCoord compute_down_projection(DownProjection const &projection, + DimCoord const &coord, + DimDomain const &domain) { + std::unordered_set input_dims = input_dims_of_down_projection(projection); + std::unordered_set coord_dims = get_coord_dims(coord); + if (input_dims != coord_dims) { + throw mk_runtime(fmt::format("compute_down_projection expected coord dimensions to match projection input dimensions, but received inputs_dims={} and coord_dims={}", input_dims, coord_dims)); + } + + std::unordered_set output_dims = output_dims_of_down_projection(projection); + + return DimCoord{ + generate_map(output_dims, + [&](R const &output_dim) { + std::unordered_set src_dims = projection.dim_mapping.at_r(output_dim); + + DimCoord src_coord = restrict_coord_to_dims(coord, src_dims); + DimDomain src_domain = restrict_domain_to_dims(domain, src_dims); + + return flatten_coord(src_coord, src_domain); + }), + }; +} + template void project_dims(DownProjection &proj, std::unordered_set const &from, R const &onto) { for (L const &l : from) { diff --git a/lib/utils/include/utils/orthotope/down_projection.struct.toml b/lib/utils/include/utils/orthotope/down_projection.struct.toml index e9d9747bec..419434905b 100644 --- a/lib/utils/include/utils/orthotope/down_projection.struct.toml +++ b/lib/utils/include/utils/orthotope/down_projection.struct.toml @@ -12,7 +12,6 @@ template_params = [ includes = [ "utils/many_to_one/many_to_one.h", - "utils/orthotope/orthotope_dim_idx_t.dtg.h", ] [[fields]] diff --git a/lib/utils/include/utils/orthotope/eq_projection.h b/lib/utils/include/utils/orthotope/eq_projection.h index f9b0766f5d..0c764cfcc5 100644 --- a/lib/utils/include/utils/orthotope/eq_projection.h +++ b/lib/utils/include/utils/orthotope/eq_projection.h @@ -6,6 +6,16 @@ namespace FlexFlow { +template +std::unordered_set input_dims_of_eq_projection(EqProjection const &projection) { + return projection.dim_mapping.left_values(); +} + +template +std::unordered_set output_dims_of_eq_projection(EqProjection const &projection) { + return projection.dim_mapping.right_values(); +} + template EqProjection invert_eq_projection(EqProjection const &input) { return EqProjection{ diff --git a/lib/utils/include/utils/orthotope/orthotope.h b/lib/utils/include/utils/orthotope/orthotope.h index 8660173e52..04d1f68f4b 100644 --- a/lib/utils/include/utils/orthotope/orthotope.h +++ b/lib/utils/include/utils/orthotope/orthotope.h @@ -2,21 +2,23 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_H #include "utils/orthotope/orthotope.dtg.h" -#include "utils/orthotope/orthotope_coordinate.dtg.h" -#include "utils/orthotope/orthotope_dim_idx_t.dtg.h" -#include +#include "utils/orthotope/orthotope_coord.dtg.h" namespace FlexFlow { -std::set get_orthotope_dims(Orthotope const &); -int orthotope_num_dims(Orthotope const &); +nonnegative_int get_orthotope_num_dims(Orthotope const &); -bool orthotope_contains_coord(Orthotope const &, OrthotopeCoordinate const &); -std::unordered_set orthotope_get_contained_coordinates(Orthotope const &); +nonnegative_int get_orthotope_volume(Orthotope const &); -int orthotope_get_volume(Orthotope const &); +std::unordered_set get_all_coords_in_orthotope(Orthotope const &); -Orthotope orthotope_drop_dims_except(Orthotope const &, std::set const &); +bool orthotope_contains_coord(Orthotope const &, OrthotopeCoord const &); + +Orthotope restrict_orthotope_to_dims(Orthotope const &, std::set const &); + +nonnegative_int flatten_orthotope_coord(OrthotopeCoord const &, Orthotope const &); + +OrthotopeCoord unflatten_orthotope_coord(nonnegative_int, Orthotope const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/orthotope/orthotope.struct.toml b/lib/utils/include/utils/orthotope/orthotope.struct.toml index 0117b44893..a1fcb2a80e 100644 --- a/lib/utils/include/utils/orthotope/orthotope.struct.toml +++ b/lib/utils/include/utils/orthotope/orthotope.struct.toml @@ -9,13 +9,15 @@ features = [ ] includes = [ - "utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.h", + "", + "utils/nonnegative_int/nonnegative_int.h", ] src_includes = [ - "utils/orthotope/orthotope_dim_indexed/json.h", + "utils/fmt/vector.h", + "utils/hash/vector.h", ] [[fields]] name = "dims" -type = "::FlexFlow::OrthotopeDimIndexed" +type = "std::vector<::FlexFlow::nonnegative_int>" diff --git a/lib/utils/include/utils/orthotope/orthotope_bijective_projection.h b/lib/utils/include/utils/orthotope/orthotope_bijective_projection.h deleted file mode 100644 index cb5a64cf23..0000000000 --- a/lib/utils/include/utils/orthotope/orthotope_bijective_projection.h +++ /dev/null @@ -1,43 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_BIJECTIVE_PROJECTION_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_BIJECTIVE_PROJECTION_H - -#include "utils/orthotope/orthotope.dtg.h" -#include "utils/orthotope/orthotope_bijective_projection.dtg.h" -#include "utils/orthotope/orthotope_coordinate.dtg.h" -#include -#include - -namespace FlexFlow { - -OrthotopeBijectiveProjection - make_orthotope_projection_from_map(std::unordered_map const &, bool reversed); - -bool is_valid_projection_between(OrthotopeBijectiveProjection const &proj, Orthotope const &src, Orthotope const &dst); - -std::unordered_map get_src_to_dst_dim_map(OrthotopeBijectiveProjection const &); -std::unordered_map> get_dst_dims_by_src_dim_map(OrthotopeBijectiveProjection const &); -std::unordered_map> get_src_dims_by_dst_dim_map(OrthotopeBijectiveProjection const &); - -orthotope_dim_idx_t get_dst_dim_for_src_dim(OrthotopeBijectiveProjection const &, orthotope_dim_idx_t const &); -orthotope_dim_idx_t get_src_dim_for_dst_dim(OrthotopeBijectiveProjection const &, orthotope_dim_idx_t const &); - -int get_src_num_dims(OrthotopeBijectiveProjection const &); -int get_dst_num_dims(OrthotopeBijectiveProjection const &); - -OrthotopeBijectiveProjection reverse_projection(OrthotopeBijectiveProjection const &); - -std::unordered_set get_all_bijective_projections_between(Orthotope const &src, Orthotope const &dst); -std::unordered_set get_all_bijective_projections_between_dim_numbers(int src_num_dims, int dst_num_dims); - -int project_into_1d(Orthotope const &, OrthotopeCoordinate const &); -OrthotopeCoordinate project_out_of_1d(int, Orthotope const &); - - -OrthotopeCoordinate project_coordinate_through(OrthotopeBijectiveProjection const &projection, - Orthotope const &src_orthotope, - OrthotopeCoordinate const &src_coord, - Orthotope const &dst_orthotope); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/orthotope/orthotope_bijective_projection.struct.toml b/lib/utils/include/utils/orthotope/orthotope_bijective_projection.struct.toml deleted file mode 100644 index 08a29248b9..0000000000 --- a/lib/utils/include/utils/orthotope/orthotope_bijective_projection.struct.toml +++ /dev/null @@ -1,25 +0,0 @@ -namespace = "FlexFlow" -name = "OrthotopeBijectiveProjection" -features = [ - "eq", - "fmt", - "hash", -] - -includes = [ - "", - "utils/orthotope/orthotope_dim_idx_t.dtg.h", -] - -src_includes = [ - "utils/hash/vector.h", - "utils/fmt/vector.h", -] - -[[fields]] -name = "dim_mapping" -type = "std::vector<::FlexFlow::orthotope_dim_idx_t>" - -[[fields]] -name = "reversed" -type = "bool" diff --git a/lib/utils/include/utils/orthotope/orthotope_coord.h b/lib/utils/include/utils/orthotope/orthotope_coord.h new file mode 100644 index 0000000000..cf105780f0 --- /dev/null +++ b/lib/utils/include/utils/orthotope/orthotope_coord.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_COORD_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_COORD_H + +#include "utils/orthotope/orthotope_coord.dtg.h" + +namespace FlexFlow { + +OrthotopeCoord restrict_orthotope_coord_dims_to(OrthotopeCoord const &coord, std::set const &allowed_dims); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/orthotope/orthotope_coord.struct.toml b/lib/utils/include/utils/orthotope/orthotope_coord.struct.toml new file mode 100644 index 0000000000..a66220c611 --- /dev/null +++ b/lib/utils/include/utils/orthotope/orthotope_coord.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "OrthotopeCoord" +features = [ + "eq", + "ord", + "fmt", + "hash", + "json", +] + +includes = [ + "", + "utils/nonnegative_int/nonnegative_int.h", +] + +src_includes = [ + "utils/fmt/vector.h", + "utils/hash/vector.h", +] + +[[fields]] +name = "raw" +type = "std::vector<::FlexFlow::nonnegative_int>" diff --git a/lib/utils/include/utils/orthotope/orthotope_coordinate.h b/lib/utils/include/utils/orthotope/orthotope_coordinate.h deleted file mode 100644 index c4ac1114bd..0000000000 --- a/lib/utils/include/utils/orthotope/orthotope_coordinate.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_COORDINATE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_COORDINATE_H - -#include "utils/orthotope/orthotope_dim_idx_t.dtg.h" -#include "utils/orthotope/orthotope_coordinate.dtg.h" -#include - -namespace FlexFlow { - -std::set get_orthotope_coord_dims(OrthotopeCoordinate const &); - -OrthotopeCoordinate orthotope_coord_drop_dims_except(OrthotopeCoordinate const &, std::set const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/orthotope/orthotope_coordinate.struct.toml b/lib/utils/include/utils/orthotope/orthotope_coordinate.struct.toml deleted file mode 100644 index fdaef519fa..0000000000 --- a/lib/utils/include/utils/orthotope/orthotope_coordinate.struct.toml +++ /dev/null @@ -1,21 +0,0 @@ -namespace = "FlexFlow" -name = "OrthotopeCoordinate" -features = [ - "eq", - "ord", - "hash", - "fmt", - "json", -] - -includes = [ - "utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.h", -] - -src_includes = [ - "utils/orthotope/orthotope_dim_indexed/json.h", -] - -[[fields]] -name = "idxs" -type = "::FlexFlow::OrthotopeDimIndexed" diff --git a/lib/utils/include/utils/orthotope/orthotope_dim_idx_t.h b/lib/utils/include/utils/orthotope/orthotope_dim_idx_t.h deleted file mode 100644 index d14e2633d7..0000000000 --- a/lib/utils/include/utils/orthotope/orthotope_dim_idx_t.h +++ /dev/null @@ -1,13 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_DIM_IDX_T_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_DIM_IDX_T_H - -#include "utils/orthotope/orthotope_dim_idx_t.dtg.h" -#include - -namespace FlexFlow { - -std::set dim_idxs_for_orthotope_with_num_dims(int num_dims); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/orthotope/orthotope_dim_idx_t.struct.toml b/lib/utils/include/utils/orthotope/orthotope_dim_idx_t.struct.toml deleted file mode 100644 index 68ee54c40f..0000000000 --- a/lib/utils/include/utils/orthotope/orthotope_dim_idx_t.struct.toml +++ /dev/null @@ -1,12 +0,0 @@ -namespace = "FlexFlow" -name = "orthotope_dim_idx_t" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -[[fields]] -name = "raw_idx" -type = "int" diff --git a/lib/utils/include/utils/orthotope/orthotope_dim_indexed/all_of.h b/lib/utils/include/utils/orthotope/orthotope_dim_indexed/all_of.h deleted file mode 100644 index b1f4531af3..0000000000 --- a/lib/utils/include/utils/orthotope/orthotope_dim_indexed/all_of.h +++ /dev/null @@ -1,11 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_DIM_INDEXED_ALL_OF_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_DIM_INDEXED_ALL_OF_H - -#include "utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.h" -namespace FlexFlow { - -bool all_of(OrthotopeDimIndexed const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/orthotope/orthotope_dim_indexed/drop_idxs_except.h b/lib/utils/include/utils/orthotope/orthotope_dim_indexed/drop_idxs_except.h deleted file mode 100644 index 4ee9ee93ac..0000000000 --- a/lib/utils/include/utils/orthotope/orthotope_dim_indexed/drop_idxs_except.h +++ /dev/null @@ -1,31 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_DIM_INDEXED_DROP_IDXS_EXCEPT_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_DIM_INDEXED_DROP_IDXS_EXCEPT_H - -#include "utils/containers/contains.h" -#include "utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.h" -#include "utils/exception.h" -#include "utils/containers/is_subseteq_of.h" -#include "utils/fmt/set.h" - -namespace FlexFlow { - -template -OrthotopeDimIndexed drop_idxs_except(OrthotopeDimIndexed const &d, std::set const &keep) { - OrthotopeDimIndexed result; - - if (!is_subseteq_of(d.indices(), keep)) { - throw mk_runtime_error(fmt::format("drop_idxs_except expected keep to be a subset of d's dims, but got d={}, keep={}", d, keep)); - } - - for (orthotope_dim_idx_t const &idx : d.indices()) { - if (contains(keep, idx)) { - result.push_back(d.at(idx)); - } - } - - return result; -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/orthotope/orthotope_dim_indexed/json.h b/lib/utils/include/utils/orthotope/orthotope_dim_indexed/json.h deleted file mode 100644 index 7f1f8ed9e9..0000000000 --- a/lib/utils/include/utils/orthotope/orthotope_dim_indexed/json.h +++ /dev/null @@ -1,23 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_DIM_INDEXED_JSON_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_DIM_INDEXED_JSON_H - -#include "utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.h" -#include "utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_of.h" -#include - -namespace nlohmann { - -template -struct adl_serializer<::FlexFlow::OrthotopeDimIndexed> { - static ::FlexFlow::OrthotopeDimIndexed from_json(json const &j) { - return ::FlexFlow::orthotope_dim_indexed_of(j.get>()); - } - - static void to_json(json &j, ::FlexFlow::OrthotopeDimIndexed const &d) { - j = d.get_contents(); - } -}; - -} // namespace nlohmann - -#endif diff --git a/lib/utils/include/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.h b/lib/utils/include/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.h deleted file mode 100644 index 23f76037ac..0000000000 --- a/lib/utils/include/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.h +++ /dev/null @@ -1,183 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_DIM_INDEXED_ORTHOTOPE_DIM_INDEXED_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_DIM_INDEXED_ORTHOTOPE_DIM_INDEXED_H - -#include "utils/orthotope/orthotope_dim_idx_t.dtg.h" -#include -#include "utils/hash-utils.h" -#include -#include "utils/hash/vector.h" -#include "utils/hash/tuple.h" -#include "utils/fmt/vector.h" -#include "utils/type_traits_core.h" -#include "utils/orthotope/orthotope_dim_idx_t.h" -#include -#include "utils/ord/vector.h" - -namespace FlexFlow { - -template -struct OrthotopeDimIndexed { -public: - OrthotopeDimIndexed() - : contents() - { } - - OrthotopeDimIndexed(std::initializer_list const &l) - : contents(l) - { } - - template - OrthotopeDimIndexed(Iter begin, Iter end) - : contents(begin, end) - { } - - T const &at(orthotope_dim_idx_t const &idx) const { - return this->contents.at(idx.raw_idx); - } - - T &at(orthotope_dim_idx_t const &idx) { - return this->contents.at(idx.raw_idx); - } - - T const &back() const { - return this->contents.back(); - } - - T &back() { - return this->contents.back(); - } - - T const &front() const { - return this->contents.front(); - } - - T &front() { - return this->contents.front(); - } - - void push_back(T const &t) { - this->contents.push_back(t); - } - - std::vector const &get_contents() const { - return this->contents; - } - - bool operator==(OrthotopeDimIndexed const &other) const { - return this->tie() == other.tie(); - } - - bool operator!=(OrthotopeDimIndexed const &other) const { - return this->tie() != other.tie(); - } - - std::set indices() const { - return dim_idxs_for_orthotope_with_num_dims(this->size()); - } - - std::tuple const &> tie() const { - return std::tie(contents); - } -private: - std::vector contents; -public: - using iterator = typename decltype(contents)::iterator; - using const_iterator = typename decltype(contents)::const_iterator; - using reverse_iterator = typename decltype(contents)::reverse_iterator; - using const_reverse_iterator = typename decltype(contents)::const_reverse_iterator; - - using value_type = typename decltype(contents)::value_type; - using pointer = typename decltype(contents)::pointer; - using const_pointer = typename decltype(contents)::const_pointer; - using reference = typename decltype(contents)::reference; - using const_reference = typename decltype(contents)::const_reference; - - iterator begin() { - return this->contents.begin(); - } - - const_iterator begin() const { - return this->cbegin(); - } - - const_iterator cbegin() const { - return this->contents.cbegin(); - } - - iterator end() { - return this->contents.end(); - } - - const_iterator end() const { - return this->cend(); - } - - const_iterator cend() const { - return this->contents.cend(); - } - - reverse_iterator rbegin() { - return this->contents.rbegin(); - } - - const_reverse_iterator rbegin() const { - return this->crbegin(); - } - - const_reverse_iterator crbegin() const { - return this->contents.crbegin(); - } - - reverse_iterator rend() { - return this->contents.rend(); - } - - const_reverse_iterator rend() const { - return this->crend(); - } - - const_reverse_iterator crend() const { - return this->contents.crend(); - } - - size_t size() const { - return this->contents.size(); - } - - size_t empty() const { - return this->contents.empty(); - } -}; - -template -std::enable_if_t, bool> operator<(OrthotopeDimIndexed const &lhs, OrthotopeDimIndexed const &rhs) { - return lhs.tie() < rhs.tie(); -} - -template -std::vector format_as(OrthotopeDimIndexed const &d) { - return d.get_contents(); -} - -template -std::ostream &operator<<(std::ostream &s, OrthotopeDimIndexed const &d) { - return (s << fmt::to_string(d)); -} - -} // namespace FlexFlow - -namespace std { - -template -struct hash<::FlexFlow::OrthotopeDimIndexed> { - size_t operator()(::FlexFlow::OrthotopeDimIndexed const &t) const { - static_assert(::FlexFlow::is_hashable::value, - "Elements must be hashable"); - - return ::FlexFlow::get_std_hash(t.tie()); - } -}; - -} // namespace std - -#endif diff --git a/lib/utils/include/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_from_idx_map.h b/lib/utils/include/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_from_idx_map.h deleted file mode 100644 index 858224dee5..0000000000 --- a/lib/utils/include/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_from_idx_map.h +++ /dev/null @@ -1,29 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_DIM_INDEXED_ORTHOTOPE_DIM_INDEXED_FROM_IDX_MAP_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_DIM_INDEXED_ORTHOTOPE_DIM_INDEXED_FROM_IDX_MAP_H - -#include "utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.h" -#include "utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_of.h" -#include "utils/containers/vector_from_idx_map.h" -#include "utils/containers/map_keys.h" - -namespace FlexFlow { - -template -std::optional> orthotope_dim_indexed_from_idx_map(std::unordered_map const &m) { - std::unordered_map raw_idx_map = map_keys(m, [](orthotope_dim_idx_t idx) { return idx.raw_idx; }); - - std::vector raw_vec = ({ - std::optional> returned = vector_from_idx_map(raw_idx_map); - if (!returned.has_value()) { - return std::nullopt; - } - - returned.value(); - }); - - return orthotope_dim_indexed_of(raw_vec); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_of.h b/lib/utils/include/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_of.h deleted file mode 100644 index bb1794aeb3..0000000000 --- a/lib/utils/include/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_of.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_DIM_INDEXED_ORTHOTOPE_DIM_INDEXED_OF_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_DIM_INDEXED_ORTHOTOPE_DIM_INDEXED_OF_H - -#include "utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.h" -#include - -namespace FlexFlow { - -template -OrthotopeDimIndexed orthotope_dim_indexed_of(std::vector const &v) { - return OrthotopeDimIndexed(v.cbegin(), v.cend()); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/orthotope/orthotope_dim_indexed/transform.h b/lib/utils/include/utils/orthotope/orthotope_dim_indexed/transform.h deleted file mode 100644 index 9871284864..0000000000 --- a/lib/utils/include/utils/orthotope/orthotope_dim_indexed/transform.h +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_DIM_INDEXED_TRANSFORM_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_DIM_INDEXED_TRANSFORM_H - -#include "utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.h" -#include "utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_of.h" -#include "utils/containers/transform.h" - -namespace FlexFlow { - -template > -OrthotopeDimIndexed transform(OrthotopeDimIndexed const &d, F &&f) { - return orthotope_dim_indexed_of(transform(d.get_contents(), f)); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/orthotope/orthotope_dim_indexed/zip_with.h b/lib/utils/include/utils/orthotope/orthotope_dim_indexed/zip_with.h deleted file mode 100644 index f37c9725bc..0000000000 --- a/lib/utils/include/utils/orthotope/orthotope_dim_indexed/zip_with.h +++ /dev/null @@ -1,22 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_DIM_INDEXED_ZIP_WITH_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHOTOPE_DIM_INDEXED_ZIP_WITH_H - -#include "utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.h" -#include "utils/containers/intersection.h" - -namespace FlexFlow { - -template > -OrthotopeDimIndexed zip_with(OrthotopeDimIndexed const &l, OrthotopeDimIndexed const &r, F &&f) { - OrthotopeDimIndexed result; - for (orthotope_dim_idx_t i : intersection(l.indices(), r.indices())) { - result.push_back(f(l.at(i), r.at(i))); - } - - return result; -} - - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/orthotope/orthtope.h b/lib/utils/include/utils/orthotope/orthtope.h deleted file mode 100644 index b723e8057f..0000000000 --- a/lib/utils/include/utils/orthotope/orthtope.h +++ /dev/null @@ -1,10 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHTOPE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_ORTHTOPE_H - -namespace FlexFlow { - -Orthotope orthotope_from_dim_map(std::unordered_map const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/orthotope/up_projection.h b/lib/utils/include/utils/orthotope/up_projection.h index cd81e16b7b..ba9bf8a8ad 100644 --- a/lib/utils/include/utils/orthotope/up_projection.h +++ b/lib/utils/include/utils/orthotope/up_projection.h @@ -1,12 +1,16 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_UP_PROJECTION_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_UP_PROJECTION_H +#include "utils/orthotope/dim_coord.dtg.h" +#include "utils/orthotope/dim_domain.dtg.h" #include "utils/orthotope/up_projection.dtg.h" #include "utils/orthotope/eq_projection.dtg.h" #include "utils/orthotope/down_projection.dtg.h" #include "utils/one_to_many/one_to_many_from_bidict.h" #include "utils/one_to_many/exhaustive_relational_join.h" #include "utils/one_to_many/invert_one_to_many.h" +#include "utils/containers/keys.h" +#include "utils/containers/values.h" namespace FlexFlow { @@ -15,6 +19,41 @@ UpProjection make_empty_up_projection() { return UpProjection{OneToMany{}}; } +template +std::unordered_set input_dims_of_up_projection(UpProjection const &projection) { + return projection.dim_mapping.left_values(); +} + +template +std::unordered_set output_dims_of_up_projection(UpProjection const &projection) { + return projection.dim_mapping.right_values(); +} + +template +DimCoord compute_up_projection(UpProjection const &projection, + DimCoord const &coord, + DimDomain const &domain) { + std::unordered_set input_dims = input_dims_of_up_projection(projection); + std::unordered_set coord_dims = get_coord_dims(coord); + if (input_dims != coord_dims) { + throw mk_runtime(fmt::format("compute_up_projection expected coord dimensions to match projection input dimensions, but received inputs_dims={} and coord_dims={}", input_dims, coord_dims)); + } + + std::unordered_set output_dims = output_dims_of_up_projection(projection); + + return DimCoord{ + generate_map(output_dims, + [&](R const &output_dim) { + std::unordered_set src_dims = projection.dim_mapping.at_r(output_dim); + + DimCoord src_coord = restrict_coord_to_dims(coord, src_dims); + DimDomain src_domain = restrict_domain_to_dims(domain, src_dims); + + return flatten_coord(src_coord, src_domain); + }), + }; +} + template void project_dims(UpProjection &proj, L const &onto, std::unordered_set const &from) { for (R const &r : from) { diff --git a/lib/utils/include/utils/orthotope/up_projection.struct.toml b/lib/utils/include/utils/orthotope/up_projection.struct.toml index b37aba037a..a5ec5acec5 100644 --- a/lib/utils/include/utils/orthotope/up_projection.struct.toml +++ b/lib/utils/include/utils/orthotope/up_projection.struct.toml @@ -12,7 +12,6 @@ template_params = [ includes = [ "utils/one_to_many/one_to_many.h", - "utils/orthotope/orthotope_dim_idx_t.dtg.h", ] [[fields]] diff --git a/lib/utils/src/utils/bijection/bijection.cc b/lib/utils/src/utils/bijection/bijection.cc new file mode 100644 index 0000000000..c23577ba35 --- /dev/null +++ b/lib/utils/src/utils/bijection/bijection.cc @@ -0,0 +1,12 @@ +#include "utils/bijection/bijection.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using L = value_type<0>; +using R = value_type<1>; + +template + Bijection flip_bijection(Bijection const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/cli/cli_spec.cc b/lib/utils/src/utils/cli/cli_spec.cc index e7ad5e8df4..c93bdd1152 100644 --- a/lib/utils/src/utils/cli/cli_spec.cc +++ b/lib/utils/src/utils/cli/cli_spec.cc @@ -1,6 +1,7 @@ #include "utils/cli/cli_spec.h" #include "utils/containers/range.h" #include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" #include "utils/integer_conversions.h" namespace FlexFlow { diff --git a/lib/utils/src/utils/containers/filter_idxs.cc b/lib/utils/src/utils/containers/filter_idxs.cc index fd0d61dcf8..91df95eded 100644 --- a/lib/utils/src/utils/containers/filter_idxs.cc +++ b/lib/utils/src/utils/containers/filter_idxs.cc @@ -6,6 +6,6 @@ namespace FlexFlow { using T = value_type<0>; template - std::vector filter_idxs(std::vector const &, std::function const &); + std::vector filter_idxs(std::vector const &, std::function const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/generate_vector.cc b/lib/utils/src/utils/containers/generate_vector.cc new file mode 100644 index 0000000000..03b5452b33 --- /dev/null +++ b/lib/utils/src/utils/containers/generate_vector.cc @@ -0,0 +1,9 @@ +#include "utils/containers/generate_vector.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using T = value_type<0>; +using F = std::function; + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/range.cc b/lib/utils/src/utils/containers/range.cc index d3ebd1063b..009bc1cffb 100644 --- a/lib/utils/src/utils/containers/range.cc +++ b/lib/utils/src/utils/containers/range.cc @@ -1,10 +1,14 @@ #include "utils/containers/range.h" #include +#include +#include "utils/exception.h" namespace FlexFlow { std::vector range(int start, int end, int step) { - assert(step != 0); + if (step == 0) { + throw mk_runtime_error(fmt::format("range expected step != 0, but received: {}", step)); + } std::vector result; if (step > 0) { diff --git a/lib/utils/src/utils/containers/zip3_with.cc b/lib/utils/src/utils/containers/zip3_with.cc new file mode 100644 index 0000000000..de3a7021b6 --- /dev/null +++ b/lib/utils/src/utils/containers/zip3_with.cc @@ -0,0 +1,18 @@ +#include "utils/containers/zip3_with.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using A = value_type<0>; +using B = value_type<1>; +using C = value_type<2>; +using Result = value_type<3>; +using F = std::function; + +template + std::vector zip3_with(std::vector const &, + std::vector const &, + std::vector const &, + F &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/zip3_with_strict.cc b/lib/utils/src/utils/containers/zip3_with_strict.cc new file mode 100644 index 0000000000..ed1161505e --- /dev/null +++ b/lib/utils/src/utils/containers/zip3_with_strict.cc @@ -0,0 +1,19 @@ +#include "utils/containers/zip3_with_strict.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using A = value_type<0>; +using B = value_type<1>; +using C = value_type<2>; +using Result = value_type<3>; +using F = std::function; + +template + std::vector zip3_with_strict(std::vector const &, + std::vector const &, + std::vector const &, + F &&); + + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/zip_strict.cc b/lib/utils/src/utils/containers/zip_strict.cc new file mode 100644 index 0000000000..bbc31c708e --- /dev/null +++ b/lib/utils/src/utils/containers/zip_strict.cc @@ -0,0 +1,12 @@ +#include "utils/containers/zip_strict.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using L = value_type<0>; +using R = value_type<1>; + +template + std::vector> zip_strict(std::vector const &, std::vector const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/zip_with_strict.cc b/lib/utils/src/utils/containers/zip_with_strict.cc new file mode 100644 index 0000000000..349ee9a37c --- /dev/null +++ b/lib/utils/src/utils/containers/zip_with_strict.cc @@ -0,0 +1,14 @@ +#include "utils/containers/zip_with_strict.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using T1 = value_type<0>; +using T2 = value_type<1>; +using Result = value_type<2>; +using F = std::function; + +template + std::vector zip_with_strict(std::vector const &, std::vector const &, F &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/instances/unordered_set_dataflow_graph.cc b/lib/utils/src/utils/graph/instances/unordered_set_dataflow_graph.cc index dfb26cb4e1..a874243632 100644 --- a/lib/utils/src/utils/graph/instances/unordered_set_dataflow_graph.cc +++ b/lib/utils/src/utils/graph/instances/unordered_set_dataflow_graph.cc @@ -5,6 +5,7 @@ #include "utils/containers/extend.h" #include "utils/containers/transform.h" #include "utils/containers/unordered_set_of.h" +#include "utils/containers/vector_of.h" #include "utils/graph/dataflow_graph/algorithms.h" #include "utils/graph/node/algorithms.h" #include "utils/graph/open_dataflow_graph/open_dataflow_edge.h" diff --git a/lib/utils/src/utils/nonnegative_int/nonnegative_int.cc b/lib/utils/src/utils/nonnegative_int/nonnegative_int.cc index 2c317dce86..772ccee14e 100644 --- a/lib/utils/src/utils/nonnegative_int/nonnegative_int.cc +++ b/lib/utils/src/utils/nonnegative_int/nonnegative_int.cc @@ -3,17 +3,21 @@ namespace FlexFlow { nonnegative_int::nonnegative_int(int value) { - if (value < 0) { - throw std::invalid_argument( - "Value of nonnegative_int type must be nonnegative."); - } - this->value_ = value; + this->set_value(value); } nonnegative_int::operator int() const noexcept { return this->value_; } +nonnegative_int nonnegative_int::operator*(nonnegative_int other) const { + return nonnegative_int{this->value_ * other.value_}; +} + +nonnegative_int &nonnegative_int::operator*=(nonnegative_int other) { + return this->set_value(this->value_ * other.value_); +} + bool nonnegative_int::operator<(nonnegative_int const &other) const { return this->value_ < other.value_; } @@ -83,6 +87,20 @@ int nonnegative_int::get_value() const { int format_as(nonnegative_int const &x) { return x.get_value(); } + +nonnegative_int &nonnegative_int::set_value(int value) { + if (value < 0) { + throw std::invalid_argument( + "Value of nonnegative_int type must be nonnegative."); + } + this->value_ = value; + return *this; +} + +nonnegative_int operator ""_n(unsigned long long int value) { + return nonnegative_int{static_cast(value)}; +} + } // namespace FlexFlow namespace nlohmann { diff --git a/lib/utils/src/utils/nonnegative_int/num_elements.cc b/lib/utils/src/utils/nonnegative_int/num_elements.cc new file mode 100644 index 0000000000..0f5716353b --- /dev/null +++ b/lib/utils/src/utils/nonnegative_int/num_elements.cc @@ -0,0 +1,20 @@ +#include "utils/nonnegative_int/num_elements.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using T = value_type<0>; + +template + nonnegative_int num_elements(std::vector const &); + +template + nonnegative_int num_elements(std::list const &); + +template + nonnegative_int num_elements(std::set const &); + +template + nonnegative_int num_elements(std::unordered_set const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/nonnegative_int/range.cc b/lib/utils/src/utils/nonnegative_int/range.cc new file mode 100644 index 0000000000..6bbdd6e522 --- /dev/null +++ b/lib/utils/src/utils/nonnegative_int/range.cc @@ -0,0 +1,15 @@ +#include "utils/nonnegative_int/range.h" +#include "utils/containers/range.h" +#include "utils/containers/transform.h" + +namespace FlexFlow { + +std::vector range(nonnegative_int start, nonnegative_int end, int step) { + return transform(range(start.get_value(), end.get_value(), step), [](int x) { return nonnegative_int{x}; }); +} + +std::vector range(nonnegative_int end) { + return transform(range(end.get_value()), [](int x) { return nonnegative_int{x}; }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/orthotope/dim_coord.cc b/lib/utils/src/utils/orthotope/dim_coord.cc new file mode 100644 index 0000000000..2ef50004c0 --- /dev/null +++ b/lib/utils/src/utils/orthotope/dim_coord.cc @@ -0,0 +1,21 @@ +#include "utils/orthotope/dim_coord.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using T = ordered_value_type<0>; + +template + std::unordered_set get_coord_dims(DimCoord const &); + +template + DimCoord restrict_coord_to_dims(DimCoord const &, std::unordered_set const &); + +template + OrthotopeCoord orthotope_coord_from_dim_coord(DimCoord const &); + +template + nonnegative_int flatten_coord(DimCoord const &coord, + DimDomain const &domain); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/orthotope/dim_domain.cc b/lib/utils/src/utils/orthotope/dim_domain.cc new file mode 100644 index 0000000000..30f3f0d6a6 --- /dev/null +++ b/lib/utils/src/utils/orthotope/dim_domain.cc @@ -0,0 +1,18 @@ +#include "utils/orthotope/dim_domain.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using T = ordered_value_type<0>; + +template + std::set get_domain_dims(DimDomain const &); + +template + DimDomain restrict_domain_to_dims(DimDomain const &, std::set const &); + +template + Orthotope orthotope_from_dim_domain(DimDomain const &); + + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/orthotope/dim_projection.cc b/lib/utils/src/utils/orthotope/dim_projection.cc new file mode 100644 index 0000000000..d96e6eb67f --- /dev/null +++ b/lib/utils/src/utils/orthotope/dim_projection.cc @@ -0,0 +1,15 @@ +#include "utils/orthotope/dim_projection.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using L = value_type<0>; +using R = value_type<1>; + +template + std::unordered_set input_dims_of_projection(DimProjection const &); + +template + std::unordered_set output_dims_of_projection(DimProjection const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/orthotope/down_projection.cc b/lib/utils/src/utils/orthotope/down_projection.cc index e4f6477c0e..00fa2ec454 100644 --- a/lib/utils/src/utils/orthotope/down_projection.cc +++ b/lib/utils/src/utils/orthotope/down_projection.cc @@ -1,5 +1,8 @@ #include "utils/orthotope/down_projection.h" #include "utils/archetypes/value_type.h" +#include "utils/containers/generate_vector.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/nonnegative_int/range.h" namespace FlexFlow { @@ -25,4 +28,32 @@ template template DownProjection down_from_eq_proj(EqProjection const &); +Orthotope compute_down_projection(DownProjection const &projection, + Orthotope const &domain) { + NOT_IMPLEMENTED(); +} + +OrthotopeCoord compute_down_projection(DownProjection const &projection, + OrthotopeCoord const &coord, + Orthotope const &domain) { + std::unordered_set input_dims = input_dims_of_down_projection(projection); + std::unordered_set orthotope_dims = unordered_set_of(range(get_orthotope_num_dims(domain))); + + if (input_dims != orthotope_dims) { + throw mk_runtime_error(fmt::format("compute_down_projection expected projection input dims to match orthotope dims, but received input_dims={} and orthotope_dims={}", input_dims, orthotope_dims)); + } + + std::unordered_set output_dims = output_dims_of_down_projection(projection); + + return generate_vector( + [&](R const &output_dim) { + std::unordered_set src_dims = projection.dim_mapping.at_r(output_dim); + + DimCoord src_coord = restrict_coord_to_dims(coord, src_dims); + Orthotope src_domain = restrict_orthotope_to_dims(domain, src_dims); + + return flatten_dims(src_coord, src_domain); + }), + }; +} } // namespace FlexFlow diff --git a/lib/utils/src/utils/orthotope/eq_projection.cc b/lib/utils/src/utils/orthotope/eq_projection.cc index cb4563c78b..af6f60c2d9 100644 --- a/lib/utils/src/utils/orthotope/eq_projection.cc +++ b/lib/utils/src/utils/orthotope/eq_projection.cc @@ -6,6 +6,12 @@ namespace FlexFlow { using L = value_type<0>; using R = value_type<1>; +template + std::unordered_set input_dims_of_eq_projection(EqProjection const &); + +template + std::unordered_set output_dims_of_eq_projection(EqProjection const &); + template EqProjection invert_eq_projection(EqProjection const &); diff --git a/lib/utils/src/utils/orthotope/orthotope.cc b/lib/utils/src/utils/orthotope/orthotope.cc index fe570bc978..4057798585 100644 --- a/lib/utils/src/utils/orthotope/orthotope.cc +++ b/lib/utils/src/utils/orthotope/orthotope.cc @@ -1,56 +1,74 @@ #include "utils/orthotope/orthotope.h" -#include "utils/containers/generate_map.h" -#include "utils/containers/get_all_assignments.h" +#include "utils/containers/cartesian_product.h" +#include "utils/containers/filter_idxs.h" #include "utils/containers/product.h" -#include "utils/containers/range.h" +#include "utils/containers/scanr.h" +#include "utils/containers/subvec.h" #include "utils/containers/unordered_set_of.h" -#include "utils/orthotope/orthotope_dim_indexed/drop_idxs_except.h" -#include "utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_from_idx_map.h" -#include "utils/orthotope/orthotope_dim_indexed/zip_with.h" -#include "utils/orthotope/orthotope_dim_indexed/all_of.h" -#include "utils/containers/all_of.h" -#include "utils/containers/transform.h" +#include "utils/containers/zip_strict.h" +#include "utils/containers/zip_with_strict.h" #include "utils/exception.h" +#include "utils/nonnegative_int/num_elements.h" +#include "utils/nonnegative_int/range.h" +#include "utils/containers/transform.h" +#include "utils/containers/all_of.h" +#include "utils/containers/filter_idxs.h" +#include "utils/containers/contains.h" namespace FlexFlow { -std::set get_orthotope_dims(Orthotope const &orthotope) { - return orthotope.dims.indices(); +nonnegative_int get_orthotope_num_dims(Orthotope const &orthotope) { + return num_elements(orthotope.dims); +} + +nonnegative_int get_orthotope_volume(Orthotope const &orthotope) { + return product(orthotope.dims); } -int orthotope_num_dims(Orthotope const &orthotope) { - return orthotope.dims.size(); +std::unordered_set get_all_coords_in_orthotope(Orthotope const &orthotope) { + std::unordered_multiset> raw_coords = cartesian_product(transform(orthotope.dims, [](nonnegative_int dim_size) { return range(dim_size); })); + + return unordered_set_of(transform(raw_coords, [](std::vector const &raw_coord) { return OrthotopeCoord{raw_coord}; })); } -bool orthotope_contains_coord(Orthotope const &o, OrthotopeCoordinate const &c) { - if (o.dims.size() != c.idxs.size()) { - throw mk_runtime_error(fmt::format("orthotope_contains_coord expected orthotope and coord to have the same number of dims, but received o={}, c={}", o, c)); +bool orthotope_contains_coord(Orthotope const &orthotope, OrthotopeCoord const &coord) { + if (orthotope.dims.size() != coord.raw.size()) { + throw mk_runtime_error(fmt::format("orthotope_contains_coord expected orthotope and coord to have the same number of dims, but received orthotope={}, coord={}", orthotope, coord)); } - return all_of(zip_with(o.dims, c.idxs, [](int dim_size, int dim_coord) { return dim_coord >= 0 && dim_coord < dim_size; })); + return all_of(zip_strict(coord.raw, orthotope.dims), [](nonnegative_int c, nonnegative_int o) { return c < o; }); } -std::unordered_set orthotope_get_contained_coordinates(Orthotope const &orthotope) { - std::unordered_map> possible_coord_assignments = - generate_map(get_orthotope_dims(orthotope), - [&](orthotope_dim_idx_t const &dim_idx) { - return unordered_set_of(range(orthotope.dims.at(dim_idx))); - }); - - return transform(get_all_assignments(possible_coord_assignments), - [](std::unordered_map const &assignment) { - return OrthotopeCoordinate{ - orthotope_dim_indexed_from_idx_map(assignment).value(), - }; - }); +Orthotope restrict_orthotope_dims_to(Orthotope const &orthotope, std::set const &allowed_dims) { + return Orthotope{ + filter_idxs(orthotope.dims, [&](nonnegative_int idx) { return contains(allowed_dims, idx); }), + }; } -int orthotope_get_volume(Orthotope const &o) { - return product(o.dims.get_contents()); +nonnegative_int flatten_orthotope_coord(OrthotopeCoord const &coord, Orthotope const &orthotope) { + if (orthotope.dims.size() != coord.raw.size()) { + throw mk_runtime_error(fmt::format("flatten_orthotope_coord expected orthotope and coord to have the same number of dims, but received orthotope={}, coord={}", orthotope, coord)); + } + + std::vector steps = scanr(orthotope.dims, nonnegative_int{1}, + [](nonnegative_int r, nonnegative_int accum) { + return r * accum; + }); + + return product(zip_with_strict(coord.raw, subvec(steps, 0, -1), + [](nonnegative_int coord_val, nonnegative_int step) { return coord_val * step; })); + } -Orthotope orthotope_drop_dims_except(Orthotope const &o, std::set const &keep) { - return Orthotope{drop_idxs_except(o.dims, keep)}; +OrthotopeCoord unflatten_orthotope_coord(nonnegative_int flattened, Orthotope const &orthotope) { + std::vector steps = scanr(orthotope.dims, nonnegative_int{1}, + [](nonnegative_int r, nonnegative_int accum) { + return r * accum; + }); + + return zip3_with_strict(orthotope.dims, + subvec(steps, 0, -1), + subvec(orthotope.dims, 1, 0), []() { TODO_COLIN_THIS_IS_WHAT_YOU_WERE_DOING }); } } // namespace FlexFlow diff --git a/lib/utils/src/utils/orthotope/orthotope_bijective_projection.cc b/lib/utils/src/utils/orthotope/orthotope_bijective_projection.cc deleted file mode 100644 index 001722e006..0000000000 --- a/lib/utils/src/utils/orthotope/orthotope_bijective_projection.cc +++ /dev/null @@ -1,283 +0,0 @@ -#include "utils/orthotope/orthotope_bijective_projection.h" -#include "utils/containers/generate_map.h" -#include "utils/containers/get_all_assignments.h" -#include "utils/containers/group_by.h" -#include "utils/containers/map_from_keys_and_values.h" -#include "utils/containers/map_keys.h" -#include "utils/containers/map_values.h" -#include "utils/containers/merge_maps.h" -#include "utils/containers/product.h" -#include "utils/containers/range.h" -#include "utils/containers/set_of.h" -#include "utils/containers/subvec.h" -#include "utils/containers/sum.h" -#include "utils/containers/transform.h" -#include "utils/containers/unordered_set_of.h" -#include "utils/containers/filter.h" -#include "utils/containers/values.h" -#include "utils/containers/zip_with.h" -#include "utils/exception.h" -#include "utils/orthotope/orthotope.h" -#include "utils/orthotope/orthotope_coordinate.h" -#include "utils/orthotope/orthotope_dim_idx_t.dtg.h" -#include "utils/orthotope/orthotope_dim_idx_t.h" -#include "utils/containers/vector_from_idx_map.h" -#include "utils/containers/scanr.h" -#include "utils/containers/scanr1.h" -#include "utils/containers/all_of.h" -#include "utils/fmt/vector.h" -#include "utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_from_idx_map.h" - -namespace FlexFlow { - -OrthotopeBijectiveProjection - make_orthotope_projection_from_map(std::unordered_map const &m, bool reversed) { - std::unordered_map raw_idx_map = map_keys(m, [](orthotope_dim_idx_t const &k) { return k.raw_idx; }); - return OrthotopeBijectiveProjection{ - /*dim_mapping=*/vector_from_idx_map(raw_idx_map).value(), - /*reversed=*/reversed, - }; -} - -bool is_valid_projection_between(OrthotopeBijectiveProjection const &proj, Orthotope const &src, Orthotope const &dst) { - if (proj.reversed) { - return is_valid_projection_between(reverse_projection(proj), dst, src); - } - - auto get_src_dim_size = [&](orthotope_dim_idx_t const &src_idx) -> int { - return src.dims.at(src_idx); - }; - - auto get_dst_dim_size = [&](orthotope_dim_idx_t const &dst_idx) -> int { - return dst.dims.at(dst_idx); - }; - - std::unordered_map> src_dims_by_dst_dim - = get_src_dims_by_dst_dim_map(proj); - - return all_of(src_dims_by_dst_dim, - [&](orthotope_dim_idx_t const &dst_idx, std::set const &src_idxs) -> bool { - std::vector src_dim_sizes = transform(vector_of(src_idxs), get_src_dim_size); - - return get_dst_dim_size(dst_idx) == product(src_dim_sizes); - }); -} -std::unordered_map> - get_src_dims_by_dst_dim_map(OrthotopeBijectiveProjection const &p) { - if (p.reversed) { - throw mk_runtime_error(fmt::format("get_src_dims_by_dst_dim_map expected p.reversed=false, but received p={}", p)); - } - - std::set src_dim_idxs = dim_idxs_for_orthotope_with_num_dims(get_src_num_dims(p)); - - return group_by(src_dim_idxs, - [&](orthotope_dim_idx_t const &src_dim_idx) { return get_dst_dim_for_src_dim(p, src_dim_idx); }); -} - -std::unordered_map> - get_dst_dims_by_src_dim_map(OrthotopeBijectiveProjection const &p) { - if (!p.reversed) { - throw mk_runtime_error(fmt::format("get_dst_dims_by_src_dim_map expected p.reversed=true, but received p={}", p)); - } - - std::set dst_dim_idxs = dim_idxs_for_orthotope_with_num_dims(get_dst_num_dims(p)); - - return group_by(dst_dim_idxs, - [&](orthotope_dim_idx_t const &dst_dim_idx) { return get_src_dim_for_dst_dim(p, dst_dim_idx); }); -} - -std::unordered_map get_src_to_dst_dim_map(OrthotopeBijectiveProjection const &p) { - if (p.reversed) { - throw mk_runtime_error(fmt::format("get_src_to_dst_dim_map expected p.reversed=false, but received p={}", p)); - } - - std::unordered_map raw_idx_map = generate_map(range(p.dim_mapping.size()), [&](int x) { return p.dim_mapping.at(x); }); - return map_keys(raw_idx_map, [](int src_dim_idx) { return orthotope_dim_idx_t{src_dim_idx}; }); -} - -orthotope_dim_idx_t get_dst_dim_for_src_dim(OrthotopeBijectiveProjection const &p, orthotope_dim_idx_t const &src_idx) { - if (p.reversed) { - throw mk_runtime_error(fmt::format("get_dst_dim_for_src_dim expected a non-reversed projection, but received: projection={}", p)); - } - - return p.dim_mapping.at(src_idx.raw_idx); -} - -orthotope_dim_idx_t get_src_dim_for_dst_dim(OrthotopeBijectiveProjection const &p, orthotope_dim_idx_t const &dst_idx) { - if (!p.reversed) { - throw mk_runtime_error(fmt::format("get_src_dim_for_dst_dim expected a reversed projection, but received: projection={}", p)); - } - - return get_dst_dim_for_src_dim(reverse_projection(p), dst_idx); -} - -int get_src_num_dims(OrthotopeBijectiveProjection const &p) { - if (p.reversed) { - return get_dst_num_dims(reverse_projection(p)); - } - - return p.dim_mapping.size(); -} - -int get_dst_num_dims(OrthotopeBijectiveProjection const &p) { - if (p.reversed) { - return get_src_num_dims(reverse_projection(p)); - } - - return unordered_set_of(p.dim_mapping).size(); -} - -OrthotopeBijectiveProjection reverse_projection(OrthotopeBijectiveProjection const &p) { - OrthotopeBijectiveProjection result = p; - result.reversed = !p.reversed; - return result; -} - -std::unordered_set get_all_bijective_projections_between_dim_numbers(int src_num_dims, int dst_num_dims) { - if (src_num_dims < dst_num_dims) { - return transform(get_all_bijective_projections_between_dim_numbers(dst_num_dims, src_num_dims), - [](OrthotopeBijectiveProjection const &p) { return reverse_projection(p); }); - } - - std::set src_dim_idxs = dim_idxs_for_orthotope_with_num_dims(src_num_dims); - std::set dst_dim_idxs = dim_idxs_for_orthotope_with_num_dims(dst_num_dims); - - std::unordered_map> src_to_dst_idxs = - generate_map(src_dim_idxs, [&](orthotope_dim_idx_t) { return unordered_set_of(dst_dim_idxs); }); - - std::unordered_set> valid_mappings = - filter(get_all_assignments(src_to_dst_idxs), - [&](std::unordered_map const &src_to_dst_idx) { - return set_of(values(src_to_dst_idx)) == dst_dim_idxs; - }); - - return transform(valid_mappings, [](std::unordered_map const &m) { return make_orthotope_projection_from_map(m, /*reversed=*/false); }); -} - -std::unordered_set get_all_bijective_projections_between(Orthotope const &src, Orthotope const &dst) { - return filter(get_all_bijective_projections_between_dim_numbers(/*src_num_dims=*/orthotope_num_dims(src), /*dst_num_dims=*/orthotope_num_dims(dst)), - [&](OrthotopeBijectiveProjection const &p) { - return is_valid_projection_between(p, /*src=*/src, /*dst=*/dst); - }); -} - -int project_into_1d(Orthotope const &orthotope, OrthotopeCoordinate const &coord) { - if (!orthotope_contains_coord(orthotope, coord)) { - throw mk_runtime_error(fmt::format("coord out of bounds of orthotope: orthotope={}, coord={}", orthotope, coord)); - } - - if (orthotope.dims.size() == 0) { - return 0; - } - - std::vector> coords_and_sizes = zip(coord.idxs.get_contents(), - orthotope.dims.get_contents()); - - std::vector coords = transform(coords_and_sizes, [](std::pair const &p) { return p.first; }); - std::vector dim_sizes = transform(coords_and_sizes, [](std::pair const &p) { return p.second; }); - - std::vector strides = scanr(subvec(dim_sizes, 1, std::nullopt), 1, [](int next, int accum) { return accum * next; }); - return sum(zip_with(coords, strides, [](int coord, int stride) { return coord * stride; })); -} - -OrthotopeCoordinate project_out_of_1d(int one_dimensional_coord, Orthotope const &dst_orthotope) { - if (dst_orthotope.dims.size() == 0) { - if (one_dimensional_coord == 0) { - return OrthotopeCoordinate{{}}; - } else { - throw mk_runtime_error(fmt::format("Only valid one_dimensional_coord for zero-dimensional orthotope is 0, but receieved one_dimensional_coord={}", one_dimensional_coord)); - } - } - - if (one_dimensional_coord >= orthotope_get_volume(dst_orthotope)) { - throw mk_runtime_error(fmt::format("project_out_of_1d received coordinate that would be out of bounds of dst orthotope: dst_orthotope={}, coordinate={}", dst_orthotope, one_dimensional_coord)); - } - - std::vector dim_sizes = dst_orthotope.dims.get_contents(); - std::vector strides = scanr(subvec(dim_sizes, 1, std::nullopt), 1, [](int next, int accum) { return accum * next; }); - - OrthotopeCoordinate result = OrthotopeCoordinate{ - orthotope_dim_indexed_of(zip_with(dim_sizes, strides, [&](int dim_size, int stride) { return (one_dimensional_coord / stride) % dim_size; })), - }; - return result; -} - -OrthotopeCoordinate project_coordinate_through(OrthotopeBijectiveProjection const &p, Orthotope const &src_orthotope, OrthotopeCoordinate const &src_coord, Orthotope const &dst_orthotope) { - std::set dst_dim_idxs = get_orthotope_dims(dst_orthotope); - std::set src_dim_idxs = get_orthotope_dims(src_orthotope); - - if (src_coord.idxs.size() != get_src_num_dims(p)) { - throw mk_runtime_error(fmt::format("project_coordinate_through requires projection src and coordinate to have same num dims, but got {} and {} respectively", - get_src_num_dims(p), - src_coord.idxs.size())); - } - - if (!orthotope_contains_coord(src_orthotope, src_coord)) { - throw mk_runtime_error(fmt::format("project_coordinate_through requires coord to be in the orthotope, but got coord={} and orthotope={} respectively", src_coord, src_orthotope)); - } - - if (p.reversed) { - std::unordered_map> - dst_dim_idxs_by_src_dim_idx = - group_by(dst_dim_idxs, - [&](orthotope_dim_idx_t const &dst_dim_idx) { return get_src_dim_for_dst_dim(p, dst_dim_idx); }); - - - std::unordered_map dst_sub_orthotopes_by_src_dim_idx = - map_values(dst_dim_idxs_by_src_dim_idx, - [&](std::set const &dst_dim_idxs) { - return orthotope_drop_dims_except(dst_orthotope, dst_dim_idxs); - }); - - std::unordered_map dst_coords_by_src_dim_idx = - generate_map(src_dim_idxs, - [&](orthotope_dim_idx_t const &src_idx) -> OrthotopeCoordinate { - return project_out_of_1d(src_coord.idxs.at(src_idx), - dst_sub_orthotopes_by_src_dim_idx.at(src_idx)); - }); - - std::unordered_map dst_coords = merge_maps( - transform(vector_of(src_dim_idxs), [&](orthotope_dim_idx_t const &src_idx) -> std::unordered_map { - return map_from_keys_and_values( - vector_of(dst_dim_idxs_by_src_dim_idx.at(src_idx)), - dst_coords_by_src_dim_idx.at(src_idx).idxs.get_contents()); - })); - - return OrthotopeCoordinate{ - orthotope_dim_indexed_from_idx_map(dst_coords).value(), - }; - } else { - std::unordered_map> src_dim_idxs_by_dst_dim_idx = - group_by(src_dim_idxs, - [&](orthotope_dim_idx_t const &src_dim_idx) { return get_dst_dim_for_src_dim(p, src_dim_idx); }); - - - std::unordered_map src_sub_orthotopes_by_dst_dim_idx = - map_values(src_dim_idxs_by_dst_dim_idx, - [&](std::set const &src_dim_idxs) { - return orthotope_drop_dims_except(src_orthotope, src_dim_idxs); - }); - - std::unordered_map src_sub_coords_by_dst_dim_idx = - map_values(src_dim_idxs_by_dst_dim_idx, - [&](std::set const &src_dim_idxs) { - return orthotope_coord_drop_dims_except(src_coord, src_dim_idxs); - }); - - std::unordered_map dst_coords = - generate_map(dst_dim_idxs, - [&](orthotope_dim_idx_t const &dst_idx) { - return project_into_1d( - src_sub_orthotopes_by_dst_dim_idx.at(dst_idx), - src_sub_coords_by_dst_dim_idx.at(dst_idx)); - }); - - - return OrthotopeCoordinate{ - orthotope_dim_indexed_from_idx_map(dst_coords).value(), - }; - } -} - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/orthotope/orthotope_coord.cc b/lib/utils/src/utils/orthotope/orthotope_coord.cc new file mode 100644 index 0000000000..63cb796e29 --- /dev/null +++ b/lib/utils/src/utils/orthotope/orthotope_coord.cc @@ -0,0 +1,13 @@ +#include "utils/orthotope/orthotope_coord.h" +#include "utils/containers/filter_idxs.h" +#include "utils/containers/contains.h" + +namespace FlexFlow { + +OrthotopeCoord restrict_orthotope_coord_dims_to(OrthotopeCoord const &coord, std::set const &allowed_dims) { + return OrthotopeCoord{ + filter_idxs(coord.raw, [&](nonnegative_int idx) { return contains(allowed_dims, idx); }), + }; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/orthotope/orthotope_coordinate.cc b/lib/utils/src/utils/orthotope/orthotope_coordinate.cc deleted file mode 100644 index bcf310c043..0000000000 --- a/lib/utils/src/utils/orthotope/orthotope_coordinate.cc +++ /dev/null @@ -1,21 +0,0 @@ -#include "utils/orthotope/orthotope_coordinate.h" -#include "utils/containers/filter_idxs.h" -#include "utils/containers/is_subseteq_of.h" -#include "utils/exception.h" -#include "utils/orthotope/orthotope_dim_idx_t.h" -#include "utils/fmt/set.h" -#include "utils/orthotope/orthotope_dim_indexed/drop_idxs_except.h" - -namespace FlexFlow { - -std::set get_orthotope_coord_dims(OrthotopeCoordinate const &coord) { - return coord.idxs.indices(); -} - -OrthotopeCoordinate orthotope_coord_drop_dims_except(OrthotopeCoordinate const &coord, std::set const &mask) { - OrthotopeDimIndexed new_idxs = drop_idxs_except(coord.idxs, mask); - - return OrthotopeCoordinate{new_idxs}; -} - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/orthotope/orthotope_dim_idx_t.cc b/lib/utils/src/utils/orthotope/orthotope_dim_idx_t.cc deleted file mode 100644 index a26645a521..0000000000 --- a/lib/utils/src/utils/orthotope/orthotope_dim_idx_t.cc +++ /dev/null @@ -1,12 +0,0 @@ -#include "utils/orthotope/orthotope_dim_idx_t.h" -#include "utils/containers/set_of.h" -#include "utils/containers/transform.h" -#include "utils/containers/range.h" - -namespace FlexFlow { - -std::set dim_idxs_for_orthotope_with_num_dims(int num_dims) { - return set_of(transform(range(num_dims), [](int idx) { return orthotope_dim_idx_t{idx}; })); -} - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/orthotope/orthotope_dim_indexed/all_of.cc b/lib/utils/src/utils/orthotope/orthotope_dim_indexed/all_of.cc deleted file mode 100644 index 3a4d680392..0000000000 --- a/lib/utils/src/utils/orthotope/orthotope_dim_indexed/all_of.cc +++ /dev/null @@ -1,10 +0,0 @@ -#include "utils/orthotope/orthotope_dim_indexed/all_of.h" -#include "utils/containers/all_of.h" - -namespace FlexFlow { - -bool all_of(OrthotopeDimIndexed const &d) { - return all_of(d.get_contents()); -} - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/orthotope/orthotope_dim_indexed/drop_idxs_except.cc b/lib/utils/src/utils/orthotope/orthotope_dim_indexed/drop_idxs_except.cc deleted file mode 100644 index 13d9fa2779..0000000000 --- a/lib/utils/src/utils/orthotope/orthotope_dim_indexed/drop_idxs_except.cc +++ /dev/null @@ -1,11 +0,0 @@ -#include "utils/orthotope/orthotope_dim_indexed/drop_idxs_except.h" -#include "utils/archetypes/value_type.h" - -namespace FlexFlow { - -using T = value_type<0>; - -template - OrthotopeDimIndexed drop_idxs_except(OrthotopeDimIndexed const &, std::set const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/orthotope/orthotope_dim_indexed/json.cc b/lib/utils/src/utils/orthotope/orthotope_dim_indexed/json.cc deleted file mode 100644 index 539e7e7ba0..0000000000 --- a/lib/utils/src/utils/orthotope/orthotope_dim_indexed/json.cc +++ /dev/null @@ -1,8 +0,0 @@ -#include "utils/orthotope/orthotope_dim_indexed/json.h" - -namespace nlohmann { - -template - struct adl_serializer<::FlexFlow::OrthotopeDimIndexed>; - -} // namespace nlohmann diff --git a/lib/utils/src/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.cc b/lib/utils/src/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.cc deleted file mode 100644 index 56662bd7de..0000000000 --- a/lib/utils/src/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.cc +++ /dev/null @@ -1,17 +0,0 @@ -#include "utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed.h" -#include "utils/archetypes/ordered_value_type.h" -#include "utils/archetypes/value_type.h" - -namespace FlexFlow { - -using T = value_type<0>; - -template - struct OrthotopeDimIndexed; - -using T2 = ordered_value_type<0>; - -// template -// bool operator<(OrthotopeDimIndexed const &, OrthotopeDimIndexed const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_from_idx_map.cc b/lib/utils/src/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_from_idx_map.cc deleted file mode 100644 index f7d4012688..0000000000 --- a/lib/utils/src/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_from_idx_map.cc +++ /dev/null @@ -1,11 +0,0 @@ -#include "utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_from_idx_map.h" -#include "utils/archetypes/value_type.h" - -namespace FlexFlow { - -using T = value_type<0>; - -template - std::optional> orthotope_dim_indexed_from_idx_map(std::unordered_map const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_of.cc b/lib/utils/src/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_of.cc deleted file mode 100644 index 7d4a515205..0000000000 --- a/lib/utils/src/utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_of.cc +++ /dev/null @@ -1,11 +0,0 @@ -#include "utils/orthotope/orthotope_dim_indexed/orthotope_dim_indexed_of.h" -#include "utils/archetypes/value_type.h" - -namespace FlexFlow { - -using T = value_type<0>; - -template - OrthotopeDimIndexed orthotope_dim_indexed_of(std::vector const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/orthotope/orthotope_dim_indexed/transform.cc b/lib/utils/src/utils/orthotope/orthotope_dim_indexed/transform.cc deleted file mode 100644 index 593e673ee9..0000000000 --- a/lib/utils/src/utils/orthotope/orthotope_dim_indexed/transform.cc +++ /dev/null @@ -1,13 +0,0 @@ -#include "utils/orthotope/orthotope_dim_indexed/transform.h" -#include "utils/archetypes/value_type.h" - -namespace FlexFlow { - -using T = value_type<0>; -using Result = value_type<1>; -using F = std::function; - -template - OrthotopeDimIndexed transform(OrthotopeDimIndexed const &, F &&); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/orthotope/orthotope_dim_indexed/zip_with.cc b/lib/utils/src/utils/orthotope/orthotope_dim_indexed/zip_with.cc deleted file mode 100644 index d9f9e4bb1f..0000000000 --- a/lib/utils/src/utils/orthotope/orthotope_dim_indexed/zip_with.cc +++ /dev/null @@ -1,14 +0,0 @@ -#include "utils/orthotope/orthotope_dim_indexed/zip_with.h" -#include "utils/archetypes/value_type.h" - -namespace FlexFlow { - -using T1 = value_type<0>; -using T2 = value_type<1>; -using Result = value_type<2>; -using F = std::function; - -template - OrthotopeDimIndexed zip_with(OrthotopeDimIndexed const &, OrthotopeDimIndexed const &, F &&); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/orthotope/up_projection.cc b/lib/utils/src/utils/orthotope/up_projection.cc index e8b5fb4db4..710bffcacc 100644 --- a/lib/utils/src/utils/orthotope/up_projection.cc +++ b/lib/utils/src/utils/orthotope/up_projection.cc @@ -13,6 +13,12 @@ template using L = value_type<0>; using R = value_type<1>; +template + std::unordered_set input_dims_of_up_projection(UpProjection const &); + +template + std::unordered_set output_dims_of_up_projection(UpProjection const &); + template UpProjection make_empty_up_projection(); diff --git a/lib/utils/test/src/utils/containers/range.cc b/lib/utils/test/src/utils/containers/range.cc index f115855323..14a3710ac0 100644 --- a/lib/utils/test/src/utils/containers/range.cc +++ b/lib/utils/test/src/utils/containers/range.cc @@ -50,5 +50,15 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector correct = {}; CHECK(result == correct); } + + SUBCASE("step = 0") { + SUBCASE("output is nonempty") { + CHECK_THROWS(range(2, 5, 0)); + } + + SUBCASE("output is empty") { + CHECK_THROWS(range(3, 3, 0)); + } + } } } diff --git a/lib/utils/test/src/utils/containers/zip_strict.cc b/lib/utils/test/src/utils/containers/zip_strict.cc new file mode 100644 index 0000000000..5a6cc4117e --- /dev/null +++ b/lib/utils/test/src/utils/containers/zip_strict.cc @@ -0,0 +1,28 @@ +#include +#include "utils/containers/zip_strict.h" +#include +#include "test/utils/doctest/fmt/vector.h" +#include "test/utils/doctest/fmt/pair.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("zip_strict(std::vector, std::vector)") { + SUBCASE("input lengths are the same") { + std::vector lhs = {"a", "b", "b"}; + std::vector rhs = {5, 4, 8}; + + std::vector> result = zip(lhs, rhs); + std::vector> correct = {{"a", 5}, {"b", 4}, {"b", 8}}; + + CHECK(result == correct); + } + + SUBCASE("input lengths are not the same") { + std::vector lhs = {"a", "b", "b"}; + std::vector rhs = {5, 4}; + + CHECK_THROWS(zip(lhs, rhs)); + } + } +} diff --git a/lib/utils/test/src/utils/nonnegative_int/nonnegative_int.cc b/lib/utils/test/src/utils/nonnegative_int/nonnegative_int.cc index a7b813d1f1..a442ea9a0c 100644 --- a/lib/utils/test/src/utils/nonnegative_int/nonnegative_int.cc +++ b/lib/utils/test/src/utils/nonnegative_int/nonnegative_int.cc @@ -239,6 +239,13 @@ TEST_SUITE(FF_TEST_SUITE) { } } + TEST_CASE("_n suffix") { + nonnegative_int result = 5_n; + nonnegative_int correct = nonnegative_int{5}; + + CHECK(result == correct); + } + TEST_CASE("nonnegative int >> operator") { nonnegative_int nn_int_1 = nonnegative_int{1}; std::ostringstream oss; diff --git a/lib/utils/test/src/utils/nonnegative_int/range.cc b/lib/utils/test/src/utils/nonnegative_int/range.cc new file mode 100644 index 0000000000..16ec101b07 --- /dev/null +++ b/lib/utils/test/src/utils/nonnegative_int/range.cc @@ -0,0 +1,70 @@ +#include "utils/nonnegative_int/range.h" +#include +#include "test/utils/doctest/fmt/vector.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("range(nonnegative_int, nonnegative_int, int)") { + SUBCASE("step = 1") { + nonnegative_int start = nonnegative_int{3}; + nonnegative_int end = nonnegative_int{5}; + + std::vector result = range(start, end); + std::vector correct = { + nonnegative_int{3}, + nonnegative_int{4}, + }; + + CHECK(result == correct); + } + + SUBCASE("step = -1") { + nonnegative_int start = nonnegative_int{7}; + nonnegative_int end = nonnegative_int{4}; + + std::vector result = range(start, end, -1); + std::vector correct = { + nonnegative_int{7}, + nonnegative_int{6}, + nonnegative_int{5}, + }; + + CHECK(result == correct); + } + + SUBCASE("step = 0") { + SUBCASE("output is nonempty") { + CHECK_THROWS(range(nonnegative_int{2}, nonnegative_int{5}, 0)); + } + + SUBCASE("output is empty") { + CHECK_THROWS(range(nonnegative_int{2}, nonnegative_int{2}, 0)); + } + } + } + + TEST_CASE("range(nonnegative_int)") { + SUBCASE("end is zero") { + nonnegative_int end = nonnegative_int{0}; + + std::vector result = range(end); + std::vector correct = {}; + + CHECK(result == correct); + } + + SUBCASE("end is nonzero") { + nonnegative_int end = nonnegative_int{3}; + + std::vector result = range(end); + std::vector correct = { + nonnegative_int{0}, + nonnegative_int{1}, + nonnegative_int{2}, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/orthotope/dim_coord.cc b/lib/utils/test/src/utils/orthotope/dim_coord.cc new file mode 100644 index 0000000000..9b18301b84 --- /dev/null +++ b/lib/utils/test/src/utils/orthotope/dim_coord.cc @@ -0,0 +1,25 @@ +#include "utils/orthotope/dim_coord.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("flatten_coord") { + DimCoord coord = DimCoord{{ + {3, 4_n}, + {7, 0_n}, + {1, 1_n}, + }}; + + Orthotope domain = Orthotope{{ + {3, 5_n}, + {7, 2_n}, + {1, 3_n}, + }}; + + nonnegative_int result = flatten_coord(coord, domain); + nonnegative_int correct = nonnegative_int{1 * 2 * 5 + 4 * 2 + 0}; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/orthotope/orthotope.cc b/lib/utils/test/src/utils/orthotope/orthotope.cc index c4cdb91f56..0dd4334814 100644 --- a/lib/utils/test/src/utils/orthotope/orthotope.cc +++ b/lib/utils/test/src/utils/orthotope/orthotope.cc @@ -1,144 +1,144 @@ -#include "utils/orthotope/orthotope.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("orthotope_contains_coord") { - Orthotope orthotope = Orthotope{ - {3, 1}, - }; - - SUBCASE("returns true if coord is in orthotope bounds") { - SUBCASE("smallest allowed coord") { - OrthotopeCoordinate coord = OrthotopeCoordinate{ - {0, 0}, - }; - - bool result = orthotope_contains_coord(orthotope, coord); - bool correct = true; - - CHECK(result == correct); - } - - SUBCASE("largest allowed coord") { - OrthotopeCoordinate coord = OrthotopeCoordinate{ - {2, 0}, - }; - - bool result = orthotope_contains_coord(orthotope, coord); - bool correct = true; - - CHECK(result == correct); - } - } - - SUBCASE("returns false if coord is out of orthotope bounds") { - SUBCASE("too low") { - // exhaustively check all dims because we can - SUBCASE("dim 0") { - OrthotopeCoordinate coord = OrthotopeCoordinate{ - {-1, 0}, - }; - - bool result = orthotope_contains_coord(orthotope, coord); - bool correct = false; - - CHECK(result == correct); - } - - SUBCASE("dim 1") { - OrthotopeCoordinate coord = OrthotopeCoordinate{ - {1, -1}, - }; - - bool result = orthotope_contains_coord(orthotope, coord); - bool correct = false; - - CHECK(result == correct); - } - } - - SUBCASE("too high") { - // exhaustively check all dims because we can - SUBCASE("dim 0") { - OrthotopeCoordinate coord = OrthotopeCoordinate{ - {3, 0}, - }; - - bool result = orthotope_contains_coord(orthotope, coord); - bool correct = false; - - CHECK(result == correct); - } - - SUBCASE("dim 1") { - OrthotopeCoordinate coord = OrthotopeCoordinate{ - {1, 1}, - }; - - bool result = orthotope_contains_coord(orthotope, coord); - bool correct = false; - - CHECK(result == correct); - } - } - } - - SUBCASE("throws if num dims of coord does not match num dims of the orthotope") { - OrthotopeCoordinate coord = OrthotopeCoordinate{ - {0, 0, 0}, - }; - - CHECK_THROWS(orthotope_contains_coord(orthotope, coord)); - } - - SUBCASE("works if the orthotope is zero-dimensional") { - Orthotope orthotope = Orthotope{{}}; - OrthotopeCoordinate coord = OrthotopeCoordinate{{}}; - - bool result = orthotope_contains_coord(orthotope, coord); - bool correct = true; - - CHECK(result == correct); - } - } - - TEST_CASE("orthotope_get_volume") { - SUBCASE("1d orthotope volume is just dim size") { - Orthotope input = Orthotope{{8}}; - - int result = orthotope_get_volume(input); - int correct = 8; - - CHECK(result == correct); - } - - SUBCASE("multi-dimensional orthotope") { - Orthotope input = Orthotope{{3, 5, 1, 2}}; - - int result = orthotope_get_volume(input); - int correct = 30; - - CHECK(result == correct); - } - - SUBCASE("any dim size being zero makes the volume zero") { - Orthotope input = Orthotope{{3, 5, 0, 2}}; - - int result = orthotope_get_volume(input); - int correct = 0; - - CHECK(result == correct); - } - - SUBCASE("zero-dimensional orthotope has volume 1") { - Orthotope input = Orthotope{{}}; - - int result = orthotope_get_volume(input); - int correct = 1; - - CHECK(result == correct); - } - } -} +// #include "utils/orthotope/orthotope.h" +// #include +// +// using namespace ::FlexFlow; +// +// TEST_SUITE(FF_TEST_SUITE) { +// TEST_CASE("orthotope_contains_coord") { +// Orthotope orthotope = Orthotope{ +// {3, 1}, +// }; +// +// SUBCASE("returns true if coord is in orthotope bounds") { +// SUBCASE("smallest allowed coord") { +// OrthotopeCoordinate coord = OrthotopeCoordinate{ +// {0, 0}, +// }; +// +// bool result = orthotope_contains_coord(orthotope, coord); +// bool correct = true; +// +// CHECK(result == correct); +// } +// +// SUBCASE("largest allowed coord") { +// OrthotopeCoordinate coord = OrthotopeCoordinate{ +// {2, 0}, +// }; +// +// bool result = orthotope_contains_coord(orthotope, coord); +// bool correct = true; +// +// CHECK(result == correct); +// } +// } +// +// SUBCASE("returns false if coord is out of orthotope bounds") { +// SUBCASE("too low") { +// // exhaustively check all dims because we can +// SUBCASE("dim 0") { +// OrthotopeCoordinate coord = OrthotopeCoordinate{ +// {-1, 0}, +// }; +// +// bool result = orthotope_contains_coord(orthotope, coord); +// bool correct = false; +// +// CHECK(result == correct); +// } +// +// SUBCASE("dim 1") { +// OrthotopeCoordinate coord = OrthotopeCoordinate{ +// {1, -1}, +// }; +// +// bool result = orthotope_contains_coord(orthotope, coord); +// bool correct = false; +// +// CHECK(result == correct); +// } +// } +// +// SUBCASE("too high") { +// // exhaustively check all dims because we can +// SUBCASE("dim 0") { +// OrthotopeCoordinate coord = OrthotopeCoordinate{ +// {3, 0}, +// }; +// +// bool result = orthotope_contains_coord(orthotope, coord); +// bool correct = false; +// +// CHECK(result == correct); +// } +// +// SUBCASE("dim 1") { +// OrthotopeCoordinate coord = OrthotopeCoordinate{ +// {1, 1}, +// }; +// +// bool result = orthotope_contains_coord(orthotope, coord); +// bool correct = false; +// +// CHECK(result == correct); +// } +// } +// } +// +// SUBCASE("throws if num dims of coord does not match num dims of the orthotope") { +// OrthotopeCoordinate coord = OrthotopeCoordinate{ +// {0, 0, 0}, +// }; +// +// CHECK_THROWS(orthotope_contains_coord(orthotope, coord)); +// } +// +// SUBCASE("works if the orthotope is zero-dimensional") { +// Orthotope orthotope = Orthotope{{}}; +// OrthotopeCoordinate coord = OrthotopeCoordinate{{}}; +// +// bool result = orthotope_contains_coord(orthotope, coord); +// bool correct = true; +// +// CHECK(result == correct); +// } +// } +// +// TEST_CASE("orthotope_get_volume") { +// SUBCASE("1d orthotope volume is just dim size") { +// Orthotope input = Orthotope{{8}}; +// +// int result = orthotope_get_volume(input); +// int correct = 8; +// +// CHECK(result == correct); +// } +// +// SUBCASE("multi-dimensional orthotope") { +// Orthotope input = Orthotope{{3, 5, 1, 2}}; +// +// int result = orthotope_get_volume(input); +// int correct = 30; +// +// CHECK(result == correct); +// } +// +// SUBCASE("any dim size being zero makes the volume zero") { +// Orthotope input = Orthotope{{3, 5, 0, 2}}; +// +// int result = orthotope_get_volume(input); +// int correct = 0; +// +// CHECK(result == correct); +// } +// +// SUBCASE("zero-dimensional orthotope has volume 1") { +// Orthotope input = Orthotope{{}}; +// +// int result = orthotope_get_volume(input); +// int correct = 1; +// +// CHECK(result == correct); +// } +// } +// } diff --git a/lib/utils/test/src/utils/orthotope/orthotope_bijective_projection.cc b/lib/utils/test/src/utils/orthotope/orthotope_bijective_projection.cc index 7fff0a709f..4cc4cae126 100644 --- a/lib/utils/test/src/utils/orthotope/orthotope_bijective_projection.cc +++ b/lib/utils/test/src/utils/orthotope/orthotope_bijective_projection.cc @@ -1,281 +1,281 @@ -#include "utils/orthotope/orthotope_bijective_projection.h" -#include "utils/containers/zip.h" -#include -#include "test/utils/doctest/fmt/unordered_set.h" - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("operator==(OrthotopeBijectiveProjection, OrthotopeBijectiveProjection)") { - SUBCASE("if src num dims and dst num dims are the same, projections are equivalent") { - orthotope_dim_idx_t src0 = orthotope_dim_idx_t{0}; - orthotope_dim_idx_t src1 = orthotope_dim_idx_t{1}; - - orthotope_dim_idx_t dst0 = orthotope_dim_idx_t{0}; - orthotope_dim_idx_t dst1 = orthotope_dim_idx_t{1}; - - OrthotopeBijectiveProjection p = make_orthotope_projection_from_map( - { - {src0, dst0}, - {src1, dst1}, - }, - /*reversed=*/false); - - CHECK(p == reverse_projection(p)); - } - } - - TEST_CASE("get_all_bijective_projections_between") { - SUBCASE("dst num dims greater than src num dims") { - Orthotope src = Orthotope{{6, 4}}; - orthotope_dim_idx_t src0 = orthotope_dim_idx_t{0}; - orthotope_dim_idx_t src1 = orthotope_dim_idx_t{1}; - - Orthotope dst = Orthotope{{3, 4, 2}}; - orthotope_dim_idx_t dst0 = orthotope_dim_idx_t{0}; - orthotope_dim_idx_t dst1 = orthotope_dim_idx_t{1}; - orthotope_dim_idx_t dst2 = orthotope_dim_idx_t{2}; - - std::unordered_set result = get_all_bijective_projections_between(src, dst); - std::unordered_set correct = { - make_orthotope_projection_from_map({ - {dst0, src0}, - {dst1, src1}, - {dst2, src0}, - }, /*reversed=*/true), - }; - - CHECK(result == correct); - } - - SUBCASE("src num dims greater than dst num dims") { - Orthotope src = Orthotope{{3, 4, 2}}; - orthotope_dim_idx_t src0 = orthotope_dim_idx_t{0}; - orthotope_dim_idx_t src1 = orthotope_dim_idx_t{1}; - orthotope_dim_idx_t src2 = orthotope_dim_idx_t{2}; - - Orthotope dst = Orthotope{{6, 4}}; - orthotope_dim_idx_t dst0 = orthotope_dim_idx_t{0}; - orthotope_dim_idx_t dst1 = orthotope_dim_idx_t{1}; - - std::unordered_set result = get_all_bijective_projections_between(src, dst); - std::unordered_set correct = { - make_orthotope_projection_from_map({ - {src0, dst0}, - {src1, dst1}, - {src2, dst0}, - }, /*reversed=*/false), - }; - - CHECK(result == correct); - } - - SUBCASE("multiple possible mappings") { - Orthotope src = Orthotope{{3, 3}}; - orthotope_dim_idx_t src0 = orthotope_dim_idx_t{0}; - orthotope_dim_idx_t src1 = orthotope_dim_idx_t{1}; - - Orthotope dst = Orthotope{{3, 3}}; - orthotope_dim_idx_t dst0 = orthotope_dim_idx_t{0}; - orthotope_dim_idx_t dst1 = orthotope_dim_idx_t{1}; - - std::unordered_set result = get_all_bijective_projections_between(src, dst); - std::unordered_set correct = { - make_orthotope_projection_from_map({ - {src0, dst0}, - {src1, dst1}, - }, /*reversed=*/false), - make_orthotope_projection_from_map({ - {src0, dst1}, - {src1, dst0}, - }, /*reversed=*/false), - }; - - CHECK(result == correct); - } - - SUBCASE("no possible mappings") { - Orthotope src = Orthotope{{4, 3}}; - Orthotope dst = Orthotope{{6, 2}}; - - std::unordered_set result = get_all_bijective_projections_between(src, dst); - std::unordered_set correct = {}; - - CHECK(result == correct); - } - } - - TEST_CASE("project_into_1d") { - SUBCASE("to 1d from 1d is identity") { - OrthotopeCoordinate coord = OrthotopeCoordinate{{2}}; - Orthotope orthotope = Orthotope{{5}}; - - int result = project_into_1d(orthotope, coord); - int correct = 2; - - CHECK(result == correct); - } - - SUBCASE("basic example") { - OrthotopeCoordinate coord = OrthotopeCoordinate{{4, 1}}; - Orthotope orthotope = Orthotope{{5, 3}}; - - int result = project_into_1d(orthotope, coord); - int correct = 4 * 3 + 1; - - CHECK(result == correct); - } - - SUBCASE("order matters") { - OrthotopeCoordinate coord = OrthotopeCoordinate{{1, 4}}; - Orthotope orthotope = Orthotope{{3, 5}}; - - int result = project_into_1d(orthotope, coord); - int correct = 1 * 5 + 4; - - CHECK(result == correct); - } - - SUBCASE("throws if coord is outside of orthotope") { - OrthotopeCoordinate coord = OrthotopeCoordinate{ - {2, 3, 1}, - }; - - Orthotope orthotope = Orthotope{ - {5, 3, 2}, - }; - - CHECK_THROWS(project_into_1d(orthotope, coord)); - } - - SUBCASE("throws if coord does not have same dimension as orthotope") { - OrthotopeCoordinate coord = OrthotopeCoordinate{ - {2, 3, 1}, - }; - - Orthotope orthotope = Orthotope{ - {5, 3}, - }; - - CHECK_THROWS(project_into_1d(orthotope, coord)); - } - - SUBCASE("returns 0 if orthotope is 0-dimensional") { - OrthotopeCoordinate coord = OrthotopeCoordinate{{}}; - Orthotope orthotope = Orthotope{{}}; - - int result = project_into_1d(orthotope, coord); - int correct = 0; - - CHECK(result == correct); - } - } - - TEST_CASE("project_out_of_1d") { - SUBCASE("from 1d to 1d is identity") { - Orthotope orthotope = Orthotope{{5}}; - int coord = 2; - - OrthotopeCoordinate result = project_out_of_1d(coord, orthotope); - OrthotopeCoordinate correct = OrthotopeCoordinate{{2}}; - - CHECK(result == correct); - } - - SUBCASE("basic example") { - Orthotope orthotope = Orthotope{{5, 3}}; - int coord = 4 * 3 + 1; - - OrthotopeCoordinate result = project_out_of_1d(coord, orthotope); - OrthotopeCoordinate correct = OrthotopeCoordinate{{4, 1}}; - - CHECK(result == correct); - } - - SUBCASE("orthotope dimension order matters") { - Orthotope orthotope = Orthotope{{3, 5}}; - int coord = 1 * 5 + 4; - - OrthotopeCoordinate result = project_out_of_1d(coord, orthotope); - OrthotopeCoordinate correct = OrthotopeCoordinate{{1, 4}}; - - CHECK(result == correct); - } - - SUBCASE("throws if coord would be projected outside of orthotope") { - Orthotope orthotope = Orthotope{{5, 3}}; - - SUBCASE("smallest coord outside of orthotope") { - int coord = 15; - - CHECK_THROWS(project_out_of_1d(coord, orthotope)); - } - - SUBCASE("largest coord inside of orthotope") { - int coord = 14; - - OrthotopeCoordinate result = project_out_of_1d(coord, orthotope); - OrthotopeCoordinate correct = OrthotopeCoordinate{{4, 2}}; - - CHECK(result == correct); - } - } - - SUBCASE("if dst orthotope is 0-dimensional") { - Orthotope orthotope = Orthotope{{}}; - - SUBCASE("returns 0-d coord if input coord is 0") { - int input_coord = 0; - - OrthotopeCoordinate result = project_out_of_1d(input_coord, orthotope); - OrthotopeCoordinate correct = OrthotopeCoordinate{{}}; - - CHECK(result == correct); - } - - SUBCASE("throws if input coord is anything other than zero") { - int input_coord = 1; - - CHECK_THROWS(project_out_of_1d(input_coord, orthotope)); - } - } - } - - TEST_CASE("project_coordinate_through") { - Orthotope src = Orthotope{ - {2, 3}, - }; - - Orthotope dst = Orthotope{ - {6}, - }; - - OrthotopeBijectiveProjection proj = OrthotopeBijectiveProjection{ - {orthotope_dim_idx_t{0}, orthotope_dim_idx_t{0}}, - /*reversed=*/false, - }; - - OrthotopeCoordinate src_coord = OrthotopeCoordinate{ - {1, 2}, - }; - OrthotopeCoordinate dst_coord = OrthotopeCoordinate{ - {1*3+2}, - }; - - SUBCASE("forward") { - OrthotopeCoordinate result = project_coordinate_through(proj, src, src_coord, dst); - OrthotopeCoordinate correct = dst_coord; - - CHECK(result == correct); - } - - SUBCASE("backward") { - OrthotopeBijectiveProjection reversed = reverse_projection(proj); - - OrthotopeCoordinate result = project_coordinate_through(reversed, dst, dst_coord, src); - OrthotopeCoordinate correct = src_coord; - - CHECK(result == correct); - } - } -} +// #include "utils/orthotope/orthotope_bijective_projection.h" +// #include "utils/containers/zip.h" +// #include +// #include "test/utils/doctest/fmt/unordered_set.h" +// +// using namespace ::FlexFlow; +// +// TEST_SUITE(FF_TEST_SUITE) { +// TEST_CASE("operator==(OrthotopeBijectiveProjection, OrthotopeBijectiveProjection)") { +// SUBCASE("if src num dims and dst num dims are the same, projections are equivalent") { +// orthotope_dim_idx_t src0 = orthotope_dim_idx_t{0}; +// orthotope_dim_idx_t src1 = orthotope_dim_idx_t{1}; +// +// orthotope_dim_idx_t dst0 = orthotope_dim_idx_t{0}; +// orthotope_dim_idx_t dst1 = orthotope_dim_idx_t{1}; +// +// OrthotopeBijectiveProjection p = make_orthotope_projection_from_map( +// { +// {src0, dst0}, +// {src1, dst1}, +// }, +// /*reversed=*/false); +// +// CHECK(p == reverse_projection(p)); +// } +// } +// +// TEST_CASE("get_all_bijective_projections_between") { +// SUBCASE("dst num dims greater than src num dims") { +// Orthotope src = Orthotope{{6, 4}}; +// orthotope_dim_idx_t src0 = orthotope_dim_idx_t{0}; +// orthotope_dim_idx_t src1 = orthotope_dim_idx_t{1}; +// +// Orthotope dst = Orthotope{{3, 4, 2}}; +// orthotope_dim_idx_t dst0 = orthotope_dim_idx_t{0}; +// orthotope_dim_idx_t dst1 = orthotope_dim_idx_t{1}; +// orthotope_dim_idx_t dst2 = orthotope_dim_idx_t{2}; +// +// std::unordered_set result = get_all_bijective_projections_between(src, dst); +// std::unordered_set correct = { +// make_orthotope_projection_from_map({ +// {dst0, src0}, +// {dst1, src1}, +// {dst2, src0}, +// }, /*reversed=*/true), +// }; +// +// CHECK(result == correct); +// } +// +// SUBCASE("src num dims greater than dst num dims") { +// Orthotope src = Orthotope{{3, 4, 2}}; +// orthotope_dim_idx_t src0 = orthotope_dim_idx_t{0}; +// orthotope_dim_idx_t src1 = orthotope_dim_idx_t{1}; +// orthotope_dim_idx_t src2 = orthotope_dim_idx_t{2}; +// +// Orthotope dst = Orthotope{{6, 4}}; +// orthotope_dim_idx_t dst0 = orthotope_dim_idx_t{0}; +// orthotope_dim_idx_t dst1 = orthotope_dim_idx_t{1}; +// +// std::unordered_set result = get_all_bijective_projections_between(src, dst); +// std::unordered_set correct = { +// make_orthotope_projection_from_map({ +// {src0, dst0}, +// {src1, dst1}, +// {src2, dst0}, +// }, /*reversed=*/false), +// }; +// +// CHECK(result == correct); +// } +// +// SUBCASE("multiple possible mappings") { +// Orthotope src = Orthotope{{3, 3}}; +// orthotope_dim_idx_t src0 = orthotope_dim_idx_t{0}; +// orthotope_dim_idx_t src1 = orthotope_dim_idx_t{1}; +// +// Orthotope dst = Orthotope{{3, 3}}; +// orthotope_dim_idx_t dst0 = orthotope_dim_idx_t{0}; +// orthotope_dim_idx_t dst1 = orthotope_dim_idx_t{1}; +// +// std::unordered_set result = get_all_bijective_projections_between(src, dst); +// std::unordered_set correct = { +// make_orthotope_projection_from_map({ +// {src0, dst0}, +// {src1, dst1}, +// }, /*reversed=*/false), +// make_orthotope_projection_from_map({ +// {src0, dst1}, +// {src1, dst0}, +// }, /*reversed=*/false), +// }; +// +// CHECK(result == correct); +// } +// +// SUBCASE("no possible mappings") { +// Orthotope src = Orthotope{{4, 3}}; +// Orthotope dst = Orthotope{{6, 2}}; +// +// std::unordered_set result = get_all_bijective_projections_between(src, dst); +// std::unordered_set correct = {}; +// +// CHECK(result == correct); +// } +// } +// +// TEST_CASE("project_into_1d") { +// SUBCASE("to 1d from 1d is identity") { +// OrthotopeCoordinate coord = OrthotopeCoordinate{{2}}; +// Orthotope orthotope = Orthotope{{5}}; +// +// int result = project_into_1d(orthotope, coord); +// int correct = 2; +// +// CHECK(result == correct); +// } +// +// SUBCASE("basic example") { +// OrthotopeCoordinate coord = OrthotopeCoordinate{{4, 1}}; +// Orthotope orthotope = Orthotope{{5, 3}}; +// +// int result = project_into_1d(orthotope, coord); +// int correct = 4 * 3 + 1; +// +// CHECK(result == correct); +// } +// +// SUBCASE("order matters") { +// OrthotopeCoordinate coord = OrthotopeCoordinate{{1, 4}}; +// Orthotope orthotope = Orthotope{{3, 5}}; +// +// int result = project_into_1d(orthotope, coord); +// int correct = 1 * 5 + 4; +// +// CHECK(result == correct); +// } +// +// SUBCASE("throws if coord is outside of orthotope") { +// OrthotopeCoordinate coord = OrthotopeCoordinate{ +// {2, 3, 1}, +// }; +// +// Orthotope orthotope = Orthotope{ +// {5, 3, 2}, +// }; +// +// CHECK_THROWS(project_into_1d(orthotope, coord)); +// } +// +// SUBCASE("throws if coord does not have same dimension as orthotope") { +// OrthotopeCoordinate coord = OrthotopeCoordinate{ +// {2, 3, 1}, +// }; +// +// Orthotope orthotope = Orthotope{ +// {5, 3}, +// }; +// +// CHECK_THROWS(project_into_1d(orthotope, coord)); +// } +// +// SUBCASE("returns 0 if orthotope is 0-dimensional") { +// OrthotopeCoordinate coord = OrthotopeCoordinate{{}}; +// Orthotope orthotope = Orthotope{{}}; +// +// int result = project_into_1d(orthotope, coord); +// int correct = 0; +// +// CHECK(result == correct); +// } +// } +// +// TEST_CASE("project_out_of_1d") { +// SUBCASE("from 1d to 1d is identity") { +// Orthotope orthotope = Orthotope{{5}}; +// int coord = 2; +// +// OrthotopeCoordinate result = project_out_of_1d(coord, orthotope); +// OrthotopeCoordinate correct = OrthotopeCoordinate{{2}}; +// +// CHECK(result == correct); +// } +// +// SUBCASE("basic example") { +// Orthotope orthotope = Orthotope{{5, 3}}; +// int coord = 4 * 3 + 1; +// +// OrthotopeCoordinate result = project_out_of_1d(coord, orthotope); +// OrthotopeCoordinate correct = OrthotopeCoordinate{{4, 1}}; +// +// CHECK(result == correct); +// } +// +// SUBCASE("orthotope dimension order matters") { +// Orthotope orthotope = Orthotope{{3, 5}}; +// int coord = 1 * 5 + 4; +// +// OrthotopeCoordinate result = project_out_of_1d(coord, orthotope); +// OrthotopeCoordinate correct = OrthotopeCoordinate{{1, 4}}; +// +// CHECK(result == correct); +// } +// +// SUBCASE("throws if coord would be projected outside of orthotope") { +// Orthotope orthotope = Orthotope{{5, 3}}; +// +// SUBCASE("smallest coord outside of orthotope") { +// int coord = 15; +// +// CHECK_THROWS(project_out_of_1d(coord, orthotope)); +// } +// +// SUBCASE("largest coord inside of orthotope") { +// int coord = 14; +// +// OrthotopeCoordinate result = project_out_of_1d(coord, orthotope); +// OrthotopeCoordinate correct = OrthotopeCoordinate{{4, 2}}; +// +// CHECK(result == correct); +// } +// } +// +// SUBCASE("if dst orthotope is 0-dimensional") { +// Orthotope orthotope = Orthotope{{}}; +// +// SUBCASE("returns 0-d coord if input coord is 0") { +// int input_coord = 0; +// +// OrthotopeCoordinate result = project_out_of_1d(input_coord, orthotope); +// OrthotopeCoordinate correct = OrthotopeCoordinate{{}}; +// +// CHECK(result == correct); +// } +// +// SUBCASE("throws if input coord is anything other than zero") { +// int input_coord = 1; +// +// CHECK_THROWS(project_out_of_1d(input_coord, orthotope)); +// } +// } +// } +// +// TEST_CASE("project_coordinate_through") { +// Orthotope src = Orthotope{ +// {2, 3}, +// }; +// +// Orthotope dst = Orthotope{ +// {6}, +// }; +// +// OrthotopeBijectiveProjection proj = OrthotopeBijectiveProjection{ +// {orthotope_dim_idx_t{0}, orthotope_dim_idx_t{0}}, +// /*reversed=*/false, +// }; +// +// OrthotopeCoordinate src_coord = OrthotopeCoordinate{ +// {1, 2}, +// }; +// OrthotopeCoordinate dst_coord = OrthotopeCoordinate{ +// {1*3+2}, +// }; +// +// SUBCASE("forward") { +// OrthotopeCoordinate result = project_coordinate_through(proj, src, src_coord, dst); +// OrthotopeCoordinate correct = dst_coord; +// +// CHECK(result == correct); +// } +// +// SUBCASE("backward") { +// OrthotopeBijectiveProjection reversed = reverse_projection(proj); +// +// OrthotopeCoordinate result = project_coordinate_through(reversed, dst, dst_coord, src); +// OrthotopeCoordinate correct = src_coord; +// +// CHECK(result == correct); +// } +// } +// } From b081f11e1a30e0fed5e53d5f77e10c0f339b07b1 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sat, 26 Jul 2025 23:34:34 -0700 Subject: [PATCH 10/62] Fix build errors --- CMakeLists.txt | 1 - cmake/visit_struct.cmake | 16 - .../machine_mapping/machine_mapping.h | 5 +- .../cost_estimator/network_cost_model.cc | 1 + .../cost_estimator/op_cost_estimate_key.cc | 1 - .../machine_mapping/machine_mapping_cache.cc | 10 +- .../machine_mapping_with_memory_cache.cc | 10 +- .../unstructured_device_mapping.cc | 5 +- .../task_graph_simulator/pcg_task_graph.cc | 2 +- .../test/src/allowed_machine_views.cc | 2 +- .../machine_mapping/machine_mapping.cc | 2 +- lib/compiler/test/src/graph_optimize_state.cc | 2 +- lib/compiler/test/src/unity_algorithm.cc | 2 +- .../include/kernels/batch_matmul_kernels.h | 34 +- .../kernels/batch_matmul_kernels_cpu.h | 32 +- .../kernels/batch_matmul_kernels_gpu.h | 16 +- lib/kernels/include/kernels/perf_metrics.h | 40 +- .../include/kernels/perf_metrics.struct.toml | 59 ++ .../src/cuda/ops/element_binary_kernels.cu | 1 + lib/kernels/src/cuda/ops/pool_2d_kernels.cu | 1 + lib/kernels/src/cuda/ops/softmax_kernels.cu | 1 + lib/kernels/src/kernels/accessor.cc | 2 +- .../src/kernels/batch_matmul_kernels.cc | 143 +++-- .../src/kernels/batch_matmul_kernels_cpu.cc | 32 +- .../src/kernels/element_binary_kernels.cc | 1 + lib/kernels/src/kernels/pool_2d_kernels.cc | 1 + lib/kernels/src/kernels/reduce_kernels.cc | 1 + lib/kernels/src/perf_metrics.cc | 16 - .../local-execution/local_cost_estimator.cc | 2 +- .../local_task_argument_accessor.cc | 2 +- .../local-execution/local_task_registry.cc | 8 +- .../src/local-execution/loss_functions.cc | 2 +- .../op-attrs/dim_ordered/dim_ordered.h | 200 ------- .../include/op-attrs/dim_ordered/slice.h | 32 -- .../include/op-attrs/dim_ordered/transform.h | 20 - .../include/op-attrs/dim_ordered/zip.h | 19 - lib/op-attrs/include/op-attrs/ff_dim_t.h | 2 + .../include/op-attrs/ff_ordered/get_idxs.h | 7 +- .../include/op-attrs/operator_attrs.h | 5 +- ...ator_space_parallel_tensor_space_mapping.h | 16 - ...r_space_to_parallel_tensor_space_mapping.h | 14 + ...parallel_tensor_space_mapping.struct.toml} | 2 +- .../include/op-attrs/operator_task_space.h | 14 +- lib/op-attrs/include/op-attrs/ops/attention.h | 6 +- .../include/op-attrs/ops/batch_matmul.h | 3 - .../ops/batch_matmul_attrs.struct.toml | 6 +- .../include/op-attrs/ops/batch_norm.h | 4 +- lib/op-attrs/include/op-attrs/ops/broadcast.h | 4 +- lib/op-attrs/include/op-attrs/ops/cast.h | 7 +- lib/op-attrs/include/op-attrs/ops/combine.h | 3 - lib/op-attrs/include/op-attrs/ops/concat.h | 4 +- lib/op-attrs/include/op-attrs/ops/conv_2d.h | 7 +- lib/op-attrs/include/op-attrs/ops/core.h | 15 - lib/op-attrs/include/op-attrs/ops/dropout.h | 8 +- .../include/op-attrs/ops/element_binary.h | 12 +- .../include/op-attrs/ops/element_unary.h | 7 +- lib/op-attrs/include/op-attrs/ops/embedding.h | 7 +- lib/op-attrs/include/op-attrs/ops/flat.h | 11 +- lib/op-attrs/include/op-attrs/ops/gather.h | 7 +- lib/op-attrs/include/op-attrs/ops/input.h | 3 - .../include/op-attrs/ops/layer_norm.h | 8 +- lib/op-attrs/include/op-attrs/ops/linear.h | 13 +- .../include/op-attrs/ops/loss_functions.h | 1 - lib/op-attrs/include/op-attrs/ops/noop.h | 7 +- lib/op-attrs/include/op-attrs/ops/pool_2d.h | 8 +- lib/op-attrs/include/op-attrs/ops/reduce.h | 7 +- lib/op-attrs/include/op-attrs/ops/reduction.h | 7 +- .../include/op-attrs/ops/repartition.h | 7 +- lib/op-attrs/include/op-attrs/ops/replicate.h | 7 +- lib/op-attrs/include/op-attrs/ops/reshape.h | 7 +- lib/op-attrs/include/op-attrs/ops/reverse.h | 3 - lib/op-attrs/include/op-attrs/ops/softmax.h | 8 +- lib/op-attrs/include/op-attrs/ops/split.h | 7 +- lib/op-attrs/include/op-attrs/ops/topk.h | 7 +- lib/op-attrs/include/op-attrs/ops/transpose.h | 7 +- lib/op-attrs/include/op-attrs/ops/weight.h | 3 - .../op-attrs/parallel_tensor_dim_degrees.h | 2 +- .../include/op-attrs/parallel_tensor_shape.h | 1 - .../parallel_tensor_space_coordinate.h | 2 +- ...rallel_tensor_space_coordinate.struct.toml | 9 +- .../include/op-attrs/task_space_coordinate.h | 12 + .../task_space_coordinate.struct.toml | 4 +- .../src/op-attrs/dim_ordered/dim_ordered.cc | 1 - .../src/op-attrs/dim_ordered/slice.cc | 1 - .../src/op-attrs/dim_ordered/transform.cc | 1 - lib/op-attrs/src/op-attrs/dim_ordered/zip.cc | 1 - lib/op-attrs/src/op-attrs/ff_dim_t.cc | 7 + .../src/op-attrs/ff_ordered/get_idxs.cc | 2 +- ...space_to_parallel_tensor_space_mapping.cc} | 6 +- .../src/op-attrs/operator_task_space.cc | 22 +- lib/op-attrs/src/op-attrs/ops/attention.cc | 7 +- lib/op-attrs/src/op-attrs/ops/batch_matmul.cc | 1 + lib/op-attrs/src/op-attrs/ops/broadcast.cc | 1 + lib/op-attrs/src/op-attrs/ops/conv_2d.cc | 7 +- .../ops/conv_2d/conv_2d_input_shape.cc | 1 + .../src/op-attrs/ops/element_binary.cc | 32 +- lib/op-attrs/src/op-attrs/ops/flat.cc | 24 +- lib/op-attrs/src/op-attrs/ops/gather.cc | 1 + lib/op-attrs/src/op-attrs/ops/layer_norm.cc | 5 +- lib/op-attrs/src/op-attrs/ops/linear.cc | 11 +- lib/op-attrs/src/op-attrs/ops/reduce.cc | 1 + lib/op-attrs/src/op-attrs/ops/reshape.cc | 1 + lib/op-attrs/src/op-attrs/ops/reverse.cc | 1 + lib/op-attrs/src/op-attrs/ops/split.cc | 1 + lib/op-attrs/src/op-attrs/ops/topk.cc | 1 + lib/op-attrs/src/op-attrs/ops/transpose.cc | 1 + .../src/op-attrs/parallel_op_attrs.cc | 1 + .../op-attrs/parallel_tensor_dim_degrees.cc | 19 +- .../src/op-attrs/parallel_tensor_shape.cc | 16 +- .../parallel_tensor_space_coordinate.cc | 6 +- lib/op-attrs/src/op-attrs/shape_inference.cc | 7 +- .../src/op-attrs/task_space_coordinate.cc | 9 + .../src/parallel_dim_mapping_record.cc | 60 -- .../src/parallel_dim_mapping_record.h | 54 -- .../src/op-attrs/dim_ordered/dim_ordered.cc | 13 - .../test/src/op-attrs/dim_ordered/zip.cc | 41 -- ...space_to_parallel_tensor_space_mapping.cc} | 7 +- .../test/src/op-attrs/operator_task_space.cc | 24 +- .../test/src/op-attrs/ops/batch_matmul.cc | 8 +- lib/op-attrs/test/src/op-attrs/ops/conv_2d.cc | 2 +- lib/op-attrs/test/src/op-attrs/ops/flat.cc | 18 +- .../op-attrs/parallel_tensor_dim_degrees.cc | 98 ++-- lib/pcg/include/pcg/model_compilation.h | 26 - .../parallel_computation_graph.h | 4 + .../pcg/cg_operator_tensor_shape_signature.cc | 1 + .../file_format/v1/v1_computation_graph.cc | 3 +- lib/pcg/src/pcg/machine_view.cc | 6 +- .../generate_weight_transform.cc | 9 +- .../parallel_computation_graph.cc | 14 + .../test/src/pcg/computation_graph_builder.cc | 2 +- lib/pcg/test/src/pcg/machine_view.cc | 44 +- .../src/pcg/start_invariant_machine_view.cc | 15 +- lib/runtime/src/fused_op_attrs.h | 1 - lib/runtime/src/ops/fused_parallel_op_attrs.h | 5 +- lib/runtime/test/src/main.cc | 2 +- lib/runtime/test/src/test_op_task_spec.cc | 2 +- lib/runtime/test/src/test_serialization.cc | 2 +- .../substitution-generator/legacy_rules.cc | 2 +- .../evaluate_substitution_output.cc | 8 +- .../operator_pattern/eval_list_access.cc | 3 +- .../operator_pattern/satisfies_constraint.cc | 5 +- .../materialize_operator_from_attrs_map.cc | 9 +- .../src/substitutions/pcg_pattern.cc | 5 +- .../src/substitutions/pcg_pattern_match.cc | 3 +- .../src/substitutions/substitution.cc | 1 + .../tensor_pattern/eval_list_access.cc | 5 +- .../tensor_pattern/satisfies_constraint.cc | 5 +- .../unlabelled/unlabelled_graph_pattern.cc | 1 + lib/task-spec/include/task-spec/arg_ref.h | 7 +- lib/task-spec/include/task-spec/config.h | 179 ------ .../include/task-spec/ff_config.struct.toml | 115 ++++ .../task-spec/ff_init_info.struct.toml | 22 + .../task-spec/ff_iteration_config.struct.toml | 18 + .../include/task-spec/op_task_signature.h | 26 +- .../include/task-spec/runtime_arg_ref.h | 3 +- .../include/task-spec/serialization.h | 41 +- .../src/task-spec/op_task_signature.cc | 32 ++ .../src/task-spec/ops/batch_matmul.cc | 80 +-- .../task-spec/training_layer_plus_context.cc | 1 + lib/utils/CMakeLists.txt | 1 - .../include/utils/containers/filter_idxs.h | 2 +- .../utils/containers/generate_vector.h | 2 +- .../include/utils/containers/require_same.h | 15 +- lib/utils/include/utils/containers/slice.h | 2 +- .../include/utils/containers/zip3_with.h | 6 +- .../utils/containers/zip3_with_strict.h | 11 +- lib/utils/include/utils/json/visitable.h | 152 ----- lib/utils/include/utils/orthotope/dim_coord.h | 14 +- .../include/utils/orthotope/dim_domain.h | 6 +- .../utils/orthotope/dim_domain.struct.toml | 4 +- lib/utils/include/utils/orthotope/orthotope.h | 4 +- .../utils/orthotope/orthotope.struct.toml | 4 +- .../include/utils/orthotope/orthotope_coord.h | 2 +- lib/utils/include/utils/sequence.h | 8 - .../include/utils/stack_vector/stack_vector.h | 1 + lib/utils/include/utils/type_traits.h | 8 - lib/utils/include/utils/variant.h | 2 +- lib/utils/include/utils/visitable.h | 517 ------------------ lib/utils/include/utils/visitable_core.h | 120 ---- lib/utils/src/utils/containers/merge_maps.cc | 22 +- lib/utils/src/utils/containers/range.cc | 1 + .../src/utils/containers/require_same.cc | 16 + .../digraph/algorithms/transitive_closure.cc | 5 +- .../algorithms/transitive_reduction.cc | 3 +- lib/utils/src/utils/nonnegative_int/range.cc | 4 +- lib/utils/src/utils/orthotope/dim_coord.cc | 8 +- lib/utils/src/utils/orthotope/dim_domain.cc | 4 +- .../src/utils/orthotope/down_projection.cc | 26 +- lib/utils/src/utils/orthotope/orthotope.cc | 61 ++- .../src/utils/orthotope/orthotope_coord.cc | 2 +- .../utils/doctest/check_without_stringify.h | 2 +- .../common/include/test/utils/rapidcheck.h | 1 - .../include/test/utils/rapidcheck/doctest.h | 2 +- .../include/test/utils/rapidcheck/visitable.h | 58 -- .../test/src/utils/containers/filter_idxs.cc | 2 +- .../test/src/utils/orthotope/dim_coord.cc | 14 +- .../test/src/utils/orthotope/orthotope.cc | 434 ++++++++++----- lib/utils/test/src/utils/variant.cc | 4 +- 198 files changed, 1270 insertions(+), 2542 deletions(-) delete mode 100644 cmake/visit_struct.cmake create mode 100644 lib/kernels/include/kernels/perf_metrics.struct.toml delete mode 100644 lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h delete mode 100644 lib/op-attrs/include/op-attrs/dim_ordered/slice.h delete mode 100644 lib/op-attrs/include/op-attrs/dim_ordered/transform.h delete mode 100644 lib/op-attrs/include/op-attrs/dim_ordered/zip.h delete mode 100644 lib/op-attrs/include/op-attrs/operator_space_parallel_tensor_space_mapping.h create mode 100644 lib/op-attrs/include/op-attrs/operator_space_to_parallel_tensor_space_mapping.h rename lib/op-attrs/include/op-attrs/{operator_space_parallel_tensor_space_mapping.struct.toml => operator_space_to_parallel_tensor_space_mapping.struct.toml} (87%) delete mode 100644 lib/op-attrs/include/op-attrs/ops/core.h create mode 100644 lib/op-attrs/include/op-attrs/task_space_coordinate.h delete mode 100644 lib/op-attrs/src/op-attrs/dim_ordered/dim_ordered.cc delete mode 100644 lib/op-attrs/src/op-attrs/dim_ordered/slice.cc delete mode 100644 lib/op-attrs/src/op-attrs/dim_ordered/transform.cc delete mode 100644 lib/op-attrs/src/op-attrs/dim_ordered/zip.cc rename lib/op-attrs/src/op-attrs/{operator_space_parallel_tensor_space_mapping.cc => operator_space_to_parallel_tensor_space_mapping.cc} (82%) create mode 100644 lib/op-attrs/src/op-attrs/task_space_coordinate.cc delete mode 100644 lib/op-attrs/src/parallel_dim_mapping_record.cc delete mode 100644 lib/op-attrs/src/parallel_dim_mapping_record.h delete mode 100644 lib/op-attrs/test/src/op-attrs/dim_ordered/dim_ordered.cc delete mode 100644 lib/op-attrs/test/src/op-attrs/dim_ordered/zip.cc rename lib/op-attrs/test/src/op-attrs/{operator_space_parallel_tensor_space_mapping.cc => operator_space_to_parallel_tensor_space_mapping.cc} (71%) delete mode 100644 lib/pcg/include/pcg/model_compilation.h delete mode 100644 lib/task-spec/include/task-spec/config.h create mode 100644 lib/task-spec/include/task-spec/ff_config.struct.toml create mode 100644 lib/task-spec/include/task-spec/ff_init_info.struct.toml create mode 100644 lib/task-spec/include/task-spec/ff_iteration_config.struct.toml delete mode 100644 lib/utils/include/utils/json/visitable.h delete mode 100644 lib/utils/include/utils/visitable.h delete mode 100644 lib/utils/include/utils/visitable_core.h delete mode 100644 lib/utils/test/common/include/test/utils/rapidcheck/visitable.h diff --git a/CMakeLists.txt b/CMakeLists.txt index a5f5a6fa11..8b313f5d4f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -102,7 +102,6 @@ include(spdlog) include(doctestlib) # named doctestlib to avoid a name collision with doctest.cmake in rapidcheck include(gbenchmark) include(libassert) -include(visit_struct) include(CTest) include(fmt) include(legion) diff --git a/cmake/visit_struct.cmake b/cmake/visit_struct.cmake deleted file mode 100644 index 108745fc14..0000000000 --- a/cmake/visit_struct.cmake +++ /dev/null @@ -1,16 +0,0 @@ -add_library( - visit_struct - INTERFACE -) -target_include_directories( - visit_struct - INTERFACE - ${CMAKE_CURRENT_SOURCE_DIR}/deps/visit_struct/include/ -) -set_target_properties( - visit_struct - PROPERTIES - CXX_STANDARD 11 - CXX_STANDARD_REQUIRED YES - CXX_EXTENSIONS NO -) diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping.h index 7375cde985..088e9fd2b0 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping.h @@ -1,10 +1,9 @@ -#ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_H -#define _FLEXFLOW_COMPILER_MACHINE_MAPPING_H +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_H #include "compiler/machine_mapping/machine_mapping.dtg.h" #include "pcg/device_id_t.dtg.h" #include "pcg/machine_specification.dtg.h" -#include "pcg/operator_task_space.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" namespace FlexFlow { diff --git a/lib/compiler/src/compiler/cost_estimator/network_cost_model.cc b/lib/compiler/src/compiler/cost_estimator/network_cost_model.cc index 76fe66e88c..99701cae18 100644 --- a/lib/compiler/src/compiler/cost_estimator/network_cost_model.cc +++ b/lib/compiler/src/compiler/cost_estimator/network_cost_model.cc @@ -1,4 +1,5 @@ #include "compiler/cost_estimator/network_cost_model.h" +#include "utils/exception.h" namespace FlexFlow { diff --git a/lib/compiler/src/compiler/cost_estimator/op_cost_estimate_key.cc b/lib/compiler/src/compiler/cost_estimator/op_cost_estimate_key.cc index 92b07bbe23..edd9eba9b4 100644 --- a/lib/compiler/src/compiler/cost_estimator/op_cost_estimate_key.cc +++ b/lib/compiler/src/compiler/cost_estimator/op_cost_estimate_key.cc @@ -6,7 +6,6 @@ #include "pcg/machine_specification.dtg.h" #include "pcg/machine_view.dtg.h" #include "pcg/machine_view.h" -#include "pcg/operator_task_space.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" #include diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc index fbfccf737f..a430eed7a5 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc @@ -1,6 +1,7 @@ #include "compiler/machine_mapping/machine_mapping_cache.h" #include "utils/containers/contains_key.h" #include "utils/containers/try_at.h" +#include namespace FlexFlow { @@ -17,12 +18,9 @@ std::optional void machine_mapping_cache_save(MachineMappingCache &cache, MachineMappingState const &k, MachineMappingResult const &v) { - if (contains_key(cache.raw_map, k)) { - throw mk_runtime_error( - fmt::format("machine_mapping_cache_save expected key to not already " - "exist, but received existing key {}", - k)); - } + ASSERT(!contains_key(cache.raw_map, k), + "machine_mapping_cache_save expected key to not already exist", + k); cache.raw_map.emplace(k, v); } diff --git a/lib/compiler/src/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_cache.cc b/lib/compiler/src/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_cache.cc index 617ba682be..ce9d08f79a 100644 --- a/lib/compiler/src/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_cache.cc +++ b/lib/compiler/src/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_cache.cc @@ -1,6 +1,7 @@ #include "compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_cache.h" #include "utils/containers/contains_key.h" #include "utils/containers/try_at.h" +#include namespace FlexFlow { @@ -19,12 +20,9 @@ void machine_mapping_with_memory_cache_save( MachineMappingWithMemoryCache &cache, MachineMappingState const &k, MachineMappingWithMemoryResult const &v) { - if (contains_key(cache.raw_map, k)) { - throw mk_runtime_error(fmt::format( - "machine_mapping_with_memory_cache_save expected key to not already " - "exist, but received existing key {}", - k)); - } + ASSERT(!contains_key(cache.raw_map, k), + "machine_mapping_with_memory_cache_save expected key to not already exist", + k); cache.raw_map.emplace(k, v); } diff --git a/lib/compiler/src/compiler/machine_mapping/unstructured_device_mapping.cc b/lib/compiler/src/compiler/machine_mapping/unstructured_device_mapping.cc index 63e359d9ac..0e58eff5fb 100644 --- a/lib/compiler/src/compiler/machine_mapping/unstructured_device_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/unstructured_device_mapping.cc @@ -1,10 +1,9 @@ - #include "compiler/machine_mapping/unstructured_device_mapping.h" #include "compiler/machine_mapping/unstructured_device_mapping.dtg.h" #include "pcg/machine_specification.h" #include "pcg/machine_view.h" -#include "pcg/operator_task_space.dtg.h" -#include "pcg/operator_task_space.h" +#include "op-attrs/operator_task_space.dtg.h" +#include "op-attrs/operator_task_space.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "utils/containers/keys.h" #include "utils/containers/map_values.h" diff --git a/lib/compiler/src/compiler/task_graph_simulator/pcg_task_graph.cc b/lib/compiler/src/compiler/task_graph_simulator/pcg_task_graph.cc index c072b0e61e..dd34112088 100644 --- a/lib/compiler/src/compiler/task_graph_simulator/pcg_task_graph.cc +++ b/lib/compiler/src/compiler/task_graph_simulator/pcg_task_graph.cc @@ -6,7 +6,7 @@ #include "pcg/machine_specification.dtg.h" #include "pcg/machine_view.dtg.h" #include "pcg/machine_view.h" -#include "pcg/operator_task_space.h" +#include "op-attrs/operator_task_space.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" diff --git a/lib/compiler/test/src/allowed_machine_views.cc b/lib/compiler/test/src/allowed_machine_views.cc index 15f7d60060..426bce6341 100644 --- a/lib/compiler/test/src/allowed_machine_views.cc +++ b/lib/compiler/test/src/allowed_machine_views.cc @@ -1,5 +1,5 @@ #include "compiler/allowed_machine_views.h" -#include "doctest/doctest.h" +#include #include "utils/containers/extend.h" #include "utils/containers/range.h" #include "utils/containers/transform.h" diff --git a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping.cc index 928d30ecaa..a5aea74021 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping.cc @@ -1,5 +1,5 @@ #include "compiler/machine_mapping/machine_mapping.h" -#include "doctest/doctest.h" +#include #include "pcg/machine_view.h" using namespace FlexFlow; diff --git a/lib/compiler/test/src/graph_optimize_state.cc b/lib/compiler/test/src/graph_optimize_state.cc index e7060ef421..2f4acf7be0 100644 --- a/lib/compiler/test/src/graph_optimize_state.cc +++ b/lib/compiler/test/src/graph_optimize_state.cc @@ -1,5 +1,5 @@ #include "compiler/graph_optimize_state.h" -#include "doctest/doctest.h" +#include #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" using namespace FlexFlow; diff --git a/lib/compiler/test/src/unity_algorithm.cc b/lib/compiler/test/src/unity_algorithm.cc index 8ff0978ea5..9004fc2c66 100644 --- a/lib/compiler/test/src/unity_algorithm.cc +++ b/lib/compiler/test/src/unity_algorithm.cc @@ -1,5 +1,5 @@ #include "compiler/unity_algorithm.h" -#include "doctest/doctest.h" +#include TEST_SUITE(FF_TEST_SUITE) { // Rapidcheck does not work for now diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index db377162b6..d54663f110 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -1,37 +1,31 @@ #ifndef _FLEXFLOW_OPS_KERNELS_BATCH_MATMUL_KERNELS_H #define _FLEXFLOW_OPS_KERNELS_BATCH_MATMUL_KERNELS_H +#include "kernels/accessor.h" #include "kernels/device_handle_t.dtg.h" #include "kernels/device_stream_t.dtg.h" #include "kernels/ff_handle.h" +#include "utils/nonnegative_int/nonnegative_int.h" namespace FlexFlow::Kernels::BatchMatmul { void forward_kernel(device_stream_t const &stream, device_handle_t const &handle, - float *output_ptr, - float const *a_input_ptr, - float const *b_input_ptr, - int m, - int n, - int k, - int batch, - int seq_length, - int a_seq_length_dim, - int b_seq_length_dim); + GenericTensorAccessorW const &output, + GenericTensorAccessorR const &input_a, + GenericTensorAccessorR const &input_b, + positive_int seq_length, + std::optional a_seq_length_dim, + std::optional b_seq_length_dim); void backward_kernel(device_stream_t const &stream, device_handle_t const &handle, - float const *o_ptr, - float const *o_grad_ptr, - float const *a_ptr, - float *a_grad_ptr, - float const *b_ptr, - float *b_grad_ptr, - int m, - int n, - int k, - int batch); + GenericTensorAccessorR const &output, + GenericTensorAccessorR const &output_grad, + GenericTensorAccessorR const &input_a, + GenericTensorAccessorW const &input_a_grad, + GenericTensorAccessorR const &input_b, + GenericTensorAccessorW const &input_b_grad); } // namespace FlexFlow::Kernels::BatchMatmul diff --git a/lib/kernels/include/kernels/batch_matmul_kernels_cpu.h b/lib/kernels/include/kernels/batch_matmul_kernels_cpu.h index fdef3d7fa1..6d9c804be2 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels_cpu.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels_cpu.h @@ -5,27 +5,19 @@ namespace FlexFlow::Kernels::BatchMatmul { -void cpu_forward_kernel(float *output_ptr, - float const *a_input_ptr, - float const *b_input_ptr, - int m, - int n, - int k, - int batch, - int seq_length, - int a_seq_length_dim, - int b_seq_length_dim); +void cpu_forward_kernel(GenericTensorAccessorW const &output, + GenericTensorAccessorR const &input_a, + GenericTensorAccessorR const &input_b, + positive_int seq_length, + std::optional a_seq_length_dim, + std::optional b_seq_length_dim); -void cpu_backward_kernel(float const *o_ptr, - float const *o_grad_ptr, - float const *a_ptr, - float *a_grad_ptr, - float const *b_ptr, - float *b_grad_ptr, - int m, - int n, - int k, - int batch); +void cpu_backward_kernel(GenericTensorAccessorR const &output, + GenericTensorAccessorR const &output_grad, + GenericTensorAccessorR const &input_a, + GenericTensorAccessorW const &input_a_grad, + GenericTensorAccessorR const &input_b, + GenericTensorAccessorW const &input_b_grad); } // namespace FlexFlow::Kernels::BatchMatmul diff --git a/lib/kernels/include/kernels/batch_matmul_kernels_gpu.h b/lib/kernels/include/kernels/batch_matmul_kernels_gpu.h index 4a35c000c3..1e13755b81 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels_gpu.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels_gpu.h @@ -10,8 +10,8 @@ namespace FlexFlow::Kernels::BatchMatmul { void gpu_forward_kernel(ffStream_t stream, PerDeviceFFHandle const &handle, float *output_ptr, - float const *a_input_ptr, - float const *b_input_ptr, + float const *input_a_ptr, + float const *input_b_ptr, int m, int n, int k, @@ -22,12 +22,12 @@ void gpu_forward_kernel(ffStream_t stream, void gpu_backward_kernel(ffStream_t stream, PerDeviceFFHandle const &handle, - float const *o_ptr, - float const *o_grad_ptr, - float const *a_ptr, - float *a_grad_ptr, - float const *b_ptr, - float *b_grad_ptr, + float const *output_ptr, + float const *output_grad_ptr, + float const *input_a_ptr, + float *input_a_grad_ptr, + float const *input_b_ptr, + float *input_b_grad_ptr, int m, int n, int k, diff --git a/lib/kernels/include/kernels/perf_metrics.h b/lib/kernels/include/kernels/perf_metrics.h index c4a34e4f79..69f96491e0 100644 --- a/lib/kernels/include/kernels/perf_metrics.h +++ b/lib/kernels/include/kernels/perf_metrics.h @@ -1,37 +1,11 @@ #ifndef _FLEXFLOW_KERNELS_INCLUDE_KERNELS_PERF_METRICS_H #define _FLEXFLOW_KERNELS_INCLUDE_KERNELS_PERF_METRICS_H -#include "utils/fmt.h" -#include "utils/visitable.h" +#include "kernels/perf_metrics.dtg.h" +#include namespace FlexFlow { -struct PerfMetrics : public use_visitable_cmp { - PerfMetrics() = delete; - PerfMetrics(double start_time); - PerfMetrics(int train_all, - std::optional train_correct, - std::optional cce_loss, - std::optional sparse_cce_loss, - std::optional mse_loss, - std::optional rmse_loss, - std::optional mae_loss, - double start_time_micro, - double current_time_micro); - - int train_all = 0; // measure_accuracy_denominator - std::optional train_correct = 0; // measure_accuracy numerator - std::optional cce_loss = - std::nullopt; // measure_categorical_crossentropy - std::optional sparse_cce_loss = - 0.0f; // measure_sparse_categorical_crossentropy - std::optional mse_loss = 0.0f; // measure_mean_squared_error - std::optional rmse_loss = 0.0f; // measure_root_mean_squared_error - std::optional mae_loss = 0.0f; // measure_mean_absolute_error - double start_time; - double current_time; -}; - float get_throughput(PerfMetrics const &); float get_accuracy(PerfMetrics const &); @@ -40,16 +14,6 @@ PerfMetrics apply_scale(PerfMetrics const &, float scale); } // namespace FlexFlow -VISITABLE_STRUCT(::FlexFlow::PerfMetrics, - train_all, - train_correct, - cce_loss, - sparse_cce_loss, - mse_loss, - rmse_loss, - mae_loss, - start_time); - namespace fmt { template <> diff --git a/lib/kernels/include/kernels/perf_metrics.struct.toml b/lib/kernels/include/kernels/perf_metrics.struct.toml new file mode 100644 index 0000000000..d7f1b67a35 --- /dev/null +++ b/lib/kernels/include/kernels/perf_metrics.struct.toml @@ -0,0 +1,59 @@ +namespace = "FlexFlow" +name = "PerfMetrics" +features = [ + "eq", + "hash", + "json", +] + +includes = [ + "", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/json/optional.h", +] + +[[fields]] +name = "train_all" +type = "int" +docstring = "measure_accuracy denominator" + +[[fields]] +name = "train_correct" +type = "std::optional" +docstring = "measure_accuracy numerator" + +[[fields]] +name = "cce_loss" +type = "std::optional" +docstring = "measure_categorical_crossentropy" + +[[fields]] +name = "sparse_cce_loss" +type = "std::optional" +docstring = "measure_sparse_categorical_crossentropy" + +[[fields]] +name = "mse_loss" +type = "std::optional" +docstring = "measure_mean_squared_error" + +[[fields]] +name = "rmse_loss" +type = "std::optional" +docstring = "measure_root_mean_squared_error" + +[[fields]] +name = "mae_loss" +type = "std::optional" +docstring = "measure_mean_absolute_error" + +[[fields]] +name = "start_time" +type = "double" + +[[fields]] +name = "current_time" +type = "double" diff --git a/lib/kernels/src/cuda/ops/element_binary_kernels.cu b/lib/kernels/src/cuda/ops/element_binary_kernels.cu index 7e13486429..e4c698ad02 100644 --- a/lib/kernels/src/cuda/ops/element_binary_kernels.cu +++ b/lib/kernels/src/cuda/ops/element_binary_kernels.cu @@ -18,6 +18,7 @@ #include "kernels/ff_handle.h" #include "op-attrs/datatype.h" #include "op-attrs/operator_type.h" +#include "utils/exception.h" namespace FlexFlow { namespace Kernels { diff --git a/lib/kernels/src/cuda/ops/pool_2d_kernels.cu b/lib/kernels/src/cuda/ops/pool_2d_kernels.cu index ec185a360e..4e06f2da02 100644 --- a/lib/kernels/src/cuda/ops/pool_2d_kernels.cu +++ b/lib/kernels/src/cuda/ops/pool_2d_kernels.cu @@ -15,6 +15,7 @@ #include "internal/device.h" #include "kernels/pool_2d_kernels_gpu.h" +#include "utils/exception.h" namespace FlexFlow { diff --git a/lib/kernels/src/cuda/ops/softmax_kernels.cu b/lib/kernels/src/cuda/ops/softmax_kernels.cu index 85575d7bf6..1ecf42dfd8 100644 --- a/lib/kernels/src/cuda/ops/softmax_kernels.cu +++ b/lib/kernels/src/cuda/ops/softmax_kernels.cu @@ -15,6 +15,7 @@ #include "internal/device.h" #include "kernels/softmax_kernels_gpu.h" +#include "utils/exception.h" namespace FlexFlow { diff --git a/lib/kernels/src/kernels/accessor.cc b/lib/kernels/src/kernels/accessor.cc index 868940bf6c..31536caf9c 100644 --- a/lib/kernels/src/kernels/accessor.cc +++ b/lib/kernels/src/kernels/accessor.cc @@ -19,7 +19,7 @@ nonnegative_int calculate_accessor_offset(TensorDimsCoord const &coord, nonnegative_int offset = 0_n; positive_int multiplier = 1_p; - for (ff_dim_t dim : reversed(get_idxs(tensor_dims.ff_ordered))) { + for (ff_dim_t dim : reversed(vector_of(get_idxs(tensor_dims.ff_ordered)))) { ASSERT(coord.ff_ordered.at(dim) < dim_at_idx(tensor_dims, dim), "Out of bounds access", dim); diff --git a/lib/kernels/src/kernels/batch_matmul_kernels.cc b/lib/kernels/src/kernels/batch_matmul_kernels.cc index 652d4fb137..ea9791ddc3 100644 --- a/lib/kernels/src/kernels/batch_matmul_kernels.cc +++ b/lib/kernels/src/kernels/batch_matmul_kernels.cc @@ -1,46 +1,77 @@ #include "kernels/batch_matmul_kernels.h" #include "kernels/batch_matmul_kernels_cpu.h" #include "kernels/batch_matmul_kernels_gpu.h" +#include "utils/containers/require_same.h" namespace FlexFlow::Kernels::BatchMatmul { +static std::tuple + get_params(TensorDims const &input_a_dims, + TensorDims const &input_b_dims, + TensorDims const &output_dims) { + positive_int m = require_same( + dim_at_idx(input_b_dims, relative_ff_dim_t{-1}), + dim_at_idx(output_dims, relative_ff_dim_t{-1})); + + positive_int n = require_same( + dim_at_idx(input_a_dims, relative_ff_dim_t{-2}), + dim_at_idx(output_dims, relative_ff_dim_t{-2})); + + positive_int k = require_same( + dim_at_idx(input_a_dims, relative_ff_dim_t{-1}), + dim_at_idx(input_b_dims, relative_ff_dim_t{-2})); + + TensorDims leading_dims = require_same( + slice_tensor_dims(input_a_dims, + relative_ff_dim_t{0}, + relative_ff_dim_t{-2}), + slice_tensor_dims(input_b_dims, + relative_ff_dim_t{0}, + relative_ff_dim_t{-2})); + + positive_int batch = get_num_elements(leading_dims); + + return {m, n, k, batch}; +} + void forward_kernel(device_stream_t const &stream, device_handle_t const &handle, - float *output_ptr, - float const *a_input_ptr, - float const *b_input_ptr, - int m, - int n, - int k, - int batch, - int seq_length, - int a_seq_length_dim, - int b_seq_length_dim) { + GenericTensorAccessorW const &output, + GenericTensorAccessorR const &input_a, + GenericTensorAccessorR const &input_b, + positive_int seq_length, + std::optional a_seq_length_dim, + std::optional b_seq_length_dim) { + + auto [m, n, k, batch] = get_params(input_a.shape.dims, input_b.shape.dims, output.shape.dims); + + auto get_raw_seq_len = [](std::optional seq_len) -> int { + return transform(seq_len, + [](positive_int x) { return x.int_from_positive_int(); }) + .value_or(-1); + }; + if (stream.is_gpu()) { gpu_forward_kernel( /*stream=*/stream.require_gpu(), /*handle=*/handle.require_for_gpu(), - /*output_ptr=*/output_ptr, - /*a_input_ptr=*/a_input_ptr, - /*b_input_ptr=*/b_input_ptr, - /*m=*/m, - /*n=*/n, - /*k=*/k, - /*batch=*/batch, - /*seq_length=*/seq_length, - /*a_seq_length_dim=*/a_seq_length_dim, - /*b_seq_length_dim=*/b_seq_length_dim); + /*output_ptr=*/output.get_float_ptr(), + /*a_input_ptr=*/input_a.get_float_ptr(), + /*b_input_ptr=*/input_b.get_float_ptr(), + /*m=*/m.int_from_positive_int(), + /*n=*/n.int_from_positive_int(), + /*k=*/k.int_from_positive_int(), + /*batch=*/batch.int_from_positive_int(), + /*seq_length=*/seq_length.int_from_positive_int(), + /*a_seq_length_dim=*/get_raw_seq_len(a_seq_length_dim), + /*b_seq_length_dim=*/get_raw_seq_len(b_seq_length_dim)); } else { ASSERT(stream.is_cpu()); ASSERT(handle.is_for_cpu()); cpu_forward_kernel( - /*output_ptr=*/output_ptr, - /*a_input_ptr=*/a_input_ptr, - /*b_input_ptr=*/b_input_ptr, - /*m=*/m, - /*n=*/n, - /*k=*/k, - /*batch=*/batch, + /*output=*/output, + /*input_a=*/input_a, + /*input_b=*/input_b, /*seq_length=*/seq_length, /*a_seq_length_dim=*/a_seq_length_dim, /*b_seq_length_dim=*/b_seq_length_dim); @@ -49,44 +80,42 @@ void forward_kernel(device_stream_t const &stream, void backward_kernel(device_stream_t const &stream, device_handle_t const &handle, - float const *o_ptr, - float const *o_grad_ptr, - float const *a_ptr, - float *a_grad_ptr, - float const *b_ptr, - float *b_grad_ptr, - int m, - int n, - int k, - int batch) { + GenericTensorAccessorR const &output, + GenericTensorAccessorR const &output_grad, + GenericTensorAccessorR const &input_a, + GenericTensorAccessorW const &input_a_grad, + GenericTensorAccessorR const &input_b, + GenericTensorAccessorW const &input_b_grad) { + TensorShape input_a_shape = require_same(input_a.shape, input_a_grad.shape); + TensorShape input_b_shape = require_same(input_b.shape, input_b_grad.shape); + TensorShape output_shape = require_same(output.shape, output_grad.shape); + + auto [m, n, k, batch] = get_params(input_a_shape.dims, input_b_shape.dims, output_shape.dims); + if (stream.is_gpu()) { gpu_backward_kernel( /*stream=*/stream.require_gpu(), /*handle=*/handle.require_for_gpu(), - /*o_ptr=*/o_ptr, - /*o_grad_ptr=*/o_grad_ptr, - /*a_ptr=*/a_ptr, - /*a_grad_ptr=*/a_grad_ptr, - /*b_ptr=*/b_ptr, - /*b_grad_ptr=*/b_grad_ptr, - /*m=*/m, - /*n=*/n, - /*k=*/k, - /*batch=*/batch); + /*output_ptr=*/output.get_float_ptr(), + /*output_grad_ptr=*/output_grad.get_float_ptr(), + /*input_a_ptr=*/input_a.get_float_ptr(), + /*input_a_grad_ptr=*/input_a_grad.get_float_ptr(), + /*input_b_ptr=*/input_b.get_float_ptr(), + /*input_b_grad_ptr=*/input_b_grad.get_float_ptr(), + /*m=*/m.int_from_positive_int(), + /*n=*/n.int_from_positive_int(), + /*k=*/k.int_from_positive_int(), + /*batch=*/batch.int_from_positive_int()); } else { ASSERT(stream.is_cpu()); ASSERT(handle.is_for_cpu()); cpu_backward_kernel( - /*o_ptr=*/o_ptr, - /*o_grad_ptr=*/o_grad_ptr, - /*a_ptr=*/a_ptr, - /*a_grad_ptr=*/a_grad_ptr, - /*b_ptr=*/b_ptr, - /*b_grad_ptr=*/b_grad_ptr, - /*m=*/m, - /*n=*/n, - /*k=*/k, - /*batch=*/batch); + /*output=*/output, + /*output_grad=*/output_grad, + /*input_a=*/input_a, + /*input_a_grad=*/input_a_grad, + /*input_b=*/input_b, + /*input_b_grad=*/input_b_grad); } } diff --git a/lib/kernels/src/kernels/batch_matmul_kernels_cpu.cc b/lib/kernels/src/kernels/batch_matmul_kernels_cpu.cc index f139d42992..292841d19f 100644 --- a/lib/kernels/src/kernels/batch_matmul_kernels_cpu.cc +++ b/lib/kernels/src/kernels/batch_matmul_kernels_cpu.cc @@ -2,29 +2,21 @@ namespace FlexFlow::Kernels::BatchMatmul { -void cpu_forward_kernel(float *output_ptr, - float const *a_input_ptr, - float const *b_input_ptr, - int m, - int n, - int k, - int batch, - int seq_length, - int a_seq_length_dim, - int b_seq_length_dim) { +void cpu_forward_kernel(GenericTensorAccessorW const &output, + GenericTensorAccessorR const &input_a, + GenericTensorAccessorR const &input_b, + positive_int seq_length, + std::optional a_seq_length_dim, + std::optional b_seq_length_dim) { NOT_IMPLEMENTED(); } -void cpu_backward_kernel(float const *o_ptr, - float const *o_grad_ptr, - float const *a_ptr, - float *a_grad_ptr, - float const *b_ptr, - float *b_grad_ptr, - int m, - int n, - int k, - int batch) { +void cpu_backward_kernel(GenericTensorAccessorR const &output, + GenericTensorAccessorR const &output_grad, + GenericTensorAccessorR const &input_a, + GenericTensorAccessorW const &input_a_grad, + GenericTensorAccessorR const &input_b, + GenericTensorAccessorW const &input_b_grad) { NOT_IMPLEMENTED(); } diff --git a/lib/kernels/src/kernels/element_binary_kernels.cc b/lib/kernels/src/kernels/element_binary_kernels.cc index bea317dfec..1d8fbaaf77 100644 --- a/lib/kernels/src/kernels/element_binary_kernels.cc +++ b/lib/kernels/src/kernels/element_binary_kernels.cc @@ -1,6 +1,7 @@ #include "kernels/element_binary_kernels.h" #include "kernels/element_binary_kernels_cpu.h" #include "kernels/element_binary_kernels_gpu.h" +#include namespace FlexFlow::Kernels::ElementBinary { diff --git a/lib/kernels/src/kernels/pool_2d_kernels.cc b/lib/kernels/src/kernels/pool_2d_kernels.cc index 6ebfc68c86..f8f5571716 100644 --- a/lib/kernels/src/kernels/pool_2d_kernels.cc +++ b/lib/kernels/src/kernels/pool_2d_kernels.cc @@ -1,6 +1,7 @@ #include "kernels/pool_2d_kernels.h" #include "kernels/pool_2d_kernels_cpu.h" #include "kernels/pool_2d_kernels_gpu.h" +#include namespace FlexFlow::Kernels::Pool2D { diff --git a/lib/kernels/src/kernels/reduce_kernels.cc b/lib/kernels/src/kernels/reduce_kernels.cc index bd3d6a8cd1..284d07dd96 100644 --- a/lib/kernels/src/kernels/reduce_kernels.cc +++ b/lib/kernels/src/kernels/reduce_kernels.cc @@ -1,6 +1,7 @@ #include "kernels/reduce_kernels.h" #include "kernels/reduce_kernels_cpu.h" #include "kernels/reduce_kernels_gpu.h" +#include namespace FlexFlow::Kernels::Reduce { diff --git a/lib/kernels/src/perf_metrics.cc b/lib/kernels/src/perf_metrics.cc index 2036ddd35a..c9161e55b1 100644 --- a/lib/kernels/src/perf_metrics.cc +++ b/lib/kernels/src/perf_metrics.cc @@ -2,22 +2,6 @@ namespace FlexFlow { -PerfMetrics::PerfMetrics(double _start_time) - : start_time(_start_time), current_time(_start_time) {} - -PerfMetrics::PerfMetrics(int _train_all, - std::optional _train_correct, - std::optional _cce_loss, - std::optional _sparse_cce_loss, - std::optional _mse_loss, - std::optional _rmse_loss, - std::optional _mae_loss, - double _start_time_micro, - double _current_time_micro) - : train_all(_train_all), train_correct(_train_correct), cce_loss(_cce_loss), - mse_loss(_mse_loss), rmse_loss(_rmse_loss), mae_loss(_mae_loss), - start_time(_start_time_micro), current_time(_current_time_micro) {} - float get_throughput(PerfMetrics const &m) { return m.train_all / (m.current_time - m.start_time); } diff --git a/lib/local-execution/test/src/local-execution/local_cost_estimator.cc b/lib/local-execution/test/src/local-execution/local_cost_estimator.cc index 107b835383..dfee96dd93 100644 --- a/lib/local-execution/test/src/local-execution/local_cost_estimator.cc +++ b/lib/local-execution/test/src/local-execution/local_cost_estimator.cc @@ -1,5 +1,5 @@ #include "local-execution/local_cost_estimator.h" -#include "doctest/doctest.h" +#include #include "internal/test_utils.h" #include "kernels/device_handle_t.h" #include "kernels/managed_per_device_ff_handle.h" diff --git a/lib/local-execution/test/src/local-execution/local_task_argument_accessor.cc b/lib/local-execution/test/src/local-execution/local_task_argument_accessor.cc index 482795b278..add58ad77b 100644 --- a/lib/local-execution/test/src/local-execution/local_task_argument_accessor.cc +++ b/lib/local-execution/test/src/local-execution/local_task_argument_accessor.cc @@ -1,5 +1,5 @@ #include "local-execution/local_task_argument_accessor.h" -#include "doctest/doctest.h" +#include #include "kernels/local_cpu_allocator.h" #include "task-spec/task_signature_impl.h" #include "utils/fmt/variant.h" diff --git a/lib/local-execution/test/src/local-execution/local_task_registry.cc b/lib/local-execution/test/src/local-execution/local_task_registry.cc index 27cd74b2a6..dd7a6e4440 100644 --- a/lib/local-execution/test/src/local-execution/local_task_registry.cc +++ b/lib/local-execution/test/src/local-execution/local_task_registry.cc @@ -206,8 +206,12 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("Partial task does not exist") { ComputationGraphOpAttrs bmm_attrs = ComputationGraphOpAttrs{ - BatchMatmulAttrs{/*a_seq_length_dim=*/10_n, - /*b_seq_length_dim=*/20_n}}; + BatchMatmulAttrs{ + /*a_seq_length_dim=*/10_p, + /*b_seq_length_dim=*/20_p, + }, + }; + LocalTaskRegistry task_registry = construct_local_task_registry_for_layers({ {layer_guid, LayerAttrs{bmm_attrs, std::nullopt}}, diff --git a/lib/local-execution/test/src/local-execution/loss_functions.cc b/lib/local-execution/test/src/local-execution/loss_functions.cc index e5fffb980c..7a18fec545 100644 --- a/lib/local-execution/test/src/local-execution/loss_functions.cc +++ b/lib/local-execution/test/src/local-execution/loss_functions.cc @@ -1,4 +1,4 @@ -#include "doctest/doctest.h" +#include #include "internal/test_utils.h" #include "kernels/local_cuda_allocator.h" #include "kernels/managed_ff_stream.h" diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h b/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h deleted file mode 100644 index 5c47745209..0000000000 --- a/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h +++ /dev/null @@ -1,200 +0,0 @@ -#ifndef _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_FF_STACK_VECTOR_H -#define _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_FF_STACK_VECTOR_H - -#include "op-attrs/ff_dim_t.dtg.h" -#include "op-attrs/relative_ff_dim_t.dtg.h" -#include "utils/containers/range.h" -#include "utils/fmt/vector.h" -#include "utils/stack_vector/stack_vector.h" -#include - -namespace FlexFlow { - -template -struct DimOrdered { - DimOrdered() {} - - DimOrdered(std::initializer_list const &l) - : contents(l.begin(), l.end()) {} - - DimOrdered(std::vector const &contents) - : contents(contents.begin(), contents.end()) {} - - template - DimOrdered(It begin, It end) : contents(begin, end) {} - - template - DimOrdered(stack_vector const &contents) - : contents(contents.begin(), contents.end()) {} - - T const &at(Idx idx) const { - nonnegative_int raw = idx.value; - return this->contents.at(raw.unwrap_nonnegative()); - } - - T &at(Idx idx) { - nonnegative_int raw = idx.value; - return this->contents.at(raw.unwrap_nonnegative()); - } - - T const &operator[](Idx idx) const { - return this->at(idx); - } - - T &operator[](Idx idx) { - return this->at(idx); - } - - bool idx_is_valid(Idx const &idx) const { - nonnegative_int raw = idx.value; - return (raw < this->contents.size()); - } - - bool operator==(DimOrdered const &other) const { - return this->contents == other.contents; - } - - bool operator!=(DimOrdered const &other) const { - return this->contents != other.contents; - } - - using iterator = typename stack_vector::iterator; - using const_iterator = - typename stack_vector::const_iterator; - using reverse_iterator = - typename stack_vector::reverse_iterator; - using const_reverse_iterator = - typename stack_vector::const_reverse_iterator; - using value_type = T; - using pointer = value_type *; - using const_pointer = value_type const *; - using reference = value_type &; - using const_reference = value_type const &; - - iterator begin() { - return this->contents.begin(); - } - - const_iterator begin() const { - return this->cbegin(); - } - - const_iterator cbegin() const { - return this->contents.cbegin(); - } - - iterator end() { - return this->contents.end(); - } - - const_iterator end() const { - return this->cend(); - } - - const_iterator cend() const { - return this->contents.cend(); - } - - reverse_iterator rbegin() { - return this->contents.rbegin(); - } - - const_reverse_iterator rbegin() const { - return this->crbegin(); - } - - const_reverse_iterator crbegin() const { - return this->contents.crbegin(); - } - - reverse_iterator rend() { - return this->contents.rend(); - } - - const_reverse_iterator rend() const { - return this->crend(); - } - - const_reverse_iterator crend() const { - return this->contents.crend(); - } - - size_t size() const { - return this->contents.size(); - } - - size_t empty() const { - return this->contents.empty(); - } - - size_t num_dims() const { - return this->size(); - } - - friend struct ::std::hash; - -private: - stack_vector contents; -}; - -template -auto operator<(DimOrdered const &lhs, DimOrdered const &rhs) - -> std::enable_if_t, bool> { - return std::lexicographical_compare( - lhs.cbegin(), lhs.cend(), rhs.cbegin(), rhs.cend()); -} - -template -std::string format_as(DimOrdered const &v) { - std::vector as_vec(v.cbegin(), v.cend()); - return fmt::format("", as_vec); -} - -template -std::ostream &operator<<(std::ostream &s, DimOrdered const &v) { - return (s << fmt::to_string(v)); -} - -} // namespace FlexFlow - -namespace nlohmann { -template -struct adl_serializer<::FlexFlow::DimOrdered> { - static ::FlexFlow::DimOrdered from_json(nlohmann::json const &j) { - return {j.template get>()}; - } - - static void to_json(nlohmann::json &j, - ::FlexFlow::DimOrdered const &x) { - j = std::vector{x.cbegin(), x.cend()}; - } -}; -} // namespace nlohmann - -namespace std { - -template -struct hash<::FlexFlow::DimOrdered> { - size_t operator()(::FlexFlow::DimOrdered const &t) const { - static_assert(::FlexFlow::is_hashable::value, - "Elements must be hashable"); - - return get_std_hash(t.contents); - } -}; - -} // namespace std - -namespace rc { - -template -struct Arbitrary<::FlexFlow::DimOrdered> { - static Gen<::FlexFlow::DimOrdered> arbitrary() { - return gen::construct<::FlexFlow::DimOrdered>( - gen::arbitrary<::FlexFlow::stack_vector>()); - } -}; - -} // namespace rc - -#endif diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/slice.h b/lib/op-attrs/include/op-attrs/dim_ordered/slice.h deleted file mode 100644 index 76526447be..0000000000 --- a/lib/op-attrs/include/op-attrs/dim_ordered/slice.h +++ /dev/null @@ -1,32 +0,0 @@ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_SLICE_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_SLICE_H - -#include "op-attrs/dim_ordered/dim_ordered.h" -#include "utils/containers/slice.h" -#include "utils/containers/transform.h" -#include "utils/containers/vector_of.h" -#include "utils/optional.h" - -namespace FlexFlow { - -template -DimOrdered nonoverloaded_slice(DimOrdered const &d, - std::optional const &start, - std::optional const &end) { - auto to_raw_idx = [](std::optional const &idx) -> std::optional { - return transform(idx, [](Idx const &i) { return i.value; }); - }; - - return DimOrdered{ - slice(vector_of(d), to_raw_idx(start), to_raw_idx(end))}; -} -template -DimOrdered slice(DimOrdered const &d, - std::optional const &start = std::nullopt, - std::optional const &end = std::nullopt) { - return ff_dim_t_nonoverloaded_slice(d, start, end); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/transform.h b/lib/op-attrs/include/op-attrs/dim_ordered/transform.h deleted file mode 100644 index 4fd3df0abb..0000000000 --- a/lib/op-attrs/include/op-attrs/dim_ordered/transform.h +++ /dev/null @@ -1,20 +0,0 @@ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_TRANSFORM_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_TRANSFORM_H - -#include "op-attrs/dim_ordered/dim_ordered.h" -#include "utils/containers/vector_of.h" -#include "utils/containers/vector_transform.h" - -namespace FlexFlow { - -template -DimOrdered> - transform(DimOrdered const &d, F f) { - using Out = std::invoke_result_t; - - return DimOrdered{vector_transform(vector_of(d), f)}; -} - -} // namespace FlexFlow - -#endif diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/zip.h b/lib/op-attrs/include/op-attrs/dim_ordered/zip.h deleted file mode 100644 index cc8b050f50..0000000000 --- a/lib/op-attrs/include/op-attrs/dim_ordered/zip.h +++ /dev/null @@ -1,19 +0,0 @@ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_ZIP_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_ZIP_H - -#include "op-attrs/dim_ordered/dim_ordered.h" -#include "utils/containers/vector_of.h" -#include "utils/containers/zip.h" - -namespace FlexFlow { - -template -DimOrdered> zip(DimOrdered const &lhs, - DimOrdered const &rhs) { - return DimOrdered>{ - zip(vector_of(lhs), vector_of(rhs))}; -} - -} // namespace FlexFlow - -#endif diff --git a/lib/op-attrs/include/op-attrs/ff_dim_t.h b/lib/op-attrs/include/op-attrs/ff_dim_t.h index 0979201f67..1411886eee 100644 --- a/lib/op-attrs/include/op-attrs/ff_dim_t.h +++ b/lib/op-attrs/include/op-attrs/ff_dim_t.h @@ -11,6 +11,8 @@ relative_ff_dim_t relative_ff_dim_t_from_ff_dim_t(ff_dim_t ff_dim); ff_dim_t add_to_ff_dim(ff_dim_t ff_dim, int value); +std::vector ff_dim_range(nonnegative_int num_elements); + } // namespace FlexFlow namespace rc { diff --git a/lib/op-attrs/include/op-attrs/ff_ordered/get_idxs.h b/lib/op-attrs/include/op-attrs/ff_ordered/get_idxs.h index 5ff390d3fe..c4e4c4f55f 100644 --- a/lib/op-attrs/include/op-attrs/ff_ordered/get_idxs.h +++ b/lib/op-attrs/include/op-attrs/ff_ordered/get_idxs.h @@ -3,14 +3,15 @@ #include "op-attrs/ff_dim_t.h" #include "op-attrs/ff_ordered/ff_ordered.h" -#include "utils/containers/count.h" +#include "utils/containers/range.h" #include "utils/containers/transform.h" +#include "utils/containers/set_of.h" namespace FlexFlow { template -std::vector get_idxs(FFOrdered const &d) { - return transform(count(d.size()), +std::set get_idxs(FFOrdered const &d) { + return transform(set_of(range(d.size())), [](int i) { return ff_dim_t{nonnegative_int{i}}; }); } diff --git a/lib/op-attrs/include/op-attrs/operator_attrs.h b/lib/op-attrs/include/op-attrs/operator_attrs.h index d94f7af4fb..a6a1c35ca7 100644 --- a/lib/op-attrs/include/op-attrs/operator_attrs.h +++ b/lib/op-attrs/include/op-attrs/operator_attrs.h @@ -1,5 +1,5 @@ -#ifndef _OPERATOR_PARAMS_H -#define _OPERATOR_PARAMS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_ATTRS_H #include "op-attrs/ops/attention.h" #include "op-attrs/ops/batch_matmul.h" @@ -9,7 +9,6 @@ #include "op-attrs/ops/combine.h" #include "op-attrs/ops/concat.h" #include "op-attrs/ops/conv_2d.h" -#include "op-attrs/ops/core.h" #include "op-attrs/ops/dropout.h" #include "op-attrs/ops/element_binary.h" #include "op-attrs/ops/element_unary.h" diff --git a/lib/op-attrs/include/op-attrs/operator_space_parallel_tensor_space_mapping.h b/lib/op-attrs/include/op-attrs/operator_space_parallel_tensor_space_mapping.h deleted file mode 100644 index 908f25aaa6..0000000000 --- a/lib/op-attrs/include/op-attrs/operator_space_parallel_tensor_space_mapping.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_SPACE_PARALLEL_TENSOR_SPACE_MAPPING_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_SPACE_PARALLEL_TENSOR_SPACE_MAPPING_H - -#include "op-attrs/operator_space_parallel_tensor_space_mapping.dtg.h" -#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" - -namespace FlexFlow { - -OperatorSpaceParallelTensorSpaceMapping - get_identity_mapping(nonnegative_int num_dims); - -compute_ - -} // namespace FlexFlow - -#endif diff --git a/lib/op-attrs/include/op-attrs/operator_space_to_parallel_tensor_space_mapping.h b/lib/op-attrs/include/op-attrs/operator_space_to_parallel_tensor_space_mapping.h new file mode 100644 index 0000000000..2697c24a6a --- /dev/null +++ b/lib/op-attrs/include/op-attrs/operator_space_to_parallel_tensor_space_mapping.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_SPACE_TO_PARALLEL_TENSOR_SPACE_MAPPING_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_SPACE_TO_PARALLEL_TENSOR_SPACE_MAPPING_H + +#include "op-attrs/operator_space_to_parallel_tensor_space_mapping.dtg.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" + +namespace FlexFlow { + +OperatorSpaceToParallelTensorSpaceMapping + get_identity_mapping(nonnegative_int num_shard_dims); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/operator_space_parallel_tensor_space_mapping.struct.toml b/lib/op-attrs/include/op-attrs/operator_space_to_parallel_tensor_space_mapping.struct.toml similarity index 87% rename from lib/op-attrs/include/op-attrs/operator_space_parallel_tensor_space_mapping.struct.toml rename to lib/op-attrs/include/op-attrs/operator_space_to_parallel_tensor_space_mapping.struct.toml index 24c7676527..9226175c2e 100644 --- a/lib/op-attrs/include/op-attrs/operator_space_parallel_tensor_space_mapping.struct.toml +++ b/lib/op-attrs/include/op-attrs/operator_space_to_parallel_tensor_space_mapping.struct.toml @@ -1,5 +1,5 @@ namespace = "FlexFlow" -name = "OperatorSpaceParallelTensorSpaceMapping" +name = "OperatorSpaceToParallelTensorSpaceMapping" features = [ "eq", "hash", diff --git a/lib/op-attrs/include/op-attrs/operator_task_space.h b/lib/op-attrs/include/op-attrs/operator_task_space.h index ceb0146f15..b57e19b6e8 100644 --- a/lib/op-attrs/include/op-attrs/operator_task_space.h +++ b/lib/op-attrs/include/op-attrs/operator_task_space.h @@ -1,11 +1,8 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_OPERATOR_TASK_SPACE_H -#define _FLEXFLOW_PCG_INCLUDE_OPERATOR_TASK_SPACE_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_TASK_SPACE_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_TASK_SPACE_H -#include "pcg/operator_task_space.dtg.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" -#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" -#include "pcg/task_space_coordinate.dtg.h" -#include +#include "op-attrs/operator_task_space.dtg.h" +#include "op-attrs/task_space_coordinate.dtg.h" #include namespace FlexFlow { @@ -19,9 +16,6 @@ TaskSpaceCoordinate nonnegative_int num_dims(OperatorTaskSpace const &task); positive_int num_tasks(OperatorTaskSpace const &task); -OperatorTaskSpace get_operator_task_space(ParallelComputationGraph const &pcg, - parallel_layer_guid_t const &layer); - } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/attention.h b/lib/op-attrs/include/op-attrs/ops/attention.h index 5ca237561f..9407cc6942 100644 --- a/lib/op-attrs/include/op-attrs/ops/attention.h +++ b/lib/op-attrs/include/op-attrs/ops/attention.h @@ -1,12 +1,11 @@ -#ifndef _FLEXFLOW_ATTENTION_ATTRS_H -#define _FLEXFLOW_ATTENTION_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_H #include "op-attrs/incoming_tensor_role.dtg.h" #include "op-attrs/initializer_attrs.dtg.h" #include "op-attrs/ops/attention/multihead_attention_inputs.dtg.h" #include "op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.h" #include "op-attrs/ops/attention_attrs.dtg.h" -#include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" #include @@ -123,7 +122,6 @@ tl::expected, std::string> get_initializers( std::optional const &output_bias_initializer = std::nullopt); -CHECK_VALID_OP_ATTR(MultiHeadAttentionAttrs); } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h index 333da4fa29..f17757ac85 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h @@ -2,15 +2,12 @@ #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_MATMUL_H #include "op-attrs/ops/batch_matmul_attrs.dtg.h" -#include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" #include namespace FlexFlow { -CHECK_VALID_OP_ATTR(BatchMatmulAttrs); - bool is_valid(BatchMatmulAttrs const &, ParallelTensorShape const &, ParallelTensorShape const &); diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/batch_matmul_attrs.struct.toml index 394dfb5fcc..0ec3f3e319 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul_attrs.struct.toml @@ -11,7 +11,7 @@ features = [ ] includes = [ - "utils/nonnegative_int/nonnegative_int.h", + "utils/positive_int/positive_int.h", "", ] @@ -23,8 +23,8 @@ src_includes = [ [[fields]] name = "a_seq_length_dim" -type = "std::optional<::FlexFlow::nonnegative_int>" +type = "std::optional<::FlexFlow::positive_int>" [[fields]] name = "b_seq_length_dim" -type = "std::optional<::FlexFlow::nonnegative_int>" +type = "std::optional<::FlexFlow::positive_int>" diff --git a/lib/op-attrs/include/op-attrs/ops/batch_norm.h b/lib/op-attrs/include/op-attrs/ops/batch_norm.h index bcf6794f38..35d6cb496d 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_norm.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_norm.h @@ -4,10 +4,10 @@ #include "op-attrs/incoming_tensor_role.dtg.h" #include "op-attrs/initializer_attrs.dtg.h" #include "op-attrs/ops/batch_norm_attrs.dtg.h" -#include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" +#include namespace FlexFlow { @@ -61,8 +61,6 @@ tl::expected, std::string> tl::expected, std::string> get_initializers(BatchNormAttrs const &attrs); -CHECK_VALID_OP_ATTR(BatchNormAttrs); - } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/broadcast.h b/lib/op-attrs/include/op-attrs/ops/broadcast.h index 4fd7d49234..9b6bd49418 100644 --- a/lib/op-attrs/include/op-attrs/ops/broadcast.h +++ b/lib/op-attrs/include/op-attrs/ops/broadcast.h @@ -2,15 +2,13 @@ #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BROADCAST_H #include "op-attrs/ops/broadcast_attrs.dtg.h" -#include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" #include "utils/record_formatter.h" +#include namespace FlexFlow { -CHECK_VALID_OP_ATTR(BroadcastAttrs); - RecordFormatter as_dot(BroadcastAttrs const &); tl::expected get_output_shape(BroadcastAttrs const &, diff --git a/lib/op-attrs/include/op-attrs/ops/cast.h b/lib/op-attrs/include/op-attrs/ops/cast.h index 30818f046d..38a1e87a76 100644 --- a/lib/op-attrs/include/op-attrs/ops/cast.h +++ b/lib/op-attrs/include/op-attrs/ops/cast.h @@ -1,8 +1,7 @@ -#ifndef _FLEXFLOW_CAST_ATTRS_H -#define _FLEXFLOW_CAST_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CAST_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CAST_H #include "op-attrs/ops/cast_attrs.dtg.h" -#include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" #include "utils/record_formatter.h" @@ -10,8 +9,6 @@ namespace FlexFlow { -CHECK_VALID_OP_ATTR(CastAttrs); - RecordFormatter as_dot(CastAttrs const &); tl::expected get_output_shape(CastAttrs const &, diff --git a/lib/op-attrs/include/op-attrs/ops/combine.h b/lib/op-attrs/include/op-attrs/ops/combine.h index d9ca314c2b..6839bc12e1 100644 --- a/lib/op-attrs/include/op-attrs/ops/combine.h +++ b/lib/op-attrs/include/op-attrs/ops/combine.h @@ -2,15 +2,12 @@ #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_COMBINE_H #include "op-attrs/ops/combine_attrs.dtg.h" -#include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "utils/record_formatter.h" #include namespace FlexFlow { -CHECK_VALID_OP_ATTR(CombineAttrs); - RecordFormatter as_dot(CombineAttrs const &); tl::expected diff --git a/lib/op-attrs/include/op-attrs/ops/concat.h b/lib/op-attrs/include/op-attrs/ops/concat.h index f07f06df85..1647553b96 100644 --- a/lib/op-attrs/include/op-attrs/ops/concat.h +++ b/lib/op-attrs/include/op-attrs/ops/concat.h @@ -2,14 +2,12 @@ #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONCAT_H #include "op-attrs/ops/concat_attrs.dtg.h" -#include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" +#include namespace FlexFlow { -CHECK_VALID_OP_ATTR(ConcatAttrs); - tl::expected get_output_shape(ConcatAttrs const &, std::vector const &); tl::expected diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d.h b/lib/op-attrs/include/op-attrs/ops/conv_2d.h index e4c7467de2..5ae4649571 100644 --- a/lib/op-attrs/include/op-attrs/ops/conv_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d.h @@ -1,17 +1,14 @@ -#ifndef _FLEXFLOW_CONV_2D_ATTRS_H -#define _FLEXFLOW_CONV_2D_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_H #include "op-attrs/incoming_tensor_role.dtg.h" #include "op-attrs/initializer_attrs.dtg.h" #include "op-attrs/ops/conv_2d_attrs.dtg.h" -#include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_shape.h" namespace FlexFlow { -CHECK_VALID_OP_ATTR(Conv2DAttrs); - std::vector get_conv2d_incoming_tensor_roles(Conv2DAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/ops/core.h b/lib/op-attrs/include/op-attrs/ops/core.h deleted file mode 100644 index 611b53def5..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/core.h +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_OPS_CORE_H -#define _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_OPS_CORE_H - -#include "utils/type_traits.h" - -namespace FlexFlow { - -#define CHECK_VALID_OP_ATTR(TYPENAME) CHECK_WELL_BEHAVED_VALUE_TYPE(TYPENAME) - -template -using is_valid_opattr = is_well_behaved_value_type; - -} // namespace FlexFlow - -#endif diff --git a/lib/op-attrs/include/op-attrs/ops/dropout.h b/lib/op-attrs/include/op-attrs/ops/dropout.h index 86e5db4d77..d5f3ae0c0d 100644 --- a/lib/op-attrs/include/op-attrs/ops/dropout.h +++ b/lib/op-attrs/include/op-attrs/ops/dropout.h @@ -1,10 +1,10 @@ -#ifndef _FLEXFLOW_DROPOUT_ATTRS_H -#define _FLEXFLOW_DROPOUT_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_DROPOUT_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_DROPOUT_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/dropout_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" +#include namespace FlexFlow { @@ -12,8 +12,6 @@ TensorShape get_output_shape(DropoutAttrs const &, TensorShape const &); tl::expected get_output_shape(DropoutAttrs const &, ParallelTensorShape const &); -CHECK_VALID_OP_ATTR(DropoutAttrs); - } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/element_binary.h b/lib/op-attrs/include/op-attrs/ops/element_binary.h index d51c3a3afa..970cf0f1b6 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_binary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_binary.h @@ -1,22 +1,18 @@ -#ifndef _FLEXFLOW_ELEMENT_BINARY_ATTRS_H -#define _FLEXFLOW_ELEMENT_BINARY_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_BINARY_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_BINARY_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/element_binary_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.h" -#include namespace FlexFlow { -tl::expected get_output_shape( +TensorShape get_output_shape( ElementBinaryAttrs const &, TensorShape const &, TensorShape const &); -tl::expected +ParallelTensorShape get_output_shape(ElementBinaryAttrs const &, ParallelTensorShape const &, ParallelTensorShape const &); -CHECK_VALID_OP_ATTR(ElementBinaryAttrs); - } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary.h b/lib/op-attrs/include/op-attrs/ops/element_unary.h index 1a965b2c51..655310fd85 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_unary.h @@ -1,7 +1,6 @@ -#ifndef _FLEXFLOW_ELEMENTARY_UNARY_ATTRS_H -#define _FLEXFLOW_ELEMENTARY_UNARY_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_UNARY_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_UNARY_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/element_unary_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" @@ -16,8 +15,6 @@ tl::expected tl::expected get_output_shape(ElementUnaryAttrs const &, ParallelTensorShape const &); -CHECK_VALID_OP_ATTR(ElementUnaryAttrs); - } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/embedding.h b/lib/op-attrs/include/op-attrs/ops/embedding.h index d44adf5f54..8bebf23488 100644 --- a/lib/op-attrs/include/op-attrs/ops/embedding.h +++ b/lib/op-attrs/include/op-attrs/ops/embedding.h @@ -1,8 +1,7 @@ -#ifndef _FLEXFLOW_EMBEDDING_ATTRS_H -#define _FLEXFLOW_EMBEDDING_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_EMBEDDING_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_EMBEDDING_H #include "op-attrs/initializer_attrs.dtg.h" -#include "op-attrs/ops/core.h" #include "op-attrs/ops/embedding_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_shape.h" @@ -11,8 +10,6 @@ namespace FlexFlow { -CHECK_VALID_OP_ATTR(EmbeddingAttrs); - RecordFormatter as_dot(EmbeddingAttrs const &); tl::expected get_output_shape(EmbeddingAttrs const &, diff --git a/lib/op-attrs/include/op-attrs/ops/flat.h b/lib/op-attrs/include/op-attrs/ops/flat.h index 710cbdb44b..34f19267c0 100644 --- a/lib/op-attrs/include/op-attrs/ops/flat.h +++ b/lib/op-attrs/include/op-attrs/ops/flat.h @@ -1,7 +1,6 @@ -#ifndef _FLEXFLOW_FLAT_ATTRS_H -#define _FLEXFLOW_FLAT_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_FLAT_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_FLAT_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/flat_attrs.dtg.h" #include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" @@ -9,13 +8,11 @@ namespace FlexFlow { -CHECK_VALID_OP_ATTR(FlatAttrs); - TensorShape get_output_shape(FlatAttrs const &, TensorShape const &); -tl::expected +ParallelTensorDimDegrees get_output_parallel_dim_degrees(FlatAttrs const &, ParallelTensorDimDegrees const &); -tl::expected +ParallelTensorShape get_output_shape(FlatAttrs const &, ParallelTensorShape const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/gather.h b/lib/op-attrs/include/op-attrs/ops/gather.h index 42efd13b60..3b67b9130b 100644 --- a/lib/op-attrs/include/op-attrs/ops/gather.h +++ b/lib/op-attrs/include/op-attrs/ops/gather.h @@ -1,14 +1,11 @@ -#ifndef _FLEXFLOW_GATHER_ATTRS_H -#define _FLEXFLOW_GATHER_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_GATHER_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_GATHER_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/gather_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { -CHECK_VALID_OP_ATTR(GatherAttrs); - TensorShape get_output_shape(GatherAttrs const &, TensorShape const &input, TensorShape const &index); diff --git a/lib/op-attrs/include/op-attrs/ops/input.h b/lib/op-attrs/include/op-attrs/ops/input.h index fe92c77a52..cf9c49f231 100644 --- a/lib/op-attrs/include/op-attrs/ops/input.h +++ b/lib/op-attrs/include/op-attrs/ops/input.h @@ -1,15 +1,12 @@ #ifndef _FLEXFLOW_OP_ATTRS_OPS_OP_ATTRS_INPUT_H #define _FLEXFLOW_OP_ATTRS_OPS_OP_ATTRS_INPUT_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/input_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { -CHECK_VALID_OP_ATTR(InputAttrs); - TensorShape get_output_shape(InputAttrs const &); ParallelTensorShape get_output_parallel_tensor_shape(InputAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/ops/layer_norm.h b/lib/op-attrs/include/op-attrs/ops/layer_norm.h index 4dcbeb665e..1d2cb14e99 100644 --- a/lib/op-attrs/include/op-attrs/ops/layer_norm.h +++ b/lib/op-attrs/include/op-attrs/ops/layer_norm.h @@ -1,12 +1,12 @@ -#ifndef _FLEXFLOW_OP_META_OPS_LAYER_NORM_ATTRS_H -#define _FLEXFLOW_OP_META_OPS_LAYER_NORM_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LAYER_NORM_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LAYER_NORM_H #include "op-attrs/incoming_tensor_role.dtg.h" #include "op-attrs/initializer_attrs.dtg.h" -#include "op-attrs/ops/core.h" #include "op-attrs/ops/layer_norm_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" +#include namespace FlexFlow { @@ -44,8 +44,6 @@ tl::expected, std::string> */ std::vector get_initializers(LayerNormAttrs const &attrs); -CHECK_VALID_OP_ATTR(LayerNormAttrs); - } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/linear.h b/lib/op-attrs/include/op-attrs/ops/linear.h index c1ec733dac..0c11090fd9 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear.h +++ b/lib/op-attrs/include/op-attrs/ops/linear.h @@ -1,16 +1,15 @@ -#ifndef _FLEXFLOW_LINEAR_ATTRS_H -#define _FLEXFLOW_LINEAR_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LINEAR_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LINEAR_H #include "op-attrs/incoming_tensor_role.dtg.h" #include "op-attrs/initializer_attrs.dtg.h" -#include "op-attrs/ops/core.h" #include "op-attrs/ops/linear_attrs.dtg.h" #include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" #include "utils/record_formatter.h" #include -#include "op-attrs/operator_space_parallel_tensor_space_mapping.dtg.h" +#include "op-attrs/operator_space_to_parallel_tensor_space_mapping.dtg.h" #include "op-attrs/parallel_tensor_space_mapping.dtg.h" namespace FlexFlow { @@ -18,8 +17,6 @@ namespace FlexFlow { std::vector get_linear_incoming_tensor_roles(LinearAttrs const &); -CHECK_VALID_OP_ATTR(LinearAttrs); - RecordFormatter as_dot(LinearAttrs const &); tl::expected @@ -65,9 +62,9 @@ tl::expected, std::string> get_initializers( tl::expected get_input_to_output_mapping(LinearAttrs const &attrs, nonnegative_int input_num_dims); -tl::expected +tl::expected get_operator_to_input_mapping(LinearAttrs const &attrs, nonnegative_int input_num_dims); -tl::expected +tl::expected get_operator_to_output_mapping(LinearAttrs const &attrs, nonnegative_int input_num_dims); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/loss_functions.h b/lib/op-attrs/include/op-attrs/ops/loss_functions.h index 657f8d91dc..c19d7f9e87 100644 --- a/lib/op-attrs/include/op-attrs/ops/loss_functions.h +++ b/lib/op-attrs/include/op-attrs/ops/loss_functions.h @@ -1,7 +1,6 @@ #ifndef _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LOSS_FUNCTIONS_H #define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LOSS_FUNCTIONS_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/loss_functions/loss_attrs.dtg.h" #include "op-attrs/ops/loss_functions/loss_function.dtg.h" #include "op-attrs/ops/loss_functions/nonconfigurable_loss_attrs.dtg.h" diff --git a/lib/op-attrs/include/op-attrs/ops/noop.h b/lib/op-attrs/include/op-attrs/ops/noop.h index 2c61dff886..8c8e191132 100644 --- a/lib/op-attrs/include/op-attrs/ops/noop.h +++ b/lib/op-attrs/include/op-attrs/ops/noop.h @@ -1,15 +1,12 @@ -#ifndef _FLEXFLOW_OP_ATTRS_OPS_NOOP_H -#define _FLEXFLOW_OP_ATTRS_OPS_NOOP_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_NOOP_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_NOOP_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/noop_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { -CHECK_VALID_OP_ATTR(NoopAttrs); - TensorShape get_output_shape(NoopAttrs const &, TensorShape const &); ParallelTensorShape get_output_shape(NoopAttrs const &, ParallelTensorShape const &); diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d.h b/lib/op-attrs/include/op-attrs/ops/pool_2d.h index 368250c957..016e632b33 100644 --- a/lib/op-attrs/include/op-attrs/ops/pool_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d.h @@ -1,16 +1,14 @@ -#ifndef _FLEXFLOW_POOL_2D_ATTRS_H -#define _FLEXFLOW_POOL_2D_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_POOL_2D_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_POOL_2D_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/pool_2d_attrs.dtg.h" #include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" +#include namespace FlexFlow { -CHECK_VALID_OP_ATTR(Pool2DAttrs); - tl::expected make_adaptive_pool2d_attrs(TensorDims const &input_dims, positive_int output_h, diff --git a/lib/op-attrs/include/op-attrs/ops/reduce.h b/lib/op-attrs/include/op-attrs/ops/reduce.h index 04e44b4161..5595ab9df5 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduce.h +++ b/lib/op-attrs/include/op-attrs/ops/reduce.h @@ -1,14 +1,11 @@ -#ifndef _FLEXFLOW_OP_META_OPS_REDUCE_ATTRS_H -#define _FLEXFLOW_OP_META_OPS_REDUCE_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCE_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCE_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/reduce_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { -CHECK_VALID_OP_ATTR(ReduceAttrs); - ParallelTensorShape get_output_shape(ReduceAttrs const &, ParallelTensorShape const &input_shape); diff --git a/lib/op-attrs/include/op-attrs/ops/reduction.h b/lib/op-attrs/include/op-attrs/ops/reduction.h index e8b2483cd5..b107178744 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduction.h +++ b/lib/op-attrs/include/op-attrs/ops/reduction.h @@ -1,7 +1,6 @@ -#ifndef _FLEXFLOW_REDUCTION_ATTRS_H -#define _FLEXFLOW_REDUCTION_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCTION_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCTION_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/reduction_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "utils/record_formatter.h" @@ -9,8 +8,6 @@ namespace FlexFlow { -CHECK_VALID_OP_ATTR(ReductionAttrs); - RecordFormatter as_dot(ReductionAttrs const &); tl::expected diff --git a/lib/op-attrs/include/op-attrs/ops/repartition.h b/lib/op-attrs/include/op-attrs/ops/repartition.h index b67486ed35..7733bc6989 100644 --- a/lib/op-attrs/include/op-attrs/ops/repartition.h +++ b/lib/op-attrs/include/op-attrs/ops/repartition.h @@ -1,7 +1,6 @@ -#ifndef _FLEXFLOW_PARTITION_ATTRS_H -#define _FLEXFLOW_PARTITION_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPARTITION_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPARTITION_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/repartition_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "utils/record_formatter.h" @@ -9,8 +8,6 @@ namespace FlexFlow { -CHECK_VALID_OP_ATTR(RepartitionAttrs); - RecordFormatter as_dot(RepartitionAttrs const &); tl::expected diff --git a/lib/op-attrs/include/op-attrs/ops/replicate.h b/lib/op-attrs/include/op-attrs/ops/replicate.h index 10a4636d27..6a6ecd3d1e 100644 --- a/lib/op-attrs/include/op-attrs/ops/replicate.h +++ b/lib/op-attrs/include/op-attrs/ops/replicate.h @@ -1,15 +1,12 @@ -#ifndef _FLEXFLOW_REPLICATE_ATTRS_H -#define _FLEXFLOW_REPLICATE_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPLICATE_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPLICATE_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/replicate_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "utils/record_formatter.h" namespace FlexFlow { -CHECK_VALID_OP_ATTR(ReplicateAttrs); - RecordFormatter as_dot(ReplicateAttrs const &); ParallelTensorShape get_output_shape(ReplicateAttrs const &attrs, diff --git a/lib/op-attrs/include/op-attrs/ops/reshape.h b/lib/op-attrs/include/op-attrs/ops/reshape.h index e87ca5c750..c7b8863ed6 100644 --- a/lib/op-attrs/include/op-attrs/ops/reshape.h +++ b/lib/op-attrs/include/op-attrs/ops/reshape.h @@ -1,14 +1,11 @@ -#ifndef _FLEXFLOW_RESHAPE_ATTRS_H -#define _FLEXFLOW_RESHAPE_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_RESHAPE_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_RESHAPE_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/reshape_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { -CHECK_VALID_OP_ATTR(ReshapeAttrs); - TensorShape get_output_shape(ReshapeAttrs const &attrs, TensorShape const &input_shape); ParallelTensorShape get_output_shape(ReshapeAttrs const &attrs, diff --git a/lib/op-attrs/include/op-attrs/ops/reverse.h b/lib/op-attrs/include/op-attrs/ops/reverse.h index 023e714c20..7b8ea7cbe5 100644 --- a/lib/op-attrs/include/op-attrs/ops/reverse.h +++ b/lib/op-attrs/include/op-attrs/ops/reverse.h @@ -1,15 +1,12 @@ #ifndef _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REVERSE_H #define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REVERSE_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/reverse_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { -CHECK_VALID_OP_ATTR(ReverseAttrs); - TensorShape get_output_shape(ReverseAttrs const &, TensorShape const &); ParallelTensorShape get_output_shape(ReverseAttrs const &attrs, ParallelTensorShape const &input_shape); diff --git a/lib/op-attrs/include/op-attrs/ops/softmax.h b/lib/op-attrs/include/op-attrs/ops/softmax.h index 6eacc66b78..63bd7f1736 100644 --- a/lib/op-attrs/include/op-attrs/ops/softmax.h +++ b/lib/op-attrs/include/op-attrs/ops/softmax.h @@ -1,15 +1,13 @@ -#ifndef _FLEXFLOW_SOFTMAX_ATTRS_H -#define _FLEXFLOW_SOFTMAX_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SOFTMAX_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SOFTMAX_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/softmax_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" +#include namespace FlexFlow { -CHECK_VALID_OP_ATTR(SoftmaxAttrs); - tl::expected get_output_shape(SoftmaxAttrs const &attrs, TensorShape const &input_shape); tl::expected diff --git a/lib/op-attrs/include/op-attrs/ops/split.h b/lib/op-attrs/include/op-attrs/ops/split.h index e6a08d6e77..b29a591b1b 100644 --- a/lib/op-attrs/include/op-attrs/ops/split.h +++ b/lib/op-attrs/include/op-attrs/ops/split.h @@ -1,7 +1,6 @@ -#ifndef _FLEXFLOW_SPLIT_ATTRS_H -#define _FLEXFLOW_SPLIT_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SPLIT_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SPLIT_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/split_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" @@ -9,8 +8,6 @@ namespace FlexFlow { -CHECK_VALID_OP_ATTR(SplitAttrs); - std::vector get_output_shapes(SplitAttrs const &, TensorShape const &); std::vector diff --git a/lib/op-attrs/include/op-attrs/ops/topk.h b/lib/op-attrs/include/op-attrs/ops/topk.h index d6de90903a..cf28d0f8e9 100644 --- a/lib/op-attrs/include/op-attrs/ops/topk.h +++ b/lib/op-attrs/include/op-attrs/ops/topk.h @@ -1,15 +1,12 @@ -#ifndef _FLEXFLOW_TOPK_ATTRS_H -#define _FLEXFLOW_TOPK_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TOPK_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TOPK_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/topk_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { -CHECK_VALID_OP_ATTR(TopKAttrs); - TensorShape get_output_shape(TopKAttrs const &, TensorShape const &); ParallelTensorShape get_output_shape(TopKAttrs const &attrs, ParallelTensorShape const &input_shape); diff --git a/lib/op-attrs/include/op-attrs/ops/transpose.h b/lib/op-attrs/include/op-attrs/ops/transpose.h index 6de83ee414..71012cab0d 100644 --- a/lib/op-attrs/include/op-attrs/ops/transpose.h +++ b/lib/op-attrs/include/op-attrs/ops/transpose.h @@ -1,15 +1,12 @@ -#ifndef _FLEXFLOW_OP_META_OPS_TRANSPOSE_ATTRS_H -#define _FLEXFLOW_OP_META_OPS_TRANSPOSE_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TRANSPOSE_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TRANSPOSE_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/transpose_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { -CHECK_VALID_OP_ATTR(TransposeAttrs); - TensorShape get_output_shape(TransposeAttrs const &, TensorShape const &); ParallelTensorShape get_output_shape(TransposeAttrs const &, ParallelTensorShape const &); diff --git a/lib/op-attrs/include/op-attrs/ops/weight.h b/lib/op-attrs/include/op-attrs/ops/weight.h index 66eb0064ed..a5bcce7e5c 100644 --- a/lib/op-attrs/include/op-attrs/ops/weight.h +++ b/lib/op-attrs/include/op-attrs/ops/weight.h @@ -1,7 +1,6 @@ #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_WEIGHT_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_WEIGHT_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/weight_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" @@ -9,8 +8,6 @@ namespace FlexFlow { -CHECK_VALID_OP_ATTR(WeightAttrs); - RecordFormatter as_dot(WeightAttrs const &); TensorShape get_output_shape(WeightAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.h b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.h index 22128ca74e..ea4c84409a 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.h @@ -9,7 +9,7 @@ namespace FlexFlow { std::set get_nontrivial_parallel_tensor_dim_indices(ParallelTensorDimDegrees const &); -std::unordered_map +std::unordered_map get_parallel_tensor_degree_map(ParallelTensorDimDegrees const &); std::unordered_set diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h index e366f99b8e..b83e3c9a98 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h @@ -48,7 +48,6 @@ positive_int get_total_parallel_degree(ParallelTensorShape const &); bool is_valid(ParallelTensorShape const &); TensorShape require_not_parallel(ParallelTensorShape const &); -TensorShape get_tensor_shape_unsafe(ParallelTensorShape const &); std::vector get_tensor_shapes_unsafe(std::vector const &); diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_space_coordinate.h b/lib/op-attrs/include/op-attrs/parallel_tensor_space_coordinate.h index d0da8033c1..a8cca2ff10 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_space_coordinate.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_space_coordinate.h @@ -7,7 +7,7 @@ namespace FlexFlow { ParallelTensorSpaceCoordinate - parallel_tensor_space_coord_from_map(std::unordered_map const &); + parallel_tensor_space_coord_from_map(std::unordered_map const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_space_coordinate.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_space_coordinate.struct.toml index 359c4b96a9..d82156d32d 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_space_coordinate.struct.toml +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_space_coordinate.struct.toml @@ -9,17 +9,18 @@ features = [ ] includes = [ - "op-attrs/dim_ordered/dim_ordered.h", + "op-attrs/ff_ordered/ff_ordered.h", + "utils/nonnegative_int/nonnegative_int.h", ] [[fields]] name = "sum_idx" -type = "int" +type = "::FlexFlow::nonnegative_int" [[fields]] name = "discard_copy_idx" -type = "int" +type = "::FlexFlow::nonnegative_int" [[fields]] name = "shard_idxs" -type = "::FlexFlow::FFOrdered" +type = "::FlexFlow::FFOrdered<::FlexFlow::nonnegative_int>" diff --git a/lib/op-attrs/include/op-attrs/task_space_coordinate.h b/lib/op-attrs/include/op-attrs/task_space_coordinate.h new file mode 100644 index 0000000000..55cb4a7a53 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/task_space_coordinate.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TASK_SPACE_COORDINATE_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TASK_SPACE_COORDINATE_H + +#include "op-attrs/task_space_coordinate.dtg.h" + +namespace FlexFlow { + +TaskSpaceCoordinate make_task_space_coordinate(std::vector const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/task_space_coordinate.struct.toml b/lib/op-attrs/include/op-attrs/task_space_coordinate.struct.toml index 508d0c21b6..4f8a281212 100644 --- a/lib/op-attrs/include/op-attrs/task_space_coordinate.struct.toml +++ b/lib/op-attrs/include/op-attrs/task_space_coordinate.struct.toml @@ -10,9 +10,9 @@ features = [ ] includes = [ - "utils/orthotope/orthotope_coordinate.dtg.h", + "utils/orthotope/orthotope_coord.dtg.h", ] [[fields]] name = "orthotope_coord" -type = "::FlexFlow::OrthotopeCoordinate" +type = "::FlexFlow::OrthotopeCoord" diff --git a/lib/op-attrs/src/op-attrs/dim_ordered/dim_ordered.cc b/lib/op-attrs/src/op-attrs/dim_ordered/dim_ordered.cc deleted file mode 100644 index 511c69d333..0000000000 --- a/lib/op-attrs/src/op-attrs/dim_ordered/dim_ordered.cc +++ /dev/null @@ -1 +0,0 @@ -#include "op-attrs/dim_ordered/dim_ordered.h" diff --git a/lib/op-attrs/src/op-attrs/dim_ordered/slice.cc b/lib/op-attrs/src/op-attrs/dim_ordered/slice.cc deleted file mode 100644 index 8c3dbd7bbc..0000000000 --- a/lib/op-attrs/src/op-attrs/dim_ordered/slice.cc +++ /dev/null @@ -1 +0,0 @@ -#include "op-attrs/dim_ordered/slice.h" diff --git a/lib/op-attrs/src/op-attrs/dim_ordered/transform.cc b/lib/op-attrs/src/op-attrs/dim_ordered/transform.cc deleted file mode 100644 index 73683eba94..0000000000 --- a/lib/op-attrs/src/op-attrs/dim_ordered/transform.cc +++ /dev/null @@ -1 +0,0 @@ -#include "op-attrs/dim_ordered/transform.h" diff --git a/lib/op-attrs/src/op-attrs/dim_ordered/zip.cc b/lib/op-attrs/src/op-attrs/dim_ordered/zip.cc deleted file mode 100644 index 208fc4a719..0000000000 --- a/lib/op-attrs/src/op-attrs/dim_ordered/zip.cc +++ /dev/null @@ -1 +0,0 @@ -#include "op-attrs/dim_ordered/zip.h" diff --git a/lib/op-attrs/src/op-attrs/ff_dim_t.cc b/lib/op-attrs/src/op-attrs/ff_dim_t.cc index 63c783d909..b0cdd2de4d 100644 --- a/lib/op-attrs/src/op-attrs/ff_dim_t.cc +++ b/lib/op-attrs/src/op-attrs/ff_dim_t.cc @@ -1,4 +1,6 @@ #include "op-attrs/ff_dim_t.h" +#include "utils/nonnegative_int/nonnegative_range.h" +#include "utils/containers/transform.h" namespace FlexFlow { @@ -11,6 +13,11 @@ ff_dim_t add_to_ff_dim(ff_dim_t ff_dim, int value) { return ff_dim_t{nonnegative_int{ff_dim.value.unwrap_nonnegative() + value}}; } +std::vector ff_dim_range(nonnegative_int num_elements) { + return transform(nonnegative_range(num_elements), + [](nonnegative_int idx) { return ff_dim_t{idx}; }); +} + } // namespace FlexFlow namespace rc { diff --git a/lib/op-attrs/src/op-attrs/ff_ordered/get_idxs.cc b/lib/op-attrs/src/op-attrs/ff_ordered/get_idxs.cc index 3da15bebba..7b93643735 100644 --- a/lib/op-attrs/src/op-attrs/ff_ordered/get_idxs.cc +++ b/lib/op-attrs/src/op-attrs/ff_ordered/get_idxs.cc @@ -5,6 +5,6 @@ namespace FlexFlow { using T = value_type<0>; -template std::vector get_idxs(FFOrdered const &); +template std::set get_idxs(FFOrdered const &); } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/operator_space_parallel_tensor_space_mapping.cc b/lib/op-attrs/src/op-attrs/operator_space_to_parallel_tensor_space_mapping.cc similarity index 82% rename from lib/op-attrs/src/op-attrs/operator_space_parallel_tensor_space_mapping.cc rename to lib/op-attrs/src/op-attrs/operator_space_to_parallel_tensor_space_mapping.cc index 6db994afb1..4392914295 100644 --- a/lib/op-attrs/src/op-attrs/operator_space_parallel_tensor_space_mapping.cc +++ b/lib/op-attrs/src/op-attrs/operator_space_to_parallel_tensor_space_mapping.cc @@ -1,4 +1,4 @@ -#include "op-attrs/operator_space_parallel_tensor_space_mapping.h" +#include "op-attrs/operator_space_to_parallel_tensor_space_mapping.h" #include "op-attrs/parallel_tensor_dim_degrees.h" #include "op-attrs/parallel_tensor_dim_idx_t.h" #include "utils/nonnegative_int/range.h" @@ -9,7 +9,7 @@ namespace FlexFlow { -OperatorSpaceParallelTensorSpaceMapping +OperatorSpaceToParallelTensorSpaceMapping get_identity_mapping(nonnegative_int num_shard_dims) { std::set parallel_tensor_dim_indices @@ -22,7 +22,7 @@ OperatorSpaceParallelTensorSpaceMapping bidict raw_bidict = bidict_from_keys_and_values(vector_of(operator_space_dim_indices), vector_of(parallel_tensor_dim_indices)); - return OperatorSpaceParallelTensorSpaceMapping{DimProjection{EqProjection{raw_bidict}}}; + return OperatorSpaceToParallelTensorSpaceMapping{DimProjection{EqProjection{raw_bidict}}}; } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/operator_task_space.cc b/lib/op-attrs/src/op-attrs/operator_task_space.cc index d612680de6..8aa157d90c 100644 --- a/lib/op-attrs/src/op-attrs/operator_task_space.cc +++ b/lib/op-attrs/src/op-attrs/operator_task_space.cc @@ -1,10 +1,7 @@ -#include "pcg/operator_task_space.h" +#include "op-attrs/operator_task_space.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/parallel_tensor_shape.h" -#include "pcg/operator_task_space.dtg.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.h" -#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" -#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" +#include "op-attrs/operator_task_space.dtg.h" #include "utils/containers/cartesian_product.h" #include "utils/containers/extend.h" #include "utils/containers/maximum.h" @@ -32,7 +29,7 @@ std::unordered_set unordered_set_of(cartesian_product(coordinate_ranges)); std::unordered_set task_space_coordinates = transform(raw_coordinates, [](std::vector const &point) { - return TaskSpaceCoordinate{point}; + return TaskSpaceCoordinate{OrthotopeCoord{point}}; }); return task_space_coordinates; } @@ -49,17 +46,4 @@ nonnegative_int num_dims(OperatorTaskSpace const &task) { positive_int num_tasks(OperatorTaskSpace const &task) { return product(task.degrees); } - -OperatorTaskSpace get_operator_task_space(ParallelComputationGraph const &pcg, - parallel_layer_guid_t const &layer) { - parallel_tensor_guid_t out_tensor = get_layer_outputs(pcg, layer).at(0); - ParallelTensorShape shape = get_parallel_tensor_shape(pcg, out_tensor); - - std::vector degrees; - extend(degrees, vector_of(ff_ordered_shard_degrees(shape))); - degrees.push_back(get_sum_degree(shape)); - degrees.push_back(get_discard_copy_degree(shape)); - return OperatorTaskSpace{degrees}; -} - } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/attention.cc b/lib/op-attrs/src/op-attrs/ops/attention.cc index cc6ef8cfac..d7744fa7db 100644 --- a/lib/op-attrs/src/op-attrs/ops/attention.cc +++ b/lib/op-attrs/src/op-attrs/ops/attention.cc @@ -5,8 +5,10 @@ #include "op-attrs/tensor_dims.h" #include "op-attrs/tensor_shape.h" #include "utils/containers/extend.h" +#include "utils/exception.h" #include "utils/expected.h" #include "utils/integer_conversions.h" +#include namespace FlexFlow { @@ -95,10 +97,9 @@ positive_int get_num_samples(MultiHeadAttentionInputs const &inputs) { } static void check_attrs(MultiHeadAttentionAttrs const &attrs) { - if (attrs.add_bias_kv) { - throw mk_runtime_error("add_bias_kv is not yet supported. If you need this " + ASSERT(!attrs.add_bias_kv, + "add_bias_kv is not yet supported. If you need this " "functionality, please create an issue."); - } } std::vector diff --git a/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc b/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc index 3c76561d17..3044e285e5 100644 --- a/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc +++ b/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc @@ -1,6 +1,7 @@ #include "op-attrs/ops/batch_matmul.h" #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_dims.h" +#include "utils/exception.h" namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/broadcast.cc b/lib/op-attrs/src/op-attrs/ops/broadcast.cc index d84a9ee46e..6e0c535f91 100644 --- a/lib/op-attrs/src/op-attrs/ops/broadcast.cc +++ b/lib/op-attrs/src/op-attrs/ops/broadcast.cc @@ -1,5 +1,6 @@ #include "op-attrs/ops/broadcast.h" #include "op-attrs/tensor_dims.h" +#include "utils/exception.h" #include "utils/record_formatter.h" namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc index 6ff1b8a06e..8c2e665454 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc @@ -4,6 +4,7 @@ #include "op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.h" #include "utils/fmt/optional.h" #include "utils/integer_conversions.h" +#include namespace FlexFlow { @@ -198,11 +199,7 @@ std::vector std::optional maybe_kernel_initializer, std::optional maybe_bias_initializer) { - if (!attrs.use_bias && maybe_bias_initializer.has_value()) { - throw mk_runtime_error(fmt::format( - "Unexpectedly received bias initializer while use_bias=false: {}", - maybe_bias_initializer)); - } + ASSERT(attrs.use_bias == maybe_bias_initializer.has_value()); TensorShape kernel_shape = get_kernel_shape(attrs, input_shape); diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.cc index 79bb14f2b2..7d8cd9aaff 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.cc @@ -1,6 +1,7 @@ #include "op-attrs/ops/conv_2d/conv_2d_input_shape.h" #include "op-attrs/tensor_dims.h" #include "op-attrs/tensor_shape.h" +#include namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/element_binary.cc b/lib/op-attrs/src/op-attrs/ops/element_binary.cc index 16957a036c..84da09cab6 100644 --- a/lib/op-attrs/src/op-attrs/ops/element_binary.cc +++ b/lib/op-attrs/src/op-attrs/ops/element_binary.cc @@ -1,8 +1,9 @@ #include "op-attrs/ops/element_binary.h" +#include "utils/exception.h" namespace FlexFlow { -tl::expected +TensorShape get_output_shape(ElementBinaryAttrs const &attrs, TensorShape const &input_lhs, TensorShape const &input_rhs) { @@ -13,18 +14,14 @@ tl::expected } else if (attrs.should_broadcast_rhs) { NOT_IMPLEMENTED(); } else { - if (input_lhs != input_rhs) { - return tl::unexpected(fmt::format( - "Expected input shapes to match, but receieved LHS ({}) != RHS ({})", - input_lhs, - input_rhs)); - } + ASSERT(input_lhs == input_rhs, + "Expected input shapes to match"); return input_lhs; } } -tl::expected +ParallelTensorShape get_output_shape(ElementBinaryAttrs const &attrs, ParallelTensorShape const &input_lhs, ParallelTensorShape const &input_rhs) { @@ -35,21 +32,13 @@ tl::expected } else if (attrs.should_broadcast_rhs) { NOT_IMPLEMENTED(); } else { - if (input_lhs != input_rhs) { - return tl::unexpected(fmt::format( - "Expected input shapes to match, but receieved LHS ({}) != RHS ({})", - input_lhs, - input_rhs)); - } + ASSERT(input_lhs == input_rhs, + "Expected input shapes to match"); switch (attrs.type) { case OperatorType::EW_ADD: { - if (get_discard_copy_degree(input_lhs) != 1) { - return tl::unexpected( - fmt::format("Elementwise Add expected discard copy degree of " - "inputs to be 1, but receieved {}", - get_discard_copy_degree(input_lhs))); - } + ASSERT(get_discard_copy_degree(input_lhs) == 1, + "Elementwise Add expected discard copy degree of inputs to be 1"); break; } @@ -64,8 +53,7 @@ tl::expected case OperatorType::EW_MIN: NOT_IMPLEMENTED(); default: - return tl::unexpected(fmt::format( - "Unexpected element-wise binary operator {}", attrs.type)); + PANIC("Unexpected element-wise binary operator", attrs.type); } return input_lhs; diff --git a/lib/op-attrs/src/op-attrs/ops/flat.cc b/lib/op-attrs/src/op-attrs/ops/flat.cc index 14180cecf8..aca776f36f 100644 --- a/lib/op-attrs/src/op-attrs/ops/flat.cc +++ b/lib/op-attrs/src/op-attrs/ops/flat.cc @@ -34,7 +34,7 @@ TensorShape get_output_shape(FlatAttrs const &attrs, }; } -tl::expected +ParallelTensorDimDegrees get_output_parallel_dim_degrees( FlatAttrs const &attrs, ParallelTensorDimDegrees const &input_degrees) { FFOrdered flattened_dim_degrees = @@ -44,14 +44,8 @@ tl::expected return input_degrees; } - if (any_of(flattened_dim_degrees, - [](positive_int degree) { return degree != 1; })) { - return tl::unexpected( - fmt::format("get_output_parallel_dim_degrees for {} expected all shard " - "degrees of flattened dimensions to be 1, but received {}", - attrs, - input_degrees)); - } + ASSERT(any_of(flattened_dim_degrees, [](positive_int degree) { return degree != 1; }), + "get_output_parallel_dim_degrees for {} expected all shard degrees of flattened dimensions to be 1"); return ParallelTensorDimDegrees{ /*sum_degree=*/input_degrees.sum_degree, @@ -65,20 +59,12 @@ tl::expected }; } -tl::expected +ParallelTensorShape get_output_shape(FlatAttrs const &attrs, ParallelTensorShape const &input_shape) { TensorShape unpar = get_output_shape(attrs, get_reduced_shape(input_shape)); - ParallelTensorDimDegrees degrees = ({ - tl::expected returned = - get_output_parallel_dim_degrees(attrs, - get_parallel_degrees(input_shape)); - if (!returned.has_value()) { - return tl::unexpected(returned.error()); - } - returned.value(); - }); + ParallelTensorDimDegrees degrees = get_output_parallel_dim_degrees(attrs, get_parallel_degrees(input_shape)); return lift_to_parallel_with_degrees(unpar, degrees); } diff --git a/lib/op-attrs/src/op-attrs/ops/gather.cc b/lib/op-attrs/src/op-attrs/ops/gather.cc index 4b1053aee1..2c5a4bbdc0 100644 --- a/lib/op-attrs/src/op-attrs/ops/gather.cc +++ b/lib/op-attrs/src/op-attrs/ops/gather.cc @@ -1,4 +1,5 @@ #include "op-attrs/ops/gather.h" +#include "utils/exception.h" namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/layer_norm.cc b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc index e0db1cdfe7..c58a2bba62 100644 --- a/lib/op-attrs/src/op-attrs/ops/layer_norm.cc +++ b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc @@ -9,6 +9,7 @@ #include "utils/containers/contains.h" #include "utils/containers/extend.h" #include "utils/containers/filter.h" +#include "utils/containers/vector_of.h" #include "utils/expected.h" #include "utils/fmt/set.h" @@ -72,7 +73,7 @@ tl::expected } std::vector non_layer_norm_dim_idxs = filter( - get_idxs(input_shape.dims.ff_ordered), + vector_of(get_idxs(input_shape.dims.ff_ordered)), [&](ff_dim_t const &dim_idx) { return !contains(attrs.axes, dim_idx); }); std::vector raw_weight_dims = transform(non_layer_norm_dim_idxs, [&](ff_dim_t const &dim_idx) { @@ -180,7 +181,7 @@ tl::expected } std::vector non_layer_norm_dim_idxs = filter( - get_idxs(input_shape.dims.shard_dims), + vector_of(get_idxs(input_shape.dims.shard_dims)), [&](ff_dim_t const &dim_idx) { return !contains(attrs.axes, dim_idx); }); std::vector raw_weight_shard_dims = transform(non_layer_norm_dim_idxs, [&](ff_dim_t const &dim_idx) { diff --git a/lib/op-attrs/src/op-attrs/ops/linear.cc b/lib/op-attrs/src/op-attrs/ops/linear.cc index ce8f7d5c0a..9da56de5cf 100644 --- a/lib/op-attrs/src/op-attrs/ops/linear.cc +++ b/lib/op-attrs/src/op-attrs/ops/linear.cc @@ -2,7 +2,10 @@ #include "op-attrs/ff_ordered/slice.h" #include "op-attrs/ff_ordered/transform.h" #include "op-attrs/initializers/kaiming_initializer_mode.h" +#include "op-attrs/operator_space_to_parallel_tensor_space_mapping.h" +#include "op-attrs/parallel_tensor_dim_idx_t.h" #include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/relative_ff_dim_t.h" #include "op-attrs/tensor_dims.h" #include "op-attrs/tensor_shape.h" #include "utils/containers/product.h" @@ -256,7 +259,7 @@ tl::expected /*from=*/{discard_copy_dim_idx()}, /*onto=*/shard_dim_idx(output_channel_dim)); - for (ff_dim_t const &idx : ff_dim_range(nonnegative_int{input_num_dims.get_value() - 1})) { + for (ff_dim_t const &idx : ff_dim_range(nonnegative_int{input_num_dims.unwrap_nonnegative() - 1})) { project_dims(inp_to_out, /*from=*/{shard_dim_idx(idx)}, /*onto=*/shard_dim_idx(idx)); @@ -265,7 +268,7 @@ tl::expected return ParallelTensorSpaceMapping{DimProjection{inp_to_out}}; } -tl::expected +tl::expected get_operator_to_input_projection(LinearAttrs const &attrs, nonnegative_int input_num_dims) { @@ -277,14 +280,14 @@ tl::expected EqProjection op_to_out = throw_if_unexpected(get_operator_to_output_mapping(attrs, input_num_dims)).raw_projection.require_eq_proj(); - return OperatorSpaceParallelTensorSpaceMapping{ + return OperatorSpaceToParallelTensorSpaceMapping{ DimProjection{ compose_up_projections(up_from_eq_proj(op_to_out), out_to_inp), }, }; } -tl::expected +tl::expected get_operator_to_output_mapping(LinearAttrs const &attrs, nonnegative_int input_num_shard_dims) { nonnegative_int output_num_shard_dims = input_num_shard_dims; diff --git a/lib/op-attrs/src/op-attrs/ops/reduce.cc b/lib/op-attrs/src/op-attrs/ops/reduce.cc index 2a8bf06ecf..e5474ae124 100644 --- a/lib/op-attrs/src/op-attrs/ops/reduce.cc +++ b/lib/op-attrs/src/op-attrs/ops/reduce.cc @@ -1,4 +1,5 @@ #include "op-attrs/ops/reduce.h" +#include "utils/exception.h" namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/reshape.cc b/lib/op-attrs/src/op-attrs/ops/reshape.cc index 6216ad8c6c..d8ea92d540 100644 --- a/lib/op-attrs/src/op-attrs/ops/reshape.cc +++ b/lib/op-attrs/src/op-attrs/ops/reshape.cc @@ -1,4 +1,5 @@ #include "op-attrs/ops/reshape.h" +#include "utils/exception.h" namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/reverse.cc b/lib/op-attrs/src/op-attrs/ops/reverse.cc index c38d7e4782..3a063d1af9 100644 --- a/lib/op-attrs/src/op-attrs/ops/reverse.cc +++ b/lib/op-attrs/src/op-attrs/ops/reverse.cc @@ -1,4 +1,5 @@ #include "op-attrs/ops/reverse.h" +#include "utils/exception.h" namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/split.cc b/lib/op-attrs/src/op-attrs/ops/split.cc index a9fe691584..ed737b18d9 100644 --- a/lib/op-attrs/src/op-attrs/ops/split.cc +++ b/lib/op-attrs/src/op-attrs/ops/split.cc @@ -1,4 +1,5 @@ #include "op-attrs/ops/split.h" +#include "utils/exception.h" namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/topk.cc b/lib/op-attrs/src/op-attrs/ops/topk.cc index 7a6868340b..179d6cfdd3 100644 --- a/lib/op-attrs/src/op-attrs/ops/topk.cc +++ b/lib/op-attrs/src/op-attrs/ops/topk.cc @@ -1,4 +1,5 @@ #include "op-attrs/ops/topk.h" +#include "utils/exception.h" namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/transpose.cc b/lib/op-attrs/src/op-attrs/ops/transpose.cc index 50e6fb35f5..08276d7b21 100644 --- a/lib/op-attrs/src/op-attrs/ops/transpose.cc +++ b/lib/op-attrs/src/op-attrs/ops/transpose.cc @@ -1,4 +1,5 @@ #include "op-attrs/ops/transpose.h" +#include "utils/exception.h" namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/parallel_op_attrs.cc b/lib/op-attrs/src/op-attrs/parallel_op_attrs.cc index c458d4149d..ccacd9bc3e 100644 --- a/lib/op-attrs/src/op-attrs/parallel_op_attrs.cc +++ b/lib/op-attrs/src/op-attrs/parallel_op_attrs.cc @@ -3,6 +3,7 @@ #include "op-attrs/ops/reduction.h" #include "op-attrs/ops/repartition.h" #include "op-attrs/ops/replicate.h" +#include "utils/exception.h" #include "utils/overload.h" namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dim_degrees.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dim_degrees.cc index aef224a31e..f082fb2514 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dim_degrees.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dim_degrees.cc @@ -1,5 +1,5 @@ #include "op-attrs/parallel_tensor_dim_degrees.h" -#include "op-attrs/dim_ordered/get_idxs.h" +#include "op-attrs/ff_ordered/get_idxs.h" #include "op-attrs/parallel_tensor_dim_idx_t.dtg.h" #include "op-attrs/parallel_tensor_space_coordinate.h" #include "utils/containers/filtrans.h" @@ -12,6 +12,7 @@ #include "utils/containers/unordered_set_of.h" #include "utils/containers/transform.h" #include "utils/containers/set_union.h" +#include "utils/nonnegative_int/nonnegative_range.h" namespace FlexFlow { @@ -38,19 +39,19 @@ std::set get_nontrivial_parallel_tensor_dim_indices(P return set_union(nontrivial_replica_dims, nontrivial_shard_dims); } -std::unordered_map +std::unordered_map get_parallel_tensor_degree_map(ParallelTensorDimDegrees const °rees) { - std::unordered_map replica_dim_degrees = { + std::unordered_map replica_dim_degrees = { {parallel_tensor_dim_idx_t{ReplicaType::SUM}, degrees.sum_degree.value}, {parallel_tensor_dim_idx_t{ReplicaType::DISCARD_COPY}, degrees.discard_copy_degree.value}, }; - std::unordered_map shard_dim_degrees = + std::unordered_map shard_dim_degrees = generate_map(get_idxs(degrees.shard_degrees), [&](ff_dim_t const &dim) { return degrees.shard_degrees.at(dim); }); - return merge_maps( + return merge_disjoint_maps( replica_dim_degrees, map_keys(shard_dim_degrees, [](ff_dim_t const &dim) { return parallel_tensor_dim_idx_t{dim}; })); } @@ -58,15 +59,15 @@ std::unordered_map std::unordered_set get_parallel_tensor_space_coordinates(ParallelTensorDimDegrees const °rees) { - std::unordered_map degree_map = get_parallel_tensor_degree_map(degrees); + std::unordered_map degree_map = get_parallel_tensor_degree_map(degrees); std::unordered_map< parallel_tensor_dim_idx_t, - std::unordered_set> possible_per_dim_coords - = map_values(degree_map, [](int degree) { return unordered_set_of(range(degree)); }); + std::unordered_set> possible_per_dim_coords + = map_values(degree_map, [](positive_int degree) { return unordered_set_of(nonnegative_range(degree)); }); return transform(get_all_assignments(possible_per_dim_coords), - [](std::unordered_map const &m) { + [](std::unordered_map const &m) { return parallel_tensor_space_coord_from_map(m); }); } diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc index 1b8f6f1dfa..31d7ba21cf 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc @@ -5,9 +5,11 @@ #include "utils/containers/product.h" #include "utils/containers/range.h" #include "utils/containers/transform.h" +#include "utils/exception.h" #include "utils/hash-utils.h" #include "utils/nonnegative_int/nonnegative_range.h" #include "utils/overload.h" +#include namespace FlexFlow { @@ -97,21 +99,13 @@ ParallelTensorShape TensorShape require_not_parallel(ParallelTensorShape const &s) { positive_int total_degree = get_total_parallel_degree(s); - if (total_degree != 1_p) { - throw mk_runtime_error( - fmt::format("Error: require_not_parallel received a parallel tensor " - "shape with parallel degree {}: {}", - total_degree, - s)); - } + ASSERT(total_degree != 1_p, + "Error: require_not_parallel received a parallel tensor shape with non-zero parallel degree", + s); return get_reduced_shape(s); } -TensorShape get_tensor_shape_unsafe(ParallelTensorShape const &) { - NOT_IMPLEMENTED(); -} - TensorShape get_piece_shape(ParallelTensorShape const &s) { return get_reduced_shape(s); } diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_space_coordinate.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_space_coordinate.cc index ec6b117b4e..97dd6bda1a 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_space_coordinate.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_space_coordinate.cc @@ -1,13 +1,13 @@ #include "op-attrs/parallel_tensor_space_coordinate.h" -#include "op-attrs/dim_ordered/ff_ordered_from_map.h" +#include "op-attrs/ff_ordered/ff_ordered_from_map.h" #include "utils/containers/filtermap_keys.h" namespace FlexFlow { ParallelTensorSpaceCoordinate - parallel_tensor_space_coord_from_map(std::unordered_map const &m) { + parallel_tensor_space_coord_from_map(std::unordered_map const &m) { - std::unordered_map shard_map = filtermap_keys + std::unordered_map shard_map = filtermap_keys (m, [](parallel_tensor_dim_idx_t const &d) { return d.try_require_shard_dim(); }); return ParallelTensorSpaceCoordinate{ diff --git a/lib/op-attrs/src/op-attrs/shape_inference.cc b/lib/op-attrs/src/op-attrs/shape_inference.cc index 4a0ff72fb4..47255f07d2 100644 --- a/lib/op-attrs/src/op-attrs/shape_inference.cc +++ b/lib/op-attrs/src/op-attrs/shape_inference.cc @@ -69,7 +69,7 @@ std::vector [&](ElementBinaryAttrs const &attrs) -> std::vector { auto [i1, i2] = require_2(input_shapes); - return {throw_if_unexpected(get_output_shape(attrs, i1, i2))}; + return {get_output_shape(attrs, i1, i2)}; }, [&](ElementUnaryAttrs const &attrs) -> std::vector { return {throw_if_unexpected( @@ -203,7 +203,7 @@ std::vector [&](ElementBinaryAttrs const &attrs) -> std::vector { auto [i1, i2] = require_2(input_shapes); - return {throw_if_unexpected(get_output_shape(attrs, i1, i2))}; + return {get_output_shape(attrs, i1, i2)}; }, [&](ElementUnaryAttrs const &attrs) -> std::vector { return {throw_if_unexpected( @@ -214,8 +214,7 @@ std::vector get_output_shape(attrs, get_only(input_shapes)))}; }, [&](FlatAttrs const &attrs) -> std::vector { - return {throw_if_unexpected( - get_output_shape(attrs, get_only(input_shapes)))}; + return {get_output_shape(attrs, get_only(input_shapes))}; }, [&](GatherAttrs const &attrs) -> std::vector { return { diff --git a/lib/op-attrs/src/op-attrs/task_space_coordinate.cc b/lib/op-attrs/src/op-attrs/task_space_coordinate.cc new file mode 100644 index 0000000000..01f3798fd2 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/task_space_coordinate.cc @@ -0,0 +1,9 @@ +#include "op-attrs/task_space_coordinate.h" + +namespace FlexFlow { + +TaskSpaceCoordinate make_task_space_coordinate(std::vector const &elems) { + return TaskSpaceCoordinate{OrthotopeCoord{elems}}; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/parallel_dim_mapping_record.cc b/lib/op-attrs/src/parallel_dim_mapping_record.cc deleted file mode 100644 index 5e734e88cd..0000000000 --- a/lib/op-attrs/src/parallel_dim_mapping_record.cc +++ /dev/null @@ -1,60 +0,0 @@ -#include "parallel_dim_mapping_record.h" -#include - -namespace FlexFlow { - -ParallelDimMappingRecord::ParallelDimMappingRecord(MappingRecordType type) - : type(type), output_dim(-1), input_dim(-1), weight_dim(-1), output_idx(-1), - input_idx(-1), weight_idx(-1) {} - -/*static*/ -ParallelDimMappingRecord ParallelDimMappingRecord::input_output_record( - int input_idx, - int input_dim, - int output_idx, - int output_dim, - std::optional operation) { - ParallelDimMappingRecord r(MappingRecordType::INPUT_OUTPUT); - r.operation = operation; - - assert(output_idx >= 0); - assert(output_dim >= 0); - assert(input_idx >= 0); - assert(input_dim >= 0); - - r.output_idx = output_idx; - r.output_dim = output_dim; - r.input_idx = input_idx; - r.input_dim = input_dim; - - return r; -} - -/*static*/ -ParallelDimMappingRecord ParallelDimMappingRecord::input_weight_record( - int input_idx, - int input_dim, - int weight_idx, - int weight_dim, - std::optional operation) { - ParallelDimMappingRecord r(MappingRecordType::INPUT_WEIGHT); - r.operation = operation; - - assert(input_idx >= 0); - assert(input_dim >= 0); - assert(weight_idx >= 0); - assert(weight_dim >= 0); - - r.input_idx = input_idx; - r.input_dim = input_dim; - r.weight_idx = weight_idx; - r.weight_dim = weight_dim; - - return r; -} - -MappingRecordType ParallelDimMappingRecord::get_type() const { - return this->type; -} - -} // namespace FlexFlow diff --git a/lib/op-attrs/src/parallel_dim_mapping_record.h b/lib/op-attrs/src/parallel_dim_mapping_record.h deleted file mode 100644 index c37ac79b40..0000000000 --- a/lib/op-attrs/src/parallel_dim_mapping_record.h +++ /dev/null @@ -1,54 +0,0 @@ -#ifndef _FLEXFLOW_OP_META_SRC_PARELLEL_DIM_MAPPING_RECORD_H -#define _FLEXFLOW_OP_META_SRC_PARELLEL_DIM_MAPPING_RECORD_H - -#include "utils/visitable.h" -#include - -namespace FlexFlow { - -enum class MappingRecordType { INPUT_OUTPUT, INPUT_WEIGHT }; - -enum class MappingOperation { PARTITION, REPLICATE }; - -class ParallelDimMappingRecord { -private: - ParallelDimMappingRecord(MappingRecordType); - -public: - ParallelDimMappingRecord() = delete; - - static ParallelDimMappingRecord input_output_record( - int input_idx, - int input_dim, - int output_idx, - int output_dim, - std::optional operation = std::nullopt); - static ParallelDimMappingRecord input_weight_record( - int input_idx, - int input_dim, - int weight_idx, - int weight_dim, - std::optional operation = std::nullopt); - MappingRecordType get_type() const; - -public: - MappingRecordType type; - std::optional operation; - - int output_dim, input_dim, weight_dim; - int output_idx, input_idx, weight_idx; -}; - -} // namespace FlexFlow - -VISITABLE_STRUCT(::FlexFlow::ParallelDimMappingRecord, - type, - operation, - output_dim, - input_dim, - weight_dim, - output_idx, - input_idx, - weight_idx); - -#endif diff --git a/lib/op-attrs/test/src/op-attrs/dim_ordered/dim_ordered.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/dim_ordered.cc deleted file mode 100644 index a5a261da25..0000000000 --- a/lib/op-attrs/test/src/op-attrs/dim_ordered/dim_ordered.cc +++ /dev/null @@ -1,13 +0,0 @@ -#include "op-attrs/dim_ordered/dim_ordered.h" -#include "doctest/doctest.h" -#include "test/utils/rapidcheck.h" - -using namespace FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - - TEST_CASE_TEMPLATE( - "Arbitrary> with T=", T, int, double, char) { - RC_SUBCASE([](DimOrdered) {}); - } -} diff --git a/lib/op-attrs/test/src/op-attrs/dim_ordered/zip.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/zip.cc deleted file mode 100644 index b77bb8f71e..0000000000 --- a/lib/op-attrs/test/src/op-attrs/dim_ordered/zip.cc +++ /dev/null @@ -1,41 +0,0 @@ -#include "op-attrs/dim_ordered/zip.h" -#include "op-attrs/ff_dim_t.dtg.h" -#include "test/utils/doctest/fmt/pair.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("zip(DimOrdered, DimOrdered)") { - DimOrdered lhs_input = {9, 9, 8, 9}; - DimOrdered rhs_input = {"m", "m", "k", "l", "m"}; - - SUBCASE("lhs is longer") { - DimOrdered> result = - zip(lhs_input, rhs_input); - - DimOrdered> correct = { - {9, "m"}, - {9, "m"}, - {8, "k"}, - {9, "l"}, - }; - - CHECK(result == correct); - } - - SUBCASE("rhs is longer") { - DimOrdered> result = - zip(rhs_input, lhs_input); - - DimOrdered> correct = { - {"m", 9}, - {"m", 9}, - {"k", 8}, - {"l", 9}, - }; - - CHECK(result == correct); - } - } -} diff --git a/lib/op-attrs/test/src/op-attrs/operator_space_parallel_tensor_space_mapping.cc b/lib/op-attrs/test/src/op-attrs/operator_space_to_parallel_tensor_space_mapping.cc similarity index 71% rename from lib/op-attrs/test/src/op-attrs/operator_space_parallel_tensor_space_mapping.cc rename to lib/op-attrs/test/src/op-attrs/operator_space_to_parallel_tensor_space_mapping.cc index bab662d9f0..e94df07fed 100644 --- a/lib/op-attrs/test/src/op-attrs/operator_space_parallel_tensor_space_mapping.cc +++ b/lib/op-attrs/test/src/op-attrs/operator_space_to_parallel_tensor_space_mapping.cc @@ -1,4 +1,4 @@ -#include "op-attrs/operator_space_parallel_tensor_space_mapping.h" +#include "op-attrs/operator_space_to_parallel_tensor_space_mapping.h" #include "op-attrs/parallel_tensor_dim_idx_t.h" #include @@ -16,8 +16,9 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_identity_mapping(ParallelTensorDimDegrees)") { nonnegative_int num_shard_dims = nonnegative_int{2}; - OperatorSpaceParallelTensorSpaceMapping result = get_identity_mapping(num_shard_dims); + OperatorSpaceToParallelTensorSpaceMapping result = get_identity_mapping(num_shard_dims); - CHECK(result == correct); + NOT_IMPLEMENTED(); + // CHECK(result == correct); } } diff --git a/lib/op-attrs/test/src/op-attrs/operator_task_space.cc b/lib/op-attrs/test/src/op-attrs/operator_task_space.cc index 4b01ed02fb..56785a5065 100644 --- a/lib/op-attrs/test/src/op-attrs/operator_task_space.cc +++ b/lib/op-attrs/test/src/op-attrs/operator_task_space.cc @@ -1,4 +1,4 @@ -#include "pcg/operator_task_space.h" +#include "op-attrs/operator_task_space.h" #include "utils/fmt/unordered_set.h" #include @@ -11,7 +11,7 @@ TEST_SUITE(FF_TEST_SUITE) { OperatorTaskSpace task = OperatorTaskSpace{{}}; std::unordered_set correct = { - TaskSpaceCoordinate{{}}}; + TaskSpaceCoordinate{OrthotopeCoord{{}}}}; std::unordered_set result = get_task_space_coordinates(task); CHECK(correct == result); @@ -21,10 +21,10 @@ TEST_SUITE(FF_TEST_SUITE) { OperatorTaskSpace task = OperatorTaskSpace{{2_p, 2_p}}; std::unordered_set correct = {{ - TaskSpaceCoordinate{{0_n, 0_n}}, - TaskSpaceCoordinate{{0_n, 1_n}}, - TaskSpaceCoordinate{{1_n, 0_n}}, - TaskSpaceCoordinate{{1_n, 1_n}}, + TaskSpaceCoordinate{OrthotopeCoord{{0_n, 0_n}}}, + TaskSpaceCoordinate{OrthotopeCoord{{0_n, 1_n}}}, + TaskSpaceCoordinate{OrthotopeCoord{{1_n, 0_n}}}, + TaskSpaceCoordinate{OrthotopeCoord{{1_n, 1_n}}}, }}; std::unordered_set result = get_task_space_coordinates(task); @@ -35,10 +35,10 @@ TEST_SUITE(FF_TEST_SUITE) { OperatorTaskSpace task = OperatorTaskSpace{{1_p, 2_p, 2_p}}; std::unordered_set correct = {{ - TaskSpaceCoordinate{{0_n, 0_n, 0_n}}, - TaskSpaceCoordinate{{0_n, 0_n, 1_n}}, - TaskSpaceCoordinate{{0_n, 1_n, 0_n}}, - TaskSpaceCoordinate{{0_n, 1_n, 1_n}}, + TaskSpaceCoordinate{OrthotopeCoord{{0_n, 0_n, 0_n}}}, + TaskSpaceCoordinate{OrthotopeCoord{{0_n, 0_n, 1_n}}}, + TaskSpaceCoordinate{OrthotopeCoord{{0_n, 1_n, 0_n}}}, + TaskSpaceCoordinate{OrthotopeCoord{{0_n, 1_n, 1_n}}}, }}; std::unordered_set result = get_task_space_coordinates(task); @@ -50,7 +50,7 @@ TEST_SUITE(FF_TEST_SUITE) { OperatorTaskSpace task = OperatorTaskSpace{{3_p, 2_p}}; - TaskSpaceCoordinate correct = TaskSpaceCoordinate{{2_n, 1_n}}; + TaskSpaceCoordinate correct = TaskSpaceCoordinate{OrthotopeCoord{{2_n, 1_n}}}; TaskSpaceCoordinate result = get_task_space_maximum_coordinate(task); CHECK(correct == result); } @@ -58,7 +58,7 @@ TEST_SUITE(FF_TEST_SUITE) { OperatorTaskSpace task = OperatorTaskSpace{{3_p, 2_p, 4_p}}; - TaskSpaceCoordinate correct = TaskSpaceCoordinate{{2_n, 1_n, 3_n}}; + TaskSpaceCoordinate correct = TaskSpaceCoordinate{OrthotopeCoord{{2_n, 1_n, 3_n}}}; TaskSpaceCoordinate result = get_task_space_maximum_coordinate(task); CHECK(correct == result); } diff --git a/lib/op-attrs/test/src/op-attrs/ops/batch_matmul.cc b/lib/op-attrs/test/src/op-attrs/ops/batch_matmul.cc index d251fb731d..1044c379f0 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/batch_matmul.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/batch_matmul.cc @@ -12,9 +12,9 @@ TEST_SUITE(FF_TEST_SUITE) { positive_int p = 10_p; BatchMatmulAttrs attrs = BatchMatmulAttrs{ - /*a_seq_length_dim=*/0_n, // TODO figure out if these arguments are + /*a_seq_length_dim=*/1_p, // TODO figure out if these arguments are // still relevant - /*b_seq_length_dim=*/0_n, + /*b_seq_length_dim=*/1_p, }; TensorShape input_lhs_shape = TensorShape{ @@ -106,9 +106,9 @@ TEST_SUITE(FF_TEST_SUITE) { positive_int o_sum = 11_p; BatchMatmulAttrs attrs = BatchMatmulAttrs{ - /*a_seq_length_dim=*/0_n, // TODO figure out if these arguments are + /*a_seq_length_dim=*/0_p, // TODO figure out if these arguments are // still relevant - /*b_seq_length_dim=*/0_n, + /*b_seq_length_dim=*/0_p, }; auto make_lhs = [&](SumDegree o_sum, diff --git a/lib/op-attrs/test/src/op-attrs/ops/conv_2d.cc b/lib/op-attrs/test/src/op-attrs/ops/conv_2d.cc index 56407c03f1..7dcd24ca0d 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/conv_2d.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/conv_2d.cc @@ -1,5 +1,5 @@ #include "op-attrs/ops/conv_2d.h" -#include "doctest/doctest.h" +#include #include "utils/integer_conversions.h" using namespace ::FlexFlow; diff --git a/lib/op-attrs/test/src/op-attrs/ops/flat.cc b/lib/op-attrs/test/src/op-attrs/ops/flat.cc index c4fe8a5250..c4cbf33cd1 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/flat.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/flat.cc @@ -148,11 +148,7 @@ TEST_SUITE(FF_TEST_SUITE) { FFOrdered{1_p, 1_p, 2_p, 1_p}, }; - std::optional result = - optional_from_expected(get_output_parallel_dim_degrees(attrs, input)); - std::optional correct = std::nullopt; - - CHECK(result == correct); + CHECK_THROWS(get_output_parallel_dim_degrees(attrs, input)); } SUBCASE("allows sum parallelism") { @@ -162,9 +158,9 @@ TEST_SUITE(FF_TEST_SUITE) { FFOrdered{1_p, 1_p, 1_p, 1_p}, }; - std::optional result = - optional_from_expected(get_output_parallel_dim_degrees(attrs, input)); - std::optional correct = + ParallelTensorDimDegrees result = + get_output_parallel_dim_degrees(attrs, input); + ParallelTensorDimDegrees correct = ParallelTensorDimDegrees{ SumDegree{2_p}, DiscardCopyDegree{1_p}, @@ -181,9 +177,9 @@ TEST_SUITE(FF_TEST_SUITE) { FFOrdered{1_p, 1_p, 1_p, 1_p}, }; - std::optional result = - optional_from_expected(get_output_parallel_dim_degrees(attrs, input)); - std::optional correct = + ParallelTensorDimDegrees result = + get_output_parallel_dim_degrees(attrs, input); + ParallelTensorDimDegrees correct = ParallelTensorDimDegrees{ SumDegree{1_p}, DiscardCopyDegree{2_p}, diff --git a/lib/op-attrs/test/src/op-attrs/parallel_tensor_dim_degrees.cc b/lib/op-attrs/test/src/op-attrs/parallel_tensor_dim_degrees.cc index 314a7f2ae5..3f0015bd2d 100644 --- a/lib/op-attrs/test/src/op-attrs/parallel_tensor_dim_degrees.cc +++ b/lib/op-attrs/test/src/op-attrs/parallel_tensor_dim_degrees.cc @@ -14,22 +14,22 @@ static parallel_tensor_dim_idx_t shard_dim_idx_from_raw(int idx) { TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_parallel_tensor_degree_map") { ParallelTensorDimDegrees degrees = ParallelTensorDimDegrees{ - SumDegree{3}, - DiscardCopyDegree{1}, - FFOrdered{ - 1, - 2, - 1, + SumDegree{3_p}, + DiscardCopyDegree{1_p}, + FFOrdered{ + 1_p, + 2_p, + 1_p, }, }; - std::unordered_map result = get_parallel_tensor_degree_map(degrees); - std::unordered_map correct = { - {parallel_tensor_dim_idx_t{ReplicaType::SUM}, 3}, - {parallel_tensor_dim_idx_t{ReplicaType::DISCARD_COPY}, 1}, - {shard_dim_idx_from_raw(0), 1}, - {shard_dim_idx_from_raw(1), 2}, - {shard_dim_idx_from_raw(2), 1}, + std::unordered_map result = get_parallel_tensor_degree_map(degrees); + std::unordered_map correct = { + {parallel_tensor_dim_idx_t{ReplicaType::SUM}, 3_p}, + {parallel_tensor_dim_idx_t{ReplicaType::DISCARD_COPY}, 1_p}, + {shard_dim_idx_from_raw(0), 1_p}, + {shard_dim_idx_from_raw(1), 2_p}, + {shard_dim_idx_from_raw(2), 1_p}, }; CHECK(result == correct); @@ -37,46 +37,46 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_parallel_tensor_space_coordinates") { ParallelTensorDimDegrees degrees = ParallelTensorDimDegrees{ - SumDegree{3}, - DiscardCopyDegree{1}, - FFOrdered{ - 1, - 2, - 1, + SumDegree{3_p}, + DiscardCopyDegree{1_p}, + FFOrdered{ + 1_p, + 2_p, + 1_p, }, }; std::unordered_set result = get_parallel_tensor_space_coordinates(degrees); std::unordered_set correct = { ParallelTensorSpaceCoordinate{ - /*sum_idx=*/0, - /*discard_copy_idx=*/0, - /*shard_idxs=*/FFOrdered{0, 0, 0}, + /*sum_idx=*/0_n, + /*discard_copy_idx=*/0_n, + /*shard_idxs=*/FFOrdered{0_n, 0_n, 0_n}, }, ParallelTensorSpaceCoordinate{ - /*sum_idx=*/1, - /*discard_copy_idx=*/0, - /*shard_idxs=*/FFOrdered{0, 0, 0}, + /*sum_idx=*/1_n, + /*discard_copy_idx=*/0_n, + /*shard_idxs=*/FFOrdered{0_n, 0_n, 0_n}, }, ParallelTensorSpaceCoordinate{ - /*sum_idx=*/2, - /*discard_copy_idx=*/0, - /*shard_idxs=*/FFOrdered{0, 0, 0}, + /*sum_idx=*/2_n, + /*discard_copy_idx=*/0_n, + /*shard_idxs=*/FFOrdered{0_n, 0_n, 0_n}, }, ParallelTensorSpaceCoordinate{ - /*sum_idx=*/0, - /*discard_copy_idx=*/0, - /*shard_idxs=*/FFOrdered{0, 1, 0}, + /*sum_idx=*/0_n, + /*discard_copy_idx=*/0_n, + /*shard_idxs=*/FFOrdered{0_n, 1_n, 0_n}, }, ParallelTensorSpaceCoordinate{ - /*sum_idx=*/1, - /*discard_copy_idx=*/0, - /*shard_idxs=*/FFOrdered{0, 1, 0}, + /*sum_idx=*/1_n, + /*discard_copy_idx=*/0_n, + /*shard_idxs=*/FFOrdered{0_n, 1_n, 0_n}, }, ParallelTensorSpaceCoordinate{ - /*sum_idx=*/2, - /*discard_copy_idx=*/0, - /*shard_idxs=*/FFOrdered{0, 1, 0}, + /*sum_idx=*/2_n, + /*discard_copy_idx=*/0_n, + /*shard_idxs=*/FFOrdered{0_n, 1_n, 0_n}, }, }; @@ -86,9 +86,9 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_nontrivial_parallel_tensor_dim_indices(ParallelTensorDimDegrees)") { SUBCASE("a replica dim has degree 1") { ParallelTensorDimDegrees degrees = ParallelTensorDimDegrees{ - SumDegree{3}, - DiscardCopyDegree{1}, - FFOrdered{4, 2, 4}, + SumDegree{3_p}, + DiscardCopyDegree{1_p}, + FFOrdered{4_p, 2_p, 4_p}, }; std::set result = get_nontrivial_parallel_tensor_dim_indices(degrees); @@ -104,9 +104,9 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("a shard dim has degree 1") { ParallelTensorDimDegrees degrees = ParallelTensorDimDegrees{ - SumDegree{3}, - DiscardCopyDegree{2}, - FFOrdered{1, 4, 1}, + SumDegree{3_p}, + DiscardCopyDegree{2_p}, + FFOrdered{1_p, 4_p, 1_p}, }; std::set result = get_nontrivial_parallel_tensor_dim_indices(degrees); @@ -121,9 +121,9 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("no dims have degree 1") { ParallelTensorDimDegrees degrees = ParallelTensorDimDegrees{ - SumDegree{3}, - DiscardCopyDegree{2}, - FFOrdered{4, 2, 5}, + SumDegree{3_p}, + DiscardCopyDegree{2_p}, + FFOrdered{4_p, 2_p, 5_p}, }; std::set result = get_nontrivial_parallel_tensor_dim_indices(degrees); @@ -140,9 +140,9 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("all dims have degree 1") { ParallelTensorDimDegrees degrees = ParallelTensorDimDegrees{ - SumDegree{1}, - DiscardCopyDegree{1}, - FFOrdered{1, 1, 1}, + SumDegree{1_p}, + DiscardCopyDegree{1_p}, + FFOrdered{1_p, 1_p, 1_p}, }; std::set result = get_nontrivial_parallel_tensor_dim_indices(degrees); diff --git a/lib/pcg/include/pcg/model_compilation.h b/lib/pcg/include/pcg/model_compilation.h deleted file mode 100644 index 1ab66161ec..0000000000 --- a/lib/pcg/include/pcg/model_compilation.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_MODEL_COMPILATION_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_MODEL_COMPILATION_H - -#include "pcg/computation_graph.h" -#include "pcg/optimizer.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.h" -#include "pcg/tensor_mapping.h" - -namespace FlexFlow { - -struct ModelCompilationInput { - ComputationGraph computation_graph; - Optimizer optimizer; -}; -FF_VISITABLE_STRUCT(ModelCompilationInput, computation_graph, optimizer); - -struct ModelCompilationResult { - ModelCompilationInput input; - ParallelComputationGraph pcg; - req tensor_mapping; -}; -FF_VISITABLE_STRUCT(ModelCompilationResult, input, pcg, tensor_mapping); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h index 3542e73dea..58b1bbfcaf 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_H #define _FLEXFLOW_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_H +#include "op-attrs/operator_task_space.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" #include "pcg/parallel_computation_graph/parallel_layer_added_result.dtg.h" @@ -27,6 +28,9 @@ ParallelLayerAddedResult add_parallel_layer( ParallelLayerAddedResult pcg_add_input_layer(ParallelComputationGraph &pcg, TensorShape const &tensor_shape); +OperatorTaskSpace get_operator_task_space(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &layer); + std::unordered_set get_pcg_edges_from_layer_to_layer(ParallelComputationGraph const &, parallel_layer_guid_t const &, diff --git a/lib/pcg/src/pcg/cg_operator_tensor_shape_signature.cc b/lib/pcg/src/pcg/cg_operator_tensor_shape_signature.cc index 90ffb85c9b..54624815fc 100644 --- a/lib/pcg/src/pcg/cg_operator_tensor_shape_signature.cc +++ b/lib/pcg/src/pcg/cg_operator_tensor_shape_signature.cc @@ -1,4 +1,5 @@ #include "pcg/cg_operator_tensor_shape_signature.h" +#include namespace FlexFlow { diff --git a/lib/pcg/src/pcg/file_format/v1/v1_computation_graph.cc b/lib/pcg/src/pcg/file_format/v1/v1_computation_graph.cc index 3511ccc269..d9a54f5530 100644 --- a/lib/pcg/src/pcg/file_format/v1/v1_computation_graph.cc +++ b/lib/pcg/src/pcg/file_format/v1/v1_computation_graph.cc @@ -1,5 +1,6 @@ #include "pcg/file_format/v1/v1_computation_graph.h" #include "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h" +#include "utils/bidict/algorithms/transform_values.h" namespace FlexFlow { @@ -17,7 +18,7 @@ std::pair> to_v1_including_node_numbering(cg.raw_graph); V1ComputationGraph v1_cg = V1ComputationGraph{raw.first}; bidict v1_node_ids = - map_values(raw.second, [](Node const &n) { return layer_guid_t{n}; }); + transform_values(raw.second, [](Node const &n) { return layer_guid_t{n}; }); return {v1_cg, v1_node_ids}; } diff --git a/lib/pcg/src/pcg/machine_view.cc b/lib/pcg/src/pcg/machine_view.cc index 0fbb021a55..56736a2c48 100644 --- a/lib/pcg/src/pcg/machine_view.cc +++ b/lib/pcg/src/pcg/machine_view.cc @@ -4,8 +4,8 @@ #include "pcg/machine_specification.h" #include "pcg/machine_specification_dimension.dtg.h" #include "pcg/machine_view_dimension.dtg.h" -#include "pcg/operator_task_space.dtg.h" -#include "pcg/operator_task_space.h" +#include "op-attrs/operator_task_space.dtg.h" +#include "op-attrs/operator_task_space.h" #include "pcg/stride_t.dtg.h" #include "utils/containers/contains.h" #include "utils/containers/count.h" @@ -98,7 +98,7 @@ std::optional get_machine_space_coordinate( }); std::vector coord_points = transform(dimension_indices, [&](nonnegative_int i) { - return coord.raw_coord.at(i.unwrap_nonnegative()); + return coord.orthotope_coord.raw.at(i.unwrap_nonnegative()); }); std::vector strides = transform(dimension_indices, [&](nonnegative_int i) { diff --git a/lib/pcg/src/pcg/parallel_computation_graph/generate_weight_transform.cc b/lib/pcg/src/pcg/parallel_computation_graph/generate_weight_transform.cc index e3caffe260..5bb5aedaf8 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/generate_weight_transform.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/generate_weight_transform.cc @@ -1,6 +1,7 @@ #include "pcg/parallel_computation_graph/generate_weight_transform.h" #include "op-attrs/ff_ordered/enumerate.h" #include "op-attrs/parallel_tensor_shape.h" +#include namespace FlexFlow { @@ -10,12 +11,8 @@ std::unordered_set std::unordered_set result; positive_int sum_degree = get_sum_degree(goal); - if (sum_degree != 1) { - throw mk_runtime_error( - fmt::format("generate_weight_transform currently only supports " - "sum_degree = 1, but received {}", - sum_degree)); - } + ASSERT(sum_degree == 1, + "generate_weight_transform currently only supports sum_degree = 1"); positive_int discard_copy_degree = get_discard_copy_degree(goal); if (discard_copy_degree != 1) { diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index 052d30df0f..6b2def07c8 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -1,11 +1,13 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "op-attrs/get_incoming_tensor_roles.h" +#include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/pcg_operator_attrs.h" #include "op-attrs/shape_inference.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" #include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" #include "utils/containers/concat_vectors.h" +#include "utils/containers/extend.h" #include "utils/containers/filtrans.h" #include "utils/containers/get_only.h" #include "utils/containers/repeat_element.h" @@ -117,6 +119,18 @@ ParallelLayerAddedResult pcg_add_input_layer(ParallelComputationGraph &pcg, /*output_flags=*/std::vector{CreateGrad::NO}); } +OperatorTaskSpace get_operator_task_space(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &layer) { + parallel_tensor_guid_t out_tensor = get_layer_outputs(pcg, layer).at(0); + ParallelTensorShape shape = get_parallel_tensor_shape(pcg, out_tensor); + + std::vector degrees; + extend(degrees, vector_of(ff_ordered_shard_degrees(shape))); + degrees.push_back(get_sum_degree(shape)); + degrees.push_back(get_discard_copy_degree(shape)); + return OperatorTaskSpace{degrees}; +} + std::unordered_set get_edges(ParallelComputationGraph const &pcg) { return transform(get_edges(pcg.raw_graph), [](DataflowEdge const &e) { diff --git a/lib/pcg/test/src/pcg/computation_graph_builder.cc b/lib/pcg/test/src/pcg/computation_graph_builder.cc index f7430b3403..a1e9d86659 100644 --- a/lib/pcg/test/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/test/src/pcg/computation_graph_builder.cc @@ -1,5 +1,5 @@ #include "pcg/computation_graph_builder.h" -#include "doctest/doctest.h" +#include #include "pcg/computation_graph.h" using namespace ::FlexFlow; diff --git a/lib/pcg/test/src/pcg/machine_view.cc b/lib/pcg/test/src/pcg/machine_view.cc index ecc196a118..dc3bbf267b 100644 --- a/lib/pcg/test/src/pcg/machine_view.cc +++ b/lib/pcg/test/src/pcg/machine_view.cc @@ -1,4 +1,6 @@ #include "pcg/machine_view.h" +#include "op-attrs/ff_ordered/ff_ordered.h" +#include "op-attrs/task_space_coordinate.h" #include "pcg/gpu_id_t.dtg.h" #include "test/utils/doctest/fmt/optional.h" #include "utils/containers/transform.h" @@ -57,7 +59,7 @@ TEST_SUITE(FF_TEST_SUITE) { /*intra_node_bandwidth=*/0}; SUBCASE("Task with TaskSpaceCoordinate = (0,)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{0_n}}; + TaskSpaceCoordinate coord = make_task_space_coordinate({0_n}); MachineSpaceCoordinate correct = MachineSpaceCoordinate{ /*node_idx=*/0_n, /*device_idx=*/1_n, DeviceType::GPU}; MachineSpaceCoordinate result = @@ -66,7 +68,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("Task with TaskSpaceCoordinate = (1,)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1_n}}; + TaskSpaceCoordinate coord = make_task_space_coordinate({1_n}); MachineSpaceCoordinate correct = MachineSpaceCoordinate{ /*node_idx=*/0_n, /*device_idx=*/3_n, DeviceType::GPU}; MachineSpaceCoordinate result = @@ -75,7 +77,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("Task with TaskSpaceCoordinate = (2,)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{2_n}}; + TaskSpaceCoordinate coord = make_task_space_coordinate({2_n}); MachineSpaceCoordinate correct = MachineSpaceCoordinate{ /*node_idx=*/0_n, /*device_idx=*/5_n, DeviceType::GPU}; MachineSpaceCoordinate result = @@ -84,7 +86,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("TaskSpaceCoordinate is out of bounds") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{4_n}}; + TaskSpaceCoordinate coord = make_task_space_coordinate({4_n}); std::optional result = get_machine_space_coordinate(task, mv, coord, ms); std::optional correct = std::nullopt; @@ -121,14 +123,16 @@ TEST_SUITE(FF_TEST_SUITE) { MachineViewDimension{stride_t{2_p}, MachineSpecificationDimension::INTRA_NODE}}}; MachineSpecification ms = - MachineSpecification{/*num_nodes=*/3_p, - /*num_cpus_per_node=*/5_p, - /*num_gpus_per_node=*/5_p, - /*inter_node_bandwidth=*/0, - /*intra_node_bandwidth=*/0}; + MachineSpecification{ + /*num_nodes=*/3_p, + /*num_cpus_per_node=*/5_p, + /*num_gpus_per_node=*/5_p, + /*inter_node_bandwidth=*/0, + /*intra_node_bandwidth=*/0, + }; SUBCASE("Task with TaskSpaceCoordinate = (0,0)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{0_n, 0_n}}; + TaskSpaceCoordinate coord = make_task_space_coordinate({0_n, 0_n}); MachineSpaceCoordinate correct = MachineSpaceCoordinate{ /*node_idx=*/1_n, /*device_idx=*/2_n, DeviceType::GPU}; MachineSpaceCoordinate result = @@ -137,7 +141,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("Task with TaskSpaceCoordinate = (0,1)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{0_n, 1_n}}; + TaskSpaceCoordinate coord = make_task_space_coordinate({0_n, 1_n}); MachineSpaceCoordinate correct = MachineSpaceCoordinate{ /*node_idx=*/1_n, /*device_idx=*/4_n, DeviceType::GPU}; MachineSpaceCoordinate result = @@ -146,7 +150,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("Task with TaskSpaceCoordinate = (1,0)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1_n, 0_n}}; + TaskSpaceCoordinate coord = make_task_space_coordinate({1_n, 0_n}); MachineSpaceCoordinate correct = MachineSpaceCoordinate{ /*node_idx=*/2_n, /*device_idx=*/2_n, DeviceType::GPU}; MachineSpaceCoordinate result = @@ -155,7 +159,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("Task with TaskSpaceCoordinate = (1,1)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1_n, 1_n}}; + TaskSpaceCoordinate coord = make_task_space_coordinate({1_n, 1_n}); MachineSpaceCoordinate correct = MachineSpaceCoordinate{ /*node_idx=*/2_n, /*device_idx=*/4_n, DeviceType::GPU}; MachineSpaceCoordinate result = @@ -195,7 +199,7 @@ TEST_SUITE(FF_TEST_SUITE) { /*intra_node_bandwidth=*/0}; SUBCASE("Task with TaskSpaceCoordinate = (0,0)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{0_n, 0_n}}; + TaskSpaceCoordinate coord = make_task_space_coordinate({0_n, 0_n}); MachineSpaceCoordinate correct = MachineSpaceCoordinate{ /*node_idx=*/1_n, /*device_idx=*/0_n, DeviceType::GPU}; MachineSpaceCoordinate result = @@ -204,7 +208,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("Task with TaskSpaceCoordinate = (0,1)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{0_n, 1_n}}; + TaskSpaceCoordinate coord = make_task_space_coordinate({0_n, 1_n}); MachineSpaceCoordinate correct = MachineSpaceCoordinate{ /*node_idx=*/1_n, /*device_idx=*/4_n, DeviceType::GPU}; MachineSpaceCoordinate result = @@ -213,7 +217,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("Task with TaskSpaceCoordinate = (1,0)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1_n, 0_n}}; + TaskSpaceCoordinate coord = make_task_space_coordinate({1_n, 0_n}); MachineSpaceCoordinate correct = MachineSpaceCoordinate{ /*node_idx=*/1_n, /*device_idx=*/1_n, DeviceType::GPU}; MachineSpaceCoordinate result = @@ -222,7 +226,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("Task with TaskSpaceCoordinate = (1,1)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1_n, 1_n}}; + TaskSpaceCoordinate coord = make_task_space_coordinate({1_n, 1_n}); MachineSpaceCoordinate correct = MachineSpaceCoordinate{ /*node_idx=*/1_n, /*device_idx=*/5_n, DeviceType::GPU}; MachineSpaceCoordinate result = @@ -271,7 +275,7 @@ TEST_SUITE(FF_TEST_SUITE) { /*intra_node_bandwidth=*/0}; SUBCASE("Task with TaskSpaceCoordinate = (0,0,1)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{0_n, 1_n, 0_n}}; + TaskSpaceCoordinate coord = make_task_space_coordinate({0_n, 1_n, 0_n}); MachineSpaceCoordinate correct = MachineSpaceCoordinate{ /*node_idx=*/0_n, /*device_idx=*/3_n, DeviceType::GPU}; MachineSpaceCoordinate result = @@ -280,7 +284,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("Task with TaskSpaceCoordinate = (1,1,0)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1_n, 0_n, 1_n}}; + TaskSpaceCoordinate coord = make_task_space_coordinate({1_n, 0_n, 1_n}); MachineSpaceCoordinate correct = MachineSpaceCoordinate{ /*node_idx=*/1_n, /*device_idx=*/5_n, DeviceType::GPU}; MachineSpaceCoordinate result = @@ -289,7 +293,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("Task with TaskSpaceCoordinate = (1,1,1)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1_n, 1_n, 1_n}}; + TaskSpaceCoordinate coord = make_task_space_coordinate({1_n, 1_n, 1_n}); MachineSpaceCoordinate correct = MachineSpaceCoordinate{ /*node_idx=*/1_n, /*device_idx=*/7_n, DeviceType::GPU}; MachineSpaceCoordinate result = diff --git a/lib/pcg/test/src/pcg/start_invariant_machine_view.cc b/lib/pcg/test/src/pcg/start_invariant_machine_view.cc index afd6ad6b33..bc1725544a 100644 --- a/lib/pcg/test/src/pcg/start_invariant_machine_view.cc +++ b/lib/pcg/test/src/pcg/start_invariant_machine_view.cc @@ -1,4 +1,5 @@ #include "pcg/start_invariant_machine_view.h" +#include "op-attrs/task_space_coordinate.h" #include "utils/fmt/unordered_set.h" #include "utils/fmt/vector.h" #include @@ -108,7 +109,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("get_machine_space_offset") { SUBCASE("Task with TaskSpaceCoordinate = (0,)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{0_n}}; + TaskSpaceCoordinate coord = make_task_space_coordinate({0_n}); MachineSpaceOffset correct = MachineSpaceOffset{0, 0, DeviceType::GPU}; MachineSpaceOffset result = @@ -117,7 +118,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("Task with TaskSpaceCoordinate = (1,)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1_n}}; + TaskSpaceCoordinate coord = make_task_space_coordinate({1_n}); MachineSpaceOffset correct = MachineSpaceOffset{0, 2, DeviceType::GPU}; MachineSpaceOffset result = @@ -126,7 +127,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("Task with TaskSpaceCoordinate = (2,)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{2_n}}; + TaskSpaceCoordinate coord = make_task_space_coordinate({2_n}); MachineSpaceOffset correct = MachineSpaceOffset{0, 4, DeviceType::GPU}; MachineSpaceOffset result = @@ -178,7 +179,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("get_machine_space_offset") { SUBCASE("Task with TaskSpaceCoordinate = (0,0)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{0_n, 0_n}}; + TaskSpaceCoordinate coord = make_task_space_coordinate({0_n, 0_n}); MachineSpaceOffset correct = MachineSpaceOffset{0, 0, DeviceType::GPU}; MachineSpaceOffset result = @@ -187,7 +188,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("Task with TaskSpaceCoordinate = (0,1)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{0_n, 1_n}}; + TaskSpaceCoordinate coord = make_task_space_coordinate({0_n, 1_n}); MachineSpaceOffset correct = MachineSpaceOffset{0, 2, DeviceType::GPU}; MachineSpaceOffset result = @@ -196,7 +197,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("Task with TaskSpaceCoordinate = (1,0)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1_n, 0_n}}; + TaskSpaceCoordinate coord = make_task_space_coordinate({1_n, 0_n}); MachineSpaceOffset correct = MachineSpaceOffset{1, 0, DeviceType::GPU}; MachineSpaceOffset result = @@ -205,7 +206,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("Task with TaskSpaceCoordinate = (1,1)") { - TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1_n, 1_n}}; + TaskSpaceCoordinate coord = make_task_space_coordinate({1_n, 1_n}); MachineSpaceOffset correct = MachineSpaceOffset{1, 2, DeviceType::GPU}; MachineSpaceOffset result = diff --git a/lib/runtime/src/fused_op_attrs.h b/lib/runtime/src/fused_op_attrs.h index a8ea524165..a1ab876167 100644 --- a/lib/runtime/src/fused_op_attrs.h +++ b/lib/runtime/src/fused_op_attrs.h @@ -2,7 +2,6 @@ #define _FLEXFLOW_RUNTIME_SRC_FUSED_OP_ATTRS_H #include "op-attrs/get_op_type.h" -#include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.h" #include "operator.h" #include "utils/visitable.h" diff --git a/lib/runtime/src/ops/fused_parallel_op_attrs.h b/lib/runtime/src/ops/fused_parallel_op_attrs.h index 454b7caead..4dd64199d3 100644 --- a/lib/runtime/src/ops/fused_parallel_op_attrs.h +++ b/lib/runtime/src/ops/fused_parallel_op_attrs.h @@ -1,7 +1,6 @@ -#ifndef _FLEXFLOW_FUSED_PARALLEL_OP_ATTRS_H -#define _FLEXFLOW_FUSED_PARALLEL_OP_ATTRS_H +#ifndef _FLEXFLOW_LIB_RUNTIME_INCLUDE_OPS_FUSED_PARALLEL_OP_ATTRS_H +#define _FLEXFLOW_LIB_RUNTIME_INCLUDE_OPS_FUSED_PARALLEL_OP_ATTRS_H -#include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.h" #include "parallel_op_info.h" #include "utils/visitable.h" diff --git a/lib/runtime/test/src/main.cc b/lib/runtime/test/src/main.cc index 9522fa7fdb..0a3f254ea8 100644 --- a/lib/runtime/test/src/main.cc +++ b/lib/runtime/test/src/main.cc @@ -1,2 +1,2 @@ #define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN -#include "doctest/doctest.h" +#include diff --git a/lib/runtime/test/src/test_op_task_spec.cc b/lib/runtime/test/src/test_op_task_spec.cc index bb0bee567c..5a99d52193 100644 --- a/lib/runtime/test/src/test_op_task_spec.cc +++ b/lib/runtime/test/src/test_op_task_spec.cc @@ -1,4 +1,4 @@ -#include "doctest/doctest.h" +#include #include "op_task_invocation.h" #include "op_task_signature.h" diff --git a/lib/runtime/test/src/test_serialization.cc b/lib/runtime/test/src/test_serialization.cc index 471f2a2709..b3cb749162 100644 --- a/lib/runtime/test/src/test_serialization.cc +++ b/lib/runtime/test/src/test_serialization.cc @@ -1,4 +1,4 @@ -#include "doctest/doctest.h" +#include #include "legion/legion_utilities.h" #include "op-attrs/ffconst.h" #include "serialization.h" diff --git a/lib/substitution-generator/test/substitution-generator/legacy_rules.cc b/lib/substitution-generator/test/substitution-generator/legacy_rules.cc index 4dd9bb8cc4..19102d9670 100644 --- a/lib/substitution-generator/test/substitution-generator/legacy_rules.cc +++ b/lib/substitution-generator/test/substitution-generator/legacy_rules.cc @@ -1,5 +1,5 @@ #include "substitution-generator/legacy_rules.h" -#include "doctest/doctest.h" +#include using namespace FlexFlow; using nlohmann::json; diff --git a/lib/substitutions/src/substitutions/apply_substitution/evaluate_substitution_output.cc b/lib/substitutions/src/substitutions/apply_substitution/evaluate_substitution_output.cc index 272c5f2dd5..bd926eb6b6 100644 --- a/lib/substitutions/src/substitutions/apply_substitution/evaluate_substitution_output.cc +++ b/lib/substitutions/src/substitutions/apply_substitution/evaluate_substitution_output.cc @@ -2,6 +2,8 @@ #include "substitutions/apply_substitution/perform_shape_inference.h" #include "substitutions/output_graph/output_operator_attrs_assignment.h" #include "substitutions/sub_parallel_computation_graph.h" +#include "utils/bidict/algorithms/transform_keys.h" +#include "utils/bidict/algorithms/transform_values.h" #include "utils/containers/map_keys.h" #include "utils/containers/map_values.h" #include "utils/graph/labelled_open_dataflow_graph/algorithms/permute_input_ids.h" @@ -47,7 +49,7 @@ std::pair }); bidict result_input_map = - map_keys(map_values(new_input_id_permutation, + transform_keys(transform_values(new_input_id_permutation, [](DataflowGraphInput const &i) { return OutputGraphExprInput{i}; }), @@ -55,8 +57,8 @@ std::pair return input_parallel_tensor_guid_t{i.raw_input}; }); - bidict result_node_map = map_keys( - map_values(new_node_id_permutation, + bidict result_node_map = transform_keys( + transform_values(new_node_id_permutation, [](Node const &n) { return OutputGraphExprNode{n}; }), [](NewNode const &n) { return parallel_layer_guid_t{n.raw_node}; }); diff --git a/lib/substitutions/src/substitutions/operator_pattern/eval_list_access.cc b/lib/substitutions/src/substitutions/operator_pattern/eval_list_access.cc index 6f41772a9e..15f514a60e 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/eval_list_access.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/eval_list_access.cc @@ -4,6 +4,7 @@ #include "utils/containers/make.h" #include "utils/containers/transform.h" #include "utils/overload.h" +#include namespace FlexFlow { @@ -28,7 +29,7 @@ std::optional return transform(at_idx(v, acc.index), make()); } else { - throw mk_runtime_error("Invalid operand"); + PANIC("Invalid operand"); } }); } diff --git a/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc b/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc index 194ae49255..96c33989fe 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc @@ -1,5 +1,6 @@ #include "substitutions/operator_pattern/satisfies_constraint.h" #include "substitutions/operator_pattern/operator_attribute_expr.h" +#include namespace FlexFlow { @@ -17,9 +18,7 @@ bool operator_satisfies_constraint( case ConstraintType::EQUAL: return expr_val.value() == constraint.attribute_value; default: - throw mk_runtime_error( - fmt::format("Unknown constraint type {}", - static_cast(constraint.constraint_type))); + PANIC("Unknown constraint type", constraint.constraint_type); } } diff --git a/lib/substitutions/src/substitutions/output_graph/materialize_operator_from_attrs_map.cc b/lib/substitutions/src/substitutions/output_graph/materialize_operator_from_attrs_map.cc index cf5a1e17f9..a9ebeef9b8 100644 --- a/lib/substitutions/src/substitutions/output_graph/materialize_operator_from_attrs_map.cc +++ b/lib/substitutions/src/substitutions/output_graph/materialize_operator_from_attrs_map.cc @@ -1,6 +1,7 @@ #include "substitutions/output_graph/materialize_operator_from_attrs_map.h" #include "utils/containers/contains_key.h" #include "utils/fmt/unordered_map.h" +#include namespace FlexFlow { @@ -16,8 +17,9 @@ struct Accessor { if (contains_key(this->m, k)) { return this->m.at(k).get(); } else { - throw mk_runtime_error( - fmt::format("Could not find key {} in attrs map: {}", k, this->m)); + PANIC("Could not find key in attrs map", + k, + this->m); } } }; @@ -151,8 +153,7 @@ PCGOperatorAttrs materialize_operator_from_attrs_map( case OperatorType::PIPELINE: case OperatorType::FUSED_PARALLEL: default: - throw mk_runtime_error( - fmt::format("Unsupported operator type {}", op_type)); + PANIC("Unsupported operator type", op_type); } } diff --git a/lib/substitutions/src/substitutions/pcg_pattern.cc b/lib/substitutions/src/substitutions/pcg_pattern.cc index a0af875848..784d0f0751 100644 --- a/lib/substitutions/src/substitutions/pcg_pattern.cc +++ b/lib/substitutions/src/substitutions/pcg_pattern.cc @@ -5,6 +5,7 @@ #include "substitutions/tensor_pattern/satisfies_pattern.h" #include "substitutions/unlabelled/find_pattern_matches.h" #include "substitutions/unlabelled/pattern_value.h" +#include "utils/bidict/algorithms/transform_values.h" #include "utils/containers/map_values.h" #include "utils/containers/transform.h" #include "utils/graph/dataflow_graph/algorithms.h" @@ -47,8 +48,8 @@ std::vector auto pcg_match_from_unlabelled_match = [](UnlabelledDataflowGraphPatternMatch const &m) { return PCGPatternMatch{ - map_values(m.node_assignment, - [](Node const &n) { return parallel_layer_guid_t{n}; }), + transform_values(m.node_assignment, + [](Node const &n) { return parallel_layer_guid_t{n}; }), map_values(m.input_assignment, [](OpenDataflowValue const &i) { return open_parallel_tensor_guid_t{i}; diff --git a/lib/substitutions/src/substitutions/pcg_pattern_match.cc b/lib/substitutions/src/substitutions/pcg_pattern_match.cc index b701be65cf..2281d46514 100644 --- a/lib/substitutions/src/substitutions/pcg_pattern_match.cc +++ b/lib/substitutions/src/substitutions/pcg_pattern_match.cc @@ -3,6 +3,7 @@ #include "substitutions/sub_parallel_computation_graph.h" #include "utils/bidict/algorithms/bidict_from_keys_and_values.h" #include "utils/bidict/algorithms/merge_disjoint_bidicts.h" +#include "utils/bidict/algorithms/transform_values.h" #include "utils/containers/map_values.h" #include "utils/containers/zip.h" @@ -36,7 +37,7 @@ bidict UnlabelledDataflowGraphPatternMatch get_unlabelled_pattern_match(PCGPatternMatch const &match) { return UnlabelledDataflowGraphPatternMatch{ - map_values( + transform_values( match.node_assignment, [](parallel_layer_guid_t const &l) { return l.raw_graph_node; }), map_values(match.input_assignment, diff --git a/lib/substitutions/src/substitutions/substitution.cc b/lib/substitutions/src/substitutions/substitution.cc index 874700d303..801fcfe42a 100644 --- a/lib/substitutions/src/substitutions/substitution.cc +++ b/lib/substitutions/src/substitutions/substitution.cc @@ -4,6 +4,7 @@ #include "utils/bidict/algorithms/left_entries.h" #include "utils/bidict/algorithms/right_entries.h" #include "utils/containers/map_values.h" +#include "utils/bidict/algorithms/transform.h" #include "utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h" #include "utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_node_labels.h" #include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_isomorphism.dtg.h" diff --git a/lib/substitutions/src/substitutions/tensor_pattern/eval_list_access.cc b/lib/substitutions/src/substitutions/tensor_pattern/eval_list_access.cc index 7bfb1f5e9e..b7b0e1fa83 100644 --- a/lib/substitutions/src/substitutions/tensor_pattern/eval_list_access.cc +++ b/lib/substitutions/src/substitutions/tensor_pattern/eval_list_access.cc @@ -2,6 +2,7 @@ #include "substitutions/tensor_pattern/get_attribute.h" #include "utils/containers/at_idx.h" #include "utils/overload.h" +#include namespace FlexFlow { @@ -14,8 +15,8 @@ TensorAttributeValue [&](std::vector const &v) -> TensorAttributeValue { return TensorAttributeValue{at_idx(v, acc.index).value()}; }, - [](auto &&) -> TensorAttributeValue { - throw mk_runtime_error("Invalid operand"); + [](auto &&x) -> TensorAttributeValue { + PANIC("Invalid operand", x); }, }); } diff --git a/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc b/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc index 974bfcabc0..e2f2e211fa 100644 --- a/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc +++ b/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc @@ -1,5 +1,6 @@ #include "substitutions/tensor_pattern/satisfies_constraint.h" #include "substitutions/tensor_pattern/tensor_attribute_expr.h" +#include namespace FlexFlow { @@ -13,9 +14,7 @@ bool parallel_tensor_satisfies_constraint( case ConstraintType::EQUAL: return expr_val == constraint.attribute_value; default: - throw mk_runtime_error( - fmt::format("Unknown constraint type {}", - static_cast(constraint.constraint_type))); + PANIC("Unknown constraint type", constraint.constraint_type); } } diff --git a/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc b/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc index 84e0d91fee..afa943922d 100644 --- a/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc +++ b/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc @@ -11,6 +11,7 @@ #include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h" #include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" #include "utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.h" +#include "utils/bidict/algorithms/transform.h" namespace FlexFlow { diff --git a/lib/task-spec/include/task-spec/arg_ref.h b/lib/task-spec/include/task-spec/arg_ref.h index 8d3402c578..4a6e4f56f8 100644 --- a/lib/task-spec/include/task-spec/arg_ref.h +++ b/lib/task-spec/include/task-spec/arg_ref.h @@ -1,10 +1,9 @@ -#ifndef _FLEXFLOW_LOCAL_EXECUTION_ARG_REF_H -#define _FLEXFLOW_LOCAL_EXECUTION_ARG_REF_H +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_ARG_REF_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_ARG_REF_H #include "kernels/ff_handle.h" -// #include "task-spec/serialization.h #include "utils/type_index.h" -#include "utils/visitable.h" +#include "utils/hash-utils.h" namespace FlexFlow { diff --git a/lib/task-spec/include/task-spec/config.h b/lib/task-spec/include/task-spec/config.h deleted file mode 100644 index ff7c4af5a5..0000000000 --- a/lib/task-spec/include/task-spec/config.h +++ /dev/null @@ -1,179 +0,0 @@ -/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef _FLEXFLOW_LOCAL_EXECUTION_CONFIG_H_ -#define _FLEXFLOW_LOCAL_EXECUTION_CONFIG_H_ - -#include "utils/fmt.h" -#include "utils/visitable.h" -#include - -namespace FlexFlow { - -enum class ComputationMode { - TRAINING, - INFERENCE, -}; - -// ======================================================== -// Define Runtime Constants -// ======================================================== -// Pre-assigned const flags -#define MAP_TO_FB_MEMORY 0xABCD0000 -#define MAP_TO_ZC_MEMORY 0xABCE0000 - -struct FFInitInfo : public use_visitable_cmp { - size_t workSpaceSize; - bool allowTensorOpMathConversion; -}; - -using legion_mapping_tag_id_t = unsigned long; - -struct FFConfig : public use_visitable_cmp { -public: - enum PreservedIDs { - InvalidID = 0, - DataParallelism_GPU = 1, - // DataParallelism_GPU_2D = 2, - // DataParallelism_GPU_3D = 3, - // DataParallelism_GPU_4D = 4, - // DataParallelism_GPU_5D = 5, - DataParallelism_CPU = 11, - // DataParallelism_CPU_2D = 12, - // DataParallelism_CPU_3D = 13, - // DataParallelism_CPU_4D = 14, - // DataParallelism_CPU_5D = 15, - }; - - FFConfig() = default; - static legion_mapping_tag_id_t get_hash_id(std::string const &pcname); - -public: - int epochs = 1; - int batchSize = 64; - int numNodes = 1; - int cpusPerNode = 0; - int workersPerNode = 0; - float learningRate = 0.01f; - float weightDecay = 0.0001f; - size_t workSpaceSize = (size_t)1 * 1024 * 1024 * 1024; // 2GB - bool profiling = false; - bool perform_fusion = false; - size_t simulator_work_space_size = (size_t)2 * 1024 * 1024 * 1024; // 2GB - size_t search_budget = -1; - float search_alpha = 1.2f; - bool search_overlap_backward_update = false; - ComputationMode computationMode = ComputationMode::TRAINING; - // Control parallelizable dimensions - bool only_data_parallel = false; - bool enable_parameter_parallel = false; - bool enable_inplace_optimizations = false; - // Control Tensor Op Math Conversion - bool allow_tensor_op_math_conversion = false; - std::optional dataset_path = std::nullopt; - std::optional export_strategy_computation_graph_file = - std::nullopt; - bool include_costs_dot_graph = false; - std::optional substitution_json_path = std::nullopt; - int machine_model_version = 0; - std::optional machine_model_file = std::nullopt; - int simulator_segment_size = 16777216; // 16 MB - int simulator_max_num_segments = 1; - std::optional search_num_nodes = std::nullopt; - std::optional search_num_workers = std::nullopt; - int base_optimize_threshold = 10; - bool enable_control_replication = true; - // The default python data loader type is 2 to enable control replication - int python_data_loader_type = 2; -}; - -struct FFIterationConfig { - FFIterationConfig() = delete; - void reset(); - int seq_length; -}; - -FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(FFIterationConfig, seq_length); - -enum FieldIDs { - FID_DATA, -}; - -} // namespace FlexFlow - -VISITABLE_STRUCT(::FlexFlow::FFInitInfo, - workSpaceSize, - allowTensorOpMathConversion); -MAKE_VISIT_HASHABLE(::FlexFlow::FFInitInfo); - -VISITABLE_STRUCT(::FlexFlow::FFConfig, - epochs, - batchSize, - numNodes, - cpusPerNode, - workersPerNode, - learningRate, - weightDecay, - workSpaceSize, - profiling, - perform_fusion, - simulator_work_space_size, - search_budget, - search_alpha, - search_overlap_backward_update, - computationMode, - only_data_parallel, - enable_parameter_parallel, - enable_inplace_optimizations, - allow_tensor_op_math_conversion, - dataset_path, - export_strategy_computation_graph_file, - include_costs_dot_graph, - substitution_json_path, - machine_model_version, - machine_model_file, - simulator_segment_size, - simulator_max_num_segments, - search_num_nodes, - search_num_workers, - base_optimize_threshold, - enable_control_replication, - python_data_loader_type); - -namespace fmt { - -template <> -struct formatter<::FlexFlow::ComputationMode> : formatter { - template - auto format(::FlexFlow::ComputationMode m, FormatContext &ctx) const - -> decltype(ctx.out()) { - using namespace FlexFlow; - - string_view name = "unknown"; - switch (m) { - case ComputationMode::TRAINING: - name = "Training"; - break; - case ComputationMode::INFERENCE: - name = "Inference"; - break; - } - return formatter::format(name, ctx); - } -}; - -} // namespace fmt - -#endif diff --git a/lib/task-spec/include/task-spec/ff_config.struct.toml b/lib/task-spec/include/task-spec/ff_config.struct.toml new file mode 100644 index 0000000000..959e96092d --- /dev/null +++ b/lib/task-spec/include/task-spec/ff_config.struct.toml @@ -0,0 +1,115 @@ +namespace = "FlexFlow" +name = "FFConfig" +features = [ + "eq", + "ord", + "hash", + "fmt", + "rapidcheck", + "json", +] + +includes = [ + "utils/positive_int/positive_int.h", + "utils/nonnegative_int/nonnegative_int.h", + "", + "", +] + +src_includes = [ + "utils/rapidcheck/optional.h", + "utils/json/optional.h", + "utils/fmt/optional.h", +] + +[[fields]] +name = "epochs" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "batch_size" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "num_nodes" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "cpus_per_node" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "gpus_per_node" +type = "::FlexFlow::positive_int" + +[[fields]] +name = "learning_rate" +type = "float" + +[[fields]] +name = "weight_decay" +type = "float" + +[[fields]] +name = "workspace_size" +type = "::FlexFlow::nonnegative_int" + +[[fields]] +name = "enable_profiling" +type = "bool" + +[[fields]] +name = "perform_fusion" +type = "bool" + +[[fields]] +name = "simulator_workspace_size" +type = "::FlexFlow::nonnegative_int" + +[[fields]] +name = "search_budget" +type = "std::optional<::FlexFlow::positive_int>" + +[[fields]] +name = "search_alpha" +type = "float" + +[[fields]] +name = "search_overlap_backward_update" +type = "bool" + +[[fields]] +name = "only_data_parallel" +type = "bool" + +[[fields]] +name = "enable_parameter_parallel" +type = "bool" + +[[fields]] +name = "enable_inplace_optimizations" +type = "bool" + +[[fields]] +name = "allow_tensor_op_math_conversion" +type = "bool" + +[[fields]] +name = "dataset_path" +type = "std::optional" + +[[fields]] +name = "export_strategy_computation_graph_file" +type = "std::optional" + +[[fields]] +name = "include_costs_dot_graph" +type = "bool" + +[[fields]] +name = "substitution_json_path" +type = "std::optional" + +[[fields]] +name = "base_optimize_threshold" +type = "::FlexFlow::positive_int" diff --git a/lib/task-spec/include/task-spec/ff_init_info.struct.toml b/lib/task-spec/include/task-spec/ff_init_info.struct.toml new file mode 100644 index 0000000000..dbb8b76a1d --- /dev/null +++ b/lib/task-spec/include/task-spec/ff_init_info.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "FFInitInfo" +features = [ + "eq", + "ord", + "hash", + "fmt", + "rapidcheck", + "json", +] + +includes = [ + "utils/nonnegative_int/nonnegative_int.h", +] + +[[fields]] +name = "workspace_size" +type = "::FlexFlow::nonnegative_int" + +[[fields]] +name = "allow_tensor_op_math_conversion" +type = "bool" diff --git a/lib/task-spec/include/task-spec/ff_iteration_config.struct.toml b/lib/task-spec/include/task-spec/ff_iteration_config.struct.toml new file mode 100644 index 0000000000..6c9e59ba8a --- /dev/null +++ b/lib/task-spec/include/task-spec/ff_iteration_config.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "FFIterationConfig" +features = [ + "eq", + "ord", + "hash", + "fmt", + "rapidcheck", + "json", +] + +includes = [ + "utils/positive_int/positive_int.h", +] + +[[fields]] +name = "seq_length" +type = "::FlexFlow::positive_int" diff --git a/lib/task-spec/include/task-spec/op_task_signature.h b/lib/task-spec/include/task-spec/op_task_signature.h index eba0023906..d4769b4664 100644 --- a/lib/task-spec/include/task-spec/op_task_signature.h +++ b/lib/task-spec/include/task-spec/op_task_signature.h @@ -11,7 +11,6 @@ #include "utils/hash/unordered_map.h" #include "utils/hash/unordered_set.h" #include "utils/type_index.h" -#include "utils/visitable.h" namespace FlexFlow { @@ -89,13 +88,25 @@ struct OpTaskSignature { void set_arg_types(std::unordered_map const &); std::unordered_map get_arg_types() const; + bool operator==(OpTaskSignature const &) const; + bool operator!=(OpTaskSignature const &) const; + +public: OpTaskType type; std::optional return_value; std::unordered_map task_arg_types; std::unordered_set op_tensor_slots; + +private: + std::tuple< + decltype(type) const &, + decltype(return_value) const &, + decltype(task_arg_types) const &, + decltype(op_tensor_slots) const & + > tie() const; + + friend ::std::hash<::FlexFlow::OpTaskSignature>; }; -FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION( - OpTaskSignature, type, return_value, task_arg_types, op_tensor_slots); std::string format_as(OpTaskSignature const &x); std::ostream &operator<<(std::ostream &s, OpTaskSignature const &x); @@ -104,4 +115,13 @@ OpTaskSignature infer_bwd_signature(OpTaskSignature const &fwd); } // namespace FlexFlow +namespace std { + +template<> +struct hash<::FlexFlow::OpTaskSignature> { + size_t operator()(::FlexFlow::OpTaskSignature const &) const; +}; + +} + #endif diff --git a/lib/task-spec/include/task-spec/runtime_arg_ref.h b/lib/task-spec/include/task-spec/runtime_arg_ref.h index 532482f89e..9a0c43cec5 100644 --- a/lib/task-spec/include/task-spec/runtime_arg_ref.h +++ b/lib/task-spec/include/task-spec/runtime_arg_ref.h @@ -5,7 +5,8 @@ #include "kernels/profiling_settings.dtg.h" #include "pcg/device_type.dtg.h" #include "task-spec/arg_ref.h" -#include "task-spec/config.h" +#include "task-spec/ff_config.dtg.h" +#include "task-spec/ff_iteration_config.dtg.h" #include "task-spec/device_specific.h" #include "task-spec/runtime_arg_ref_type.dtg.h" diff --git a/lib/task-spec/include/task-spec/serialization.h b/lib/task-spec/include/task-spec/serialization.h index 2fc4b4b706..d5e11a8bb9 100644 --- a/lib/task-spec/include/task-spec/serialization.h +++ b/lib/task-spec/include/task-spec/serialization.h @@ -1,25 +1,13 @@ -#ifndef _FLEXFLOW_LOCAL_EXECUTION_SERIALIZATION_H -#define _FLEXFLOW_LOCAL_EXECUTION_SERIALIZATION_H +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_SERIALIZATION_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_SERIALIZATION_H #include "kernels/device.h" #include "kernels/nccl.h" -#include "op-attrs/dim_ordered/dim_ordered.h" +#include "op-attrs/ff_ordered/ff_ordered.h" #include "utils/required.h" -#include "utils/strong_typedef.h" #include "utils/type_traits.h" #include "utils/variant.h" -#include "utils/visitable.h" -namespace FlexFlow { - -struct InternalTestType { - int x; - float y; -}; - -} // namespace FlexFlow - -VISITABLE_STRUCT(::FlexFlow::InternalTestType, x, y); namespace FlexFlow { @@ -46,26 +34,12 @@ struct visit_trivially_serializable> { template <> struct visit_trivially_serializable<> : std::true_type {}; -template -struct is_trivially_serializable< - T, - typename std::enable_if< - visit_trivially_serializable>::value>::type> - : std::true_type {}; - template struct is_trivially_serializable< T, typename std::enable_if::value>::type> : std::true_type {}; -template -struct is_trivially_serializable>> - : is_trivially_serializable> {}; - -template -struct is_trivially_serializable> : is_trivially_serializable {}; - template <> struct is_trivially_serializable : std::true_type {}; template <> @@ -86,8 +60,8 @@ template struct is_trivially_serializable> : is_trivially_serializable {}; -template -struct is_trivially_serializable> +template +struct is_trivially_serializable> : is_trivially_serializable {}; template @@ -134,11 +108,6 @@ static_assert(is_trivially_serializable::value, ""); static_assert(is_trivially_serializable::value, ""); static_assert(is_trivially_serializable>::value, ""); -static_assert(std::is_same, - std::tuple>::value, - ""); -static_assert(visit_trivially_serializable::value, ""); -static_assert(is_trivially_serializable::value, ""); } // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/op_task_signature.cc b/lib/task-spec/src/task-spec/op_task_signature.cc index 94ac16d092..ece7c5a061 100644 --- a/lib/task-spec/src/task-spec/op_task_signature.cc +++ b/lib/task-spec/src/task-spec/op_task_signature.cc @@ -2,6 +2,9 @@ #include "utils/fmt/optional.h" #include "utils/fmt/unordered_map.h" #include "utils/fmt/unordered_set.h" +#include "utils/hash/unordered_set.h" +#include "utils/hash/unordered_map.h" +#include "utils/hash/tuple.h" namespace FlexFlow { @@ -148,6 +151,27 @@ std::unordered_map return this->task_arg_types; } +bool OpTaskSignature::operator==(OpTaskSignature const &other) const { + return this->tie() == other.tie(); +} + +bool OpTaskSignature::operator!=(OpTaskSignature const &other) const { + return this->tie() != other.tie(); +} + +std::tuple< + OpTaskType const &, + std::optional const &, + std::unordered_map const &, + std::unordered_set const & +> OpTaskSignature::tie() const { + return std::tie( + this->type, + this->return_value, + this->task_arg_types, + this->op_tensor_slots); +} + std::string format_as(OpTaskSignature const &x) { std::ostringstream oss; oss << "::operator()(::FlexFlow::OpTaskSignature const &x) const { + return get_std_hash(x.tie()); +} + +} diff --git a/lib/task-spec/src/task-spec/ops/batch_matmul.cc b/lib/task-spec/src/task-spec/ops/batch_matmul.cc index f8d6955b41..2638df6b73 100644 --- a/lib/task-spec/src/task-spec/ops/batch_matmul.cc +++ b/lib/task-spec/src/task-spec/ops/batch_matmul.cc @@ -77,48 +77,17 @@ static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { DeviceType kernel_device_type = acc.get_argument(KERNEL_DEVICE_TYPE); - positive_int m = dim_at_idx(b_input.shape.dims, legion_dim_t{0_n}); - ASSERT(m == dim_at_idx(output.shape.dims, legion_dim_t{0_n})); - positive_int n = dim_at_idx(a_input.shape.dims, legion_dim_t{1_n}); - ASSERT(n == dim_at_idx(output.shape.dims, legion_dim_t{1_n})); - positive_int k = dim_at_idx(a_input.shape.dims, legion_dim_t{0_n}); - ASSERT(k == dim_at_idx(b_input.shape.dims, legion_dim_t{1_n})); - - ASSERT(get_num_elements(a_input.shape.dims) == - get_num_elements(b_input.shape.dims)); - ASSERT(get_num_elements(a_input.shape.dims) == - get_num_elements(output.shape.dims)); - - positive_int batch = 1_p; - for (nonnegative_int i : - nonnegative_range(2_n, get_num_dims(a_input.shape.dims))) { - positive_int dim_size = dim_at_idx(a_input.shape.dims, legion_dim_t{i}); - ASSERT(dim_size == dim_at_idx(b_input.shape.dims, legion_dim_t{i})); - ASSERT(dim_size == dim_at_idx(output.shape.dims, legion_dim_t{i})); - batch *= dim_size; - } - - auto get_raw_seq_len = [](std::optional seq_len) -> int { - return transform(seq_len, - [](nonnegative_int x) { return x.unwrap_nonnegative(); }) - .value_or(-1); - }; - return profile(forward_kernel, profiling, kernel_device_type, "[BatchMatmul] forward_time = {:.2lf}ms\n", handle, - output.get_float_ptr(), - a_input.get_float_ptr(), - b_input.get_float_ptr(), - m.int_from_positive_int(), - n.int_from_positive_int(), - k.int_from_positive_int(), - batch.int_from_positive_int(), - get_raw_seq_len(attrs.a_seq_length_dim), - get_raw_seq_len(attrs.b_seq_length_dim), - iter_config.seq_length); + output, + a_input, + b_input, + iter_config.seq_length, + attrs.a_seq_length_dim, + attrs.b_seq_length_dim); } static std::optional @@ -143,42 +112,17 @@ static std::optional auto b_input_grad = acc.get_tensor_grad(B_INPUT); ASSERT(b_input.shape == b_input_grad.shape); - // check dins - positive_int m = dim_at_idx(b_input.shape.dims, legion_dim_t{0_n}); - ASSERT(m == dim_at_idx(output.shape.dims, legion_dim_t{0_n})); - positive_int n = dim_at_idx(a_input.shape.dims, legion_dim_t{1_n}); - ASSERT(n == dim_at_idx(output.shape.dims, legion_dim_t{1_n})); - positive_int k = dim_at_idx(a_input.shape.dims, legion_dim_t{0_n}); - ASSERT(k == dim_at_idx(b_input.shape.dims, legion_dim_t{1_n})); - ASSERT(get_num_elements(a_input.shape.dims) == - get_num_elements(b_input.shape.dims)); - ASSERT(get_num_elements(a_input.shape.dims) == - get_num_elements(output.shape.dims)); - - positive_int batch = 1_p; - for (nonnegative_int i : - nonnegative_range(2_n, get_num_dims(a_input.shape.dims))) { - positive_int dim_size = dim_at_idx(a_input.shape.dims, legion_dim_t{i}); - ASSERT(dim_size == dim_at_idx(b_input.shape.dims, legion_dim_t{i})); - ASSERT(dim_size == dim_at_idx(output.shape.dims, legion_dim_t{i})); - batch *= dim_size; - } - return profile(backward_kernel, profiling, kernel_device_type, "[BatchMatmul] backward_time = {:.2lf}ms\n", handle, - output.get_float_ptr(), - output_grad.get_float_ptr(), - a_input.get_float_ptr(), - a_input_grad.get_float_ptr(), - b_input.get_float_ptr(), - b_input_grad.get_float_ptr(), - m.int_from_positive_int(), - n.int_from_positive_int(), - k.int_from_positive_int(), - batch.int_from_positive_int()); + output, + output_grad, + a_input, + a_input_grad, + b_input, + b_input_grad); } TaskImplFunction get_batch_matmul_fwd_task_impl() { diff --git a/lib/task-spec/src/task-spec/training_layer_plus_context.cc b/lib/task-spec/src/task-spec/training_layer_plus_context.cc index 9adbc6b2a1..258be5de09 100644 --- a/lib/task-spec/src/task-spec/training_layer_plus_context.cc +++ b/lib/task-spec/src/task-spec/training_layer_plus_context.cc @@ -1,6 +1,7 @@ #include "task-spec/training_layer_plus_context.h" #include "task-spec/training_tensor_group_with_attrs.h" #include "utils/containers/transform.h" +#include namespace FlexFlow { diff --git a/lib/utils/CMakeLists.txt b/lib/utils/CMakeLists.txt index 6d8f22bc29..e2f7c433d6 100644 --- a/lib/utils/CMakeLists.txt +++ b/lib/utils/CMakeLists.txt @@ -9,7 +9,6 @@ ff_add_library( src/ DEPS expected - visit_struct fmt json cuda diff --git a/lib/utils/include/utils/containers/filter_idxs.h b/lib/utils/include/utils/containers/filter_idxs.h index e9d9777e74..ea5384c854 100644 --- a/lib/utils/include/utils/containers/filter_idxs.h +++ b/lib/utils/include/utils/containers/filter_idxs.h @@ -15,7 +15,7 @@ std::vector filter_idxs(std::vector const &input, std::function> +template > std::vector generate_vector(nonnegative_int length, F &&f) { std::vector result; for (nonnegative_int idx : range(length)) { diff --git a/lib/utils/include/utils/containers/require_same.h b/lib/utils/include/utils/containers/require_same.h index f638e1da1a..2f3439db32 100644 --- a/lib/utils/include/utils/containers/require_same.h +++ b/lib/utils/include/utils/containers/require_same.h @@ -8,14 +8,21 @@ namespace FlexFlow { template T const &require_same(T const &l, T const &r) { - if (l != r) { - throw mk_runtime_error( - fmt::format("require_same received non-equal inputs: {} != {}", l, r)); - } + ASSERT(l == r, "require_same received non-equal inputs"); return l; } +template +T const &require_same(T const &t1, T const &t2, T const &t3) { + return require_same(require_same(t1, t2), t3); +} + +template +T const &require_same(T const &t1, T const &t2, T const &t3, T const &t4) { + return require_same(require_same(require_same(t1, t2), t3), t4); +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/slice.h b/lib/utils/include/utils/containers/slice.h index a82fb383b5..29a4858a0d 100644 --- a/lib/utils/include/utils/containers/slice.h +++ b/lib/utils/include/utils/containers/slice.h @@ -10,7 +10,7 @@ namespace FlexFlow { template std::vector slice(std::vector const &v, - int const &maybe_start, + int maybe_start, std::optional const &maybe_end) { auto begin_iter = v.cbegin(); auto end_iter = v.cend(); diff --git a/lib/utils/include/utils/containers/zip3_with.h b/lib/utils/include/utils/containers/zip3_with.h index 70ed2a73ba..938565bd58 100644 --- a/lib/utils/include/utils/containers/zip3_with.h +++ b/lib/utils/include/utils/containers/zip3_with.h @@ -5,14 +5,14 @@ namespace FlexFlow { -template > +template > std::vector zip3_with(std::vector const &v_a, std::vector const &v_b, std::vector const &v_c, F &&f) { std::vector result; - for (int i = 0; i < std::min(v_a.size(), v_b.size(), v_c.size()); i++) { - result.push_back(v_a.at(i), v_b.at(i), v_c.at(i)); + for (int i = 0; i < std::min(v_a.size(), std::min(v_b.size(), v_c.size())); i++) { + result.push_back(f(v_a.at(i), v_b.at(i), v_c.at(i))); } return result; diff --git a/lib/utils/include/utils/containers/zip3_with_strict.h b/lib/utils/include/utils/containers/zip3_with_strict.h index ae7239f5d8..793022efdf 100644 --- a/lib/utils/include/utils/containers/zip3_with_strict.h +++ b/lib/utils/include/utils/containers/zip3_with_strict.h @@ -2,20 +2,19 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP3_WITH_STRICT_H #include -#include "utils/exception.h" -#include "utils/fmt/vector.h" +#include #include "utils/containers/zip3_with.h" namespace FlexFlow { -template > +template > std::vector zip3_with_strict(std::vector const &v_a, std::vector const &v_b, std::vector const &v_c, F &&f) { - if (!(v_a.size() == v_b.size() && v_b.size() == v_c.size())) { - throw mk_runtime_error(fmt::format("zip3_with_strict requires inputs to have the same length, but received v_a = {} (length {}), v_b = {} (length {}), and v_c = {} (length {})", v_a, v_a.size(), v_b, v_b.size(), v_c, v_c.size())); - } + ASSERT(v_a.size() == v_b.size() && v_b.size() == v_c.size(), + "zip3_with_strict requires inputs to have the same length, but received mismatched lengths", + v_a.size(), v_b.size(), v_c.size()); return zip3_with(v_a, v_b, v_c, f); } diff --git a/lib/utils/include/utils/json/visitable.h b/lib/utils/include/utils/json/visitable.h deleted file mode 100644 index abc20065de..0000000000 --- a/lib/utils/include/utils/json/visitable.h +++ /dev/null @@ -1,152 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_JSON_H -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_JSON_H - -#include "utils/json/is_json_deserializable.h" -#include "utils/json/is_json_serializable.h" -#include "utils/json/is_jsonable.h" -#include "utils/json_core.h" -#include "utils/optional.h" -#include "utils/sequence.h" -#include "utils/type_traits.h" -#include "utils/variant.h" -#include "utils/visitable.h" - -namespace FlexFlow { - -struct json_serialization_visitor { - json_serialization_visitor() = delete; - json_serialization_visitor(json &j) : j(j) {} - - json &j; - - template - void operator()(char const *field_name, T const &field_value) { - j[field_name] = field_value; - } -}; - -struct json_deserialization_visitor { - json_deserialization_visitor() = delete; - json_deserialization_visitor(json const &j) : j(j) {} - - json const &j; - - template - void operator()(char const *field_name, T &field_value) { - j.at(field_name).get_to(field_value); - } -}; - -static_assert(std::is_same>, - std::tuple>::value, - ""); -static_assert(std::is_same>, - std::tuple<>>::value, - ""); - -template -typename std::enable_if<(idx >= std::tuple_size>::value), - std::tuple<>>::type - tuple_from_json_impl(json const &j) { - return std::tuple<>{}; -} - -template -struct TupleFromJson { - tuple_tail_t> operator()(json const &j) { - using FieldT = visit_struct::type_at; - - FieldT field = - j.at(visit_struct::get_name()).template get(); - - return std::tuple_cat(std::tuple(field), - TupleFromJson<(idx + 1), T>{}(j)); - } -}; - -template -struct TupleFromJson< - idx, - T, - typename std::enable_if<( - idx > std::tuple_size>::value)>::type> { - std::tuple<> operator()(json const &j) { - return {}; - } -}; - -template -visit_as_tuple_t tuple_from_json(json const &j) { - return TupleFromJson<0, T>{}(j); -} - -template -void visit_json_serialize(json &j, T const &t) { - static_assert(is_visitable::value, "Type must be visitable"); - static_assert(elements_satisfy::value, - "Elements must be deserializable"); - - json_serialization_visitor vis(j); - visit_struct::for_each(t, vis); -} - -template -void visit_json_deserialize(json const &j, T &t) { - static_assert(is_visitable::value, "Type must be visitable"); - static_assert(elements_satisfy::value, - "Elements must be deserializable"); - - json_deserialization_visitor vis(j); - visit_struct::for_each(t, vis); -} - -template -T moveonly_visit_json_deserialize(json const &j) { - static_assert(is_visitable::value, "Type must be visitable"); - static_assert(!std::is_default_constructible::value, ""); - static_assert(elements_satisfy::value, - "Elements must be deserializable"); - - return visitable_from_tuple(tuple_from_json(j)); -} - -} // namespace FlexFlow - -namespace nlohmann { - -template -struct adl_serializer< - T, - typename std::enable_if<::FlexFlow::conjunction< - ::FlexFlow::is_visitable, - ::FlexFlow::elements_satisfy<::FlexFlow::is_json_serializable, T>, - std::is_default_constructible>::value>::type> { - static void to_json(json &j, T const &t) { - ::FlexFlow::visit_json_serialize(j, t); - } - - static void from_json(json const &j, T &t) { - ::FlexFlow::visit_json_deserialize(j, t); - } -}; - -template -struct adl_serializer< - T, - typename std::enable_if<::FlexFlow::conjunction< - ::FlexFlow::is_visitable, - ::FlexFlow::elements_satisfy<::FlexFlow::is_json_serializable, T>, - ::FlexFlow::negation>, - std::is_move_constructible>::value>::type> { - static void to_json(json &j, T const &t) { - ::FlexFlow::visit_json_serialize(j, t); - } - - static T from_json(json const &j) { - return ::FlexFlow::moveonly_visit_json_deserialize(j); - } -}; - -} // namespace nlohmann - -#endif diff --git a/lib/utils/include/utils/orthotope/dim_coord.h b/lib/utils/include/utils/orthotope/dim_coord.h index b57fd823d1..d07c1bc12c 100644 --- a/lib/utils/include/utils/orthotope/dim_coord.h +++ b/lib/utils/include/utils/orthotope/dim_coord.h @@ -1,7 +1,6 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_DIM_COORD_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_DIM_COORD_H -#include "utils/containers/subvec.h" #include "utils/containers/zip_with_strict.h" #include "utils/orthotope/dim_coord.dtg.h" #include "utils/orthotope/dim_domain.dtg.h" @@ -15,6 +14,7 @@ #include "utils/containers/scanr.h" #include "utils/containers/product.h" #include "utils/containers/map_from_keys_and_values.h" +#include "utils/orthotope/dim_domain.h" namespace FlexFlow { @@ -40,16 +40,16 @@ OrthotopeCoord orthotope_coord_from_dim_coord(DimCoord const &coord) { template DimCoord dim_coord_from_orthotope_coord(OrthotopeCoord const &coord, DimDomain const &domain) { return DimCoord{ - map_from_keys_and_values(coord.raw, get_domain_dims(domain)), + map_from_keys_and_values(sorted(get_domain_dims(domain)), coord.raw), }; } template -nonnegative_int flatten_coord(DimCoord const &coord, +nonnegative_int flatten_dim_coord(DimCoord const &coord, DimDomain const &domain) { - if (get_coord_dims(coord) != get_dims_for_domain(domain)) { - throw mk_runtime_error(fmt::format("flatten_dims expected coord dimensions to match domain dimensions, but received coord={} and domain={}", coord, domain)); - } + ASSERT(get_coord_dims(coord) == get_domain_dims(domain), + "flatten_dim_coord expected coord dimensions to match domain dimensions", + coord, domain); OrthotopeCoord orthotope_coord = orthotope_coord_from_dim_coord(coord); Orthotope orthotope_domain = orthotope_from_dim_domain(domain); @@ -58,7 +58,7 @@ nonnegative_int flatten_coord(DimCoord const &coord, } template -DimCoord unflatten_coord(nonnegative_int flattened, DimDomain const &domain) { +DimCoord unflatten_dim_coord(nonnegative_int flattened, DimDomain const &domain) { Orthotope orthotope_domain = orthotope_from_dim_domain(domain); OrthotopeCoord orthotope_coord = unflatten_orthotope_coord(flattened, orthotope_domain); diff --git a/lib/utils/include/utils/orthotope/dim_domain.h b/lib/utils/include/utils/orthotope/dim_domain.h index 2c24e11943..8ad51db42e 100644 --- a/lib/utils/include/utils/orthotope/dim_domain.h +++ b/lib/utils/include/utils/orthotope/dim_domain.h @@ -3,11 +3,15 @@ #include "utils/orthotope/dim_domain.dtg.h" #include "utils/orthotope/orthotope.dtg.h" +#include "utils/containers/keys.h" +#include "utils/containers/restrict_keys.h" +#include "utils/containers/sorted.h" +#include "utils/containers/transform.h" namespace FlexFlow { template -std::set get_domain_dims(DimDomain const &domain) { +std::unordered_set get_domain_dims(DimDomain const &domain) { return keys(domain.dims); } diff --git a/lib/utils/include/utils/orthotope/dim_domain.struct.toml b/lib/utils/include/utils/orthotope/dim_domain.struct.toml index 86aa2fb7ff..d4db1a3efd 100644 --- a/lib/utils/include/utils/orthotope/dim_domain.struct.toml +++ b/lib/utils/include/utils/orthotope/dim_domain.struct.toml @@ -14,7 +14,7 @@ template_params = [ includes = [ "", - "utils/nonnegative_int/nonnegative_int.h", + "utils/positive_int/positive_int.h", ] src_includes = [ @@ -24,4 +24,4 @@ src_includes = [ [[fields]] name = "dims" -type = "std::unordered_map" +type = "std::unordered_map" diff --git a/lib/utils/include/utils/orthotope/orthotope.h b/lib/utils/include/utils/orthotope/orthotope.h index 04d1f68f4b..599f479bc1 100644 --- a/lib/utils/include/utils/orthotope/orthotope.h +++ b/lib/utils/include/utils/orthotope/orthotope.h @@ -6,9 +6,9 @@ namespace FlexFlow { -nonnegative_int get_orthotope_num_dims(Orthotope const &); +nonnegative_int orthotope_get_num_dims(Orthotope const &); -nonnegative_int get_orthotope_volume(Orthotope const &); +positive_int orthotope_get_volume(Orthotope const &); std::unordered_set get_all_coords_in_orthotope(Orthotope const &); diff --git a/lib/utils/include/utils/orthotope/orthotope.struct.toml b/lib/utils/include/utils/orthotope/orthotope.struct.toml index a1fcb2a80e..2ffcb6960a 100644 --- a/lib/utils/include/utils/orthotope/orthotope.struct.toml +++ b/lib/utils/include/utils/orthotope/orthotope.struct.toml @@ -10,7 +10,7 @@ features = [ includes = [ "", - "utils/nonnegative_int/nonnegative_int.h", + "utils/positive_int/positive_int.h", ] src_includes = [ @@ -20,4 +20,4 @@ src_includes = [ [[fields]] name = "dims" -type = "std::vector<::FlexFlow::nonnegative_int>" +type = "std::vector<::FlexFlow::positive_int>" diff --git a/lib/utils/include/utils/orthotope/orthotope_coord.h b/lib/utils/include/utils/orthotope/orthotope_coord.h index cf105780f0..7cbb6a9de5 100644 --- a/lib/utils/include/utils/orthotope/orthotope_coord.h +++ b/lib/utils/include/utils/orthotope/orthotope_coord.h @@ -5,7 +5,7 @@ namespace FlexFlow { -OrthotopeCoord restrict_orthotope_coord_dims_to(OrthotopeCoord const &coord, std::set const &allowed_dims); +OrthotopeCoord restrict_orthotope_coord_to_dims(OrthotopeCoord const &coord, std::set const &allowed_dims); } // namespace FlexFlow diff --git a/lib/utils/include/utils/sequence.h b/lib/utils/include/utils/sequence.h index 07e4554299..26ed4a55f9 100644 --- a/lib/utils/include/utils/sequence.h +++ b/lib/utils/include/utils/sequence.h @@ -2,7 +2,6 @@ #define _FLEXFLOW_UTILS_INCLUDE_UTILS_SEQUENCE_H #include "utils/tuple.h" -#include "utils/visitable_core.h" #include #include @@ -79,13 +78,6 @@ using seq_enumerate_t = typename seq_enumerate::type; template struct seq_transform_type; -template -struct seq_transform_type> - : tuple_prepend_type< - visit_struct::traits::clean_t()( - std::declval>()))>, - typename seq_transform_type>::type> {}; - template struct seq_transform_type> { using type = std::tuple<>; diff --git a/lib/utils/include/utils/stack_vector/stack_vector.h b/lib/utils/include/utils/stack_vector/stack_vector.h index 64d005a10e..75a311eba2 100644 --- a/lib/utils/include/utils/stack_vector/stack_vector.h +++ b/lib/utils/include/utils/stack_vector/stack_vector.h @@ -12,6 +12,7 @@ #include #include #include +#include "utils/check_fmtable.h" namespace FlexFlow { diff --git a/lib/utils/include/utils/type_traits.h b/lib/utils/include/utils/type_traits.h index 7abb3ffd5b..2f07050527 100644 --- a/lib/utils/include/utils/type_traits.h +++ b/lib/utils/include/utils/type_traits.h @@ -3,7 +3,6 @@ #include "utils/metafunction.h" #include "utils/type_traits_core.h" -#include "utils/visitable_core.h" #include #include @@ -81,13 +80,6 @@ struct elements_satisfy { "than 1 argument"); }; -template